import os import torch from torch.utils.data import Dataset import numpy as np import pickle from utils.logger import setup_logger # 设置日志记录器 logger = setup_logger('dataset') class BRepSDFDataset(Dataset): def __init__(self, brep_dir:str, sdf_dir:str, split:str='train'): """ 初始化数据集 参数: brep_dir: pkl文件目录 sdf_dir: npz文件目录 split: 数据集分割('train', 'val', 'test') """ self.brep_dir = os.path.join(brep_dir, split) self.sdf_dir = os.path.join(sdf_dir, split) self.split = split # 检查目录是否存在 if not os.path.exists(self.brep_dir): raise ValueError(f"B-rep directory not found: {self.brep_dir}") if not os.path.exists(self.sdf_dir): raise ValueError(f"SDF directory not found: {self.sdf_dir}") # 加载数据列表 self.data_list = self._load_data_list() # 检查数据集是否为空 if len(self.data_list) == 0: raise ValueError(f"No valid data found in {split} set") logger.info(f"Loaded {split} dataset with {len(self.data_list)} samples") def _load_data_list(self): data_list = [] for sample_dir in os.listdir(self.brep_dir): sample_path = os.path.join(self.brep_dir, sample_dir) if os.path.isdir(sample_path): data_list.append(sample_path) return data_list def __len__(self): return len(self.data_list) def __getitem__(self, idx): sample_path = self.data_list[idx] # 解析 .step 文件 step_file = os.path.join(sample_path, 'model.step') brep_features = self._parse_step_file(step_file) # 加载 .sdf 文件 sdf_file = os.path.join(sample_path, 'sdf.npy') sdf = np.load(sdf_file) # 转换为 torch 张量 brep_features = torch.tensor(brep_features, dtype=torch.float32) sdf = torch.tensor(sdf, dtype=torch.float32) return { 'brep_features': brep_features, 'sdf': sdf } def _parse_step_file(self, step_file): # 解析 .step 文件的逻辑 # 返回 B-rep 特征 pass def test_dataset(): """测试数据集功能""" try: # 设置测试路径 brep_dir = '/home/wch/myDeepSDF/test_data/pkl' sdf_dir = '/home/wch/myDeepSDF/test_data/sdf' split = 'train' logger.info("="*50) logger.info("Testing dataset") logger.info(f"B-rep directory: {brep_dir}") logger.info(f"SDF directory: {sdf_dir}") logger.info(f"Split: {split}") # ... (其余测试代码保持不变) ... dataset = BRepSDFDataset(brep_data_dir='/home/wch/myDeepSDF/test_data/pkl', sdf_data_dir='/home/wch/myDeepSDF/test_data/sdf', split='train') dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True) for batch in dataloader: print(batch['brep_features'].shape) print(batch['sdf'].shape) break if __name__ == '__main__': test_dataset()