| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -271,7 +271,7 @@ class BRepSDFDataset(Dataset): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def _load_sdf_file(self, sdf_path): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """加载和处理SDF数据""" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """加载和处理SDF数据,并进行随机采样""" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 加载SDF值 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            sdf_data = np.load(sdf_path) | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -285,7 +285,35 @@ class BRepSDFDataset(Dataset): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if sdf_pos.shape[1] != 4 or sdf_neg.shape[1] != 4: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                raise ValueError(f"Invalid SDF data shape: pos={sdf_pos.shape}, neg={sdf_neg.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 随机采样 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            max_points = self.config.data.num_query_points  # 例如4096 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 确保正负样本均衡 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            num_pos = min(max_points // 2, sdf_pos.shape[0]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            num_neg = min(max_points // 2, sdf_neg.shape[0]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 随机采样正样本 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if sdf_pos.shape[0] > num_pos: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                pos_indices = np.random.choice(sdf_pos.shape[0], num_pos, replace=False) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                sdf_pos = sdf_pos[pos_indices] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 随机采样负样本 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if sdf_neg.shape[0] > num_neg: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                neg_indices = np.random.choice(sdf_neg.shape[0], num_neg, replace=False) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                sdf_neg = sdf_neg[neg_indices] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 合并数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            sdf_np = np.concatenate([sdf_pos, sdf_neg], axis=0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 再次随机打乱 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            np.random.shuffle(sdf_np) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 如果总点数仍然超过最大限制,再次采样 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if sdf_np.shape[0] > max_points: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                indices = np.random.choice(sdf_np.shape[0], max_points, replace=False) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                sdf_np = sdf_np[indices] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.debug(f"Sampled SDF points: {sdf_np.shape[0]} (max: {max_points})") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            return torch.from_numpy(sdf_np.astype(np.float32)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        except Exception as e: | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |