| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -12,7 +12,7 @@ from brep2sdf.config.default_config import get_default_config | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					class BRepSDFDataset(Dataset): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, brep_dir:str, sdf_dir:str, valid_data_dir:str, split:str='train'): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, brep_dir:str, sdf_dir:str, valid_data_dir:str, use_filter: bool=True, split:str='train'): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        初始化数据集 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -52,6 +52,9 @@ class BRepSDFDataset(Dataset): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.brep_data_list = self._load_data_list(self.brep_dir) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.sdf_data_list = self._load_data_list(self.sdf_dir) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if use_filter: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self._filter_num_faces_and_num_edges() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 检查数据集是否为空 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if len(self.brep_data_list) == 0 : | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise ValueError(f"No valid brep data found in {split} set") | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -76,6 +79,20 @@ class BRepSDFDataset(Dataset): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #logger.info(data_list) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return data_list | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def _filter_num_faces_and_num_edges(self): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ''' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        Filter the data if their face_num or edge_num > max_face or max_edge. | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ''' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # Collect indices of elements that satisfy the condition | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        filtered_indices = [ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            idx for idx in range(len(self.brep_data_list)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if (self._get_brep_face_and_edge(self.brep_data_list[idx]) <= (self.max_face, self.max_edge)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # Use filtered_indices to update brep_data_list and sdf_data_list | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.brep_data_list = [self.brep_data_list[idx] for idx in filtered_indices] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.sdf_data_list = [self.sdf_data_list[idx] for idx in filtered_indices] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __len__(self): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return len(self.brep_data_list) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -87,8 +104,7 @@ class BRepSDFDataset(Dataset): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            name = os.path.splitext(os.path.basename(brep_path))[0] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 加载B-rep和SDF数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            with open(brep_path, 'rb') as f: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                brep_raw = pickle.load(f) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            brep_raw = self._load_brep_file(brep_path) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            sdf_data = self._load_sdf_file(sdf_path) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            try: | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -162,7 +178,10 @@ class BRepSDFDataset(Dataset): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"Error loading sample from {brep_path}: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error("Data structure:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def _load_brep_file(self, brep_path): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        with open(brep_path, 'rb') as f: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            brep_raw = pickle.load(f) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            return brep_raw | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def _load_sdf_file(self, sdf_path): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """加载和处理SDF数据,并进行随机采样""" | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -222,7 +241,12 @@ class BRepSDFDataset(Dataset): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"Error type: {type(e).__name__}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"Error message: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def _get_brep_face_and_edge(self, brep_path: str) -> tuple[int,int]: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        brep: dict = self._load_brep_file(brep_path) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        face_edge_adj = brep["faceEdge_adj"] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        num_faces, num_edges = face_edge_adj.shape | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return num_faces, num_edges | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def test_dataset(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """测试数据集功能""" | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |