diff --git a/brep2sdf/data/data.py b/brep2sdf/data/data.py index 8761992..49bfd28 100644 --- a/brep2sdf/data/data.py +++ b/brep2sdf/data/data.py @@ -9,250 +9,6 @@ from brep2sdf.config.default_config import get_default_config - - -class BRepSDFDataset(Dataset): - def __init__(self, brep_dir:str, sdf_dir:str, valid_data_dir:str, use_filter: bool=True, split:str='train'): - """ - 初始化数据集 - - 参数: - brep_dir: pkl文件目录 - sdf_dir: npz文件目录 - split: 数据集分割('train', 'val', 'test') - """ - super().__init__() - # 使用配置文件 - self.config = get_default_config() - - self.brep_dir = os.path.join(brep_dir, split) - self.sdf_dir = os.path.join(sdf_dir, split) - self.split = split - - # 使用配置文件中的参数替换固定参数 - self.max_face = self.config.data.max_face - self.max_edge = self.config.data.max_edge - self.bbox_scaled = self.config.data.bbox_scaled - - # 检查目录是否存在 - if not os.path.exists(self.brep_dir): - raise ValueError(f"B-rep directory not found: {self.brep_dir}") - if not os.path.exists(self.sdf_dir): - raise ValueError(f"SDF directory not found: {self.sdf_dir}") - - # 加载数据列表 - # 如果存在valid_data_file,则加载valid_list - valid_data_file = os.path.join(valid_data_dir, f'{split}_success.txt') - if valid_data_file: - valid_data_file = os.path.join(self.brep_dir, valid_data_file) - self.valid_data_list = self._load_valid_list(valid_data_file) - else: - raise ValueError(f"Valid data file not found: {valid_data_file}") - - self.brep_data_list = self._load_data_list(self.brep_dir) - self.sdf_data_list = self._load_data_list(self.sdf_dir) - - if use_filter: - self._filter_num_faces_and_num_edges() - - # 检查数据集是否为空 - if len(self.brep_data_list) == 0 : - raise ValueError(f"No valid brep data found in {split} set") - if len(self.sdf_data_list) == 0: - raise ValueError(f"No valid sdf data found in {split} set") - - logger.info(f"Loaded {split} dataset with {len(self.brep_data_list)} samples") - - def _load_valid_list(self,valid_data_file:str): - with open(valid_data_file, 'r') as f: - valid_list = [line.strip() for line in f.readlines()] - return valid_list - - # data_dir 为 self.brep_dir or sdf_dir - def _load_data_list(self, data_dir): - data_list = [] - for sample_file in os.listdir(data_dir): - if sample_file.split('.')[0] in self.valid_data_list: - path = os.path.join(data_dir, sample_file) - data_list.append(path) - - #logger.info(data_list) - return data_list - - def _filter_num_faces_and_num_edges(self): - ''' - Filter the data if their face_num or edge_num > max_face or max_edge. - ''' - # Collect indices of elements that satisfy the condition - filtered_indices = [ - idx for idx in range(len(self.brep_data_list)) - if (self._get_brep_face_and_edge(self.brep_data_list[idx]) <= (self.max_face, self.max_edge)) - ] - - #filtered_indices = filtered_indices[0:8] # TODO rm - - # Use filtered_indices to update brep_data_list and sdf_data_list - self.brep_data_list = [self.brep_data_list[idx] for idx in filtered_indices] - self.sdf_data_list = [self.sdf_data_list[idx] for idx in filtered_indices] - - def __len__(self): - return len(self.brep_data_list) - - def __getitem__(self, idx): - """获取单个数据样本""" - try: - brep_path = self.brep_data_list[idx] - sdf_path = self.sdf_data_list[idx] - name = os.path.splitext(os.path.basename(brep_path))[0] - - # 加载B-rep和SDF数据 - brep_raw = self._load_brep_file(brep_path) - sdf_data = self._load_sdf_file(sdf_path) - - try: - # 处理B-rep数据 - brep_features = process_brep_data( - data=brep_raw, - max_face=self.max_face, - max_edge=self.max_edge, - bbox_scaled=self.bbox_scaled - ) - ''' - # 打印数据形状 - logger.debug(f"Processed data shapes for {os.path.basename(brep_path)}:") - for value in brep_features: - if isinstance(value, torch.Tensor): - logger.debug(f" {value.shape}") - # 检查返回值的类型和数量 - if not isinstance(brep_features, tuple): - logger.error(f"process_brep_data returned {type(brep_features)}, expected tuple") - raise ValueError("Invalid return type from process_brep_data") - - if len(brep_features) != 6: - logger.error(f"Expected 6 features, got {len(brep_features)}") - logger.error("Features returned:") - for i, feat in enumerate(brep_features): - if isinstance(feat, torch.Tensor): - logger.error(f" {i}: Tensor of shape {feat.shape}") - else: - logger.error(f" {i}: {type(feat)}") - raise ValueError(f"Incorrect number of features: {len(brep_features)}") - ''' - # 解包处理后的特征 - edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos = brep_features - sdf_points = sdf_data[:, :3] - sdf_values = sdf_data[:, 3:] - - # 构建返回字典 - return { - 'name': name, - 'edge_ncs': edge_ncs, # [max_face, max_edge, 10, 3] - 'edge_pos': edge_pos, # [max_face, max_edge, 6] - 'edge_mask': edge_mask, # [max_face, max_edge] - 'surf_ncs': surf_ncs, # [max_face, 100, 3] - 'surf_pos': surf_pos, # [max_face, 6] - 'vertex_pos': vertex_pos, # [max_face, max_edge, 6] - 'points': sdf_points, # [num_queries, 3] 所有点的xyz坐标 - 'sdf': sdf_values # [num_queries, 1] 所有点的sdf值 - } - - except Exception as e: - logger.error(f"\nError processing B-rep data for file: {brep_path}") - logger.error(f"Error type: {type(e).__name__}") - logger.error(f"Error message: {str(e)}") - - # 打印原始数据的结构 - logger.error("\nRaw data structure:") - for key, value in brep_raw.items(): - if isinstance(value, list): - logger.error(f" {key}: list of length {len(value)}") - if value: - logger.error(f" First element type: {type(value[0])}") - if hasattr(value[0], 'shape'): - logger.error(f" First element shape: {value[0].shape}") - elif hasattr(value, 'shape'): - logger.error(f" {key}: shape {value.shape}") - else: - logger.error(f" {key}: {type(value)}") - raise - - except Exception as e: - logger.error(f"Error loading sample from {brep_path}: {str(e)}") - logger.error("Data structure:") - raise - def _load_brep_file(self, brep_path): - with open(brep_path, 'rb') as f: - brep_raw = pickle.load(f) - return brep_raw - - def _load_sdf_file(self, sdf_path): - """加载和处理SDF数据,并进行随机采样""" - try: - # 加载SDF值 - sdf_data = np.load(sdf_path) - if 'pos' not in sdf_data or 'neg' not in sdf_data: - raise ValueError("Missing pos/neg data in SDF file") - - sdf_pos = sdf_data['pos'] # (N1, 4) - sdf_neg = sdf_data['neg'] # (N2, 4) - - # 添加数据验证 - if sdf_pos.shape[1] != 4 or sdf_neg.shape[1] != 4: - raise ValueError(f"Invalid SDF data shape: pos={sdf_pos.shape}, neg={sdf_neg.shape}") - - # 随机采样 - max_points = self.config.data.num_query_points # 例如4096 - - # 确保正负样本均衡 - if max_points // 2 > sdf_pos.shape[0]: - logger.warning(f"正样本过少,期望>{max_points // 2},实际:{sdf_pos.shape[0]}") - - if max_points // 2 > sdf_neg.shape[0]: - num_neg = sdf_neg.shape[0] - else: - num_neg = max_points // 2 - - num_pos = max_points - num_neg - - # 随机采样正样本 - if sdf_pos.shape[0] > num_pos: - pos_indices = np.random.choice(sdf_pos.shape[0], num_pos, replace=False) - sdf_pos = sdf_pos[pos_indices] - - # 随机采样负样本 - if sdf_neg.shape[0] > num_neg: - neg_indices = np.random.choice(sdf_neg.shape[0], num_neg, replace=False) - sdf_neg = sdf_neg[neg_indices] - - # 合并数据 - sdf_np = np.concatenate([sdf_pos, sdf_neg], axis=0) - - # 再次随机打乱 - np.random.shuffle(sdf_np) - - # 如果总点数仍然超过最大限制,再次采样 - if sdf_np.shape[0] > max_points: - indices = np.random.choice(sdf_np.shape[0], max_points, replace=False) - sdf_np = sdf_np[indices] - - #logger.debug(f"Sampled SDF points: {sdf_np.shape[0]} (max: {max_points})") - return torch.from_numpy(sdf_np.astype(np.float32)) - - except Exception as e: - logger.error(f"Error loading SDF from {sdf_path}") - logger.error(f"Error type: {type(e).__name__}") - logger.error(f"Error message: {str(e)}") - raise - - def _get_brep_face_and_edge(self, brep_path: str) -> tuple[int,int]: - brep: dict = self._load_brep_file(brep_path) - face_edge_adj = brep["faceEdge_adj"] - num_faces, num_edges = face_edge_adj.shape - return num_faces, num_edges - - - - def load_brep_file(brep_path): with open(brep_path, 'rb') as f: brep_raw = pickle.load(f) @@ -457,270 +213,42 @@ def check_tensor(tensor: torch.Tensor | None, name: str, epoch: int, step: int = - - - - -def test_dataset(): - """测试数据集功能""" - try: - # 获取配置 - config = get_default_config() - - # 定义预期的数据维度 - expected_shapes = { - 'edge_ncs': (config.data.max_face, config.data.max_edge, config.model.num_edge_points, 3), - 'edge_pos': (config.data.max_face, config.data.max_edge, 6), - 'edge_mask': (config.data.max_face, config.data.max_edge), - 'surf_ncs': (config.data.max_face, config.model.num_surf_points, 3), - 'surf_pos': (config.data.max_face, 6), - 'vertex_pos': (config.data.max_face, config.data.max_edge, 2, 3), - 'points': (config.data.num_query_points, 3), - 'sdf': (config.data.num_query_points, 1) - } - - logger.info("="*50) - logger.info("测试数据集") - logger.info(f"预期形状:") - for key, shape in expected_shapes.items(): - logger.info(f" {key}: {shape}") - - # 初始化数据集 - dataset = BRepSDFDataset( - brep_dir=config.data.brep_dir, - sdf_dir=config.data.sdf_dir, - valid_data_dir=config.data.valid_data_dir, - split='train' - ) - - # 测试数据加载 - logger.info("\n测试数据加载...") - sample = dataset[0] - - # 检查数据类型和形状 - logger.info("\n数据类型和形状检查:") - for key, value in sample.items(): - if isinstance(value, torch.Tensor): - actual_shape = tuple(value.shape) - expected_shape = expected_shapes.get(key) - shape_match = "✓" if actual_shape == expected_shape else "✗" - - logger.info(f"\n{key}:") - logger.info(f" 实际形状: {actual_shape}") - logger.info(f" 预期形状: {expected_shape}") - logger.info(f" 匹配状态: {shape_match}") - logger.info(f" 数据类型: {value.dtype}") - - # 仅对浮点类型计算数值范围、均值和标准差 - if value.dtype.is_floating_point: - logger.info(f" 数值范围: [{value.min():.3f}, {value.max():.3f}]") - logger.info(f" 均值: {value.mean():.3f}") - logger.info(f" 标准差: {value.std():.3f}") - - if shape_match == "✗": - logger.warning(f" 形状不匹配: {key}") - if key in ['points', 'sdf']: - logger.warning(f" 查询点数量不一致,预期 {expected_shape[0]},实际 {actual_shape[0]}") - elif key in ['edge_ncs', 'edge_pos', 'edge_mask']: - logger.warning(f" 边数量不一致,预期 {expected_shape[1]},实际 {actual_shape[1]}") - elif key in ['surf_ncs', 'surf_pos']: - logger.warning(f" 面数量不一致,预期 {expected_shape[0]},实际 {actual_shape[0]}") - - # 测试批处理 - logger.info("\n测试批处理...") - batch_size = 4 - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=batch_size, - shuffle=True, - num_workers=0 - ) - - batch = next(iter(dataloader)) - logger.info("\n批处理形状检查:") - for key, value in batch.items(): - if isinstance(value, torch.Tensor): - batch_shape = tuple(value.shape) - expected_batch_shape = (batch_size,) + expected_shapes[key] - shape_match = "✓" if batch_shape == expected_batch_shape else "✗" - - logger.info(f"\n{key}:") - logger.info(f" 实际形状: {batch_shape}") - logger.info(f" 预期形状: {expected_batch_shape}") - logger.info(f" 匹配状态: {shape_match}") - logger.info(f" 数据类型: {value.dtype}") - - # 仅对浮点类型计算数值范围、均值和标准差 - if value.dtype.is_floating_point: - logger.info(f" 数值范围: [{value.min():.3f}, {value.max():.3f}]") - logger.info(f" 均值: {value.mean():.3f}") - logger.info(f" 标准差: {value.std():.3f}") - - if shape_match == "✗": - logger.warning(f" 批处理形状不匹配: {key}") - - logger.info("\n测试完成!") - logger.info("="*50) - - except Exception as e: - logger.error(f"测试过程中出错: {str(e)}") - raise -from collections import defaultdict -from tqdm import tqdm -def validate_dataset(split: str = 'train', num_samples: int = None): - """全面验证数据集 +def check_data_format(data, step_file): + """检查数据格式是否正确""" + required_keys = [ + 'surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs', 'corner_wcs', + 'edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj', + 'surf_bbox_wcs', 'edge_bbox_wcs', 'corner_unique' + ] - Args: - split: 数据集分割 ('train', 'val', 'test') - num_samples: 要检查的样本数量,None表示检查所有样本 - """ - try: - config = get_default_config() - logger.info(f"开始验证{split}数据集...") - - # 初始化数据集 - dataset = BRepSDFDataset( - brep_dir=config.data.brep_dir, - sdf_dir=config.data.sdf_dir, - valid_data_dir=config.data.valid_data_dir, - split='train' - ) - - total_samples = len(dataset) if num_samples is None else min(num_samples, len(dataset)) - logger.info(f"总样本数: {total_samples}") - - # 初始化统计信息 - stats = { - 'face_counts': [], - 'edge_counts': [], - 'vertex_counts': [], - 'sdf_point_counts': [], - 'invalid_samples': [], - 'shape_mismatches': defaultdict(int), - 'value_ranges': defaultdict(lambda: {'min': float('inf'), 'max': float('-inf')}), - 'nan_counts': defaultdict(int), - 'inf_counts': defaultdict(int) - } - - # 遍历数据集 - for idx in tqdm(range(total_samples), desc="验证数据"): - try: - sample = dataset[idx] - - # 1. 检查数据完整性 - required_keys = ['surf_ncs', 'surf_pos', 'edge_ncs', 'edge_pos', - 'vertex_pos', 'points', 'sdf', 'edge_mask'] - missing_keys = [key for key in required_keys if key not in sample] - if missing_keys: - stats['invalid_samples'].append((idx, f"缺少键: {missing_keys}")) - continue - - # 2. 检查形状 - expected_shapes = { - 'surf_ncs': (config.data.max_face, config.model.num_surf_points, 3), - 'surf_pos': (config.data.max_face, 6), - 'edge_ncs': (config.data.max_face, config.data.max_edge, config.model.num_edge_points, 3), - 'edge_pos': (config.data.max_face, config.data.max_edge, 6), - 'edge_mask': (config.data.max_face, config.data.max_edge), - 'vertex_pos': (config.data.max_face, config.data.max_edge, 2, 3), - 'points': (config.data.num_query_points, 3), - 'sdf': (config.data.num_query_points, 1) - } - - for key, expected_shape in expected_shapes.items(): - if key in sample: - actual_shape = tuple(sample[key].shape) - if actual_shape != expected_shape: - stats['shape_mismatches'][key] += 1 - stats['invalid_samples'].append( - (idx, f"{key} 形状不匹配: 预期 {expected_shape}, 实际 {actual_shape}") - ) - - # 3. 检查数值范围和无效值 - for key, tensor in sample.items(): - if isinstance(tensor, torch.Tensor) and tensor.dtype.is_floating_point: - # 更新值范围 - stats['value_ranges'][key]['min'] = min(stats['value_ranges'][key]['min'], - tensor.min().item()) - stats['value_ranges'][key]['max'] = max(stats['value_ranges'][key]['max'], - tensor.max().item()) - - # 检查NaN和Inf - nan_count = torch.isnan(tensor).sum().item() - inf_count = torch.isinf(tensor).sum().item() - if nan_count > 0: - stats['nan_counts'][key] += nan_count - if inf_count > 0: - stats['inf_counts'][key] += inf_count - - # 4. 收集统计信息 - stats['face_counts'].append(sample['surf_ncs'].shape[0]) - stats['edge_counts'].append(sample['edge_ncs'].shape[1]) - stats['vertex_counts'].append(len(torch.unique(sample['vertex_pos'].reshape(-1, 3), dim=0))) - stats['sdf_point_counts'].append(sample['points'].shape[0]) - - except Exception as e: - stats['invalid_samples'].append((idx, str(e))) - - # 输出统计结果 - logger.info("\n=== 数据集验证结果 ===") - - # 1. 基本统计信息 - logger.info("\n基本统计信息:") - logger.info(f"总样本数: {total_samples}") - logger.info(f"有效样本数: {total_samples - len(stats['invalid_samples'])}") - logger.info(f"无效样本数: {len(stats['invalid_samples'])}") - - # 2. 形状不匹配统计 - if stats['shape_mismatches']: - logger.info("\n形状不匹配统计:") - for key, count in stats['shape_mismatches'].items(): - logger.info(f" {key}: {count}个样本不匹配") - - # 3. 数值范围统计 - logger.info("\n数值范围统计:") - for key, ranges in stats['value_ranges'].items(): - logger.info(f" {key}:") - logger.info(f" 最小值: {ranges['min']:.3f}") - logger.info(f" 最大值: {ranges['max']:.3f}") - - # 4. 无效值统计 - if sum(stats['nan_counts'].values()) > 0 or sum(stats['inf_counts'].values()) > 0: - logger.info("\n无效值统计:") - for key in stats['nan_counts'].keys(): - if stats['nan_counts'][key] > 0: - logger.info(f" {key} 包含 {stats['nan_counts'][key]} 个 NaN 值") - for key in stats['inf_counts'].keys(): - if stats['inf_counts'][key] > 0: - logger.info(f" {key} 包含 {stats['inf_counts'][key]} 个 Inf 值") - - # 5. 几何特征统计 - logger.info("\n几何特征统计:") - for name, values in [ - ('面数', stats['face_counts']), - ('边数', stats['edge_counts']), - ('顶点数', stats['vertex_counts']), - ('SDF采样点数', stats['sdf_point_counts']) - ]: - values = np.array(values) - logger.info(f" {name}:") - logger.info(f" 最小值: {np.min(values)}") - logger.info(f" 最大值: {np.max(values)}") - logger.info(f" 平均值: {np.mean(values):.2f}") - logger.info(f" 中位数: {np.median(values):.2f}") - logger.info(f" 标准差: {np.std(values):.2f}") - - # 6. 输出无效样本详情 - if stats['invalid_samples']: - logger.info("\n无效样本详情:") - for idx, error in stats['invalid_samples']: - logger.info(f" 样本 {idx}: {error}") - - return stats - - except Exception as e: - logger.error(f"验证过程出错: {str(e)}") - raise + # 检查所有必需的键是否存在 + for key in required_keys: + if key not in data: + return False, f"Missing key: {key}" + + # 检查几何数据 + geometry_arrays = ['surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs'] + for key in geometry_arrays: + if not isinstance(data[key], np.ndarray): + return False, f"{key} should be a numpy array" + # 允许对象数组 + if data[key].dtype != object: + return False, f"{key} should be a numpy array with dtype=object" + + # 检查其他数组 + float32_arrays = ['corner_wcs', 'surf_bbox_wcs', 'edge_bbox_wcs', 'corner_unique'] + for key in float32_arrays: + if not isinstance(data[key], np.ndarray): + return False, f"{key} should be a numpy array" + if data[key].dtype != np.float32: + return False, f"{key} should be a numpy array with dtype=float32" + + int32_arrays = ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj'] + for key in int32_arrays: + if not isinstance(data[key], np.ndarray): + return False, f"{key} should be a numpy array" + if data[key].dtype != np.int32: + return False, f"{key} should be a numpy array with dtype=int32" + + return True, "" -if __name__ == '__main__': - validate_dataset(split='train', num_samples=None) # 先测试100个样本 \ No newline at end of file diff --git a/brep2sdf/data/pre_process_by_mesh.py b/brep2sdf/data/pre_process_by_mesh.py index 9629e4c..7d804d0 100644 --- a/brep2sdf/data/pre_process_by_mesh.py +++ b/brep2sdf/data/pre_process_by_mesh.py @@ -34,6 +34,8 @@ from OCC.Core.Bnd import Bnd_Box # 包围盒 from OCC.Core.TopoDS import TopoDS_Shape, topods, TopoDS_Vertex # 拓扑数据结构 from OCC.Core.StlAPI import StlAPI_Writer +from brep2sdf.data.sampler import sample_sdf_points_and_normals +from brep2sdf.data.data import check_data_format # 导入配置 from brep2sdf.config.default_config import get_default_config config = get_default_config() @@ -533,244 +535,8 @@ def batch_compute_normals(mesh, surf_wcs, normal_type='vertex', k_neighbors=3): return normals_output -def sample_sdf_points_and_normals( - trimesh_mesh_ncs: trimesh.Trimesh, - surf_bbox_ncs: np.ndarray, - num_sdf_samples: int = 4096, - sdf_sampling_std_dev: float = 0.01 -) -> np.ndarray | None: - """ - 在归一化坐标系(NCS)下采样固定数量的点,并计算它们的SDF值和最近表面法线。 - 采用均匀采样和近表面采样的混合策略。 - - 参数: - trimesh_mesh_ncs: 归一化的 Trimesh 对象。 - surf_bbox_ncs: 归一化坐标系下各面的包围盒 [num_faces, 6]。 - num_sdf_samples: 要采样的总点数。 - sdf_sampling_std_dev: 近表面采样时沿法线扰动的标准差。 - - 返回: - np.ndarray | None: 形状为 (N, 7) 的数组 [x, y, z, nx, ny, nz, sdf], - 如果采样或计算失败则返回 None。 - """ - logger.debug("为 SDF 计算采样点 (固定数量策略)...") - if trimesh_mesh_ncs is None or not isinstance(trimesh_mesh_ncs, trimesh.Trimesh): - logger.error("无效的 Trimesh 对象提供给 SDF 采样。") - return None - if num_sdf_samples <= 0: - logger.warning("请求的 SDF 采样点数为零或负数,不进行采样。") - return None - - # 使用固定的采样边界 [-0.5, 0.5],因为我们已经做过归一化 - min_bound_ncs = np.array([-0.5, -0.5, -0.5], dtype=np.float32) - max_bound_ncs = np.array([0.5, 0.5, 0.5], dtype=np.float32) - bbox_size_ncs = max_bound_ncs - min_bound_ncs - - # --- 使用固定的总样本数分配点数 --- - num_uniform_samples = num_sdf_samples // 2 - num_near_surface_samples = num_sdf_samples - num_uniform_samples - logger.debug(f"固定 SDF 采样点数: {num_sdf_samples} (均匀: {num_uniform_samples}, 近表面: {num_near_surface_samples})") - - # --- 执行采样 --- - sampled_points_list = [] - - # 均匀采样 (在 [-0.5, 0.5] 范围内) - if num_uniform_samples > 0: - uniform_points = np.random.uniform(-0.5, 0.5, (num_uniform_samples, 3)) - sampled_points_list.append(uniform_points) - # 近表面采样 - if num_near_surface_samples > 0: - if trimesh_mesh_ncs.faces.shape[0] > 0: - try: - near_points_on_surface = trimesh_mesh_ncs.sample(num_near_surface_samples) - proximity_query_near = ProximityQuery(trimesh_mesh_ncs) - closest_points_near, distances_near, face_indices_near = proximity_query_near.on_surface(near_points_on_surface) - if np.any(face_indices_near >= len(trimesh_mesh_ncs.face_normals)): - raise IndexError("Face index out of bounds during near-surface normal lookup") - normals_near = trimesh_mesh_ncs.face_normals[face_indices_near] - perturbations = np.random.randn(num_near_surface_samples, 1) * sdf_sampling_std_dev - near_points = near_points_on_surface + normals_near * perturbations - # 确保近表面点也在 [-0.5, 0.5] 范围内 - near_points = np.clip(near_points, -0.5, 0.5) - sampled_points_list.append(near_points) - except Exception as e: - logger.warning(f"近表面采样失败: {e}。将替换为均匀采样点。") - fallback_uniform = np.random.uniform(-0.5, 0.5, (num_near_surface_samples, 3)) - sampled_points_list.append(fallback_uniform) - else: - logger.warning("网格没有面,无法执行近表面采样。将替换为均匀采样点。") - fallback_uniform = np.random.uniform(-0.5, 0.5, (num_near_surface_samples, 3)) - sampled_points_list.append(fallback_uniform) - - # --- 合并采样点 --- - if not sampled_points_list: - logger.warning("没有为SDF采样到任何点。") - return None - - sampled_points_ncs = np.vstack(sampled_points_list).astype(np.float32) - - try: - proximity_query = ProximityQuery(trimesh_mesh_ncs) - - # 分批计算SDF以避免内存问题 - batch_size = 1000 - sdf_values = [] - closest_points = [] - face_indices = [] - - for i in range(0, len(sampled_points_ncs), batch_size): - batch_points = sampled_points_ncs[i:i + batch_size] - - # 计算当前批次的最近点和面 - batch_closest, batch_distances, batch_faces = proximity_query.on_surface(batch_points) - - # 计算点到最近面的向量 - direction_vectors = batch_points - batch_closest - - # 使用batch_compute_normals计算最近点的法向量 - # 将batch_closest重新组织为所需的格式 (N,) 数组,每个元素是 (M, 3) 数组 - closest_points_reshaped = np.array([batch_closest], dtype=object) - closest_points_reshaped[0] = batch_closest - - # 计算法向量 - normals_batch = batch_compute_normals( - trimesh_mesh_ncs, - closest_points_reshaped, - normal_type='vertex', # 使用顶点法向量 - k_neighbors=3 - )[0] # 取第一个元素因为我们只传入了一个批次 - - # 计算方向向量与法向量的点积 - dot_products = np.sum(direction_vectors * normals_batch, axis=1) - signs = np.sign(dot_products) - - # 确保零点处的符号处理 - zero_mask = np.abs(batch_distances) < 1e-6 - signs[zero_mask] = 0.0 - - # 计算带符号距离 - batch_sdf = batch_distances * signs - - # 限制SDF值的范围 - batch_sdf = np.clip(batch_sdf, -1.414, 1.414) # sqrt(2) - - # 添加调试信息 - if i == 0: # 只打印第一个批次的统计信息 - logger.debug(f"批次统计 (首批次):") - logger.debug(f" 法向量范围: [{normals_batch.min():.4f}, {normals_batch.max():.4f}]") - logger.debug(f" 法向量长度: {np.linalg.norm(normals_batch, axis=1).mean():.4f}") - logger.debug(f" 距离范围: [{batch_distances.min():.4f}, {batch_distances.max():.4f}]") - logger.debug(f" 点积范围: [{dot_products.min():.4f}, {dot_products.max():.4f}]") - logger.debug(f" 符号分布: 正={np.sum(signs > 0)}, 负={np.sum(signs < 0)}, 零={np.sum(signs == 0)}") - logger.debug(f" SDF范围: [{batch_sdf.min():.4f}, {batch_sdf.max():.4f}]") - - sdf_values.append(batch_sdf) - closest_points.append(batch_closest) - face_indices.append(batch_faces) - - # 合并批次结果 - sdf_values = np.concatenate(sdf_values) - closest_points = np.concatenate(closest_points) - - # 为所有点计算法向量 - all_points_reshaped = np.array([closest_points], dtype=object) - all_points_reshaped[0] = closest_points - sampled_normals = batch_compute_normals( - trimesh_mesh_ncs, - all_points_reshaped, - normal_type='vertex', - k_neighbors=3 - )[0] - - # 验证法向量 - normal_lengths = np.linalg.norm(sampled_normals, axis=1) - logger.debug(f"最终法向量统计:") - logger.debug(f" 形状: {sampled_normals.shape}") - logger.debug(f" 长度: min={normal_lengths.min():.4f}, max={normal_lengths.max():.4f}, mean={normal_lengths.mean():.4f}") - logger.debug(f" 分量范围: x=[{sampled_normals[:,0].min():.4f}, {sampled_normals[:,0].max():.4f}]") - logger.debug(f" 分量范围: y=[{sampled_normals[:,1].min():.4f}, {sampled_normals[:,1].max():.4f}]") - logger.debug(f" 分量范围: z=[{sampled_normals[:,2].min():.4f}, {sampled_normals[:,2].max():.4f}]") - - # 添加验证 - valid_mask = ( - ~np.isnan(sdf_values) & ~np.isinf(sdf_values) & - ~np.isnan(sampled_normals).any(axis=1) & ~np.isinf(sampled_normals).any(axis=1) & - ~np.isnan(sampled_points_ncs).any(axis=1) & ~np.isinf(sampled_points_ncs).any(axis=1) - ) - - if not np.all(valid_mask): - num_invalid = sampled_points_ncs.shape[0] - np.sum(valid_mask) - logger.warning(f"在 SDF/法线计算中发现 {num_invalid} 个无效条目。将它们过滤掉。") - sampled_points_ncs = sampled_points_ncs[valid_mask] - sampled_normals = sampled_normals[valid_mask] - sdf_values = sdf_values[valid_mask] - - if sampled_points_ncs.shape[0] > 0: - combined_data = np.hstack(( - sampled_points_ncs, - sampled_normals, - sdf_values[:, np.newaxis] - )).astype(np.float32) - - # 添加SDF分布验证 - final_sdf = combined_data[:, -1] - logger.debug(f"最终SDF分布验证:") - logger.debug(f" 正值点数: {np.sum(final_sdf > 0)}") - logger.debug(f" 负值点数: {np.sum(final_sdf < 0)}") - logger.debug(f" 零值点数: {np.sum(np.abs(final_sdf) < 1e-6)}") - logger.debug(f" SDF范围: [{final_sdf.min():.4f}, {final_sdf.max():.4f}]") - - # 验证分布是否合理 - if np.sum(final_sdf > 0) == 0 or np.sum(final_sdf < 0) == 0: - logger.warning("警告:SDF值分布异常,没有正值或负值!") - - return combined_data - else: - logger.warning("过滤 SDF/法线结果后没有剩余有效点。") - return None - except Exception as e: - logger.error(f"计算 SDF 或法线时失败: {str(e)}") - return None -def check_data_format(data, step_file): - """检查数据格式是否正确""" - required_keys = [ - 'surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs', 'corner_wcs', - 'edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj', - 'surf_bbox_wcs', 'edge_bbox_wcs', 'corner_unique' - ] - - # 检查所有必需的键是否存在 - for key in required_keys: - if key not in data: - return False, f"Missing key: {key}" - - # 检查几何数据 - geometry_arrays = ['surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs'] - for key in geometry_arrays: - if not isinstance(data[key], np.ndarray): - return False, f"{key} should be a numpy array" - # 允许对象数组 - if data[key].dtype != object: - return False, f"{key} should be a numpy array with dtype=object" - - # 检查其他数组 - float32_arrays = ['corner_wcs', 'surf_bbox_wcs', 'edge_bbox_wcs', 'corner_unique'] - for key in float32_arrays: - if not isinstance(data[key], np.ndarray): - return False, f"{key} should be a numpy array" - if data[key].dtype != np.float32: - return False, f"{key} should be a numpy array with dtype=float32" - - int32_arrays = ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj'] - for key in int32_arrays: - if not isinstance(data[key], np.ndarray): - return False, f"{key} should be a numpy array" - if data[key].dtype != np.int32: - return False, f"{key} should be a numpy array with dtype=int32" - - return True, "" def process_single_step(step_path:str, output_path:str=None, sample_normal_vector=False, sample_sdf_points=False, timeout:int=300) -> dict: """处理单个STEP文件, 从 brep 2 pkl diff --git a/brep2sdf/data/sampler.py b/brep2sdf/data/sampler.py new file mode 100644 index 0000000..215f0aa --- /dev/null +++ b/brep2sdf/data/sampler.py @@ -0,0 +1,300 @@ +""" +CAD模型处理脚本 +功能:将STEP格式的CAD模型转换为结构化数据,包括: +- 几何信息:面、边、顶点的坐标数据 +- 拓扑信息:面-边-顶点的邻接关系 +- 空间信息:包围盒数据 +""" + +import os +import pickle # 用于数据序列化 +import argparse # 命令行参数解析 +import numpy as np +from tqdm import tqdm # 进度条显示 +from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError # 并行处理 +import logging +from datetime import datetime +from scipy.spatial import cKDTree +from brep2sdf.utils.logger import logger +import tempfile +import trimesh +from trimesh.proximity import ProximityQuery + +# 导入OpenCASCADE相关库 +from OCC.Core.STEPControl import STEPControl_Reader # STEP文件读取器 +from OCC.Core.TopExp import TopExp_Explorer, topexp # 拓扑结构遍历 +from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX # 拓扑类型定义 +from OCC.Core.BRep import BRep_Tool # B-rep工具 +from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh # 网格剖分 +from OCC.Core.TopLoc import TopLoc_Location # 位置变换 +from OCC.Core.IFSelect import IFSelect_RetDone,IFSelect_RetError, IFSelect_RetFail, IFSelect_RetVoid # 操作状态码 +from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射 +from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算 +from OCC.Core.Bnd import Bnd_Box # 包围盒 +from OCC.Core.TopoDS import TopoDS_Shape, topods, TopoDS_Vertex # 拓扑数据结构 +from OCC.Core.StlAPI import StlAPI_Writer + +# 导入配置 +from brep2sdf.config.default_config import get_default_config +config = get_default_config() + +# 设置最大面数阈值,用于加速处理 +MAX_FACE = config.data.max_face + + +def _sample_uniform_points(num_points: int) -> np.ndarray: + """在 [-0.5, 0.5] 范围内均匀采样点 + + 参数: + num_points: 要采样的点数 + + 返回: + np.ndarray: 形状为 (num_points, 3) 的采样点数组 + """ + return np.random.uniform(-0.5, 0.5, (num_points, 3)) + +def _sample_near_surface_points( + mesh: trimesh.Trimesh, + num_points: int, + std_dev: float +) -> np.ndarray: + """在网格表面附近采样点 + + 参数: + mesh: 输入的trimesh网格 + num_points: 要采样的点数 + std_dev: 沿法线方向的扰动标准差 + + 返回: + np.ndarray: 形状为 (num_points, 3) 的采样点数组 + """ + if mesh.faces.shape[0] == 0: + logger.warning("网格没有面,无法执行近表面采样。将替换为均匀采样点。") + return _sample_uniform_points(num_points) + + try: + near_points_on_surface = mesh.sample(num_points) + proximity_query_near = ProximityQuery(mesh) + closest_points_near, _, face_indices_near = proximity_query_near.on_surface(near_points_on_surface) + + if np.any(face_indices_near >= len(mesh.face_normals)): + raise IndexError("Face index out of bounds during near-surface normal lookup") + + normals_near = mesh.face_normals[face_indices_near] + perturbations = np.random.randn(num_points, 1) * std_dev + near_points = near_points_on_surface + normals_near * perturbations + return np.clip(near_points, -0.5, 0.5) + + except Exception as e: + logger.warning(f"近表面采样失败: {e}。将替换为均匀采样点。") + return _sample_uniform_points(num_points) + +def sample_points( + trimesh_mesh_ncs: trimesh.Trimesh, + num_uniform_samples: int, + num_near_surface_samples: int, + sdf_sampling_std_dev: float +) -> np.ndarray | None: + """组合均匀采样和近表面采样的点 + + 参数: + trimesh_mesh_ncs: 归一化的trimesh网格 + num_uniform_samples: 均匀采样点数 + num_near_surface_samples: 近表面采样点数 + sdf_sampling_std_dev: 近表面采样的标准差 + + 返回: + np.ndarray | None: 合并后的采样点数组,失败时返回None + """ + sampled_points_list = [] + + # 均匀采样 + if num_uniform_samples > 0: + uniform_points = _sample_uniform_points(num_uniform_samples) + sampled_points_list.append(uniform_points) + + # 近表面采样 + if num_near_surface_samples > 0: + near_points = _sample_near_surface_points( + trimesh_mesh_ncs, + num_near_surface_samples, + sdf_sampling_std_dev + ) + sampled_points_list.append(near_points) + + # 合并采样点 + if not sampled_points_list: + logger.warning("没有采样到任何点。") + return None + + return np.vstack(sampled_points_list).astype(np.float32) + +# 在原始的sample_sdf_points_and_normals函数中使用新的采样函数 +def sample_sdf_points_and_normals( + trimesh_mesh_ncs: trimesh.Trimesh, + surf_bbox_ncs: np.ndarray, + num_sdf_samples: int = 4096, + sdf_sampling_std_dev: float = 0.01 +) -> np.ndarray | None: + """ + 在归一化坐标系(NCS)下采样固定数量的点,并计算它们的SDF值和最近表面法线。 + 采用均匀采样和近表面采样的混合策略。 + + 参数: + trimesh_mesh_ncs: 归一化的 Trimesh 对象。 + surf_bbox_ncs: 归一化坐标系下各面的包围盒 [num_faces, 6]。 + num_sdf_samples: 要采样的总点数。 + sdf_sampling_std_dev: 近表面采样时沿法线扰动的标准差。 + + 返回: + np.ndarray | None: 形状为 (N, 7) 的数组 [x, y, z, nx, ny, nz, sdf], + 如果采样或计算失败则返回 None。 + """ + logger.debug("为 SDF 计算采样点 (固定数量策略)...") + if trimesh_mesh_ncs is None or not isinstance(trimesh_mesh_ncs, trimesh.Trimesh): + logger.error("无效的 Trimesh 对象提供给 SDF 采样。") + return None + if num_sdf_samples <= 0: + logger.warning("请求的 SDF 采样点数为零或负数,不进行采样。") + return None + + # 使用固定的采样边界 [-0.5, 0.5],因为我们已经做过归一化 + min_bound_ncs = np.array([-0.5, -0.5, -0.5], dtype=np.float32) + max_bound_ncs = np.array([0.5, 0.5, 0.5], dtype=np.float32) + bbox_size_ncs = max_bound_ncs - min_bound_ncs + + # --- 使用固定的总样本数分配点数 --- + num_uniform_samples = num_sdf_samples // 2 + num_near_surface_samples = num_sdf_samples - num_uniform_samples + logger.debug(f"固定 SDF 采样点数: {num_sdf_samples} (均匀: {num_uniform_samples}, 近表面: {num_near_surface_samples})") + + # --- 执行采样 --- + sampled_points_ncs = sample_points( + trimesh_mesh_ncs, + num_uniform_samples, + num_near_surface_samples, + sdf_sampling_std_dev + ) + + try: + proximity_query = ProximityQuery(trimesh_mesh_ncs) + + # 分批计算SDF以避免内存问题 + batch_size = 1000 + sdf_values = [] + closest_points = [] + face_indices = [] + + for i in range(0, len(sampled_points_ncs), batch_size): + batch_points = sampled_points_ncs[i:i + batch_size] + + # 计算当前批次的最近点和面 + batch_closest, batch_distances, batch_faces = proximity_query.on_surface(batch_points) + + # 计算点到最近面的向量 + direction_vectors = batch_points - batch_closest + + # 使用batch_compute_normals计算最近点的法向量 + # 将batch_closest重新组织为所需的格式 (N,) 数组,每个元素是 (M, 3) 数组 + closest_points_reshaped = np.array([batch_closest], dtype=object) + closest_points_reshaped[0] = batch_closest + + # 计算法向量 + normals_batch = batch_compute_normals( + trimesh_mesh_ncs, + closest_points_reshaped, + normal_type='vertex', # 使用顶点法向量 + k_neighbors=3 + )[0] # 取第一个元素因为我们只传入了一个批次 + + # 计算方向向量与法向量的点积 + dot_products = np.sum(direction_vectors * normals_batch, axis=1) + signs = np.sign(dot_products) + + # 确保零点处的符号处理 + zero_mask = np.abs(batch_distances) < 1e-6 + signs[zero_mask] = 0.0 + + # 计算带符号距离 + batch_sdf = batch_distances * signs + + # 限制SDF值的范围 + batch_sdf = np.clip(batch_sdf, -1.414, 1.414) # sqrt(2) + + # 添加调试信息 + if i == 0: # 只打印第一个批次的统计信息 + logger.debug(f"批次统计 (首批次):") + logger.debug(f" 法向量范围: [{normals_batch.min():.4f}, {normals_batch.max():.4f}]") + logger.debug(f" 法向量长度: {np.linalg.norm(normals_batch, axis=1).mean():.4f}") + logger.debug(f" 距离范围: [{batch_distances.min():.4f}, {batch_distances.max():.4f}]") + logger.debug(f" 点积范围: [{dot_products.min():.4f}, {dot_products.max():.4f}]") + logger.debug(f" 符号分布: 正={np.sum(signs > 0)}, 负={np.sum(signs < 0)}, 零={np.sum(signs == 0)}") + logger.debug(f" SDF范围: [{batch_sdf.min():.4f}, {batch_sdf.max():.4f}]") + + sdf_values.append(batch_sdf) + closest_points.append(batch_closest) + face_indices.append(batch_faces) + + # 合并批次结果 + sdf_values = np.concatenate(sdf_values) + closest_points = np.concatenate(closest_points) + + # 为所有点计算法向量 + all_points_reshaped = np.array([closest_points], dtype=object) + all_points_reshaped[0] = closest_points + sampled_normals = batch_compute_normals( + trimesh_mesh_ncs, + all_points_reshaped, + normal_type='vertex', + k_neighbors=3 + )[0] + + # 验证法向量 + normal_lengths = np.linalg.norm(sampled_normals, axis=1) + logger.debug(f"最终法向量统计:") + logger.debug(f" 形状: {sampled_normals.shape}") + logger.debug(f" 长度: min={normal_lengths.min():.4f}, max={normal_lengths.max():.4f}, mean={normal_lengths.mean():.4f}") + logger.debug(f" 分量范围: x=[{sampled_normals[:,0].min():.4f}, {sampled_normals[:,0].max():.4f}]") + logger.debug(f" 分量范围: y=[{sampled_normals[:,1].min():.4f}, {sampled_normals[:,1].max():.4f}]") + logger.debug(f" 分量范围: z=[{sampled_normals[:,2].min():.4f}, {sampled_normals[:,2].max():.4f}]") + + # 添加验证 + valid_mask = ( + ~np.isnan(sdf_values) & ~np.isinf(sdf_values) & + ~np.isnan(sampled_normals).any(axis=1) & ~np.isinf(sampled_normals).any(axis=1) & + ~np.isnan(sampled_points_ncs).any(axis=1) & ~np.isinf(sampled_points_ncs).any(axis=1) + ) + + if not np.all(valid_mask): + num_invalid = sampled_points_ncs.shape[0] - np.sum(valid_mask) + logger.warning(f"在 SDF/法线计算中发现 {num_invalid} 个无效条目。将它们过滤掉。") + sampled_points_ncs = sampled_points_ncs[valid_mask] + sampled_normals = sampled_normals[valid_mask] + sdf_values = sdf_values[valid_mask] + + if sampled_points_ncs.shape[0] > 0: + combined_data = np.hstack(( + sampled_points_ncs, + sampled_normals, + sdf_values[:, np.newaxis] + )).astype(np.float32) + + # 添加SDF分布验证 + final_sdf = combined_data[:, -1] + logger.debug(f"最终SDF分布验证:") + logger.debug(f" 正值点数: {np.sum(final_sdf > 0)}") + logger.debug(f" 负值点数: {np.sum(final_sdf < 0)}") + logger.debug(f" 零值点数: {np.sum(np.abs(final_sdf) < 1e-6)}") + logger.debug(f" SDF范围: [{final_sdf.min():.4f}, {final_sdf.max():.4f}]") + + # 验证分布是否合理 + if np.sum(final_sdf > 0) == 0 or np.sum(final_sdf < 0) == 0: + logger.warning("警告:SDF值分布异常,没有正值或负值!") + + return combined_data + else: + logger.warning("过滤 SDF/法线结果后没有剩余有效点。") + return None + except Exception as e: + logger.error(f"计算 SDF 或法线时失败: {str(e)}") + return None \ No newline at end of file diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 802056e..f835b17 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -1,5 +1,6 @@ import torch from torch.serialization import add_safe_globals +from torch.utils.mobile_optimizer import optimize_for_mobile import torch.optim as optim import time import os @@ -64,7 +65,7 @@ class Trainer: surface_sdf_data = prepare_sdf_data( surfs, normals=self.data["surf_pnt_normals"], - max_points=4096, + max_points=50000, device=self.device ) # 如果不是仅使用零表面,则合并采样点数据 @@ -343,12 +344,36 @@ class Trainer: sdfs= model(example_input) logger.debug(f"sdfs:{sdfs}") - def _tracing_model(self): + def _tracing_model_by_script(self): """保存模型""" self.model.eval() # 确保模型中的所有逻辑都兼容 TorchScript scripted_model = torch.jit.script(self.model) - torch.jit.save(scripted_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt") + optimized_model = optimize_for_mobile(scripted_model) + torch.jit.save(optimized_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt") + + def _tracing_model(self): + """保存模型""" + self.model.eval() + + # 创建示例输入 + example_input = torch.rand(1, 3, device=self.device) + + # 使用 trace 方式导出模型 + traced_model = torch.jit.trace(self.model, example_input) + + # 保存模型 + save_path = f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt" + torch.jit.save(traced_model, save_path) + + # 验证保存的模型 + try: + loaded_model = torch.jit.load(save_path) + test_input = torch.rand(1, 3, device=self.device) + _ = loaded_model(test_input) + logger.info(f"模型已保存并验证成功:{save_path}") + except Exception as e: + logger.error(f"模型验证失败:{e}") def _load_checkpoint(self, checkpoint_path): """从检查点恢复训练状态"""