| 
						
						
							
								
							
						
						
					 | 
					@ -29,49 +29,161 @@ class BRepSDFDataset(Dataset): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            raise ValueError(f"SDF directory not found: {self.sdf_dir}") | 
					 | 
					 | 
					            raise ValueError(f"SDF directory not found: {self.sdf_dir}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 加载数据列表 | 
					 | 
					 | 
					        # 加载数据列表 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.data_list = self._load_data_list() | 
					 | 
					 | 
					        self.brep_data_list = self._load_data_list(self.brep_dir) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.sdf_data_list = self._load_data_list(self.sdf_dir) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 检查数据集是否为空 | 
					 | 
					 | 
					        # 检查数据集是否为空 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        if len(self.data_list) == 0: | 
					 | 
					 | 
					        if len(self.brep_data_list) == 0 : | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            raise ValueError(f"No valid data found in {split} set") | 
					 | 
					 | 
					            raise ValueError(f"No valid brep data found in {split} set") | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        if len(self.sdf_data_list) == 0: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            raise ValueError(f"No valid sdf data found in {split} set") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        logger.info(f"Loaded {split} dataset with {len(self.data_list)} samples") | 
					 | 
					 | 
					        logger.info(f"Loaded {split} dataset with {len(self.brep_data_list)} samples") | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    def _load_data_list(self): | 
					 | 
					 | 
					    # data_dir 为 self.brep_dir or sdf_dir | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    def _load_data_list(self, data_dir): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        data_list = [] | 
					 | 
					 | 
					        data_list = [] | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        for sample_dir in os.listdir(self.brep_dir): | 
					 | 
					 | 
					        for sample_file in os.listdir(data_dir): | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            sample_path = os.path.join(self.brep_dir, sample_dir) | 
					 | 
					 | 
					            path = os.path.join(data_dir, sample_file) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            if os.path.isdir(sample_path): | 
					 | 
					 | 
					            data_list.append(path) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                data_list.append(sample_path) | 
					 | 
					 | 
					
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        #logger.info(data_list) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        return data_list | 
					 | 
					 | 
					        return data_list | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def __len__(self): | 
					 | 
					 | 
					    def __len__(self): | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        return len(self.data_list) | 
					 | 
					 | 
					        return len(self.brep_data_list) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def __getitem__(self, idx): | 
					 | 
					 | 
					    def __getitem__(self, idx): | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        sample_path = self.data_list[idx] | 
					 | 
					 | 
					        """获取单个数据样本""" | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					         | 
					 | 
					 | 
					        brep_path = self.brep_data_list[idx] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 解析 .step 文件 | 
					 | 
					 | 
					        sdf_path = self.sdf_data_list[idx] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        step_file = os.path.join(sample_path, 'model.step') | 
					 | 
					 | 
					         | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        brep_features = self._parse_step_file(step_file) | 
					 | 
					 | 
					        try: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					         | 
					 | 
					 | 
					            # 获取文件名(不含扩展名)作为sample name | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 加载 .sdf 文件 | 
					 | 
					 | 
					            name = os.path.splitext(os.path.basename(brep_path))[0] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        sdf_file = os.path.join(sample_path, 'sdf.npy') | 
					 | 
					 | 
					             | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        sdf = np.load(sdf_file) | 
					 | 
					 | 
					            # 加载B-rep和SDF数据 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					         | 
					 | 
					 | 
					            brep_data = self._load_brep_file(brep_path) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 转换为 torch 张量 | 
					 | 
					 | 
					            sdf_data = self._load_sdf_file(sdf_path) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        brep_features = torch.tensor(brep_features, dtype=torch.float32) | 
					 | 
					 | 
					             | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        sdf = torch.tensor(sdf, dtype=torch.float32) | 
					 | 
					 | 
					            # 修改返回格式,将sdf_data作为一个键值对添加 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					         | 
					 | 
					 | 
					            return { | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        return { | 
					 | 
					 | 
					                'name': name, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            'brep_features': brep_features, | 
					 | 
					 | 
					                **brep_data,  # 解包B-rep特征 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            'sdf': sdf | 
					 | 
					 | 
					                'sdf': sdf_data  # 添加SDF数据作为一个键 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        } | 
					 | 
					 | 
					            } | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					
 | 
					 | 
					 | 
					             | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    def _parse_step_file(self, step_file): | 
					 | 
					 | 
					        except Exception as e: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 解析 .step 文件的逻辑 | 
					 | 
					 | 
					            logger.error(f"Error loading sample from {brep_path}: {str(e)}") | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 返回 B-rep 特征 | 
					 | 
					 | 
					            logger.error("Data structure:") | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        pass | 
					 | 
					 | 
					            if 'brep_data' in locals(): | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                for key, value in brep_data.items(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    if isinstance(value, np.ndarray): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        logger.error(f"  {key}: type={type(value)}, dtype={value.dtype}, shape={value.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            raise | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    def _load_brep_file(self, brep_path): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        """加载B-rep特征文件""" | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        try: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            with open(brep_path, 'rb') as f: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                brep_data = pickle.load(f) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            features = {} | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            # 1. 处理几何数据(不等长序列) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            for key in ['surf_wcs', 'surf_ncs', 'edge_wcs', 'edge_ncs']: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                if key in brep_data: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    try: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        features[key] = [ | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                            torch.from_numpy(np.array(x, dtype=np.float32))  | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                            for x in brep_data[key] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        ] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    except Exception as e: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        logger.error(f"Error converting {key}:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        logger.error(f"  Type: {type(brep_data[key])}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        if isinstance(brep_data[key], list): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                            logger.error(f"  List length: {len(brep_data[key])}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                            if len(brep_data[key]) > 0: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                                logger.error(f"  First element type: {type(brep_data[key][0])}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                                logger.error(f"  First element shape: {brep_data[key][0].shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                                logger.error(f"  First element dtype: {brep_data[key][0].dtype}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        raise ValueError(f"Failed to convert {key}: {str(e)}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            # 2. 处理固定形状的数据 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            for key in ['corner_wcs', 'corner_unique', 'surf_bbox_wcs', 'edge_bbox_wcs']: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                if key in brep_data: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    try: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        data = np.array(brep_data[key], dtype=np.float32) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        features[key] = torch.from_numpy(data) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    except Exception as e: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        logger.error(f"Error converting {key}:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        logger.error(f"  Type: {type(brep_data[key])}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        if isinstance(brep_data[key], np.ndarray): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                            logger.error(f"  Shape: {brep_data[key].shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                            logger.error(f"  dtype: {brep_data[key].dtype}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        raise ValueError(f"Failed to convert {key}: {str(e)}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            # 3. 处理邻接矩阵 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            for key in ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj']: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                if key in brep_data: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    try: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        data = np.array(brep_data[key], dtype=np.int32) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        features[key] = torch.from_numpy(data) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    except Exception as e: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        logger.error(f"Error converting {key}:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        logger.error(f"  Type: {type(brep_data[key])}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        if isinstance(brep_data[key], np.ndarray): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                            logger.error(f"  Shape: {brep_data[key].shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                            logger.error(f"  dtype: {brep_data[key].dtype}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        raise ValueError(f"Failed to convert {key}: {str(e)}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            return features | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        except Exception as e: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"\nError loading B-rep file: {brep_path}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"Error message: {str(e)}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            # 打印完整的数据结构信息 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            if 'brep_data' in locals(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.error("\nComplete data structure:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                for key, value in brep_data.items(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    logger.error(f"\n{key}:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    logger.error(f"  Type: {type(value)}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    if isinstance(value, np.ndarray): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        logger.error(f"  Shape: {value.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        logger.error(f"  dtype: {value.dtype}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    elif isinstance(value, list): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        logger.error(f"  List length: {len(value)}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        if len(value) > 0: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                            logger.error(f"  First element type: {type(value[0])}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                            if isinstance(value[0], np.ndarray): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                                logger.error(f"  First element shape: {value[0].shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                                logger.error(f"  First element dtype: {value[0].dtype}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            raise | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    def _load_sdf_file(self, sdf_path): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        """加载和处理SDF数据""" | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        try: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            # 加载SDF值 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            sdf_data = np.load(sdf_path) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            if 'pos' not in sdf_data or 'neg' not in sdf_data: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                raise ValueError("Missing pos/neg data in SDF file") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            sdf_pos = sdf_data['pos']  # (N1, 4) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            sdf_neg = sdf_data['neg']  # (N2, 4) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            # 添加数据验证 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            if sdf_pos.shape[1] != 4 or sdf_neg.shape[1] != 4: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                raise ValueError(f"Invalid SDF data shape: pos={sdf_pos.shape}, neg={sdf_neg.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            sdf_np = np.concatenate([sdf_pos, sdf_neg], axis=0) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            return torch.from_numpy(sdf_np.astype(np.float32)) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        except Exception as e: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"Error loading SDF from {sdf_path}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"Error type: {type(e).__name__}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"Error message: {str(e)}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            raise | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					def test_dataset(): | 
					 | 
					 | 
					def test_dataset(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    """测试数据集功能""" | 
					 | 
					 | 
					    """测试数据集功能""" | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -88,11 +200,10 @@ def test_dataset(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        logger.info(f"Split: {split}") | 
					 | 
					 | 
					        logger.info(f"Split: {split}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # ... (其余测试代码保持不变) ... | 
					 | 
					 | 
					        # ... (其余测试代码保持不变) ... | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        dataset = BRepSDFDataset(brep_data_dir='/home/wch/myDeepSDF/test_data/pkl', sdf_data_dir='/home/wch/myDeepSDF/test_data/sdf', split='train') | 
					 | 
					 | 
					        dataset = BRepSDFDataset(brep_dir='/home/wch/brep2sdf/test_data/pkl', sdf_dir='/home/wch/brep2sdf/test_data/sdf', split='train') | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True) | 
					 | 
					 | 
					        dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        for batch in dataloader: | 
					 | 
					 | 
					        for batch in dataloader: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            print(batch['brep_features'].shape) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            print(batch['sdf'].shape) | 
					 | 
					 | 
					            print(batch['sdf'].shape) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            break | 
					 | 
					 | 
					            break | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    except Exception as e: | 
					 | 
					 | 
					    except Exception as e: | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					
  |