| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -121,6 +121,8 @@ class BRepSDFDataset(Dataset): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 解包处理后的特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos = brep_features | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                sdf_points = sdf_data[:, :3] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                sdf_values = sdf_data[:, 3:]   | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 构建返回字典 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                return { | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -131,8 +133,8 @@ class BRepSDFDataset(Dataset): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    'surf_ncs': surf_ncs,      # [max_face, 100, 3] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    'surf_pos': surf_pos,      # [max_face, 6] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    'vertex_pos': vertex_pos,  # [max_face, max_edge, 6] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    'points': sdf_data[:, :3], # [num_queries, 3] 所有点的xyz坐标 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    'sdf': sdf_data[:, 3:]     # [num_queries, 1] 所有点的sdf值 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    'points': sdf_points, # [num_queries, 3] 所有点的xyz坐标 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    'sdf': sdf_values     # [num_queries, 1] 所有点的sdf值 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                } | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            except Exception as e: | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -161,115 +163,6 @@ class BRepSDFDataset(Dataset): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def _load_brep_file(self, brep_path): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """加载B-rep特征文件""" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 1. 加载原始数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            with open(brep_path, 'rb') as f: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                raw_data = pickle.load(f) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            brep_data = {} | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 2. 处理几何数据(不等长序列) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            geom_keys = ['surf_wcs', 'surf_ncs', 'edge_wcs', 'edge_ncs'] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for key in geom_keys: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if key in raw_data: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        # 确保数据是列表 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        if not isinstance(raw_data[key], list): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            raise ValueError(f"{key} is not a list") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        # 转换每个元素为张量 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        tensors = [] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        for i, x in enumerate(raw_data[key]): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                # 先转换为numpy数组 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                arr = np.array(x, dtype=np.float32) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                # 再转换为张量 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                tensor = torch.from_numpy(arr) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                tensors.append(tensor) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                logger.error(f"Error converting {key}[{i}]:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                logger.error(f"  Data type: {type(x)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                if isinstance(x, np.ndarray): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                    logger.error(f"  Shape: {x.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                    logger.error(f"  dtype: {x.dtype}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                raise ValueError(f"Failed to convert {key}[{i}]: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        brep_data[key] = tensors | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"Error processing {key}:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"  Raw data type: {type(raw_data[key])}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        raise ValueError(f"Failed to process {key}: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 3. 处理固定形状的数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            fixed_keys = ['corner_wcs', 'corner_unique', 'surf_bbox_wcs', 'edge_bbox_wcs'] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for key in fixed_keys: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if key in raw_data: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        # 直接从原始数据转换 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        arr = np.array(raw_data[key], dtype=np.float32) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        brep_data[key] = torch.from_numpy(arr) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"Error converting fixed shape data {key}:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"  Raw data type: {type(raw_data[key])}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        if isinstance(raw_data[key], np.ndarray): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            logger.error(f"  Shape: {raw_data[key].shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            logger.error(f"  dtype: {raw_data[key].dtype}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        raise ValueError(f"Failed to convert {key}: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 4. 处理邻接矩阵 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            adj_keys = ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj'] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for key in adj_keys: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if key in raw_data: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        # 转换为整型数组 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        arr = np.array(raw_data[key], dtype=np.int32) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        brep_data[key] = torch.from_numpy(arr) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"Error converting adjacency matrix {key}:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"  Raw data type: {type(raw_data[key])}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        if isinstance(raw_data[key], np.ndarray): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            logger.error(f"  Shape: {raw_data[key].shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            logger.error(f"  dtype: {raw_data[key].dtype}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        raise ValueError(f"Failed to convert {key}: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 5. 验证必要的键是否存在 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            required_keys = {'surf_wcs', 'edge_wcs', 'corner_wcs'} | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            missing_keys = required_keys - set(brep_data.keys()) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if missing_keys: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                raise ValueError(f"Missing required keys: {missing_keys}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 6. 使用process_brep_data处理数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                features = process_brep_data( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    data=brep_data, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    max_face=self.max_face, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    max_edge=self.max_edge, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    bbox_scaled=self.bbox_scaled | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                return features | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.error("Error in process_brep_data:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.error(f"  Error message: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 打印数据形状信息 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.error("\nInput data shapes:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                for key, value in brep_data.items(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if isinstance(value, list): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        shapes = [t.shape for t in value] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"  {key}: list of tensors with shapes {shapes}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    elif isinstance(value, torch.Tensor): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"  {key}: tensor of shape {value.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                raise | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"\nError loading B-rep file: {brep_path}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"Error message: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def _load_sdf_file(self, sdf_path): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """加载和处理SDF数据,并进行随机采样""" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        try: | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -321,52 +214,7 @@ class BRepSDFDataset(Dataset): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"Error type: {type(e).__name__}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"Error message: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    @staticmethod | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def collate_fn(batch): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """自定义批处理函数""" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 收集所有样本的名称 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        names = [item['name'] for item in batch] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 处理固定大小的张量数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        tensor_keys = ['edge_ncs', 'edge_pos', 'edge_mask',  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                      'surf_ncs', 'surf_pos', 'vertex_pos'] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        tensors = { | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            key: torch.stack([item[key] for item in batch]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for key in tensor_keys | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        } | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 处理变长的SDF数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        sdf_data = [item['sdf'] for item in batch] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        max_sdf_len = max(sdf.size(0) for sdf in sdf_data) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 填充SDF数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        padded_sdfs = [] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        sdf_masks = [] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for sdf in sdf_data: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            pad_len = max_sdf_len - sdf.size(0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if pad_len > 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                padding = torch.zeros(pad_len, sdf.size(1),  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                   dtype=sdf.dtype, device=sdf.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                padded_sdf = torch.cat([sdf, padding], dim=0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                mask = torch.cat([ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    torch.ones(sdf.size(0), dtype=torch.bool), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    torch.zeros(pad_len, dtype=torch.bool) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                ]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                padded_sdf = sdf | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                mask = torch.ones(sdf.size(0), dtype=torch.bool) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            padded_sdfs.append(padded_sdf) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            sdf_masks.append(mask) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 合并所有数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        batch_data = { | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'name': names, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'sdf': torch.stack(padded_sdfs), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'sdf_mask': torch.stack(sdf_masks), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            **tensors | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        } | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return batch_data | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def test_dataset(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """测试数据集功能""" | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -381,6 +229,7 @@ def test_dataset(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        max_edge = config.data.max_edge | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        num_edge_points = config.model.num_edge_points | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        num_surf_points = config.model.num_surf_points | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        num_query_points = config.data.num_query_points | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 定义预期的数据维度,使用配置中的参数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        expected_shapes = { | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -390,7 +239,8 @@ def test_dataset(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'surf_ncs': (max_face, num_surf_points, 3),   | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'surf_pos': (max_face, 6),             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'vertex_pos': (max_face, max_edge, 2, 3),    | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'sdf': (2097152, 4)              | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'points': (num_query_points, 3), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'sdf': (num_query_points, 1)              | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        } | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info("="*50) | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -424,6 +274,7 @@ def test_dataset(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"    Expected: {expected_shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"    Match: {shape_match}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"    dtype: {value.dtype}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"    grad: {value.requires_grad}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if not shape_match: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.warning(f"    Shape mismatch for {key}!") | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |