import os import torch from torch.utils.data import Dataset import numpy as np class BRepSDFDataset(Dataset): def __init__(self, brep_data_dir:str, sdf_data_dir:str, split:str='train'): self.brep_data_dir = brep_data_dir self.sdf_data_dir = sdf_data_dir self.split = split self.brep_data_list = self._load_brep_data_list() self.sdf_data_list = self._load_sdf_data_list() def _load_brep_data_list(self): data_list = [] split_dir = os.path.join(self.brep_data_dir, self.split) for sample_dir in os.listdir(split_dir): sample_path = os.path.join(split_dir, sample_dir) if os.path.isdir(sample_path): data_list.append(sample_path) return data_list def _load_sdf_data_list(self): data_list = [] split_dir = os.path.join(self.sdf_data_dir, self.split) for sample_dir in os.listdir(split_dir): sample_path = os.path.join(split_dir, sample_dir) if os.path.isdir(sample_path): data_list.append(sample_path) return data_list def __len__(self): return len(self.brep_data_list) def __getitem__(self, idx): sample_path = self.brep_data_list[idx] # 解析 .step 文件 step_file = os.path.join(sample_path, 'model.step') brep_features = self._parse_step_file(step_file) # 加载 .sdf 文件 sdf_file = os.path.join(sample_path, 'sdf.npy') sdf = np.load(sdf_file) # 转换为 torch 张量 brep_features = torch.tensor(brep_features, dtype=torch.float32) sdf = torch.tensor(sdf, dtype=torch.float32) return { 'brep_features': brep_features, 'sdf': sdf } def _parse_step_file(self, step_file): # 解析 .step 文件的逻辑 # 返回 B-rep 特征 pass # 使用示例 if __name__ == '__main__': dataset = BRepSDFDataset(brep_data_dir='path/to/brep_data', sdf_data_dir='path/to/sdf_data', split='train') dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True) for batch in dataloader: print(batch['brep_features'].shape) print(batch['sdf'].shape) break