diff --git a/brep2sdf/data/data.py b/brep2sdf/data/data.py index 5a2de6f..117fc69 100644 --- a/brep2sdf/data/data.py +++ b/brep2sdf/data/data.py @@ -251,6 +251,8 @@ class BRepSDFDataset(Dataset): return num_faces, num_edges + + def load_brep_file(brep_path): with open(brep_path, 'rb') as f: brep_raw = pickle.load(f) diff --git a/brep2sdf/data/pre_process.py b/brep2sdf/data/pre_process.py new file mode 100644 index 0000000..903a2ae --- /dev/null +++ b/brep2sdf/data/pre_process.py @@ -0,0 +1,540 @@ +""" +CAD模型处理脚本 +功能:将STEP格式的CAD模型转换为结构化数据,包括: +- 几何信息:面、边、顶点的坐标数据 +- 拓扑信息:面-边-顶点的邻接关系 +- 空间信息:包围盒数据 +""" + +import os +import pickle # 用于数据序列化 +import argparse # 命令行参数解析 +import numpy as np +from tqdm import tqdm # 进度条显示 +from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError # 并行处理 +import logging +from datetime import datetime +from brep2sdf.utils.logger import logger + + +# 导入OpenCASCADE相关库 +from OCC.Core.STEPControl import STEPControl_Reader # STEP文件读取器 +from OCC.Core.TopExp import TopExp_Explorer, topexp # 拓扑结构遍历 +from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX # 拓扑类型定义 +from OCC.Core.BRep import BRep_Tool # B-rep工具 +from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh # 网格剖分 +from OCC.Core.TopLoc import TopLoc_Location # 位置变换 +from OCC.Core.IFSelect import IFSelect_RetDone,IFSelect_RetError, IFSelect_RetFail, IFSelect_RetVoid # 操作状态码 +from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射 +from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算 +from OCC.Core.Bnd import Bnd_Box # 包围盒 +from OCC.Core.TopoDS import TopoDS_Shape, topods, TopoDS_Vertex # 拓扑数据结构 + +# 导入配置 +from brep2sdf.config.default_config import get_default_config +config = get_default_config() + +# 设置最大面数阈值,用于加速处理 +MAX_FACE = config.data.max_face + +def normalize(surfs, edges, corners): + """ + 将CAD模型归一化到单位立方体空间 + + 参数: + surfs: 面的点集列表 + edges: 边的点集列表 + corners: 顶点坐标数组 [num_edges, 2, 3] + + 返回: + surfs_wcs: 原始坐标系下的面点集 + edges_wcs: 原始坐标系下的边点集 + surfs_ncs: 归一化坐标系下的面点集 + edges_ncs: 归一化坐标系下的边点集 + corner_wcs: 归一化后的顶点坐标 [num_edges, 2, 3] + center: 使用的中心点坐标 [3,] + scale: 使用的缩放系数 (float) + """ + if len(corners) == 0: + return None, None, None, None, None, None, None + + # 计算包围盒和缩放因子 + corners_array = corners.reshape(-1, 3) # [num_edges*2, 3] + center = (corners_array.max(0) + corners_array.min(0)) / 2 # 计算中心点 + scale = 1.0 / (corners_array.max(0) - corners_array.min(0)).max() # 计算缩放系数 + + # 归一化面的点集 + surfs_wcs = [] # 原始世界坐标系下的面点集 + surfs_ncs = [] # 归一化坐标系下的面点集 + for surf in surfs: + surf_wcs = np.array(surf) + surf_ncs = (surf_wcs - center) * scale # 归一化变换 + surfs_wcs.append(surf_wcs) + surfs_ncs.append(surf_ncs) + + # 归一化边的点集 + edges_wcs = [] # 原始世界坐标系下的边点集 + edges_ncs = [] # 归一化坐标系下的边点集 + for edge in edges: + edge_wcs = np.array(edge) + edge_ncs = (edge_wcs - center) * scale # 归一化变换 + edges_wcs.append(edge_wcs) + edges_ncs.append(edge_ncs) + + # 归一化顶点坐标 - 保持[num_edges, 2, 3]的形状 + corner_wcs = (corners - center) * scale # 广播操作会保持原有维度 + + return (np.array(surfs_wcs, dtype=object), + np.array(edges_wcs, dtype=object), + np.array(surfs_ncs, dtype=object), + np.array(edges_ncs, dtype=object), + corner_wcs.astype(np.float32), + center.astype(np.float32), + scale + ) + +def get_adjacency_info(shape): + """ + 获取CAD模型中面、边、顶点之间的邻接关系 + + 参数: + shape: CAD模型的形状对象 + + 返回: + edgeFace_adj: 边-面邻接矩阵 (num_edges × num_faces) + faceEdge_adj: 面-边邻接矩阵 (num_faces × num_edges) + edgeCorner_adj: 边-顶点邻接矩阵 (num_edges × 2) + """ + # 创建边-面映射关系 + edge_face_map = TopTools_IndexedDataMapOfShapeListOfShape() + topexp.MapShapesAndAncestors(shape, TopAbs_EDGE, TopAbs_FACE, edge_face_map) + + # 获取所有几何元素 + faces = [] # 存储所有面 + edges = [] # 存储所有边 + vertices = [] # 存储所有顶点 + + # 创建拓扑结构探索器 + face_explorer = TopExp_Explorer(shape, TopAbs_FACE) + edge_explorer = TopExp_Explorer(shape, TopAbs_EDGE) + vertex_explorer = TopExp_Explorer(shape, TopAbs_VERTEX) + + # 收集所有几何元素 + while face_explorer.More(): + faces.append(topods.Face(face_explorer.Current())) + face_explorer.Next() + + while edge_explorer.More(): + edges.append(topods.Edge(edge_explorer.Current())) + edge_explorer.Next() + + while vertex_explorer.More(): + vertices.append(topods.Vertex(vertex_explorer.Current())) + vertex_explorer.Next() + + # 创建邻接矩阵 + num_faces = len(faces) + num_edges = len(edges) + num_vertices = len(vertices) + + edgeFace_adj = np.zeros((num_edges, num_faces), dtype=np.int32) + faceEdge_adj = np.zeros((num_faces, num_edges), dtype=np.int32) + edgeCorner_adj = np.zeros((num_edges, 2), dtype=np.int32) + + # 填充边-面邻接矩阵 + for i, edge in enumerate(edges): + # 检查每个面是否与当前边相连 + for j, face in enumerate(faces): + edge_explorer = TopExp_Explorer(face, TopAbs_EDGE) + while edge_explorer.More(): + if edge.IsSame(edge_explorer.Current()): + edgeFace_adj[i, j] = 1 + faceEdge_adj[j, i] = 1 + break + edge_explorer.Next() + + # 获取边的两个端点 + v1 = TopoDS_Vertex() + v2 = TopoDS_Vertex() + topexp.Vertices(edge, v1, v2) + + # 记录边的端点索引 + if not v1.IsNull() and not v2.IsNull(): + v1_vertex = topods.Vertex(v1) + v2_vertex = topods.Vertex(v2) + + for k, vertex in enumerate(vertices): + if v1_vertex.IsSame(vertex): + edgeCorner_adj[i, 0] = k + if v2_vertex.IsSame(vertex): + edgeCorner_adj[i, 1] = k + + return edgeFace_adj, faceEdge_adj, edgeCorner_adj + +def get_bbox(shape, subshape): + """ + 计算形状的包围盒 + + 参数: + shape: 完整的CAD模型形状 + subshape: 需要计算包围盒的子形状(面或边) + + 返回: + 包围盒的六个参数 [xmin, ymin, zmin, xmax, ymax, zmax] + """ + bbox = Bnd_Box() + brepbndlib.Add(subshape, bbox) + xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get() + return np.array([xmin, ymin, zmin, xmax, ymax, zmax]) + + + +def parse_solid(step_path): + """ + 解析STEP文件中的CAD模型数据 + + 返回: + dict: 包含以下键值对的字典: + # 几何数据 + 'surf_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示面的点云坐标 + 'edge_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示边的采样点坐标 + 'surf_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示归一化后的面点云 + 'edge_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示归一化后的边采样点 + 'corner_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 2, 3)的数组,表示每条边的两个端点坐标 + 'corner_unique': np.ndarray(dtype=float32) # 形状为(num_vertices, 3)的数组,表示所有顶点的唯一坐标,num_vertices <= num_edges * 2 + + # 拓扑关系 + 'edgeFace_adj': np.ndarray(dtype=int32) # 形状为(num_edges, num_faces)的数组,表示边-面邻接关系 + 'edgeCorner_adj': np.ndarray(dtype=int32) # 形状为(num_edges, 2)的数组,表示边-顶点邻接关系 + 'faceEdge_adj': np.ndarray(dtype=int32) # 形状为(num_faces, num_edges)的数组,表示面-边邻接关系 + + # 包围盒数据 + 'surf_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_faces, 6)的数组,表示每个面的包围盒[xmin,ymin,zmin,xmax,ymax,zmax] + 'edge_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 6)的数组,表示每条边的包围盒[xmin,ymin,zmin,xmax,ymax,zmax] + """ + # Load STEP file + reader = STEPControl_Reader() + status = reader.ReadFile(step_path) + if status != IFSelect_RetDone: + if status == IFSelect_RetError: + print("Error: An error occurred while reading the file.") + elif status == IFSelect_RetFail: + print("Error: Failed to read the file.") + elif status == IFSelect_RetVoid: + print("Error: No data was read from the file.") + else: + print(f"Unexpected status code: {status}") + raise Exception(f"Failed to read STEP file. {status}") + + reader.TransferRoots() + shape = reader.OneShape() + + # Create mesh + mesh = BRepMesh_IncrementalMesh(shape, 0.01) + mesh.Perform() + + # Initialize explorers + face_explorer = TopExp_Explorer(shape, TopAbs_FACE) + edge_explorer = TopExp_Explorer(shape, TopAbs_EDGE) + vertex_explorer = TopExp_Explorer(shape, TopAbs_VERTEX) + + face_pnts = [] + edge_pnts = [] + corner_pnts = [] + surf_bbox_wcs = [] + edge_bbox_wcs = [] + + # Extract face points + while face_explorer.More(): + face = topods.Face(face_explorer.Current()) + loc = TopLoc_Location() + triangulation = BRep_Tool.Triangulation(face, loc) + + if triangulation is not None: + points = [] + for i in range(1, triangulation.NbNodes() + 1): + node = triangulation.Node(i) + pnt = node.Transformed(loc.Transformation()) + points.append([pnt.X(), pnt.Y(), pnt.Z()]) + + if points: + points = np.array(points, dtype=np.float32) + if len(points.shape) == 2 and points.shape[1] == 3: + # 确保每个面至少有一些点 + if len(points) < 3: # 如果点数太少,跳过这个面 + continue + face_pnts.append(points) + surf_bbox_wcs.append(get_bbox(shape, face)) + + face_explorer.Next() + + # Extract edge points + num_samples = config.model.num_edge_points # 使用配置中的边采样点数 + while edge_explorer.More(): + edge = topods.Edge(edge_explorer.Current()) + curve_info = BRep_Tool.Curve(edge) + if curve_info is None: + continue # 跳过无效边 + + try: + if len(curve_info) == 3: + curve, first, last = curve_info + elif len(curve_info) == 2: + continue + curve, location = curve_info + logger.info(curve) + first, last = BRep_Tool.Range(edge) # 显式获取参数范围 + else: + raise ValueError(f"Unexpected curve info: {curve_info}") + except Exception as e: + logger.error(f"Failed to process edge {edge}: {str(e)}") + continue + + if curve is not None: + points = [] + for i in range(num_samples): + param = first + (last - first) * float(i) / (num_samples - 1) + pnt = curve.Value(param) + points.append([pnt.X(), pnt.Y(), pnt.Z()]) + + if points: + points = np.array(points, dtype=np.float32) + if len(points.shape) == 2 and points.shape[1] == 3: + edge_pnts.append(points) # 现在points是(num_edge_points, 3)形状 + edge_bbox_wcs.append(get_bbox(shape, edge)) + + edge_explorer.Next() + + # Extract vertex points + while vertex_explorer.More(): + vertex = topods.Vertex(vertex_explorer.Current()) + pnt = BRep_Tool.Pnt(vertex) + corner_pnts.append([pnt.X(), pnt.Y(), pnt.Z()]) + vertex_explorer.Next() + + # 获取邻接信息 + edgeFace_adj, faceEdge_adj, edgeCorner_adj = get_adjacency_info(shape) + + # 转换为numpy数组时确保类型正确 + face_pnts = [np.array(points, dtype=np.float32) for points in face_pnts] + edge_pnts = [np.array(points, dtype=np.float32) for points in edge_pnts] + + # 转换为对象数组 + face_pnts = np.array(face_pnts, dtype=object) + edge_pnts = np.array(edge_pnts, dtype=object) + corner_pnts = np.array(corner_pnts, dtype=np.float32) + + # 重组顶点数据为每条边两个端点的形式 + corner_pairs = [] + for edge_idx in range(len(edge_pnts)): + v1_idx, v2_idx = edgeCorner_adj[edge_idx] + v1_pos = corner_pnts[v1_idx] + v2_pos = corner_pnts[v2_idx] + # 按坐标排序确保顺序一致 + if (v1_pos > v2_pos).any(): + v1_pos, v2_pos = v2_pos, v1_pos + corner_pairs.append(np.stack([v1_pos, v2_pos])) + + corner_pairs = np.stack(corner_pairs).astype(np.float32) # [num_edges, 2, 3] + + # 确保所有数组都有正确的类型 + surf_bbox_wcs = np.array(surf_bbox_wcs, dtype=np.float32) + edge_bbox_wcs = np.array(edge_bbox_wcs, dtype=np.float32) + + # Normalize the CAD model + surfs_wcs, edges_wcs, surfs_ncs, edges_ncs, corner_wcs,center, scale = normalize( + face_pnts, edge_pnts, corner_pairs) + + # 计算归一化后的包围盒 + surf_bbox_ncs = np.empty_like(surf_bbox_wcs) + edge_bbox_ncs = np.empty_like(edge_bbox_wcs) + + # 转换曲面包围盒到归一化坐标系 + surf_bbox_ncs[:, :3] = (surf_bbox_wcs[:, :3] - center) * scale # 最小点 + surf_bbox_ncs[:, 3:] = (surf_bbox_wcs[:, 3:] - center) * scale # 最大点 + + # 转换边包围盒到归一化坐标系 + edge_bbox_ncs[:, :3] = (edge_bbox_wcs[:, :3] - center) * scale # 最小点 + edge_bbox_ncs[:, 3:] = (edge_bbox_wcs[:, 3:] - center) * scale # 最大点 + + + # 验证归一化后的数据 + if any(x is None for x in [surfs_wcs, edges_wcs, surfs_ncs, edges_ncs, corner_wcs]): + logger.error(f"Normalization failed for {step_path}") + return None + + # 创建结果字典并确保所有数组都有正确的类型 + data = { + 'surf_wcs': np.array(surfs_wcs, dtype=object), # 保持对象数组 + 'edge_wcs': np.array(edges_wcs, dtype=object), # 保持对象数组 + 'surf_ncs': np.array(surfs_ncs, dtype=object), # 保持对象数组 + 'edge_ncs': np.array(edges_ncs, dtype=object), # 保持对象数组 + 'corner_wcs': corner_wcs.astype(np.float32), # [num_edges, 2, 3] + 'edgeFace_adj': edgeFace_adj.astype(np.int32), + 'edgeCorner_adj': edgeCorner_adj.astype(np.int32), + 'faceEdge_adj': faceEdge_adj.astype(np.int32), + 'surf_bbox_wcs': surf_bbox_wcs.astype(np.float32), + 'edge_bbox_wcs': edge_bbox_wcs.astype(np.float32), + 'surf_bbox_ncs': surf_bbox_ncs.astype(np.float32), # 归一化坐标系 [num_faces, 6] + 'edge_bbox_ncs': edge_bbox_ncs.astype(np.float32), # 归一化坐标系 [num_edges, 6] + 'corner_unique': np.unique(corner_wcs.reshape(-1, 3), axis=0).astype(np.float32), # 先展平再去重 + 'normalization_params': { + 'center': center.astype(np.float32), # 归一化中心点 [3,] + 'scale': float(scale), # 归一化缩放系数 + } + } + + return data + +def load_step(step_path): + """Load STEP file and return solids""" + reader = STEPControl_Reader() + reader.ReadFile(step_path) + reader.TransferRoots() + return [reader.OneShape()] + +def check_data_format(data, step_file): + """检查数据格式是否正确""" + required_keys = [ + 'surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs', 'corner_wcs', + 'edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj', + 'surf_bbox_wcs', 'edge_bbox_wcs', 'corner_unique' + ] + + # 检查所有必需的键是否存在 + for key in required_keys: + if key not in data: + return False, f"Missing key: {key}" + + # 检查几何数据 + geometry_arrays = ['surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs'] + for key in geometry_arrays: + if not isinstance(data[key], np.ndarray): + return False, f"{key} should be a numpy array" + # 允许对象数组 + if data[key].dtype != object: + return False, f"{key} should be a numpy array with dtype=object" + + # 检查其他数组 + float32_arrays = ['corner_wcs', 'surf_bbox_wcs', 'edge_bbox_wcs', 'corner_unique'] + for key in float32_arrays: + if not isinstance(data[key], np.ndarray): + return False, f"{key} should be a numpy array" + if data[key].dtype != np.float32: + return False, f"{key} should be a numpy array with dtype=float32" + + int32_arrays = ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj'] + for key in int32_arrays: + if not isinstance(data[key], np.ndarray): + return False, f"{key} should be a numpy array" + if data[key].dtype != np.int32: + return False, f"{key} should be a numpy array with dtype=int32" + + return True, "" + +def process_single_step(step_path:str, output_path:str=None, timeout:int=300) -> dict: + """处理单个STEP文件, 从 brep 2 pkl + return data = { + 'surf_wcs': np.array(surfs_wcs, dtype=object), # 世界坐标系下的曲面几何数据(对象数组) + 'edge_wcs': np.array(edges_wcs, dtype=object), # 世界坐标系下的边几何数据(对象数组) + 'surf_ncs': np.array(surfs_ncs, dtype=object), # 归一化坐标系下的曲面几何数据(对象数组) 面归一化点云 [num_faces, num_surf_sample_points, 3] + 'edge_ncs': np.array(edges_ncs, dtype=object), # 归一化坐标系下的边几何数据(对象数组) 边归一化点云 [num_edges, num_edge_sample_points, 3] + 'corner_wcs': corner_wcs.astype(np.float32), # 世界坐标系下的角点数据 [num_edges, 2, 3] + 'edgeFace_adj': edgeFace_adj.astype(np.int32), # 边-面的邻接关系矩阵 + 'edgeCorner_adj': edgeCorner_adj.astype(np.int32),# 边-角点的邻接关系矩阵 + 'faceEdge_adj': faceEdge_adj.astype(np.int32), # 面-边的邻接关系矩阵 + 'surf_bbox_wcs': surf_bbox_wcs.astype(np.float32),# 曲面在世界坐标系下的包围盒 + 'edge_bbox_wcs': edge_bbox_wcs.astype(np.float32),# 边在世界坐标系下的包围盒 + 'corner_unique': np.unique(corner_wcs.reshape(-1, 3), axis=0).astype(np.float32) # 去重后的唯一角点坐标 + }""" + try: + logger.info("数据预处理……") + if not os.path.exists(step_path): + logger.error(f"STEP文件不存在: {step_path}") + return None + if not step_path.lower().endswith('.step') and not step_path.lower().endswith('.stp'): + logger.error(f"文件格式不支持,必须是.step或.stp文件: {step_path}") + return None + # 解析STEP文件 + data = parse_solid(step_path) + if data is None: + logger.error(f"Failed to parse STEP file: {step_path}") + return None + + # 检查数据格式 + is_valid, msg = check_data_format(data, step_path) + if not is_valid: + logger.error(f"Data format check failed for {step_path}: {msg}") + return None + + # 保存结果 + if output_path: + try: + logger.info(f"Saving results to: {output_path}") + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'wb') as f: + pickle.dump(data, f) + logger.info("数据预处理完成") + logger.info(f"Results saved successfully: {output_path}") + return data + except Exception as e: + logger.error(f'Failed to save {output_path}: {str(e)}') + return None + logger.info("数据预处理完成") + return data + + except Exception as e: + logger.error(f'Error processing {step_path}: {str(e)}') + return None + +def test(step_file_path, output_path=None): + """ + 测试函数:转换单个STEP文件并保存结果 + """ + try: + logger.info(f"Processing STEP file: {step_file_path}") + + # 解析STEP文件 + data = parse_solid(step_file_path) + if data is None: + logger.error(f"Failed to parse STEP file: {step_file_path}") + return None + + # 检查数据格式 + is_valid, msg = check_data_format(data, step_file_path) + if not is_valid: + logger.error(f"Data format check failed for {step_file_path}: {msg}") + return None + + # 打印统计信息 + logger.info("\nStatistics:") + logger.info(f"Number of surfaces: {len(data['surf_wcs'])}") + logger.info(f"Number of edges: {len(data['edge_wcs'])}") + logger.info(f"Number of corners: {len(data['corner_wcs'])}") # 修正为corner_wcs + + # 保存结果 + if output_path: + try: + logger.info(f"Saving results to: {output_path}") + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'wb') as f: + pickle.dump(data, f) + logger.info(f"Results saved successfully: {output_path}") + except Exception as e: + logger.error(f"Failed to save {output_path}: {str(e)}") + return None + + return data + + except Exception as e: + logger.error(f"Error processing {step_file_path}: {str(e)}") + return None + + + +if __name__ == '__main__': + # main() + test("/home/wch/brep2sdf/data/step/00000000/00000000_290a9120f9f249a7a05cfe9c_step_000.step","/home/wch/brep2sdf/test_data/pkl/train/00000031xx.pkl") + #test("/home/wch/brep2sdf/00000031_ad34a3f60c4a4caa99646600_step_011.step", "/home/wch/brep2sdf/test_data/pkl/train/00000031.pkl") + #test("/mnt/mynewdisk/dataset/furniture/step/furniture_dataset_step/train/bathtub_0004.step", "/home/wch/brep2sdf/test_data/pkl/train/0004.pkl") + #reader = STEPControl_Reader() diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index c562a71..c433a29 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -10,28 +10,33 @@ from brep2sdf.utils.logger import logger import numpy as np class Encoder: - def __init__(self, surf_bbox_wcs: torch.Tensor, origin_bbox_wcs: torch.Tensor, max_depth: int, feature_dim:int = 64): + def __init__(self, surf_bbox: torch.Tensor, origin_bbox: torch.Tensor, max_depth: int, feature_dim:int = 64): """ 初始化表面八叉树管理器 参数: - surf_bbox_wcs: 表面包围盒的世界坐标,形状为 (num_edges, 6), dtype=float32 - origin_bbox_wcs: 原点包围盒的世界坐标,形状为 (6), dtype=float32 + surf_bbox: 表面包围盒的世界坐标,形状为 (num_edges, 6), dtype=float32 + origin_bbox: 原点包围盒的世界坐标,形状为 (6), dtype=float32 max_depth: 八叉树的最大深度 """ self.max_depth = max_depth - # 初始化根节点,包含所有面索引。NOTE: 实际上需要判断 aabb 包含问题,但是这里默认 origin_bbox_wcs 由这些 face 计算,所以不再重复判断 - num_faces = surf_bbox_wcs.shape[0] - print(f"surf_bbox_wcs: {surf_bbox_wcs.shape}") - print(f"origin_bbox_wcs: {origin_bbox_wcs.shape}") + # 初始化根节点,包含所有面索引。NOTE: 实际上需要判断 aabb 包含问题,但是这里默认 origin_bbox 由这些 face 计算,所以不再重复判断 + num_faces = surf_bbox.shape[0] + + #print(f"surf_bbox: {surf_bbox.shape}") + #print(f"origin_bbox: {origin_bbox.shape}") self.root = OctreeNode( - bbox=origin_bbox_wcs, + bbox=origin_bbox, face_indices=np.arange(num_faces), # 初始包含所有面 max_depth=self.max_depth, feature_dim=feature_dim, - surf_bbox_wcs=surf_bbox_wcs + surf_bbox=surf_bbox ) + #print(surf_bbox) + logger.info("starting octree conduction") self.root.conduct_tree() + logger.info("complete octree conduction") + #self.root.print_tree(0) def get_feature_vector(self, query_point): return self.root.get_feature_vector(query_point) diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index cb4f968..fd0de8a 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -7,7 +7,7 @@ class GridNet: def __init__(self, surf_wcs, edge_wcs, surf_ncs, edge_ncs, corner_wcs, corner_unique, edgeFace_adj, edgeCorner_adj, faceEdge_adj, - surf_bbox_wcs, edge_bbox_wcs): + surf_bbox, edge_bbox_wcs): """ 初始化 GridNet @@ -26,7 +26,7 @@ class GridNet: 'faceEdge_adj': np.ndarray(dtype=int32) # 形状为(num_faces, num_edges)的数组,表示面-边邻接关系 # 包围盒数据 - 'surf_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_faces, 6)的数组,表示每个面的包围盒[xmin,ymin,zmin,xmax,ymax,zmax] + 'surf_bbox': np.ndarray(dtype=float32) # 形状为(num_faces, 6)的数组,表示每个面的包围盒[xmin,ymin,zmin,xmax,ymax,zmax] 'edge_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 6)的数组,表示每条边的包围盒[xmin,ymin,zmin,xmax,ymax,zmax] """ self.surf_wcs = surf_wcs @@ -38,7 +38,7 @@ class GridNet: self.edgeFace_adj = edgeFace_adj self.edgeCorner_adj = edgeCorner_adj self.faceEdge_adj = faceEdge_adj - self.surf_bbox_wcs = surf_bbox_wcs + self.surf_bbox = surf_bbox self.edge_bbox_wcs = edge_bbox_wcs # net @@ -53,8 +53,8 @@ from .decoder import Decoder class Net(nn.Module): def __init__(self, - surf_bbox_wcs, - origin_bbox_wcs, + surf_bbox, + origin_bbox, max_depth=4, feature_dim=64, decoder_input_dim=64, @@ -68,8 +68,8 @@ class Net(nn.Module): # 初始化 Encoder self.encoder = Encoder( - surf_bbox_wcs=surf_bbox_wcs, # 使用传入的bbox作为表面包围盒 - origin_bbox_wcs=origin_bbox_wcs, # 使用相同的bbox作为原点包围盒 + surf_bbox=surf_bbox, # 使用传入的bbox作为表面包围盒 + origin_bbox=origin_bbox, # 使用相同的bbox作为原点包围盒 max_depth=max_depth, feature_dim=feature_dim ) diff --git a/brep2sdf/networks/octree.py b/brep2sdf/networks/octree.py index dccb74e..4816bac 100644 --- a/brep2sdf/networks/octree.py +++ b/brep2sdf/networks/octree.py @@ -7,18 +7,32 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np -def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> bool: - """判断两个包围盒是否相交""" - return not (bbox1[3] < bbox2[0] or bbox1[0] > bbox2[3] or - bbox1[4] < bbox2[1] or bbox1[1] > bbox2[4] or - bbox1[5] < bbox2[2] or bbox1[2] > bbox2[5]) +from brep2sdf.utils.logger import logger +def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> bool: + """判断两个轴对齐包围盒(AABB)是否相交 + + 参数: + bbox1: 形状为 (6,) 的张量,格式 [min_x, min_y, min_z, max_x, max_y, max_z] + bbox2: 同bbox1格式 + + 返回: + bool: 两包围盒是否相交(包括刚好接触的情况) + """ + assert bbox1.shape == (6,) and bbox2.shape == (6,), "输入必须是形状为(6,)的张量" + + # 提取min和max坐标 + min1, max1 = bbox1[:3], bbox1[3:] + min2, max2 = bbox2[:3], bbox2[3:] + + # 向量化比较 + return torch.all((max1 >= min2) & (max2 >= min1)) class OctreeNode: feature_dim=None device=None - surf_bbox_wcs = None - def __init__(self, bbox: torch.Tensor,face_indices: np.ndarray, max_depth: int = 5, feature_dim:int = None, surf_bbox_wcs = None): + surf_bbox = None + def __init__(self, bbox: torch.Tensor,face_indices: np.ndarray, max_depth: int = 5, feature_dim:int = None, surf_bbox:torch.Tensor = None): self.bbox = bbox # 节点的边界框 self.max_depth: int = max_depth # 最大深度,当这个为0时,表示已经到达最大深度,不可再分子节点 self.children: List['OctreeNode'] = [] # 子节点列表 @@ -29,8 +43,16 @@ class OctreeNode: if feature_dim is not None: OctreeNode.feature_dim = feature_dim - if surf_bbox_wcs is not None: - OctreeNode.surf_bbox_wcs = surf_bbox_wcs # NOTE: 只在根节点时创建 + if surf_bbox is not None: + if not isinstance(surf_bbox, torch.Tensor): + raise TypeError( + f"surf_bbox 必须是 torch.Tensor 类型,但得到 {type(surf_bbox)}" + ) + if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6: + raise ValueError( + f"surf_bbox 应为二维张量且形状为 (N,6),但得到 {surf_bbox.shape}" + ) + OctreeNode.surf_bbox = surf_bbox # NOTE: 只在根节点时创建 OctreeNode.device = bbox.device def is_leaf(self): @@ -72,10 +94,10 @@ class OctreeNode: # 找到与子包围盒相交的面 intersecting_faces = [] for face_idx in self.face_indices: - face_bbox = OctreeNode.surf_bbox_wcs[face_idx] + face_bbox = OctreeNode.surf_bbox[face_idx] if bbox_intersect(bbox, face_bbox): intersecting_faces.append(face_idx) - + #print(f"{bbox}: {intersecting_faces}") if intersecting_faces: child_node = OctreeNode( bbox=bbox, @@ -95,6 +117,8 @@ class OctreeNode: """ #print(query_point) x, y, z = query_point + #logger.info(f"query_point: {query_point}") + #logger.info(f"box: {self.bbox}") min_x, min_y, min_z, max_x, max_y, max_z = self.bbox # 计算中间点 @@ -109,7 +133,7 @@ class OctreeNode: index += 2 if z >= mid_z: # 修正变量名 index += 4 - + #logger.info(f"index: {index}") return index def get_feature_vector(self, query_point:torch.Tensor): @@ -125,7 +149,18 @@ class OctreeNode: return self.trilinear_interpolation(query_point) else: index = self.get_child_index(query_point) - return self.children[index].get_feature_vector(query_point) + try: + if index < 0 or index >= len(self.children): + raise IndexError( + f"Child index {index} out of range (0-{len(self.children)-1}) " + f"for query point {query_point.cpu().numpy().tolist()}. " + f"Node bbox: {self.bbox.cpu().numpy().tolist()}" + f"dept info: {self.max_depth}" + ) + return self.children[index].get_feature_vector(query_point) + except IndexError as e: + logger.error(str(e)) + raise e def trilinear_interpolation(self, query_point:torch.Tensor) -> torch.Tensor: """ @@ -166,4 +201,31 @@ class OctreeNode: c0 = c00 * (1 - y) + c10 * y c1 = c01 * (1 - y) + c11 * y - return c0 * (1 - z) + c1 * z \ No newline at end of file + return c0 * (1 - z) + c1 * z + + + def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None: + """ + 递归打印八叉树结构 + + 参数: + depth: 当前深度 (内部使用) + max_print_depth: 最大打印深度 (None表示打印全部) + """ + if max_print_depth is not None and depth > max_print_depth: + return + + # 打印当前节点信息 + indent = " " * depth + node_type = "Leaf" if self._is_leaf else "Internal" + print(f"{indent}L{depth} [{node_type}] BBox: {self.bbox.cpu().numpy().tolist()}") + + # 打印面片信息(如果有) + if self.face_indices is not None: + print(f"{indent} Face indices: {self.face_indices.tolist()}") + print(f"{indent} len children: {len(self.children)}") + + # 递归打印子节点 + for i, child in enumerate(self.children): + print(f"{indent} Child {i}:") + child.print_tree(depth + 1, max_print_depth) \ No newline at end of file diff --git a/brep2sdf/scripts/process_brep.py b/brep2sdf/scripts/process_brep.py index ebe8544..9c121b4 100644 --- a/brep2sdf/scripts/process_brep.py +++ b/brep2sdf/scripts/process_brep.py @@ -24,7 +24,7 @@ from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX # 拓扑类 from OCC.Core.BRep import BRep_Tool # B-rep工具 from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh # 网格剖分 from OCC.Core.TopLoc import TopLoc_Location # 位置变换 -from OCC.Core.IFSelect import IFSelect_RetDone # 操作状态码 +from OCC.Core.IFSelect import IFSelect_RetDone,IFSelect_RetError, IFSelect_RetFail, IFSelect_RetVoid # 操作状态码 from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射 from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算 from OCC.Core.Bnd import Bnd_Box # 包围盒 @@ -211,7 +211,15 @@ def parse_solid(step_path): reader = STEPControl_Reader() status = reader.ReadFile(step_path) if status != IFSelect_RetDone: - raise Exception("Failed to read STEP file") + if status == IFSelect_RetError: + print("Error: An error occurred while reading the file.") + elif status == IFSelect_RetFail: + print("Error: Failed to read the file.") + elif status == IFSelect_RetVoid: + print("Error: No data was read from the file.") + else: + print(f"Unexpected status code: {status}") + raise Exception(f"Failed to read STEP file. {status}") reader.TransferRoots() shape = reader.OneShape() @@ -385,8 +393,14 @@ def check_data_format(data, step_file): return True, "" def process_single_step(step_path:str, output_path:str=None, timeout:int=300) -> dict: - """处理单个STEP文件""" + """处理单个STEP文件, 从 brep 2 pkl""" try: + if not os.path.exists(step_path): + logger.error(f"STEP文件不存在: {step_path}") + return None + if not step_path.lower().endswith('.step') and not step_path.lower().endswith('.stp'): + logger.error(f"文件格式不支持,必须是.step或.stp文件: {step_path}") + return None # 解析STEP文件 data = parse_solid(step_path) if data is None: @@ -596,5 +610,8 @@ def main(): logger.warning("No files were processed") if __name__ == '__main__': - main() - #test("/mnt/disk2/dataset/furniture/step/furniture_dataset_step/train/bathtub_0004.step", "/home/wch/brep2sdf/test_data/pkl/train/bathtub_0004.pkl") + # main() + test("/home/wch/brep2sdf/data/step/00000000/00000000_290a9120f9f249a7a05cfe9c_step_000.step","/home/wch/brep2sdf/test_data/pkl/train/00000031xx.pkl") + #test("/home/wch/brep2sdf/00000031_ad34a3f60c4a4caa99646600_step_011.step", "/home/wch/brep2sdf/test_data/pkl/train/00000031.pkl") + #test("/mnt/mynewdisk/dataset/furniture/step/furniture_dataset_step/train/bathtub_0004.step", "/home/wch/brep2sdf/test_data/pkl/train/0004.pkl") + #reader = STEPControl_Reader() diff --git a/brep2sdf/scripts/process_furniture.py b/brep2sdf/scripts/process_furniture.py index 6a4cab4..1fb3c1b 100644 --- a/brep2sdf/scripts/process_furniture.py +++ b/brep2sdf/scripts/process_furniture.py @@ -301,10 +301,12 @@ def main(): success_rate = (valid_conversions / total_files) * 100 # 这个变量在日志中被使用但未定义 logger.info(f"处理完成: {set_name} 集合, 成功率: {success_rate:.2f}% = {valid_conversions}/{total_files}个") - +def test(step_file: str, set_name:str): + process(step_file, set_name) if __name__ == "__main__": - main() + # main() + test("/home/wch/brep2sdf/00000031.step","train") diff --git a/brep2sdf/train.py b/brep2sdf/train.py index c0dabc0..609eb1b 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -2,26 +2,70 @@ import torch import torch.optim as optim import time import os +import numpy as np +import argparse from brep2sdf.config.default_config import get_default_config -from brep2sdf.data.data import load_brep_file,load_sdf_file +from brep2sdf.data.data import load_brep_file,load_sdf_file +from brep2sdf.data.pre_process import process_single_step from brep2sdf.networks.network import Net from brep2sdf.utils.logger import logger +def prepare_sdf_data(surf_data, max_points=100000, device='cuda'): + total_points = sum(len(s) for s in surf_data) + + # 降采样逻辑(修复版) + if total_points > max_points: + # 先随机打乱所有点 + all_points = np.concatenate(surf_data) + np.random.shuffle(all_points) + # 直接取前max_points个点 + sampled_points = all_points[:max_points] + sdf_array = np.zeros((max_points, 4), dtype=np.float32) + sdf_array[:, :3] = sampled_points + else: + sdf_array = np.zeros((total_points, 4), dtype=np.float32) + sdf_array[:, :3] = np.concatenate(surf_data) + + return torch.tensor(sdf_array, dtype=torch.float32, device=device) + + class Trainer: - def __init__(self, config): + def __init__(self, config, input_step): self.config = config self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - + + self.model_name = os.path.basename(input_step).replace(".step", "") + self.base_name = self.model_name + ".xyz" + data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name) + if os.path.exists(data_path): + self.data = load_brep_file(data_path) + else: + self.data = process_single_step(step_path=input_step, output_path=data_path) + + # 将曲面点云列表转换为 (N*M, 4) 数组 + surfs = self.data["surf_ncs"] + self.sdf_data = prepare_sdf_data( + surfs, + max_points=4096, + device=self.device + ) # 初始化数据集 - self.brep_data = load_brep_file(self.config.data.pkl_path) - self.sdf_data = load_sdf_file(sdf_path=self.config.data.sdf_path, num_query_points=self.config.data.num_query_points).to(self.device) + #self.brep_data = load_brep_file(self.config.data.pkl_path) + #logger.info( self.brep_data ) + #self.sdf_data = load_sdf_file(sdf_path=self.config.data.sdf_path, num_query_points=self.config.data.num_query_points).to(self.device) # 初始化网络 - bbox = self._calculate_global_bbox() + + surf_bbox=torch.tensor( + self.data['surf_bbox_ncs'], + dtype=torch.float32, + device=self.device + ) + bbox = self._calculate_global_bbox(surf_bbox) self.model = Net( - surf_bbox_wcs=self.brep_data['surf_bbox_wcs'], - origin_bbox_wcs=bbox, + surf_bbox=surf_bbox, + origin_bbox=bbox, feature_dim=64 ).to(self.device) @@ -34,28 +78,30 @@ class Trainer: - def _calculate_global_bbox(self) -> torch.Tensor: + def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor: """ - 计算整个数据集的全局边界框 + 计算整个数据集的全局边界框,综合考虑表面包围盒和采样点 + + 参数: + surf_bbox: 形状为 (num_edges, 6) 的Tensor,表示每条边的包围盒 + [xmin, ymin, zmin, xmax, ymax, zmax] 返回: - bbox_tensor: 形状为(6,)的Tensor,格式为[x_min, y_min, z_min, x_max, y_max, z_max] + 形状为 (6,) 的Tensor,格式为 [x_min, y_min, z_min, x_max, y_max, z_max] """ - # 获取所有点的坐标 - points = self.sdf_data[:, 0:3] # 假设sdf_data的前三列是点的坐标 - - # 计算最小点和最大点 - min_point = torch.min(points, dim=0).values - max_point = torch.max(points, dim=0).values - - # 确保在正确设备上 - min_point = min_point.to(self.device) - max_point = max_point.to(self.device) - - # 将最小点和最大点合并成一个(6,)的Tensor - bbox_tensor = torch.cat([min_point, max_point], dim=0) - #print(f"bbox_tensor shape: {bbox_tensor.shape}") - return bbox_tensor + # 验证输入 + if not isinstance(surf_bbox, torch.Tensor): + raise TypeError(f"surf_bbox 必须是 torch.Tensor,但得到 {type(surf_bbox)}") + if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6: + raise ValueError(f"surf_bbox 形状应为 (num_edges, 6),但得到 {surf_bbox.shape}") + + # 计算表面包围盒的全局范围 + global_min = surf_bbox[:, :3].min(dim=0).values + global_max = surf_bbox[:, 3:].max(dim=0).values + + + # 返回合并后的边界框 + return torch.cat([global_min, global_max]) def train_epoch(self, epoch: int) -> float: self.model.train() @@ -154,13 +200,12 @@ class Trainer: def _save_checkpoint(self, epoch: int, train_loss: float): """保存训练检查点""" - checkpoint_path = os.path.join( + checkpoint_dir = os.path.join( self.config.train.checkpoint_dir, - self.config.train.checkpoint_format.format( - model_name=self.config.train.model_name, - epoch=epoch - ) + self.model_name ) + os.makedirs(checkpoint_dir, exist_ok=True) + checkpoint_path = os.path.join(checkpoint_dir,f"epoch_{epoch:03d}.pth") torch.save({ 'epoch': epoch, 'model_state_dict': self.model.state_dict(), @@ -171,10 +216,16 @@ class Trainer: def main(): # 这里需要初始化配置 + # 配置命令行参数 + parser = argparse.ArgumentParser(description='STEP文件批量处理工具') + parser.add_argument('-i', '--input', required=True, + help='待处理 brep (.step) 路径') + + args = parser.parse_args() config = get_default_config() # 初始化训练器并开始训练 - trainer = Trainer(config) + trainer = Trainer(config, input_step=args.input) trainer.train() if __name__ == '__main__':