From 2e124b7193eb8911aa86737039eeb97c374413c4 Mon Sep 17 00:00:00 2001 From: mckay Date: Sat, 30 Nov 2024 18:38:19 +0800 Subject: [PATCH] fix: del collate and load_brep --- brep2sdf/data/data.py | 169 +++--------------------------------------- 1 file changed, 10 insertions(+), 159 deletions(-) diff --git a/brep2sdf/data/data.py b/brep2sdf/data/data.py index a893de3..5651aa4 100644 --- a/brep2sdf/data/data.py +++ b/brep2sdf/data/data.py @@ -121,7 +121,9 @@ class BRepSDFDataset(Dataset): # 解包处理后的特征 edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos = brep_features - + sdf_points = sdf_data[:, :3] + sdf_values = sdf_data[:, 3:] + # 构建返回字典 return { 'name': name, @@ -131,8 +133,8 @@ class BRepSDFDataset(Dataset): 'surf_ncs': surf_ncs, # [max_face, 100, 3] 'surf_pos': surf_pos, # [max_face, 6] 'vertex_pos': vertex_pos, # [max_face, max_edge, 6] - 'points': sdf_data[:, :3], # [num_queries, 3] 所有点的xyz坐标 - 'sdf': sdf_data[:, 3:] # [num_queries, 1] 所有点的sdf值 + 'points': sdf_points, # [num_queries, 3] 所有点的xyz坐标 + 'sdf': sdf_values # [num_queries, 1] 所有点的sdf值 } except Exception as e: @@ -161,115 +163,6 @@ class BRepSDFDataset(Dataset): raise - - def _load_brep_file(self, brep_path): - """加载B-rep特征文件""" - try: - # 1. 加载原始数据 - with open(brep_path, 'rb') as f: - raw_data = pickle.load(f) - - brep_data = {} - - # 2. 处理几何数据(不等长序列) - geom_keys = ['surf_wcs', 'surf_ncs', 'edge_wcs', 'edge_ncs'] - for key in geom_keys: - if key in raw_data: - try: - # 确保数据是列表 - if not isinstance(raw_data[key], list): - raise ValueError(f"{key} is not a list") - - # 转换每个元素为张量 - tensors = [] - for i, x in enumerate(raw_data[key]): - try: - # 先转换为numpy数组 - arr = np.array(x, dtype=np.float32) - # 再转换为张量 - tensor = torch.from_numpy(arr) - tensors.append(tensor) - except Exception as e: - logger.error(f"Error converting {key}[{i}]:") - logger.error(f" Data type: {type(x)}") - if isinstance(x, np.ndarray): - logger.error(f" Shape: {x.shape}") - logger.error(f" dtype: {x.dtype}") - raise ValueError(f"Failed to convert {key}[{i}]: {str(e)}") - - brep_data[key] = tensors - - except Exception as e: - logger.error(f"Error processing {key}:") - logger.error(f" Raw data type: {type(raw_data[key])}") - raise ValueError(f"Failed to process {key}: {str(e)}") - - # 3. 处理固定形状的数据 - fixed_keys = ['corner_wcs', 'corner_unique', 'surf_bbox_wcs', 'edge_bbox_wcs'] - for key in fixed_keys: - if key in raw_data: - try: - # 直接从原始数据转换 - arr = np.array(raw_data[key], dtype=np.float32) - brep_data[key] = torch.from_numpy(arr) - except Exception as e: - logger.error(f"Error converting fixed shape data {key}:") - logger.error(f" Raw data type: {type(raw_data[key])}") - if isinstance(raw_data[key], np.ndarray): - logger.error(f" Shape: {raw_data[key].shape}") - logger.error(f" dtype: {raw_data[key].dtype}") - raise ValueError(f"Failed to convert {key}: {str(e)}") - - # 4. 处理邻接矩阵 - adj_keys = ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj'] - for key in adj_keys: - if key in raw_data: - try: - # 转换为整型数组 - arr = np.array(raw_data[key], dtype=np.int32) - brep_data[key] = torch.from_numpy(arr) - except Exception as e: - logger.error(f"Error converting adjacency matrix {key}:") - logger.error(f" Raw data type: {type(raw_data[key])}") - if isinstance(raw_data[key], np.ndarray): - logger.error(f" Shape: {raw_data[key].shape}") - logger.error(f" dtype: {raw_data[key].dtype}") - raise ValueError(f"Failed to convert {key}: {str(e)}") - - # 5. 验证必要的键是否存在 - required_keys = {'surf_wcs', 'edge_wcs', 'corner_wcs'} - missing_keys = required_keys - set(brep_data.keys()) - if missing_keys: - raise ValueError(f"Missing required keys: {missing_keys}") - - # 6. 使用process_brep_data处理数据 - try: - features = process_brep_data( - data=brep_data, - max_face=self.max_face, - max_edge=self.max_edge, - bbox_scaled=self.bbox_scaled - ) - return features - - except Exception as e: - logger.error("Error in process_brep_data:") - logger.error(f" Error message: {str(e)}") - # 打印数据形状信息 - logger.error("\nInput data shapes:") - for key, value in brep_data.items(): - if isinstance(value, list): - shapes = [t.shape for t in value] - logger.error(f" {key}: list of tensors with shapes {shapes}") - elif isinstance(value, torch.Tensor): - logger.error(f" {key}: tensor of shape {value.shape}") - raise - - except Exception as e: - logger.error(f"\nError loading B-rep file: {brep_path}") - logger.error(f"Error message: {str(e)}") - raise - def _load_sdf_file(self, sdf_path): """加载和处理SDF数据,并进行随机采样""" try: @@ -321,52 +214,7 @@ class BRepSDFDataset(Dataset): logger.error(f"Error type: {type(e).__name__}") logger.error(f"Error message: {str(e)}") raise - @staticmethod - def collate_fn(batch): - """自定义批处理函数""" - # 收集所有样本的名称 - names = [item['name'] for item in batch] - - # 处理固定大小的张量数据 - tensor_keys = ['edge_ncs', 'edge_pos', 'edge_mask', - 'surf_ncs', 'surf_pos', 'vertex_pos'] - tensors = { - key: torch.stack([item[key] for item in batch]) - for key in tensor_keys - } - - # 处理变长的SDF数据 - sdf_data = [item['sdf'] for item in batch] - max_sdf_len = max(sdf.size(0) for sdf in sdf_data) - - # 填充SDF数据 - padded_sdfs = [] - sdf_masks = [] - for sdf in sdf_data: - pad_len = max_sdf_len - sdf.size(0) - if pad_len > 0: - padding = torch.zeros(pad_len, sdf.size(1), - dtype=sdf.dtype, device=sdf.device) - padded_sdf = torch.cat([sdf, padding], dim=0) - mask = torch.cat([ - torch.ones(sdf.size(0), dtype=torch.bool), - torch.zeros(pad_len, dtype=torch.bool) - ]) - else: - padded_sdf = sdf - mask = torch.ones(sdf.size(0), dtype=torch.bool) - padded_sdfs.append(padded_sdf) - sdf_masks.append(mask) - - # 合并所有数据 - batch_data = { - 'name': names, - 'sdf': torch.stack(padded_sdfs), - 'sdf_mask': torch.stack(sdf_masks), - **tensors - } - - return batch_data + def test_dataset(): """测试数据集功能""" @@ -381,6 +229,7 @@ def test_dataset(): max_edge = config.data.max_edge num_edge_points = config.model.num_edge_points num_surf_points = config.model.num_surf_points + num_query_points = config.data.num_query_points # 定义预期的数据维度,使用配置中的参数 expected_shapes = { @@ -390,7 +239,8 @@ def test_dataset(): 'surf_ncs': (max_face, num_surf_points, 3), 'surf_pos': (max_face, 6), 'vertex_pos': (max_face, max_edge, 2, 3), - 'sdf': (2097152, 4) + 'points': (num_query_points, 3), + 'sdf': (num_query_points, 1) } logger.info("="*50) @@ -424,6 +274,7 @@ def test_dataset(): logger.info(f" Expected: {expected_shape}") logger.info(f" Match: {shape_match}") logger.info(f" dtype: {value.dtype}") + logger.info(f" grad: {value.requires_grad}") if not shape_match: logger.warning(f" Shape mismatch for {key}!")