From cc88dfb798ec6914e84ecf0044dac196cceebc01 Mon Sep 17 00:00:00 2001 From: mckay Date: Sun, 6 Apr 2025 20:53:14 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81normal?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/data/pre_process_by_mesh.py | 621 +++++++++++++++++++++++++++ brep2sdf/networks/loss.py | 254 +++++++++-- brep2sdf/networks/network.py | 11 + brep2sdf/train.py | 97 +++-- 4 files changed, 928 insertions(+), 55 deletions(-) create mode 100644 brep2sdf/data/pre_process_by_mesh.py diff --git a/brep2sdf/data/pre_process_by_mesh.py b/brep2sdf/data/pre_process_by_mesh.py new file mode 100644 index 0000000..fcae98b --- /dev/null +++ b/brep2sdf/data/pre_process_by_mesh.py @@ -0,0 +1,621 @@ +""" +CAD模型处理脚本 +功能:将STEP格式的CAD模型转换为结构化数据,包括: +- 几何信息:面、边、顶点的坐标数据 +- 拓扑信息:面-边-顶点的邻接关系 +- 空间信息:包围盒数据 +""" + +import os +import pickle # 用于数据序列化 +import argparse # 命令行参数解析 +import numpy as np +from tqdm import tqdm # 进度条显示 +from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError # 并行处理 +import logging +from datetime import datetime +from scipy.spatial import cKDTree +from brep2sdf.utils.logger import logger +import tempfile +import trimesh +# 导入OpenCASCADE相关库 +from OCC.Core.STEPControl import STEPControl_Reader # STEP文件读取器 +from OCC.Core.TopExp import TopExp_Explorer, topexp # 拓扑结构遍历 +from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX # 拓扑类型定义 +from OCC.Core.BRep import BRep_Tool # B-rep工具 +from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh # 网格剖分 +from OCC.Core.TopLoc import TopLoc_Location # 位置变换 +from OCC.Core.IFSelect import IFSelect_RetDone,IFSelect_RetError, IFSelect_RetFail, IFSelect_RetVoid # 操作状态码 +from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射 +from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算 +from OCC.Core.Bnd import Bnd_Box # 包围盒 +from OCC.Core.TopoDS import TopoDS_Shape, topods, TopoDS_Vertex # 拓扑数据结构 +from OCC.Core.StlAPI import StlAPI_Writer + +# 导入配置 +from brep2sdf.config.default_config import get_default_config +config = get_default_config() + +# 设置最大面数阈值,用于加速处理 +MAX_FACE = config.data.max_face + +def normalize(surfs, edges, corners): + """ + 将CAD模型归一化到单位立方体空间 + + 参数: + surfs: 面的点集列表 + edges: 边的点集列表 + corners: 顶点坐标数组 [num_edges, 2, 3] + + 返回: + surfs_wcs: 原始坐标系下的面点集 + edges_wcs: 原始坐标系下的边点集 + surfs_ncs: 归一化坐标系下的面点集 + edges_ncs: 归一化坐标系下的边点集 + corner_wcs: 归一化后的顶点坐标 [num_edges, 2, 3] + center: 使用的中心点坐标 [3,] + scale: 使用的缩放系数 (float) + """ + if len(corners) == 0: + return None, None, None, None, None, None, None + + # 计算包围盒和缩放因子 + corners_array = corners.reshape(-1, 3) # [num_edges*2, 3] + center = (corners_array.max(0) + corners_array.min(0)) / 2 # 计算中心点 + scale = 1.0 / (corners_array.max(0) - corners_array.min(0)).max() # 计算缩放系数 + + # 归一化面的点集 + surfs_wcs = [] # 原始世界坐标系下的面点集 + surfs_ncs = [] # 归一化坐标系下的面点集 + for surf in surfs: + surf_wcs = np.array(surf) + surf_ncs = (surf_wcs - center) * scale # 归一化变换 + surfs_wcs.append(surf_wcs) + surfs_ncs.append(surf_ncs) + + # 归一化边的点集 + edges_wcs = [] # 原始世界坐标系下的边点集 + edges_ncs = [] # 归一化坐标系下的边点集 + for edge in edges: + edge_wcs = np.array(edge) + edge_ncs = (edge_wcs - center) * scale # 归一化变换 + edges_wcs.append(edge_wcs) + edges_ncs.append(edge_ncs) + + # 归一化顶点坐标 - 保持[num_edges, 2, 3]的形状 + corner_wcs = (corners - center) * scale # 广播操作会保持原有维度 + + return (np.array(surfs_wcs, dtype=object), + np.array(edges_wcs, dtype=object), + np.array(surfs_ncs, dtype=object), + np.array(edges_ncs, dtype=object), + corner_wcs.astype(np.float32), + center.astype(np.float32), + scale + ) + +def get_adjacency_info(shape, faces, edges, vertices): + """ + 优化后的邻接关系计算函数,直接使用已收集的几何元素 + + 参数新增: + faces: 已收集的面列表 + edges: 已收集的边列表 + vertices: 已收集的顶点列表 + """ + logger.debug("Get adjacency infos...") + # 创建边-面映射关系 + edge_face_map = TopTools_IndexedDataMapOfShapeListOfShape() + topexp.MapShapesAndAncestors(shape, TopAbs_EDGE, TopAbs_FACE, edge_face_map) + + # 直接使用传入的几何元素列表 + num_faces = len(faces) + num_edges = len(edges) + num_vertices = len(vertices) + logger.debug(f"num_faces: {num_faces}, num_edges: {num_edges}, num_vertices: {num_vertices}") + + edgeFace_adj = np.zeros((num_edges, num_faces), dtype=np.int32) + faceEdge_adj = np.zeros((num_faces, num_edges), dtype=np.int32) + edgeCorner_adj = np.zeros((num_edges, 2), dtype=np.int32) + + # 填充边-面邻接矩阵 + for i, edge in enumerate(edges): + # 检查每个面是否与当前边相连 + for j, face in enumerate(faces): + edge_explorer = TopExp_Explorer(face, TopAbs_EDGE) + while edge_explorer.More(): + if edge.IsSame(edge_explorer.Current()): + edgeFace_adj[i, j] = 1 + faceEdge_adj[j, i] = 1 + break + edge_explorer.Next() + + # 获取边的两个端点 + v1 = TopoDS_Vertex() + v2 = TopoDS_Vertex() + topexp.Vertices(edge, v1, v2) + + # 记录边的端点索引 + if not v1.IsNull() and not v2.IsNull(): + v1_vertex = topods.Vertex(v1) + v2_vertex = topods.Vertex(v2) + + for k, vertex in enumerate(vertices): + if v1_vertex.IsSame(vertex): + edgeCorner_adj[i, 0] = k + if v2_vertex.IsSame(vertex): + edgeCorner_adj[i, 1] = k + + return edgeFace_adj, faceEdge_adj, edgeCorner_adj + +def get_bbox(shape, subshape): + """ + 计算形状的包围盒 + + 参数: + shape: 完整的CAD模型形状 + subshape: 需要计算包围盒的子形状(面或边) + + 返回: + 包围盒的六个参数 [xmin, ymin, zmin, xmax, ymax, zmax] + """ + bbox = Bnd_Box() + brepbndlib.Add(subshape, bbox) + xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get() + return np.array([xmin, ymin, zmin, xmax, ymax, zmax]) + + + +def parse_solid(step_path,sample_normal_vector=False): + """ + 解析STEP文件中的CAD模型数据 + + 返回: + dict: 包含以下键值对的字典: + # 几何数据 + 'surf_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示面的点云坐标 + 'edge_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示边的采样点坐标 + 'surf_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示归一化后的面点云 + 'edge_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示归一化后的边采样点 + 'corner_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 2, 3)的数组,表示每条边的两个端点坐标 + 'corner_unique': np.ndarray(dtype=float32) # 形状为(num_vertices, 3)的数组,表示所有顶点的唯一坐标,num_vertices <= num_edges * 2 + + # 拓扑关系 + 'edgeFace_adj': np.ndarray(dtype=int32) # 形状为(num_edges, num_faces)的数组,表示边-面邻接关系 + 'edgeCorner_adj': np.ndarray(dtype=int32) # 形状为(num_edges, 2)的数组,表示边-顶点邻接关系 + 'faceEdge_adj': np.ndarray(dtype=int32) # 形状为(num_faces, num_edges)的数组,表示面-边邻接关系 + + # 包围盒数据 + 'surf_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_faces, 6)的数组,表示每个面的包围盒[xmin,ymin,zmin,xmax,ymax,zmax] + 'edge_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 6)的数组,表示每条边的包围盒[xmin,ymin,zmin,xmax,ymax,zmax] + """ + # Load STEP file + reader = STEPControl_Reader() + status = reader.ReadFile(step_path) + if status != IFSelect_RetDone: + if status == IFSelect_RetError: + print("Error: An error occurred while reading the file.") + elif status == IFSelect_RetFail: + print("Error: Failed to read the file.") + elif status == IFSelect_RetVoid: + print("Error: No data was read from the file.") + else: + print(f"Unexpected status code: {status}") + raise Exception(f"Failed to read STEP file. {status}") + + reader.TransferRoots() + shape = reader.OneShape() + + # Create mesh + mesh = BRepMesh_IncrementalMesh(shape, 0.01) + mesh.Perform() + + # Initialize explorers + face_explorer = TopExp_Explorer(shape, TopAbs_FACE) + edge_explorer = TopExp_Explorer(shape, TopAbs_EDGE) + vertex_explorer = TopExp_Explorer(shape, TopAbs_VERTEX) + + face_pnts = [] + edge_pnts = [] + corner_pnts = [] + surf_bbox_wcs = [] + edge_bbox_wcs = [] + + faces, edges, vertices = [], [], [] + + # Extract face points + logger.debug("Extract face points...") + while face_explorer.More(): + face = topods.Face(face_explorer.Current()) + faces.append(face) + loc = TopLoc_Location() + triangulation = BRep_Tool.Triangulation(face, loc) + + if triangulation is not None: + points = [] + for i in range(1, triangulation.NbNodes() + 1): + node = triangulation.Node(i) + pnt = node.Transformed(loc.Transformation()) + points.append([pnt.X(), pnt.Y(), pnt.Z()]) + + if points: + points = np.array(points, dtype=np.float32) + if len(points.shape) == 2 and points.shape[1] == 3: + # 确保每个面至少有一些点 + if len(points) < 3: # 如果点数太少,跳过这个面 + continue + face_pnts.append(points) + surf_bbox_wcs.append(get_bbox(shape, face)) + + face_explorer.Next() + face_count = len(faces) + if face_count > MAX_FACE: + logger.error(f"step has {face_count} faces, which exceeds MAX_FACE {MAX_FACE}") + return None + + # Extract edge points + logger.debug("Extract edge points...") + num_samples = config.model.num_edge_points # 使用配置中的边采样点数 + while edge_explorer.More(): + edge = topods.Edge(edge_explorer.Current()) + edges.append(edge) + logger.debug(len(edges)) + curve_info = BRep_Tool.Curve(edge) + if curve_info is None: + continue # 跳过无效边 + + try: + if len(curve_info) == 3: + curve, first, last = curve_info + elif len(curve_info) == 2: + curve = None # 跳过判断 + else: + raise ValueError(f"Unexpected curve info: {curve_info}") + except Exception as e: + logger.error(f"Failed to process edge {edge}: {str(e)}") + curve = None + + if curve is not None: + points = [] + for i in range(num_samples): + param = first + (last - first) * float(i) / (num_samples - 1) + pnt = curve.Value(param) + points.append([pnt.X(), pnt.Y(), pnt.Z()]) + + if points: + points = np.array(points, dtype=np.float32) + if len(points.shape) == 2 and points.shape[1] == 3: + edge_pnts.append(points) # 现在points是(num_edge_points, 3)形状 + edge_bbox_wcs.append(get_bbox(shape, edge)) + + edge_explorer.Next() + + # Extract vertex points + logger.debug("Extract vertex points...") + while vertex_explorer.More(): + vertex = topods.Vertex(vertex_explorer.Current()) + vertices.append(vertex) + pnt = BRep_Tool.Pnt(vertex) + corner_pnts.append([pnt.X(), pnt.Y(), pnt.Z()]) + vertex_explorer.Next() + + # 获取邻接信息 + edgeFace_adj, faceEdge_adj, edgeCorner_adj = get_adjacency_info( + shape, + faces=faces, # 传入已收集的面列表 + edges=edges, # 传入已收集的边列表 + vertices=vertices # 传入已收集的顶点列表 + ) + logger.debug("complete.") + + # 转换为numpy数组时确保类型正确 + face_pnts = [np.array(points, dtype=np.float32) for points in face_pnts] + edge_pnts = [np.array(points, dtype=np.float32) for points in edge_pnts] + + # 转换为对象数组 + face_pnts = np.array(face_pnts, dtype=object) + edge_pnts = np.array(edge_pnts, dtype=object) + corner_pnts = np.array(corner_pnts, dtype=np.float32) + + # 重组顶点数据为每条边两个端点的形式 + corner_pairs = [] + for edge_idx in range(len(edge_pnts)): + v1_idx, v2_idx = edgeCorner_adj[edge_idx] + v1_pos = corner_pnts[v1_idx] + v2_pos = corner_pnts[v2_idx] + # 按坐标排序确保顺序一致 + if (v1_pos > v2_pos).any(): + v1_pos, v2_pos = v2_pos, v1_pos + corner_pairs.append(np.stack([v1_pos, v2_pos])) + + corner_pairs = np.stack(corner_pairs).astype(np.float32) # [num_edges, 2, 3] + + # 确保所有数组都有正确的类型 + surf_bbox_wcs = np.array(surf_bbox_wcs, dtype=np.float32) + edge_bbox_wcs = np.array(edge_bbox_wcs, dtype=np.float32) + + # Normalize the CAD model + surfs_wcs, edges_wcs, surfs_ncs, edges_ncs, corner_wcs,center, scale = normalize( + face_pnts, edge_pnts, corner_pairs) + + # 计算归一化后的包围盒 + surf_bbox_ncs = np.empty_like(surf_bbox_wcs) + edge_bbox_ncs = np.empty_like(edge_bbox_wcs) + + # 转换曲面包围盒到归一化坐标系 + surf_bbox_ncs[:, :3] = (surf_bbox_wcs[:, :3] - center) * scale # 最小点 + surf_bbox_ncs[:, 3:] = (surf_bbox_wcs[:, 3:] - center) * scale # 最大点 + + # 转换边包围盒到归一化坐标系 + edge_bbox_ncs[:, :3] = (edge_bbox_wcs[:, :3] - center) * scale # 最小点 + edge_bbox_ncs[:, 3:] = (edge_bbox_wcs[:, 3:] - center) * scale # 最大点 + + + # 验证归一化后的数据 + if any(x is None for x in [surfs_wcs, edges_wcs, surfs_ncs, edges_ncs, corner_wcs]): + logger.error(f"Normalization failed for {step_path}") + return None + + # 创建结果字典并确保所有数组都有正确的类型 + data = { + 'surf_wcs': np.array(surfs_wcs, dtype=object), # 保持对象数组 + 'edge_wcs': np.array(edges_wcs, dtype=object), # 保持对象数组 + 'surf_ncs': np.array(surfs_ncs, dtype=object), # 保持对象数组 + 'edge_ncs': np.array(edges_ncs, dtype=object), # 保持对象数组 + 'corner_wcs': corner_wcs.astype(np.float32), # [num_edges, 2, 3] + 'edgeFace_adj': edgeFace_adj.astype(np.int32), + 'edgeCorner_adj': edgeCorner_adj.astype(np.int32), + 'faceEdge_adj': faceEdge_adj.astype(np.int32), + 'surf_bbox_wcs': surf_bbox_wcs.astype(np.float32), + 'edge_bbox_wcs': edge_bbox_wcs.astype(np.float32), + 'surf_bbox_ncs': surf_bbox_ncs.astype(np.float32), # 归一化坐标系 [num_faces, 6] + 'edge_bbox_ncs': edge_bbox_ncs.astype(np.float32), # 归一化坐标系 [num_edges, 6] + 'corner_unique': np.unique(corner_wcs.reshape(-1, 3), axis=0).astype(np.float32), # 先展平再去重 + 'normalization_params': { + 'center': center.astype(np.float32), # 归一化中心点 [3,] + 'scale': float(scale), # 归一化缩放系数 + } + } + + if sample_normal_vector: + # 从 mesh 读 法向量 + mesh.Perform() + # 导出为STL临时文件 + stl_writer = StlAPI_Writer() + stl_writer.SetASCIIMode(False) + with tempfile.NamedTemporaryFile(suffix='.stl') as tmp: + stl_writer.Write(shape, tmp.name) + trimesh_mesh = trimesh.load(tmp.name) + data['surf_pnt_normals']= batch_compute_normals(trimesh_mesh,surfs_wcs) + + return data + +def load_step(step_path): + """Load STEP file and return solids""" + reader = STEPControl_Reader() + reader.ReadFile(step_path) + reader.TransferRoots() + return [reader.OneShape()] + + +def preprocess_mesh(mesh, normal_type='vertex'): + """ + 预处理网格数据,生成 KDTree 和法向量源。 + + 参数: + mesh: trimesh.Trimesh 对象,包含顶点和法向量信息 + normal_type: str 法向量类型,可选 'vertex' 或 'face' + + 返回: + tree: cKDTree 用于加速最近邻查询 + normals_source: np.ndarray 包含顶点法向量或面法向量 + """ + if normal_type == 'vertex': + tree = cKDTree(mesh.vertices) + normals_source = mesh.vertex_normals + elif normal_type == 'face': + # 计算每个面的中心点 + face_centers = np.mean(mesh.vertices[mesh.faces], axis=1) + tree = cKDTree(face_centers) + normals_source = mesh.face_normals + else: + raise ValueError(f"Unsupported normal type: {normal_type}") + + return tree, normals_source + + +def batch_compute_normals(mesh, surf_wcs, normal_type='vertex', k_neighbors=3): + """ + 为嵌套点云数据计算法向量,并保持嵌套格式。 + + 参数: + mesh: trimesh.Trimesh 对象,包含顶点和法向量信息 + surf_wcs: np.ndarray(dtype=object) 形状为 (N,) 的数组,每个元素是形状为 (M, 3) 的 float32 数组 + normal_type: str 法向量类型,可选 'vertex' 或 'face' + k_neighbors: int 用于平滑的最近邻数量 + + 返回: + normals: np.ndarray(dtype=object) 形状为 (N,) 的数组,每个元素是形状为 (M, 3) 的 float32 数组 + """ + # 预处理网格数据 + tree, normals_source = preprocess_mesh(mesh, normal_type=normal_type) + + # 展平所有点云为一个二维数组 [P, 3],并记录分割索引 + lengths = [len(point_cloud) for point_cloud in surf_wcs] + query_points = np.concatenate(surf_wcs, axis=0).astype(np.float32) # 避免多次内存分配 + + # 批量查询最近邻 + distances, indices = tree.query(query_points, k=k_neighbors) + + # 处理k=1的特殊情况 + if k_neighbors == 1: + nearest_normals = normals_source[indices] + else: + # 加权平均(权重为距离倒数) + inv_distances = 1 / (distances + 1e-8) # 防止除以零 + sum_inv_distances = inv_distances.sum(axis=1, keepdims=True) + valid_mask = sum_inv_distances > 1e-6 + weights = np.divide(inv_distances, sum_inv_distances, out=np.zeros_like(inv_distances), where=valid_mask) + nearest_normals = np.einsum('ijk,ij->ik', normals_source[indices], weights) + + # 标准化结果 + norms = np.linalg.norm(nearest_normals, axis=1) + valid_mask = norms > 1e-6 + nearest_normals[valid_mask] /= norms[valid_mask, None] + + # 按原始嵌套结构分割法向量 + start = 0 + normals_output = np.empty(len(surf_wcs), dtype=object) + for i, length in enumerate(lengths): + end = start + length + normals_output[i] = nearest_normals[start:end] + start = end + + return normals_output + +def check_data_format(data, step_file): + """检查数据格式是否正确""" + required_keys = [ + 'surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs', 'corner_wcs', + 'edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj', + 'surf_bbox_wcs', 'edge_bbox_wcs', 'corner_unique' + ] + + # 检查所有必需的键是否存在 + for key in required_keys: + if key not in data: + return False, f"Missing key: {key}" + + # 检查几何数据 + geometry_arrays = ['surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs'] + for key in geometry_arrays: + if not isinstance(data[key], np.ndarray): + return False, f"{key} should be a numpy array" + # 允许对象数组 + if data[key].dtype != object: + return False, f"{key} should be a numpy array with dtype=object" + + # 检查其他数组 + float32_arrays = ['corner_wcs', 'surf_bbox_wcs', 'edge_bbox_wcs', 'corner_unique'] + for key in float32_arrays: + if not isinstance(data[key], np.ndarray): + return False, f"{key} should be a numpy array" + if data[key].dtype != np.float32: + return False, f"{key} should be a numpy array with dtype=float32" + + int32_arrays = ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj'] + for key in int32_arrays: + if not isinstance(data[key], np.ndarray): + return False, f"{key} should be a numpy array" + if data[key].dtype != np.int32: + return False, f"{key} should be a numpy array with dtype=int32" + + return True, "" + +def process_single_step(step_path:str, output_path:str=None, sample_normal_vector=False, timeout:int=300) -> dict: + """处理单个STEP文件, 从 brep 2 pkl + return data = { + 'surf_wcs': np.array(surfs_wcs, dtype=object), # 世界坐标系下的曲面几何数据(对象数组) + 'edge_wcs': np.array(edges_wcs, dtype=object), # 世界坐标系下的边几何数据(对象数组) + 'surf_ncs': np.array(surfs_ncs, dtype=object), # 归一化坐标系下的曲面几何数据(对象数组) 面归一化点云 [num_faces, num_surf_sample_points, 3] + 'edge_ncs': np.array(edges_ncs, dtype=object), # 归一化坐标系下的边几何数据(对象数组) 边归一化点云 [num_edges, num_edge_sample_points, 3] + 'corner_wcs': corner_wcs.astype(np.float32), # 世界坐标系下的角点数据 [num_edges, 2, 3] + 'edgeFace_adj': edgeFace_adj.astype(np.int32), # 边-面的邻接关系矩阵 + 'edgeCorner_adj': edgeCorner_adj.astype(np.int32),# 边-角点的邻接关系矩阵 + 'faceEdge_adj': faceEdge_adj.astype(np.int32), # 面-边的邻接关系矩阵 + 'surf_bbox_wcs': surf_bbox_wcs.astype(np.float32),# 曲面在世界坐标系下的包围盒 + 'edge_bbox_wcs': edge_bbox_wcs.astype(np.float32),# 边在世界坐标系下的包围盒 + 'corner_unique': np.unique(corner_wcs.reshape(-1, 3), axis=0).astype(np.float32) # 去重后的唯一角点坐标 + }""" + try: + logger.info("数据预处理……") + if not os.path.exists(step_path): + logger.error(f"STEP文件不存在: {step_path}") + return None + if not step_path.lower().endswith('.step') and not step_path.lower().endswith('.stp'): + logger.error(f"文件格式不支持,必须是.step或.stp文件: {step_path}") + return None + # 解析STEP文件 + data = parse_solid(step_path, sample_normal_vector) + if data is None: + logger.error(f"Failed to parse STEP file: {step_path}") + return None + + # 检查数据格式 + is_valid, msg = check_data_format(data, step_path) + if not is_valid: + logger.error(f"Data format check failed for {step_path}: {msg}") + return None + + # 保存结果 + if output_path: + try: + logger.info(f"Saving results to: {output_path}") + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'wb') as f: + pickle.dump(data, f) + logger.info("数据预处理完成") + logger.info(f"Results saved successfully: {output_path}") + return data + except Exception as e: + logger.error(f'Failed to save {output_path}: {str(e)}') + return None + logger.info("数据预处理完成") + return data + + except Exception as e: + logger.error(f'Error processing {step_path}: {str(e)}') + return None + +def test(step_file_path, output_path=None): + """ + 测试函数:转换单个STEP文件并保存结果 + """ + try: + logger.info(f"Processing STEP file: {step_file_path}") + + # 解析STEP文件 + data = parse_solid(step_file_path) + if data is None: + logger.error(f"Failed to parse STEP file: {step_file_path}") + return None + + # 检查数据格式 + is_valid, msg = check_data_format(data, step_file_path) + if not is_valid: + logger.error(f"Data format check failed for {step_file_path}: {msg}") + return None + + # 打印统计信息 + logger.info("\nStatistics:") + logger.info(f"Number of surfaces: {len(data['surf_wcs'])}") + logger.info(f"Number of edges: {len(data['edge_wcs'])}") + logger.info(f"Number of corners: {len(data['corner_wcs'])}") # 修正为corner_wcs + + # 保存结果 + if output_path: + try: + logger.info(f"Saving results to: {output_path}") + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'wb') as f: + pickle.dump(data, f) + logger.info(f"Results saved successfully: {output_path}") + except Exception as e: + logger.error(f"Failed to save {output_path}: {str(e)}") + return None + + return data + + except Exception as e: + logger.error(f"Error processing {step_file_path}: {str(e)}") + return None + + + +if __name__ == '__main__': + # main() + test("/home/wch/brep2sdf/data/step/00000000/00000000_290a9120f9f249a7a05cfe9c_step_000.step","/home/wch/brep2sdf/test_data/pkl/train/00000031xx.pkl") + #test("/home/wch/brep2sdf/00000031_ad34a3f60c4a4caa99646600_step_011.step", "/home/wch/brep2sdf/test_data/pkl/train/00000031.pkl") + #test("/mnt/mynewdisk/dataset/furniture/step/furniture_dataset_step/train/bathtub_0004.step", "/home/wch/brep2sdf/test_data/pkl/train/0004.pkl") + #reader = STEPControl_Reader() diff --git a/brep2sdf/networks/loss.py b/brep2sdf/networks/loss.py index 034ec62..7b4b394 100644 --- a/brep2sdf/networks/loss.py +++ b/brep2sdf/networks/loss.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn - +from .network import gradient from brep2sdf.config.default_config import get_default_config from brep2sdf.utils.logger import logger @@ -66,35 +66,229 @@ class Brep2SDFLoss(nn.Module): return grad_loss -def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): - """SDF损失函数""" - # 确保points需要梯度 - if not points.requires_grad: - points = points.detach().requires_grad_(True) - - # L1损失 - l1_loss = F.l1_loss(pred_sdf, gt_sdf) + + + + + + + +class LossManager: + def __init__(self, ablation, **condition_kwargs): + self.weights = { + "manifold": 1, + "feature_manifold": 1, # 原文里面和manifold的权重是一样的 + "normals": 1, + "eikonal": 1, + "offsurface": 1, + "consistency": 1, + "correction": 1, + } + self.condition_kwargs = condition_kwargs + self.ablation = ablation # 消融实验用 + + def _get_condition_kwargs(self, key): + """ + 获取条件参数, 期望 + ab: 损失类型 【overall, patch, off, cons, cc, cor,】 + siren: 是否使用SIREN + epoch: 当前epoch + baseline: 是否为baseline + """ + if key in self.condition_kwargs: + return self.condition_kwargs[key] + else: + raise ValueError(f"Key {key} not found in condition_kwargs") + + + def pre_process(self, mnfld_pnts, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last): + """ + 预处理 + """ + mnfld_pred_h = mnfld_pred_all[:,0] # 提取流形预测结果 + nonmnfld_pred_h = nonmnfld_pred_all[:,0] # 提取非流形预测结果 + mnfld_grad = gradient(mnfld_pnts, mnfld_pred_h) # 计算流形点的梯度 + + all_fi = torch.zeros([n_batchsize, 1], device = 'cuda') # 初始化所有流形预测值 + for i in range(n_branch - 1): + all_fi[i * n_patch_batch : (i + 1) * n_patch_batch, 0] = mnfld_pred_all[i * n_patch_batch : (i + 1) * n_patch_batch, i + 1] # 填充流形预测值 + # last patch + all_fi[(n_branch - 1) * n_patch_batch:, 0] = mnfld_pred_all[(n_branch - 1) * n_patch_batch:, n_branch] # 填充最后一个分支的流形预测值 + + return mnfld_pred_h, nonmnfld_pred_h, mnfld_grad, all_fi + + def position_loss(self, pred_sdfs: torch.Tensor, gt_sdfs: torch.Tensor) -> torch.Tensor: + """ + 计算流型损失的逻辑 + + :param pred_sdfs: 预测的SDF值,形状为 (N, 1) + :param gt_sdfs: 真实的SDF值,形状为 (N, 1) + :return: 计算得到的流型损失,标量 + """ + # 计算预测值与真实值的差 + diff = pred_sdfs - gt_sdfs + + # 计算平方误差 + squared_diff = torch.pow(diff, 2) + + # 计算均值 + manifold_loss = torch.mean(squared_diff) + + return manifold_loss - try: - # 梯度约束损失 - grad = torch.autograd.grad( - pred_sdf.sum(), - points, - create_graph=True, - retain_graph=True, - allow_unused=True - )[0] + def normals_loss(self, normals: torch.Tensor, mnfld_pnts: torch.Tensor, pred_sdfs) -> torch.Tensor: + """ + 计算法线损失 + + :param normals: 法线 + :param mnfld_pnts: 流型点 + :param all_fi: 所有流型预测值 + :param patch_sup: 是否支持补丁 + :return: 计算得到的法线损失 + """ + # NOTE 源代码 这里还有复杂逻辑 + # 计算分支梯度 + branch_grad = gradient(mnfld_pnts, pred_sdfs) # 计算分支梯度 + + # 计算法线损失 + normals_loss = (((branch_grad - normals).abs()).norm(2, dim=1)).mean() # 计算法线损失 + + return normals_loss # 返回加权后的法线损失 + + def eikonal_loss(self, nonmnfld_pnts, nonmnfld_pred): + """ + 计算Eikonal损失 + """ + grad_loss_h = torch.zeros(1).cuda() # 初始化 Eikonal 损失 + single_nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred) # 计算非流形点的梯度 + eikonal_loss = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean() # 计算 Eikonal 损失 + return eikonal_loss + + def offsurface_loss(self, nonmnfld_pnts, nonmnfld_pred): + """ + Eo + 惩罚远离表面但是预测值接近0的点 + """ + offsurface_loss = torch.zeros(1).cuda() + if not self.ablation == 'off': + offsurface_loss = torch.exp(-100.0 * torch.abs(nonmnfld_pred)).mean() # 计算离表面损失 + return offsurface_loss + + def consistency_loss(self, mnfld_pnts, mnfld_pred, all_fi): + """ + 惩罚流形点预测值和非流形点预测值不一致的点 + """ + mnfld_consistency_loss = torch.zeros(1).cuda() + if not (self.ablation == 'cons' or self.ablation == 'cc'): + mnfld_consistency_loss = (mnfld_pred - all_fi[:,0]).abs().mean() # 计算流形一致性损失 + return mnfld_consistency_loss + + def correction_loss(self, mnfld_pnts, mnfld_pred, all_fi, th_closeness = 1e-5, a_correction = 100): + """ + 修正损失 + """ + correction_loss = torch.zeros(1).cuda() # 初始化修正损失 + if not (self.ablation == 'cor' or self.ablation == 'cc'): + mismatch_id = torch.abs(mnfld_pred - all_fi[:,0]) > th_closeness # 计算不匹配的 ID + if mismatch_id.sum() != 0: # 如果存在不匹配 + correction_loss = (a_correction * torch.abs(mnfld_pred - all_fi[:,0])[mismatch_id]).mean() # 计算修正损失 + return correction_loss + + + def compute_loss(self, points, + normals, + gt_sdfs, + pred_sdfs): + """ + 计算流型损失的逻辑 + + :param outputs: 模型的输出 + :return: 计算得到的流型损失值 + """ + # 计算流形损失 + manifold_loss = self.position_loss(pred_sdfs,gt_sdfs) - if grad is not None: - grad_constraint = F.mse_loss( - torch.norm(grad, dim=-1), - torch.ones_like(pred_sdf.squeeze(-1)) - ) - else: - grad_constraint = torch.tensor(0.0, device=pred_sdf.device) + # 计算法线损失 + normals_loss = self.normals_loss(normals, points, pred_sdfs) + + # 汇总损失 + loss_details = { + "manifold": self.weights["manifold"] * manifold_loss, + "normals": self.weights["normals"] * normals_loss, + } + + # 计算总损失 + total_loss = sum(loss_details.values()) + + return total_loss, loss_details + + def _compute_loss(self, mnfld_pnts, normals, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last): + """ + 计算流型损失的逻辑 + + :param outputs: 模型的输出 + :return: 计算得到的流型损失值 + """ + mnfld_pred, nonmnfld_pred, mnfld_grad, all_fi = self.pre_process(mnfld_pnts, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last) + manifold_loss = torch.zeros(1).cuda() + # 计算流型损失(这里使用均方误差作为示例) + if not self.ablation == 'overall': + manifold_loss = (mnfld_pred.abs()).mean() # 计算流型损失 + ''' + if args.feature_sample: # 如果启用了特征采样 + feature_indices = torch.randperm(args.all_feature_sample)[:args.num_feature_sample].cuda() # 随机选择特征点 + feature_pnts = self.feature_data[feature_indices] # 获取特征点数据 + feature_mask_pair = self.feature_data_mask_pair[feature_indices] # 获取特征掩码对 + feature_pred_all = self.network(feature_pnts) # 进行前向传播,计算特征点的预测值 + feature_pred = feature_pred_all[:,0] # 提取特征预测结果 + feature_mnfld_loss = feature_pred.abs().mean() # 计算特征流形损失 + loss = loss + weight_mnfld_h * feature_mnfld_loss # 将特征流形损失加权到总损失中 - except Exception as e: - logger.warning(f"Gradient computation failed: {str(e)}") - grad_constraint = torch.tensor(0.0, device=pred_sdf.device) - - return l1_loss + grad_weight * grad_constraint \ No newline at end of file + # patch loss: + feature_id_left = [list(range(args.num_feature_sample)), feature_mask_pair[:,0].tolist()] # 获取左侧特征 ID + feature_id_right = [list(range(args.num_feature_sample)), feature_mask_pair[:,1].tolist()] # 获取右侧特征 ID + feature_fis_left = feature_pred_all[feature_id_left] # 获取左侧特征预测值 + feature_fis_right = feature_pred_all[feature_id_right] # 获取右侧特征预测值 + feature_loss_patch = feature_fis_left.abs().mean() + feature_fis_right.abs().mean() # 计算补丁损失 + loss += feature_loss_patch # 将补丁损失加权到总损失中 + + # consistency loss: + feature_loss_cons = (feature_fis_left - feature_pred).abs().mean() + (feature_fis_right - feature_pred).abs().mean() # 计算一致性损失 + ''' + manifold_loss_patch = torch.zeros(1).cuda() + if self.ablation == 'patch': + manifold_loss_patch = all_fi[:,0].abs().mean() + + # 计算法线损失 + normals_loss = self.normals_loss(normals, mnfld_pnts, all_fi, patch_sup=True) + + # 计算Eikonal损失 + eikonal_loss = self.eikonal_loss(nonmnfld_pnts, nonmnfld_pred_all) + + # 计算离表面损失 + offsurface_loss = self.offsurface_loss(nonmnfld_pnts, nonmnfld_pred_all) + + # 计算一致性损失 + consistency_loss = self.consistency_loss(mnfld_pnts, mnfld_pred, all_fi) + + # 计算修正损失 + correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi) + + + loss_details = { + "manifold": self.weights["manifold"] * manifold_loss, + "manifold_patch": manifold_loss_patch, + "normals": self.weights["normals"] * normals_loss, + "eikonal": self.weights["eikonal"] * eikonal_loss, + "offsurface": self.weights["offsurface"] * offsurface_loss, + "consistency": self.weights["consistency"] * consistency_loss, + "correction": self.weights["correction"] * correction_loss, + } + + # 计算总损失 + total_loss = sum(loss_details.values()) + + return total_loss, loss_details + + + diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index 7477418..9faeff7 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -48,6 +48,7 @@ class GridNet: import torch import torch.nn as nn +from torch.autograd import grad from .encoder import Encoder from .decoder import Decoder @@ -90,3 +91,13 @@ class Net(nn.Module): return output +def gradient(inputs, outputs): + d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device) + points_grad = grad( + outputs=outputs, + inputs=inputs, + grad_outputs=d_points, + create_graph=True, + retain_graph=True, + only_inputs=True)[0][:, -3:] + return points_grad \ No newline at end of file diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 4031e46..a77a35c 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -8,26 +8,64 @@ import argparse from brep2sdf.config.default_config import get_default_config from brep2sdf.data.data import load_brep_file,load_sdf_file -from brep2sdf.data.pre_process import process_single_step +from brep2sdf.data.pre_process_by_mesh import process_single_step from brep2sdf.networks.network import Net from brep2sdf.networks.octree import OctreeNode +from brep2sdf.networks.loss import LossManager from brep2sdf.utils.logger import logger -def prepare_sdf_data(surf_data, max_points=100000, device='cuda'): + +# 配置命令行参数 +parser = argparse.ArgumentParser(description='STEP文件批量处理工具') +parser.add_argument('-i', '--input', required=True, + help='待处理 brep (.step) 路径') +parser.add_argument( + '--use-normal', + action='store_true', # 默认为 False,如果用户指定该参数,则为 True + help='强制采样点有法向量' +) +parser.add_argument( + '--force-reprocess', + action='store_true', # 默认为 False,如果用户指定该参数,则为 True + help='强制重新进行数据预处理,忽略缓存或已有结果' +) +args = parser.parse_args() + + +def prepare_sdf_data(surf_data, normals=None, max_points=100000, device='cuda'): total_points = sum(len(s) for s in surf_data) # 降采样逻辑(修复版) if total_points > max_points: - # 先随机打乱所有点 - all_points = np.concatenate(surf_data) - np.random.shuffle(all_points) - # 直接取前max_points个点 - sampled_points = all_points[:max_points] - sdf_array = np.zeros((max_points, 4), dtype=np.float32) - sdf_array[:, :3] = sampled_points + # 生成索引 + indices = [] + for i, points in enumerate(surf_data): + indices.extend([(i, j) for j in range(len(points))]) + + # 随机打乱索引 + np.random.shuffle(indices) + + # 选择前max_points个索引 + selected_indices = indices[:max_points] + if not normals is None: + # 根据索引构建sdf_array + sdf_array = np.zeros((max_points, 4), dtype=np.float32) + for idx, (i, j) in enumerate(selected_indices): + sdf_array[idx, :3] = surf_data[i][j] + else: + sdf_array = np.zeros((max_points, 7), dtype=np.float32) + for idx, (i, j) in enumerate(selected_indices): + sdf_array[idx, :3] = surf_data[i][j] + sdf_array[idx, 3:6] = normals[i][j] else: - sdf_array = np.zeros((total_points, 4), dtype=np.float32) - sdf_array[:, :3] = np.concatenate(surf_data) + if not normals is None: + sdf_array = np.zeros((total_points, 4), dtype=np.float32) + sdf_array[:, :3] = np.concatenate(surf_data) + sdf_array = np.zeros((max_points, 7), dtype=np.float32) + else: + for idx, (i, j) in enumerate(selected_indices): + sdf_array[idx, :3] = surf_data[i][j] + sdf_array[idx, 3:6] = normals[i][j] return torch.tensor(sdf_array, dtype=torch.float32, device=device) @@ -40,15 +78,16 @@ class Trainer: self.model_name = os.path.basename(input_step).split('_')[0] self.base_name = self.model_name + ".xyz" data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name) - if os.path.exists(data_path): + if os.path.exists(data_path) and not args.force_reprocess: self.data = load_brep_file(data_path) else: - self.data = process_single_step(step_path=input_step, output_path=data_path) + self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal) # 将曲面点云列表转换为 (N*M, 4) 数组 surfs = self.data["surf_ncs"] self.sdf_data = prepare_sdf_data( surfs, + normals = self.data["surf_pnt_normals"], max_points=4096, device=self.device ) @@ -81,6 +120,8 @@ class Trainer: weight_decay=config.train.weight_decay ) + self.loss_manager = LossManager(ablation="none") + def build_tree(self,surf_bbox, max_depth=6): num_faces = surf_bbox.shape[0] @@ -131,14 +172,27 @@ class Trainer: # 获取数据并移动到设备 points = self.sdf_data[:,0:3] points.requires_grad_(True) - gt_sdf = self.sdf_data[:,3] + if args.use_normal: + normals = self.sdf_data[:,3:6] + gt_sdf = self.sdf_data[:,6] + + else: + gt_sdf = self.sdf_data[:,3] # 前向传播 self.optimizer.zero_grad() pred_sdf = self.model(points) # 计算损失 - loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) + if args.use_normal: + loss,loss_details = self.loss_manager.compute_loss( + points, + normals, + gt_sdf, + pred_sdf + ) # 计算损失 + else: + loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) # 反向传播和优化 loss.backward() @@ -173,7 +227,7 @@ class Trainer: best_val_loss = float('inf') logger.info("Starting training...") start_time = time.time() - """ + for epoch in range(1, self.config.train.num_epochs + 1): # 训练一个epoch train_loss = self.train_epoch(epoch) @@ -202,8 +256,8 @@ class Trainer: logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s') logger.info(f'Best validation loss: {best_val_loss:.6f}') self._tracing_model() - """ - self.test_load() + + #self.test_load() def test_load(self): model = self._load_checkpoint("/home/wch/brep2sdf/brep2sdf/00000054.pt") @@ -250,14 +304,7 @@ class Trainer: def main(): # 这里需要初始化配置 - # 配置命令行参数 - parser = argparse.ArgumentParser(description='STEP文件批量处理工具') - parser.add_argument('-i', '--input', required=True, - help='待处理 brep (.step) 路径') - - args = parser.parse_args() config = get_default_config() - # 初始化训练器并开始训练 trainer = Trainer(config, input_step=args.input) trainer.train()