1 changed files with 68 additions and 0 deletions
			
			
		@ -0,0 +1,68 @@ | 
				
			|||||
 | 
					import os | 
				
			||||
 | 
					import torch | 
				
			||||
 | 
					from torch.utils.data import Dataset | 
				
			||||
 | 
					import numpy as np | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					class BRepSDFDataset(Dataset): | 
				
			||||
 | 
					    def __init__(self, brep_data_dir:str, sdf_data_dir:str, split:str='train'): | 
				
			||||
 | 
					        self.brep_data_dir = brep_data_dir | 
				
			||||
 | 
					        self.sdf_data_dir = sdf_data_dir | 
				
			||||
 | 
					        self.split = split | 
				
			||||
 | 
					        self.brep_data_list = self._load_brep_data_list() | 
				
			||||
 | 
					        self.sdf_data_list = self._load_sdf_data_list() | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    def _load_brep_data_list(self): | 
				
			||||
 | 
					        data_list = [] | 
				
			||||
 | 
					        split_dir = os.path.join(self.brep_data_dir, self.split) | 
				
			||||
 | 
					        for sample_dir in os.listdir(split_dir): | 
				
			||||
 | 
					            sample_path = os.path.join(split_dir, sample_dir) | 
				
			||||
 | 
					            if os.path.isdir(sample_path): | 
				
			||||
 | 
					                data_list.append(sample_path) | 
				
			||||
 | 
					        return data_list | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    def _load_sdf_data_list(self): | 
				
			||||
 | 
					        data_list = [] | 
				
			||||
 | 
					        split_dir = os.path.join(self.sdf_data_dir, self.split) | 
				
			||||
 | 
					        for sample_dir in os.listdir(split_dir): | 
				
			||||
 | 
					            sample_path = os.path.join(split_dir, sample_dir) | 
				
			||||
 | 
					            if os.path.isdir(sample_path): | 
				
			||||
 | 
					                data_list.append(sample_path) | 
				
			||||
 | 
					        return data_list | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    def __len__(self): | 
				
			||||
 | 
					        return len(self.brep_data_list) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    def __getitem__(self, idx): | 
				
			||||
 | 
					        sample_path = self.brep_data_list[idx] | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        # 解析 .step 文件 | 
				
			||||
 | 
					        step_file = os.path.join(sample_path, 'model.step') | 
				
			||||
 | 
					        brep_features = self._parse_step_file(step_file) | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        # 加载 .sdf 文件 | 
				
			||||
 | 
					        sdf_file = os.path.join(sample_path, 'sdf.npy') | 
				
			||||
 | 
					        sdf = np.load(sdf_file) | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        # 转换为 torch 张量 | 
				
			||||
 | 
					        brep_features = torch.tensor(brep_features, dtype=torch.float32) | 
				
			||||
 | 
					        sdf = torch.tensor(sdf, dtype=torch.float32) | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        return { | 
				
			||||
 | 
					            'brep_features': brep_features, | 
				
			||||
 | 
					            'sdf': sdf | 
				
			||||
 | 
					        } | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    def _parse_step_file(self, step_file): | 
				
			||||
 | 
					        # 解析 .step 文件的逻辑 | 
				
			||||
 | 
					        # 返回 B-rep 特征 | 
				
			||||
 | 
					        pass | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					# 使用示例 | 
				
			||||
 | 
					if __name__ == '__main__': | 
				
			||||
 | 
					    dataset = BRepSDFDataset(brep_data_dir='path/to/brep_data', sdf_data_dir='path/to/sdf_data', split='train') | 
				
			||||
 | 
					    dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    for batch in dataloader: | 
				
			||||
 | 
					        print(batch['brep_features'].shape) | 
				
			||||
 | 
					        print(batch['sdf'].shape) | 
				
			||||
 | 
					        break | 
				
			||||
					Loading…
					
					
				
		Reference in new issue