| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -3,10 +3,8 @@ import torch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from torch.utils.data import Dataset | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import numpy as np | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import pickle | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from brep2sdf.utils.logger import setup_logger | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from .utils import process_brep_data | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					# 设置日志记录器 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					logger = setup_logger('dataset') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from brep2sdf.utils.logger import logger | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from brep2sdf.data.utils import process_brep_data | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -22,10 +20,16 @@ class BRepSDFDataset(Dataset): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            sdf_dir: npz文件目录 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            split: 数据集分割('train', 'val', 'test') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        super().__init__() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.brep_dir = os.path.join(brep_dir, split) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.sdf_dir = os.path.join(sdf_dir, split) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.split = split | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 设置固定参数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.max_face = 70 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.max_edge = 70 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.bbox_scaled = 1.0 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 检查目录是否存在 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if not os.path.exists(self.brep_dir): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise ValueError(f"B-rep directory not found: {self.brep_dir}") | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -59,112 +63,187 @@ class BRepSDFDataset(Dataset): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __getitem__(self, idx): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """获取单个数据样本""" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        brep_path = self.brep_data_list[idx] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        sdf_path = self.sdf_data_list[idx] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 获取文件名(不含扩展名)作为sample name | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            brep_path = self.brep_data_list[idx] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            sdf_path = self.sdf_data_list[idx] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            name = os.path.splitext(os.path.basename(brep_path))[0] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 加载B-rep和SDF数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            brep_data = self._load_brep_file(brep_path) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            with open(brep_path, 'rb') as f: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                brep_raw = pickle.load(f) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            sdf_data = self._load_sdf_file(sdf_path) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 修改返回格式,将sdf_data作为一个键值对添加 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            return { | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                'name': name, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                **brep_data,  # 解包B-rep特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                'sdf': sdf_data  # 添加SDF数据作为一个键 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            } | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 处理B-rep数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                brep_features = process_brep_data( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    data=brep_raw, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    max_face=self.max_face, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    max_edge=self.max_edge, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    bbox_scaled=self.bbox_scaled | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 检查返回值的类型和数量 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if not isinstance(brep_features, tuple): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.error(f"process_brep_data returned {type(brep_features)}, expected tuple") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    raise ValueError("Invalid return type from process_brep_data") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if len(brep_features) != 6: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.error(f"Expected 6 features, got {len(brep_features)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.error("Features returned:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    for i, feat in enumerate(brep_features): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        if isinstance(feat, torch.Tensor): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            logger.error(f"  {i}: Tensor of shape {feat.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            logger.error(f"  {i}: {type(feat)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    raise ValueError(f"Incorrect number of features: {len(brep_features)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 解包处理后的特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos = brep_features | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 构建返回字典 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                return { | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    'name': name, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    'edge_ncs': edge_ncs,      # [max_face, max_edge, 10, 3] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    'edge_pos': edge_pos,      # [max_face, max_edge, 6] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    'edge_mask': edge_mask,    # [max_face, max_edge] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    'surf_ncs': surf_ncs,      # [max_face, 100, 3] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    'surf_pos': surf_pos,      # [max_face, 6] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    'vertex_pos': vertex_pos,  # [max_face, max_edge, 6] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    'sdf': sdf_data           # [N, 4] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                } | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.error(f"\nError processing B-rep data for file: {brep_path}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.error(f"Error type: {type(e).__name__}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.error(f"Error message: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 打印原始数据的结构 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.error("\nRaw data structure:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                for key, value in brep_raw.items(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if isinstance(value, list): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"  {key}: list of length {len(value)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        if value: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            logger.error(f"    First element type: {type(value[0])}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            if hasattr(value[0], 'shape'): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                logger.error(f"    First element shape: {value[0].shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    elif hasattr(value, 'shape'): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"  {key}: shape {value.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"  {key}: {type(value)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                raise | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"Error loading sample from {brep_path}: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error("Data structure:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if 'brep_data' in locals(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                for key, value in brep_data.items(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if isinstance(value, np.ndarray): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"  {key}: type={type(value)}, dtype={value.dtype}, shape={value.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def _load_brep_file(self, brep_path): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """加载B-rep特征文件""" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 1. 加载原始数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            with open(brep_path, 'rb') as f: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                brep_data = pickle.load(f) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                raw_data = pickle.load(f) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            brep_data = {} | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 1. 处理几何数据(不等长序列) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for key in ['surf_wcs', 'surf_ncs', 'edge_wcs', 'edge_ncs']: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if key in brep_data: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 2. 处理几何数据(不等长序列) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            geom_keys = ['surf_wcs', 'surf_ncs', 'edge_wcs', 'edge_ncs'] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for key in geom_keys: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if key in raw_data: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        brep_data[key] = [ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            torch.from_numpy(np.array(x, dtype=np.float32))  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            for x in brep_data[key] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        ] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        # 确保数据是列表 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        if not isinstance(raw_data[key], list): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            raise ValueError(f"{key} is not a list") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        # 转换每个元素为张量 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        tensors = [] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        for i, x in enumerate(raw_data[key]): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                # 先转换为numpy数组 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                arr = np.array(x, dtype=np.float32) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                # 再转换为张量 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                tensor = torch.from_numpy(arr) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                tensors.append(tensor) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                logger.error(f"Error converting {key}[{i}]:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                logger.error(f"  Data type: {type(x)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                if isinstance(x, np.ndarray): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                    logger.error(f"  Shape: {x.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                    logger.error(f"  dtype: {x.dtype}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                raise ValueError(f"Failed to convert {key}[{i}]: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        brep_data[key] = tensors | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"Error converting {key}:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"  Type: {type(brep_data[key])}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        if isinstance(brep_data[key], list): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            logger.error(f"  List length: {len(brep_data[key])}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            if len(brep_data[key]) > 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                logger.error(f"  First element type: {type(brep_data[key][0])}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                logger.error(f"  First element shape: {brep_data[key][0].shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                logger.error(f"  First element dtype: {brep_data[key][0].dtype}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        raise ValueError(f"Failed to convert {key}: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"Error processing {key}:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"  Raw data type: {type(raw_data[key])}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        raise ValueError(f"Failed to process {key}: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 2. 处理固定形状的数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for key in ['corner_wcs', 'corner_unique', 'surf_bbox_wcs', 'edge_bbox_wcs']: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if key in brep_data: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 3. 处理固定形状的数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            fixed_keys = ['corner_wcs', 'corner_unique', 'surf_bbox_wcs', 'edge_bbox_wcs'] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for key in fixed_keys: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if key in raw_data: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        data = np.array(brep_data[key], dtype=np.float32) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        brep_data[key] = torch.from_numpy(data) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        # 直接从原始数据转换 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        arr = np.array(raw_data[key], dtype=np.float32) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        brep_data[key] = torch.from_numpy(arr) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"Error converting {key}:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"  Type: {type(brep_data[key])}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        if isinstance(brep_data[key], np.ndarray): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            logger.error(f"  Shape: {brep_data[key].shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            logger.error(f"  dtype: {brep_data[key].dtype}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"Error converting fixed shape data {key}:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"  Raw data type: {type(raw_data[key])}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        if isinstance(raw_data[key], np.ndarray): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            logger.error(f"  Shape: {raw_data[key].shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            logger.error(f"  dtype: {raw_data[key].dtype}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        raise ValueError(f"Failed to convert {key}: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 3. 处理邻接矩阵 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for key in ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj']: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if key in brep_data: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 4. 处理邻接矩阵 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            adj_keys = ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj'] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for key in adj_keys: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if key in raw_data: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        data = np.array(brep_data[key], dtype=np.int32) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        brep_data[key] = torch.from_numpy(data) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        # 转换为整型数组 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        arr = np.array(raw_data[key], dtype=np.int32) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        brep_data[key] = torch.from_numpy(arr) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"Error converting {key}:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"  Type: {type(brep_data[key])}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        if isinstance(brep_data[key], np.ndarray): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            logger.error(f"  Shape: {brep_data[key].shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            logger.error(f"  dtype: {brep_data[key].dtype}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"Error converting adjacency matrix {key}:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"  Raw data type: {type(raw_data[key])}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        if isinstance(raw_data[key], np.ndarray): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            logger.error(f"  Shape: {raw_data[key].shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            logger.error(f"  dtype: {raw_data[key].dtype}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        raise ValueError(f"Failed to convert {key}: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            feature_embedder = process_brep_data(brep_data, 70,70,1,) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            return feature_embedder | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 5. 验证必要的键是否存在 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            required_keys = {'surf_wcs', 'edge_wcs', 'corner_wcs'} | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            missing_keys = required_keys - set(brep_data.keys()) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if missing_keys: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                raise ValueError(f"Missing required keys: {missing_keys}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 6. 使用process_brep_data处理数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                features = process_brep_data( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    data=brep_data, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    max_face=self.max_face, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    max_edge=self.max_edge, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    bbox_scaled=self.bbox_scaled | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                return features | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.error("Error in process_brep_data:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.error(f"  Error message: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 打印数据形状信息 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.error("\nInput data shapes:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                for key, value in brep_data.items(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if isinstance(value, list): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        shapes = [t.shape for t in value] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"  {key}: list of tensors with shapes {shapes}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    elif isinstance(value, torch.Tensor): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"  {key}: tensor of shape {value.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                raise | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"\nError loading B-rep file: {brep_path}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"Error message: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 打印完整的数据结构信息 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if 'brep_data' in locals(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.error("\nComplete data structure:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                for key, value in brep_data.items(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.error(f"\n{key}:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.error(f"  Type: {type(value)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if isinstance(value, np.ndarray): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"  Shape: {value.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"  dtype: {value.dtype}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    elif isinstance(value, list): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"  List length: {len(value)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        if len(value) > 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            logger.error(f"  First element type: {type(value[0])}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            if isinstance(value[0], np.ndarray): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                logger.error(f"  First element shape: {value[0].shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                logger.error(f"  First element dtype: {value[0].dtype}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def _load_sdf_file(self, sdf_path): | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -190,6 +269,52 @@ class BRepSDFDataset(Dataset): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"Error type: {type(e).__name__}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"Error message: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    @staticmethod | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def collate_fn(batch): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """自定义批处理函数""" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 收集所有样本的名称 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        names = [item['name'] for item in batch] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 处理固定大小的张量数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        tensor_keys = ['edge_ncs', 'edge_pos', 'edge_mask',  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                      'surf_ncs', 'surf_pos', 'vertex_pos'] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        tensors = { | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            key: torch.stack([item[key] for item in batch]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for key in tensor_keys | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        } | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 处理变长的SDF数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        sdf_data = [item['sdf'] for item in batch] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        max_sdf_len = max(sdf.size(0) for sdf in sdf_data) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 填充SDF数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        padded_sdfs = [] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        sdf_masks = [] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for sdf in sdf_data: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            pad_len = max_sdf_len - sdf.size(0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if pad_len > 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                padding = torch.zeros(pad_len, sdf.size(1),  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                   dtype=sdf.dtype, device=sdf.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                padded_sdf = torch.cat([sdf, padding], dim=0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                mask = torch.cat([ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    torch.ones(sdf.size(0), dtype=torch.bool), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    torch.zeros(pad_len, dtype=torch.bool) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                ]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                padded_sdf = sdf | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                mask = torch.ones(sdf.size(0), dtype=torch.bool) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            padded_sdfs.append(padded_sdf) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            sdf_masks.append(mask) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 合并所有数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        batch_data = { | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'name': names, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'sdf': torch.stack(padded_sdfs), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'sdf_mask': torch.stack(sdf_masks), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            **tensors | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        } | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return batch_data | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def test_dataset(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """测试数据集功能""" | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |