|
|
|
import os
|
|
|
|
import torch
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
import numpy as np
|
|
|
|
import pickle
|
|
|
|
from utils.logger import setup_logger
|
|
|
|
|
|
|
|
# 设置日志记录器
|
|
|
|
logger = setup_logger('dataset')
|
|
|
|
|
|
|
|
class BRepSDFDataset(Dataset):
|
|
|
|
def __init__(self, brep_dir:str, sdf_dir:str, split:str='train'):
|
|
|
|
"""
|
|
|
|
初始化数据集
|
|
|
|
|
|
|
|
参数:
|
|
|
|
brep_dir: pkl文件目录
|
|
|
|
sdf_dir: npz文件目录
|
|
|
|
split: 数据集分割('train', 'val', 'test')
|
|
|
|
"""
|
|
|
|
self.brep_dir = os.path.join(brep_dir, split)
|
|
|
|
self.sdf_dir = os.path.join(sdf_dir, split)
|
|
|
|
self.split = split
|
|
|
|
|
|
|
|
# 检查目录是否存在
|
|
|
|
if not os.path.exists(self.brep_dir):
|
|
|
|
raise ValueError(f"B-rep directory not found: {self.brep_dir}")
|
|
|
|
if not os.path.exists(self.sdf_dir):
|
|
|
|
raise ValueError(f"SDF directory not found: {self.sdf_dir}")
|
|
|
|
|
|
|
|
# 加载数据列表
|
|
|
|
self.data_list = self._load_data_list()
|
|
|
|
|
|
|
|
# 检查数据集是否为空
|
|
|
|
if len(self.data_list) == 0:
|
|
|
|
raise ValueError(f"No valid data found in {split} set")
|
|
|
|
|
|
|
|
logger.info(f"Loaded {split} dataset with {len(self.data_list)} samples")
|
|
|
|
|
|
|
|
def _load_data_list(self):
|
|
|
|
data_list = []
|
|
|
|
for sample_dir in os.listdir(self.brep_dir):
|
|
|
|
sample_path = os.path.join(self.brep_dir, sample_dir)
|
|
|
|
if os.path.isdir(sample_path):
|
|
|
|
data_list.append(sample_path)
|
|
|
|
return data_list
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.data_list)
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
sample_path = self.data_list[idx]
|
|
|
|
|
|
|
|
# 解析 .step 文件
|
|
|
|
step_file = os.path.join(sample_path, 'model.step')
|
|
|
|
brep_features = self._parse_step_file(step_file)
|
|
|
|
|
|
|
|
# 加载 .sdf 文件
|
|
|
|
sdf_file = os.path.join(sample_path, 'sdf.npy')
|
|
|
|
sdf = np.load(sdf_file)
|
|
|
|
|
|
|
|
# 转换为 torch 张量
|
|
|
|
brep_features = torch.tensor(brep_features, dtype=torch.float32)
|
|
|
|
sdf = torch.tensor(sdf, dtype=torch.float32)
|
|
|
|
|
|
|
|
return {
|
|
|
|
'brep_features': brep_features,
|
|
|
|
'sdf': sdf
|
|
|
|
}
|
|
|
|
|
|
|
|
def _parse_step_file(self, step_file):
|
|
|
|
# 解析 .step 文件的逻辑
|
|
|
|
# 返回 B-rep 特征
|
|
|
|
pass
|
|
|
|
|
|
|
|
def test_dataset():
|
|
|
|
"""测试数据集功能"""
|
|
|
|
try:
|
|
|
|
# 设置测试路径
|
|
|
|
brep_dir = '/home/wch/myDeepSDF/test_data/pkl'
|
|
|
|
sdf_dir = '/home/wch/myDeepSDF/test_data/sdf'
|
|
|
|
split = 'train'
|
|
|
|
|
|
|
|
logger.info("="*50)
|
|
|
|
logger.info("Testing dataset")
|
|
|
|
logger.info(f"B-rep directory: {brep_dir}")
|
|
|
|
logger.info(f"SDF directory: {sdf_dir}")
|
|
|
|
logger.info(f"Split: {split}")
|
|
|
|
|
|
|
|
# ... (其余测试代码保持不变) ...
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
test_dataset()
|