From 2bb08641904c70df48e2a5b0bb10a94611370d38 Mon Sep 17 00:00:00 2001 From: mckay Date: Mon, 21 Apr 2025 13:25:07 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8A=A0=E5=85=A5=E9=82=BB=E6=8E=A5=E5=9B=BE?= =?UTF-8?q?=E4=BD=9C=E4=B8=BA=E5=85=AB=E5=8F=89=E6=A0=91subdivide=E4=BE=9D?= =?UTF-8?q?=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/data/data.py | 1 - brep2sdf/data/pre_process_by_mesh.py | 295 ++---- brep2sdf/data/sampler.py | 133 +-- brep2sdf/data/utils.py | 1400 ++++---------------------- brep2sdf/networks/octree.py | 31 +- brep2sdf/networks/patch_graph.py | 182 ++++ brep2sdf/test.py | 161 +-- brep2sdf/train.py | 59 +- 8 files changed, 576 insertions(+), 1686 deletions(-) create mode 100644 brep2sdf/networks/patch_graph.py diff --git a/brep2sdf/data/data.py b/brep2sdf/data/data.py index 49bfd28..6e5b844 100644 --- a/brep2sdf/data/data.py +++ b/brep2sdf/data/data.py @@ -4,7 +4,6 @@ from torch.utils.data import Dataset import numpy as np import pickle from brep2sdf.utils.logger import logger -from brep2sdf.data.utils import process_brep_data from brep2sdf.config.default_config import get_default_config diff --git a/brep2sdf/data/pre_process_by_mesh.py b/brep2sdf/data/pre_process_by_mesh.py index 7d804d0..1e69a1a 100644 --- a/brep2sdf/data/pre_process_by_mesh.py +++ b/brep2sdf/data/pre_process_by_mesh.py @@ -15,7 +15,6 @@ from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError import logging from datetime import datetime from scipy.spatial import cKDTree -from brep2sdf.utils.logger import logger import tempfile import trimesh from trimesh.proximity import ProximityQuery @@ -36,6 +35,8 @@ from OCC.Core.StlAPI import StlAPI_Writer from brep2sdf.data.sampler import sample_sdf_points_and_normals from brep2sdf.data.data import check_data_format +from brep2sdf.data.utils import get_bbox, normalize, get_adjacency_info,load_step, preprocess_mesh,batch_compute_normals +from brep2sdf.utils.logger import logger # 导入配置 from brep2sdf.config.default_config import get_default_config config = get_default_config() @@ -43,133 +44,6 @@ config = get_default_config() # 设置最大面数阈值,用于加速处理 MAX_FACE = config.data.max_face -def normalize(surfs, edges, corners): - """ - 将CAD模型归一化到单位立方体空间 - - 参数: - surfs: 面的点集列表 - edges: 边的点集列表 - corners: 顶点坐标数组 [num_edges, 2, 3] - - 返回: - surfs_wcs: 原始坐标系下的面点集 - edges_wcs: 原始坐标系下的边点集 - surfs_ncs: 归一化坐标系下的面点集 - edges_ncs: 归一化坐标系下的边点集 - corner_wcs: 归一化后的顶点坐标 [num_edges, 2, 3] - center: 使用的中心点坐标 [3,] - scale: 使用的缩放系数 (float) - """ - if len(corners) == 0: - return None, None, None, None, None, None, None - - # 计算包围盒和缩放因子 - corners_array = corners.reshape(-1, 3) # [num_edges*2, 3] - center = (corners_array.max(0) + corners_array.min(0)) / 2 # 计算中心点 - scale = 1.0 / (corners_array.max(0) - corners_array.min(0)).max() # 计算缩放系数 - - # 归一化面的点集 - surfs_wcs = [] # 原始世界坐标系下的面点集 - surfs_ncs = [] # 归一化坐标系下的面点集 - for surf in surfs: - surf_wcs = np.array(surf) - surf_ncs = (surf_wcs - center) * scale # 归一化变换 - surfs_wcs.append(surf_wcs) - surfs_ncs.append(surf_ncs) - - # 归一化边的点集 - edges_wcs = [] # 原始世界坐标系下的边点集 - edges_ncs = [] # 归一化坐标系下的边点集 - for edge in edges: - edge_wcs = np.array(edge) - edge_ncs = (edge_wcs - center) * scale # 归一化变换 - edges_wcs.append(edge_wcs) - edges_ncs.append(edge_ncs) - - # 归一化顶点坐标 - 保持[num_edges, 2, 3]的形状 - corner_wcs = (corners - center) * scale # 广播操作会保持原有维度 - - return (np.array(surfs_wcs, dtype=object), - np.array(edges_wcs, dtype=object), - np.array(surfs_ncs, dtype=object), - np.array(edges_ncs, dtype=object), - corner_wcs.astype(np.float32), - center.astype(np.float32), - scale - ) - -def get_adjacency_info(shape, faces, edges, vertices): - """ - 优化后的邻接关系计算函数,直接使用已收集的几何元素 - - 参数新增: - faces: 已收集的面列表 - edges: 已收集的边列表 - vertices: 已收集的顶点列表 - """ - logger.debug("Get adjacency infos...") - # 创建边-面映射关系 - edge_face_map = TopTools_IndexedDataMapOfShapeListOfShape() - topexp.MapShapesAndAncestors(shape, TopAbs_EDGE, TopAbs_FACE, edge_face_map) - - # 直接使用传入的几何元素列表 - num_faces = len(faces) - num_edges = len(edges) - num_vertices = len(vertices) - logger.debug(f"num_faces: {num_faces}, num_edges: {num_edges}, num_vertices: {num_vertices}") - - edgeFace_adj = np.zeros((num_edges, num_faces), dtype=np.int32) - faceEdge_adj = np.zeros((num_faces, num_edges), dtype=np.int32) - edgeCorner_adj = np.zeros((num_edges, 2), dtype=np.int32) - - # 填充边-面邻接矩阵 - for i, edge in enumerate(edges): - # 检查每个面是否与当前边相连 - for j, face in enumerate(faces): - edge_explorer = TopExp_Explorer(face, TopAbs_EDGE) - while edge_explorer.More(): - if edge.IsSame(edge_explorer.Current()): - edgeFace_adj[i, j] = 1 - faceEdge_adj[j, i] = 1 - break - edge_explorer.Next() - - # 获取边的两个端点 - v1 = TopoDS_Vertex() - v2 = TopoDS_Vertex() - topexp.Vertices(edge, v1, v2) - - # 记录边的端点索引 - if not v1.IsNull() and not v2.IsNull(): - v1_vertex = topods.Vertex(v1) - v2_vertex = topods.Vertex(v2) - - for k, vertex in enumerate(vertices): - if v1_vertex.IsSame(vertex): - edgeCorner_adj[i, 0] = k - if v2_vertex.IsSame(vertex): - edgeCorner_adj[i, 1] = k - - return edgeFace_adj, faceEdge_adj, edgeCorner_adj - -def get_bbox(shape, subshape): - """ - 计算形状的包围盒 - - 参数: - shape: 完整的CAD模型形状 - subshape: 需要计算包围盒的子形状(面或边) - - 返回: - 包围盒的六个参数 [xmin, ymin, zmin, xmax, ymax, zmax] - """ - bbox = Bnd_Box() - brepbndlib.Add(subshape, bbox) - xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get() - return np.array([xmin, ymin, zmin, xmax, ymax, zmax]) - - def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False): """ @@ -360,7 +234,74 @@ def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False): if any(x is None for x in [surfs_wcs, edges_wcs, surfs_ncs, edges_ncs, corner_wcs]): logger.error(f"Normalization failed for {step_path}") return None + # --- 计算边的类型 --- + logger.debug("计算边的类型...") + edge_types = [] # 0:凹边 1:凸边 + + # 对每条边进行处理 + for edge_idx in range(len(edges)): + # 获取与当前边相邻的面 + adjacent_faces = np.where(edgeFace_adj[edge_idx] == 1)[0] + + # 如果边只有一个相邻面或没有相邻面,默认为凹边 + if len(adjacent_faces) < 2: + edge_types.append(0) + continue + + # 获取两个相邻面 + face1, face2 = faces[adjacent_faces[0]], faces[adjacent_faces[1]] + + # 使用BRep_Tool获取面的几何信息 + from OCC.Core.BRepAdaptor import BRepAdaptor_Surface + from OCC.Core.gp import gp_Pnt, gp_Vec + + # 获取第一个面的法向量 + surf1 = BRepAdaptor_Surface(face1) + u1 = (surf1.FirstUParameter() + surf1.LastUParameter()) / 2 + v1 = (surf1.FirstVParameter() + surf1.LastVParameter()) / 2 + pnt1 = gp_Pnt() + du1 = gp_Vec() + dv1 = gp_Vec() + surf1.D1(u1, v1, pnt1, du1, dv1) + normal1 = du1.Crossed(dv1) + normal1.Normalize() + normal1_np = np.array([normal1.X(), normal1.Y(), normal1.Z()]) + # 获取第二个面的法向量 + surf2 = BRepAdaptor_Surface(face2) + u2 = (surf2.FirstUParameter() + surf2.LastUParameter()) / 2 + v2 = (surf2.FirstVParameter() + surf2.LastVParameter()) / 2 + pnt2 = gp_Pnt() + du2 = gp_Vec() + dv2 = gp_Vec() + surf2.D1(u2, v2, pnt2, du2, dv2) + normal2 = du2.Crossed(dv2) + normal2.Normalize() + normal2_np = np.array([normal2.X(), normal2.Y(), normal2.Z()]) + + # 获取边的方向向量 + edge = edges[edge_idx] + curve_info = BRep_Tool.Curve(edge) + if curve_info is None or len(curve_info) < 3: + edge_types.append(0) + continue + + curve, first, last = curve_info + # 计算边的方向向量 + start_point = np.array([curve.Value(first).X(), curve.Value(first).Y(), curve.Value(first).Z()]) + end_point = np.array([curve.Value(last).X(), curve.Value(last).Y(), curve.Value(last).Z()]) + edge_vector = end_point - start_point + edge_vector = edge_vector / np.linalg.norm(edge_vector) + + # 使用混合积判断凹凸性 + # 如果混合积为正,说明是凸边;为负,说明是凹边 + mixed_product = np.dot(np.cross(normal1_np, normal2_np), edge_vector) + + # 根据混合积的符号确定边的类型 + edge_types.append(1 if mixed_product > 0 else 0) + + edge_types = np.array(edge_types, dtype=np.int32) + # 创建结果字典并确保所有数组都有正确的类型 data = { 'surf_wcs': np.array(surfs_wcs, dtype=object), # 保持对象数组 @@ -368,9 +309,10 @@ def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False): 'surf_ncs': np.array(surfs_ncs, dtype=object), # 保持对象数组 'edge_ncs': np.array(edges_ncs, dtype=object), # 保持对象数组 'corner_wcs': corner_wcs.astype(np.float32), # [num_edges, 2, 3] - 'edgeFace_adj': edgeFace_adj.astype(np.int32), + 'edgeFace_adj': edgeFace_adj.astype(np.int32), # [num_edges, num_faces], 1 表示边与面相邻 'edgeCorner_adj': edgeCorner_adj.astype(np.int32), 'faceEdge_adj': faceEdge_adj.astype(np.int32), + 'edge_types': np.array(edge_types, dtype=np.int32), # [num_edges] 'surf_bbox_wcs': surf_bbox_wcs.astype(np.float32), 'edge_bbox_wcs': edge_bbox_wcs.astype(np.float32), 'surf_bbox_ncs': surf_bbox_ncs.astype(np.float32), # 归一化坐标系 [num_faces, 6] @@ -450,94 +392,6 @@ def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False): logger.warning("请求了 SDF 点采样,但 Trimesh 加载失败。") return data - - -def load_step(step_path): - """Load STEP file and return solids""" - reader = STEPControl_Reader() - reader.ReadFile(step_path) - reader.TransferRoots() - return [reader.OneShape()] - - -def preprocess_mesh(mesh, normal_type='vertex'): - """ - 预处理网格数据,生成 KDTree 和法向量源。 - - 参数: - mesh: trimesh.Trimesh 对象,包含顶点和法向量信息 - normal_type: str 法向量类型,可选 'vertex' 或 'face' - - 返回: - tree: cKDTree 用于加速最近邻查询 - normals_source: np.ndarray 包含顶点法向量或面法向量 - """ - if normal_type == 'vertex': - tree = cKDTree(mesh.vertices) - normals_source = mesh.vertex_normals - elif normal_type == 'face': - # 计算每个面的中心点 - face_centers = np.mean(mesh.vertices[mesh.faces], axis=1) - tree = cKDTree(face_centers) - normals_source = mesh.face_normals - else: - raise ValueError(f"Unsupported normal type: {normal_type}") - - return tree, normals_source - - -def batch_compute_normals(mesh, surf_wcs, normal_type='vertex', k_neighbors=3): - """ - 为嵌套点云数据计算法向量,并保持嵌套格式。 - - 参数: - mesh: trimesh.Trimesh 对象,包含顶点和法向量信息 - surf_wcs: np.ndarray(dtype=object) 形状为 (N,) 的数组,每个元素是形状为 (M, 3) 的 float32 数组 - normal_type: str 法向量类型,可选 'vertex' 或 'face' - k_neighbors: int 用于平滑的最近邻数量 - - 返回: - normals: np.ndarray(dtype=object) 形状为 (N,) 的数组,每个元素是形状为 (M, 3) 的 float32 数组 - """ - # 预处理网格数据 - tree, normals_source = preprocess_mesh(mesh, normal_type=normal_type) - - # 展平所有点云为一个二维数组 [P, 3],并记录分割索引 - lengths = [len(point_cloud) for point_cloud in surf_wcs] - query_points = np.concatenate(surf_wcs, axis=0).astype(np.float32) # 避免多次内存分配 - - # 批量查询最近邻 - distances, indices = tree.query(query_points, k=k_neighbors) - - # 处理k=1的特殊情况 - if k_neighbors == 1: - nearest_normals = normals_source[indices] - else: - # 加权平均(权重为距离倒数) - inv_distances = 1 / (distances + 1e-8) # 防止除以零 - sum_inv_distances = inv_distances.sum(axis=1, keepdims=True) - valid_mask = sum_inv_distances > 1e-6 - weights = np.divide(inv_distances, sum_inv_distances, out=np.zeros_like(inv_distances), where=valid_mask) - nearest_normals = np.einsum('ijk,ij->ik', normals_source[indices], weights) - - # 标准化结果 - norms = np.linalg.norm(nearest_normals, axis=1) - valid_mask = norms > 1e-6 - nearest_normals[valid_mask] /= norms[valid_mask, None] - - # 按原始嵌套结构分割法向量 - start = 0 - normals_output = np.empty(len(surf_wcs), dtype=object) - for i, length in enumerate(lengths): - end = start + length - normals_output[i] = nearest_normals[start:end] - start = end - - return normals_output - - - - def process_single_step(step_path:str, output_path:str=None, sample_normal_vector=False, sample_sdf_points=False, timeout:int=300) -> dict: """处理单个STEP文件, 从 brep 2 pkl return data = { @@ -549,9 +403,16 @@ def process_single_step(step_path:str, output_path:str=None, sample_normal_vecto 'edgeFace_adj': edgeFace_adj.astype(np.int32), # 边-面的邻接关系矩阵 'edgeCorner_adj': edgeCorner_adj.astype(np.int32),# 边-角点的邻接关系矩阵 'faceEdge_adj': faceEdge_adj.astype(np.int32), # 面-边的邻接关系矩阵 + 'edge_types': np.array(edge_types, dtype=np.int32)# [num_edges] 'surf_bbox_wcs': surf_bbox_wcs.astype(np.float32),# 曲面在世界坐标系下的包围盒 'edge_bbox_wcs': edge_bbox_wcs.astype(np.float32),# 边在世界坐标系下的包围盒 'corner_unique': np.unique(corner_wcs.reshape(-1, 3), axis=0).astype(np.float32) # 去重后的唯一角点坐标 + 'normalization_params': { # 归一化参数 + 'center': center.astype(np.float32), # 归一化中心点 [3,] + 'scale': float(scale), # 归一化缩放系数 + }, + 'surf_pnt_normals': np.array(dtype=object), # 表面点的法线数据 [num_faces, num_surf_sample_points, 3],仅当 sample_normal_vector=True + 'sampled_points_normals_sdf': np.array(dtype=float32), # 采样点的位置、法线和SDF值 [num_samples, 7],仅当 sample_sdf_points=True }""" try: logger.info("数据预处理……") diff --git a/brep2sdf/data/sampler.py b/brep2sdf/data/sampler.py index 215f0aa..a8eb1f2 100644 --- a/brep2sdf/data/sampler.py +++ b/brep2sdf/data/sampler.py @@ -36,100 +36,14 @@ from OCC.Core.StlAPI import StlAPI_Writer # 导入配置 from brep2sdf.config.default_config import get_default_config +from brep2sdf.data.utils import get_bbox, normalize, get_adjacency_info,load_step, preprocess_mesh,batch_compute_normals + config = get_default_config() # 设置最大面数阈值,用于加速处理 MAX_FACE = config.data.max_face -def _sample_uniform_points(num_points: int) -> np.ndarray: - """在 [-0.5, 0.5] 范围内均匀采样点 - - 参数: - num_points: 要采样的点数 - - 返回: - np.ndarray: 形状为 (num_points, 3) 的采样点数组 - """ - return np.random.uniform(-0.5, 0.5, (num_points, 3)) - -def _sample_near_surface_points( - mesh: trimesh.Trimesh, - num_points: int, - std_dev: float -) -> np.ndarray: - """在网格表面附近采样点 - - 参数: - mesh: 输入的trimesh网格 - num_points: 要采样的点数 - std_dev: 沿法线方向的扰动标准差 - - 返回: - np.ndarray: 形状为 (num_points, 3) 的采样点数组 - """ - if mesh.faces.shape[0] == 0: - logger.warning("网格没有面,无法执行近表面采样。将替换为均匀采样点。") - return _sample_uniform_points(num_points) - - try: - near_points_on_surface = mesh.sample(num_points) - proximity_query_near = ProximityQuery(mesh) - closest_points_near, _, face_indices_near = proximity_query_near.on_surface(near_points_on_surface) - - if np.any(face_indices_near >= len(mesh.face_normals)): - raise IndexError("Face index out of bounds during near-surface normal lookup") - - normals_near = mesh.face_normals[face_indices_near] - perturbations = np.random.randn(num_points, 1) * std_dev - near_points = near_points_on_surface + normals_near * perturbations - return np.clip(near_points, -0.5, 0.5) - - except Exception as e: - logger.warning(f"近表面采样失败: {e}。将替换为均匀采样点。") - return _sample_uniform_points(num_points) - -def sample_points( - trimesh_mesh_ncs: trimesh.Trimesh, - num_uniform_samples: int, - num_near_surface_samples: int, - sdf_sampling_std_dev: float -) -> np.ndarray | None: - """组合均匀采样和近表面采样的点 - - 参数: - trimesh_mesh_ncs: 归一化的trimesh网格 - num_uniform_samples: 均匀采样点数 - num_near_surface_samples: 近表面采样点数 - sdf_sampling_std_dev: 近表面采样的标准差 - - 返回: - np.ndarray | None: 合并后的采样点数组,失败时返回None - """ - sampled_points_list = [] - - # 均匀采样 - if num_uniform_samples > 0: - uniform_points = _sample_uniform_points(num_uniform_samples) - sampled_points_list.append(uniform_points) - - # 近表面采样 - if num_near_surface_samples > 0: - near_points = _sample_near_surface_points( - trimesh_mesh_ncs, - num_near_surface_samples, - sdf_sampling_std_dev - ) - sampled_points_list.append(near_points) - - # 合并采样点 - if not sampled_points_list: - logger.warning("没有采样到任何点。") - return None - - return np.vstack(sampled_points_list).astype(np.float32) - -# 在原始的sample_sdf_points_and_normals函数中使用新的采样函数 def sample_sdf_points_and_normals( trimesh_mesh_ncs: trimesh.Trimesh, surf_bbox_ncs: np.ndarray, @@ -169,12 +83,43 @@ def sample_sdf_points_and_normals( logger.debug(f"固定 SDF 采样点数: {num_sdf_samples} (均匀: {num_uniform_samples}, 近表面: {num_near_surface_samples})") # --- 执行采样 --- - sampled_points_ncs = sample_points( - trimesh_mesh_ncs, - num_uniform_samples, - num_near_surface_samples, - sdf_sampling_std_dev - ) + sampled_points_list = [] + + # 均匀采样 (在 [-0.5, 0.5] 范围内) + if num_uniform_samples > 0: + uniform_points = np.random.uniform(-0.5, 0.5, (num_uniform_samples, 3)) + sampled_points_list.append(uniform_points) + + # 近表面采样 + if num_near_surface_samples > 0: + if trimesh_mesh_ncs.faces.shape[0] > 0: + try: + near_points_on_surface = trimesh_mesh_ncs.sample(num_near_surface_samples) + proximity_query_near = ProximityQuery(trimesh_mesh_ncs) + closest_points_near, distances_near, face_indices_near = proximity_query_near.on_surface(near_points_on_surface) + if np.any(face_indices_near >= len(trimesh_mesh_ncs.face_normals)): + raise IndexError("Face index out of bounds during near-surface normal lookup") + normals_near = trimesh_mesh_ncs.face_normals[face_indices_near] + perturbations = np.random.randn(num_near_surface_samples, 1) * sdf_sampling_std_dev + near_points = near_points_on_surface + normals_near * perturbations + # 确保近表面点也在 [-0.5, 0.5] 范围内 + near_points = np.clip(near_points, -0.5, 0.5) + sampled_points_list.append(near_points) + except Exception as e: + logger.warning(f"近表面采样失败: {e}。将替换为均匀采样点。") + fallback_uniform = np.random.uniform(-0.5, 0.5, (num_near_surface_samples, 3)) + sampled_points_list.append(fallback_uniform) + else: + logger.warning("网格没有面,无法执行近表面采样。将替换为均匀采样点。") + fallback_uniform = np.random.uniform(-0.5, 0.5, (num_near_surface_samples, 3)) + sampled_points_list.append(fallback_uniform) + + # --- 合并采样点 --- + if not sampled_points_list: + logger.warning("没有为SDF采样到任何点。") + return None + + sampled_points_ncs = np.vstack(sampled_points_list).astype(np.float32) try: proximity_query = ProximityQuery(trimesh_mesh_ncs) diff --git a/brep2sdf/data/utils.py b/brep2sdf/data/utils.py index 00f3271..0ae7c54 100644 --- a/brep2sdf/data/utils.py +++ b/brep2sdf/data/utils.py @@ -1,1211 +1,225 @@ -import numpy as np -import math -import torch -import torch.nn as nn -import random -import string -import argparse -from chamferdist import ChamferDistance -from mpl_toolkits.mplot3d.art3d import Poly3DCollection -from typing import List, Optional, Tuple, Union -from brep2sdf.utils.logger import logger -from brep2sdf.config.default_config import get_default_config - - -from OCC.Core.gp import gp_Pnt, gp_Pnt -from OCC.Core.TColgp import TColgp_Array2OfPnt -from OCC.Core.GeomAPI import GeomAPI_PointsToBSplineSurface, GeomAPI_PointsToBSpline -from OCC.Core.GeomAbs import GeomAbs_C2 -from OCC.Core.BRepBuilderAPI import BRepBuilderAPI_MakeWire, BRepBuilderAPI_MakeFace, BRepBuilderAPI_MakeEdge -from OCC.Extend.TopologyUtils import TopologyExplorer, WireExplorer -from OCC.Core.TColgp import TColgp_Array1OfPnt -from OCC.Core.gp import gp_Pnt -from OCC.Core.ShapeFix import ShapeFix_Face, ShapeFix_Wire, ShapeFix_Edge -from OCC.Core.ShapeAnalysis import ShapeAnalysis_Wire -from OCC.Core.BRepBuilderAPI import BRepBuilderAPI_Sewing, BRepBuilderAPI_MakeSolid - - -def generate_random_string(length): - characters = string.ascii_letters + string.digits # You can include other characters if needed - random_string = ''.join(random.choice(characters) for _ in range(length)) - return random_string - - -def get_bbox_norm(point_cloud): - # Find the minimum and maximum coordinates along each axis - min_x = np.min(point_cloud[:, 0]) - max_x = np.max(point_cloud[:, 0]) - - min_y = np.min(point_cloud[:, 1]) - max_y = np.max(point_cloud[:, 1]) - - min_z = np.min(point_cloud[:, 2]) - max_z = np.max(point_cloud[:, 2]) - - # Create the 3D bounding box using the min and max values - min_point = np.array([min_x, min_y, min_z]) - max_point = np.array([max_x, max_y, max_z]) - return np.linalg.norm(max_point - min_point) - - -def compute_bbox_center_and_size(min_corner, max_corner): - # Calculate the center - center_x = (min_corner[0] + max_corner[0]) / 2 - center_y = (min_corner[1] + max_corner[1]) / 2 - center_z = (min_corner[2] + max_corner[2]) / 2 - center = np.array([center_x, center_y, center_z]) - # Calculate the size - size_x = max_corner[0] - min_corner[0] - size_y = max_corner[1] - min_corner[1] - size_z = max_corner[2] - min_corner[2] - size = max(size_x, size_y, size_z) - return center, size - - -def randn_tensor( - shape: Union[Tuple, List], - generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, - device: Optional["torch.device"] = None, - dtype: Optional["torch.dtype"] = None, - layout: Optional["torch.layout"] = None, -): - """This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When - passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor - will always be created on CPU. - """ - # device on which tensor is created defaults to device - rand_device = device - batch_size = shape[0] - - layout = layout or torch.strided - device = device or torch.device("cpu") - - if generator is not None: - gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type - if gen_device_type != device.type and gen_device_type == "cpu": - rand_device = "cpu" - elif gen_device_type != device.type and gen_device_type == "cuda": - raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") - - if isinstance(generator, list): - shape = (1,) + shape[1:] - latents = [ - torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) - for i in range(batch_size) - ] - latents = torch.cat(latents, dim=0).to(device) - else: - latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) - - return latents - - -def pad_repeat(x, max_len): - repeat_times = math.floor(max_len/len(x)) - sep = max_len-repeat_times*len(x) - sep1 = np.repeat(x[:sep], repeat_times+1, axis=0) - sep2 = np.repeat(x[sep:], repeat_times, axis=0) - x_repeat = np.concatenate([sep1, sep2], 0) - return x_repeat - -''' -def pad_zero(x, max_len, return_mask=False): - keys = np.ones(len(x)) - padding = np.zeros((max_len-len(x))).astype(int) - mask = 1-np.concatenate([keys, padding]) == 1 - padding = np.zeros((max_len-len(x), *x.shape[1:])) - x_padded = np.concatenate([x, padding], axis=0) - if return_mask: - return x_padded, mask - else: - return x_padded -''' -def pad_zero(x, max_len, return_mask=False): - """填充或截断数组到指定长度 - - Args: - x: 输入数组 - max_len: 目标长度 - return_mask: 是否返回掩码 - - Returns: - x_padded: 处理后的数组 - mask: (可选) 掩码,标记实际数据(True)和填充(False) - """ - # 获取实际长度 - actual_len = len(x) - - # 如果实际长度超过最大长度,进行截断 - if actual_len > max_len: - x = x[:max_len] - mask = np.ones(max_len, dtype=bool) - if return_mask: - return x, mask - return x - - # 如果需要填充 - if actual_len < max_len: - # 创建掩码 - keys = np.ones(actual_len) - padding_mask = np.zeros(max_len - actual_len) - mask = np.concatenate([keys, padding_mask]) == 1 - - # 填充数据 - padding = np.zeros((max_len - actual_len, *x.shape[1:])) - x_padded = np.concatenate([x, padding], axis=0) - else: - # 长度正好 - mask = np.ones(max_len, dtype=bool) - x_padded = x - - if return_mask: - return x_padded, mask - return x_padded - - -def plot_3d_bbox(ax, min_corner, max_corner, color='r'): - """ - Helper function for plotting 3D bounding boxese - """ - vertices = [ - (min_corner[0], min_corner[1], min_corner[2]), - (max_corner[0], min_corner[1], min_corner[2]), - (max_corner[0], max_corner[1], min_corner[2]), - (min_corner[0], max_corner[1], min_corner[2]), - (min_corner[0], min_corner[1], max_corner[2]), - (max_corner[0], min_corner[1], max_corner[2]), - (max_corner[0], max_corner[1], max_corner[2]), - (min_corner[0], max_corner[1], max_corner[2]) - ] - # Define the 12 triangles composing the box - faces = [ - [vertices[0], vertices[1], vertices[2], vertices[3]], - [vertices[4], vertices[5], vertices[6], vertices[7]], - [vertices[0], vertices[1], vertices[5], vertices[4]], - [vertices[2], vertices[3], vertices[7], vertices[6]], - [vertices[1], vertices[2], vertices[6], vertices[5]], - [vertices[4], vertices[7], vertices[3], vertices[0]] - ] - ax.add_collection3d(Poly3DCollection(faces, facecolors='blue', linewidths=1, edgecolors=color, alpha=0)) - return - - -def get_args_vae(): - parser = argparse.ArgumentParser() - parser.add_argument('--data', type=str, default='data_process/deepcad_parsed', - help='Path to data folder') - parser.add_argument('--train_list', type=str, default='data_process/deepcad_data_split_6bit_surface.pkl', - help='Path to training list') - parser.add_argument('--val_list', type=str, default='data_process/deepcad_data_split_6bit.pkl', - help='Path to validation list') - # Training parameters - parser.add_argument("--option", type=str, choices=['surface', 'edge'], default='surface', - help="Choose between option surface or edge (default: surface)") - parser.add_argument('--batch_size', type=int, default=512, help='input batch size') - parser.add_argument('--train_nepoch', type=int, default=200, help='number of epochs to train for') - parser.add_argument('--save_nepoch', type=int, default=20, help='number of epochs to save model') - parser.add_argument('--test_nepoch', type=int, default=10, help='number of epochs to test model') - parser.add_argument("--data_aug", action='store_true', help='Use data augmentation') - parser.add_argument("--finetune", action='store_true', help='Finetune from existing weights') - parser.add_argument("--weight", type=str, default=None, help='Weight path when finetuning') - parser.add_argument("--gpu", type=int, nargs='+', default=[0], help="GPU IDs to use for training (default: [0])") - # Save dirs and reload - parser.add_argument('--env', type=str, default="surface_vae", help='environment') - parser.add_argument('--dir_name', type=str, default="proj_log", help='name of the log folder.') - args = parser.parse_args() - # saved folder - args.save_dir = f'{args.dir_name}/{args.env}' - return args - - -def get_args_ldm(): - parser = argparse.ArgumentParser() - parser.add_argument('--data', type=str, default='data_process/deepcad_parsed', - help='Path to data folder') - parser.add_argument('--list', type=str, default='data_process/deepcad_data_split_6bit.pkl', - help='Path to data list') - parser.add_argument('--surfvae', type=str, default='proj_log/deepcad_surfvae/epoch_400.pt', - help='Path to pretrained surface vae weights') - parser.add_argument('--edgevae', type=str, default='proj_log/deepcad_edgevae/epoch_300.pt', - help='Path to pretrained edge vae weights') - parser.add_argument("--option", type=str, choices=['surfpos', 'surfz', 'edgepos', 'edgez'], default='surfpos', - help="Choose between option [surfpos,edgepos,surfz,edgez] (default: surfpos)") - # Training parameters - parser.add_argument('--batch_size', type=int, default=512, help='input batch size') - parser.add_argument('--train_nepoch', type=int, default=3000, help='number of epochs to train for') - parser.add_argument('--test_nepoch', type=int, default=25, help='number of epochs to test model') - parser.add_argument('--save_nepoch', type=int, default=50, help='number of epochs to save model') - parser.add_argument('--max_face', type=int, default=50, help='maximum number of faces') - parser.add_argument('--max_edge', type=int, default=30, help='maximum number of edges per face') - parser.add_argument('--threshold', type=float, default=0.05, help='minimum threshold between two faces') - parser.add_argument('--bbox_scaled', type=float, default=3, help='scaled the bbox') - parser.add_argument('--z_scaled', type=float, default=1, help='scaled the latent z') - parser.add_argument("--gpu", type=int, nargs='+', default=[0, 1], help="GPU IDs to use for training (default: [0, 1])") - parser.add_argument("--data_aug", action='store_true', help='Use data augmentation') - parser.add_argument("--cf", action='store_true', help='Use data augmentation') - # Save dirs and reload - parser.add_argument('--env', type=str, default="surface_pos", help='environment') - parser.add_argument('--dir_name', type=str, default="proj_log", help='name of the log folder.') - args = parser.parse_args() - # saved folder - args.save_dir = f'{args.dir_name}/{args.env}' - return args - - -def rotate_point_cloud(point_cloud, angle_degrees, axis): - """ - Rotate a point cloud around its center by a specified angle in degrees along a specified axis. - - Args: - - point_cloud: Numpy array of shape (N, 3) representing the point cloud. - - angle_degrees: Angle of rotation in degrees. - - axis: Axis of rotation. Can be 'x', 'y', or 'z'. - - Returns: - - rotated_point_cloud: Numpy array of shape (N, 3) representing the rotated point cloud. - """ - - # Convert angle to radians - angle_radians = np.radians(angle_degrees) - - # Compute rotation matrix based on the specified axis - if axis == 'x': - rotation_matrix = np.array([[1, 0, 0], - [0, np.cos(angle_radians), -np.sin(angle_radians)], - [0, np.sin(angle_radians), np.cos(angle_radians)]]) - elif axis == 'y': - rotation_matrix = np.array([[np.cos(angle_radians), 0, np.sin(angle_radians)], - [0, 1, 0], - [-np.sin(angle_radians), 0, np.cos(angle_radians)]]) - elif axis == 'z': - rotation_matrix = np.array([[np.cos(angle_radians), -np.sin(angle_radians), 0], - [np.sin(angle_radians), np.cos(angle_radians), 0], - [0, 0, 1]]) - else: - raise ValueError("Invalid axis. Must be 'x', 'y', or 'z'.") - - # Center the point cloud - center = np.mean(point_cloud, axis=0) - centered_point_cloud = point_cloud - center - - # Apply rotation - rotated_point_cloud = np.dot(centered_point_cloud, rotation_matrix.T) - - # Translate back to original position - rotated_point_cloud += center - - # Find the maximum absolute coordinate value - max_abs_coord = np.max(np.abs(rotated_point_cloud)) - - # Scale the point cloud to fit within the -1 to 1 cube - normalized_point_cloud = rotated_point_cloud / max_abs_coord - - return normalized_point_cloud - - -def get_bbox(pnts): - """ - Get the tighest fitting 3D (axis-aligned) bounding box giving a set of points - """ - bbox_corners = [] - for point_cloud in pnts: - # Find the minimum and maximum coordinates along each axis - min_x = np.min(point_cloud[:, 0]) - max_x = np.max(point_cloud[:, 0]) - - min_y = np.min(point_cloud[:, 1]) - max_y = np.max(point_cloud[:, 1]) - - min_z = np.min(point_cloud[:, 2]) - max_z = np.max(point_cloud[:, 2]) - - # Create the 3D bounding box using the min and max values - min_point = np.array([min_x, min_y, min_z]) - max_point = np.array([max_x, max_y, max_z]) - bbox_corners.append([min_point, max_point]) - return np.array(bbox_corners) - - -def bbox_corners(bboxes): - """ - Given the bottom-left and top-right corners of the bbox - Return all eight corners - """ - bboxes_all_corners = [] - for bbox in bboxes: - bottom_left, top_right = bbox[:3], bbox[3:] - # Bottom 4 corners - bottom_front_left = bottom_left - bottom_front_right = (top_right[0], bottom_left[1], bottom_left[2]) - bottom_back_left = (bottom_left[0], top_right[1], bottom_left[2]) - bottom_back_right = (top_right[0], top_right[1], bottom_left[2]) - - # Top 4 corners - top_front_left = (bottom_left[0], bottom_left[1], top_right[2]) - top_front_right = (top_right[0], bottom_left[1], top_right[2]) - top_back_left = (bottom_left[0], top_right[1], top_right[2]) - top_back_right = top_right - - # Combine all coordinates - all_corners = [ - bottom_front_left, - bottom_front_right, - bottom_back_left, - bottom_back_right, - top_front_left, - top_front_right, - top_back_left, - top_back_right, - ] - bboxes_all_corners.append(np.vstack(all_corners)) - bboxes_all_corners = np.array(bboxes_all_corners) - return bboxes_all_corners - - -def rotate_axis(pnts, angle_degrees, axis, normalized=False): - """ - Rotate a point cloud around its center by a specified angle in degrees along a specified axis. - - Args: - - point_cloud: Numpy array of shape (N, ..., 3) representing the point cloud. - - angle_degrees: Angle of rotation in degrees. - - axis: Axis of rotation. Can be 'x', 'y', or 'z'. - - Returns: - - rotated_point_cloud: Numpy array of shape (N, 3) representing the rotated point cloud. - """ - - # Convert angle to radians - angle_radians = np.radians(angle_degrees) - - # Convert points to homogeneous coordinates - shape = list(np.shape(pnts)) - shape[-1] = 1 - pnts_homogeneous = np.concatenate((pnts, np.ones(shape)), axis=-1) - - # Compute rotation matrix based on the specified axis - if axis == 'x': - rotation_matrix = np.array([ - [1, 0, 0, 0], - [0, np.cos(angle_radians), -np.sin(angle_radians), 0], - [0, np.sin(angle_radians), np.cos(angle_radians), 0], - [0, 0, 0, 1] - ]) - elif axis == 'y': - rotation_matrix = np.array([ - [np.cos(angle_radians), 0, np.sin(angle_radians), 0], - [0, 1, 0, 0], - [-np.sin(angle_radians), 0, np.cos(angle_radians), 0], - [0, 0, 0, 1] - ]) - elif axis == 'z': - rotation_matrix = np.array([ - [np.cos(angle_radians), -np.sin(angle_radians), 0, 0], - [np.sin(angle_radians), np.cos(angle_radians), 0, 0], - [0, 0, 1, 0], - [0, 0, 0, 1] - ]) - else: - raise ValueError("Invalid axis. Must be 'x', 'y', or 'z'.") - - # Apply rotation - rotated_pnts_homogeneous = np.dot(pnts_homogeneous, rotation_matrix.T) - rotated_pnts = rotated_pnts_homogeneous[...,:3] - - # Scale the point cloud to fit within the -1 to 1 cube - if normalized: - max_abs_coord = np.max(np.abs(rotated_pnts)) - rotated_pnts = rotated_pnts / max_abs_coord - - return rotated_pnts - - -def rescale_bbox(bboxes, scale): - # Apply scaling factors to bounding boxes - scaled_bboxes = bboxes*scale - return scaled_bboxes - - -def translate_bbox(bboxes): - """ - Randomly move object within the cube (x,y,z direction) - """ - point_cloud = bboxes.reshape(-1,3) - min_x = np.min(point_cloud[:, 0]) - max_x = np.max(point_cloud[:, 0]) - min_y = np.min(point_cloud[:, 1]) - max_y = np.max(point_cloud[:, 1]) - min_z = np.min(point_cloud[:, 2]) - max_z = np.max(point_cloud[:, 2]) - x_offset = np.random.uniform( np.min(-1-min_x,0), np.max(1-max_x,0) ) - y_offset = np.random.uniform( np.min(-1-min_y,0), np.max(1-max_y,0) ) - z_offset = np.random.uniform( np.min(-1-min_z,0), np.max(1-max_z,0) ) - random_translation = np.array([x_offset,y_offset,z_offset]) - bboxes_translated = bboxes + random_translation - return bboxes_translated - - -def edge2loop(face_edges): - face_edges_flatten = face_edges.reshape(-1,3) - # connect end points by closest distance - merged_vertex_id = [] - for edge_idx, startend in enumerate(face_edges): - self_id = [2*edge_idx, 2*edge_idx+1] - # left endpoint - distance = np.linalg.norm(face_edges_flatten - startend[0], axis=1) - min_id = list(np.argsort(distance)) - min_id_noself = [x for x in min_id if x not in self_id] - merged_vertex_id.append(sorted([2*edge_idx, min_id_noself[0]])) - # right endpoint - distance = np.linalg.norm(face_edges_flatten - startend[1], axis=1) - min_id = list(np.argsort(distance)) - min_id_noself = [x for x in min_id if x not in self_id] - merged_vertex_id.append(sorted([2*edge_idx+1, min_id_noself[0]])) - - merged_vertex_id = np.unique(np.array(merged_vertex_id),axis=0) - return merged_vertex_id - - -def keep_largelist(int_lists): - # Initialize a list to store the largest integer lists - largest_int_lists = [] - - # Convert each list to a set for efficient comparison - sets = [set(lst) for lst in int_lists] - - # Iterate through the sets and check if they are subsets of others - for i, s1 in enumerate(sets): - is_subset = False - for j, s2 in enumerate(sets): - if i!=j and s1.issubset(s2) and s1 != s2: - is_subset = True - break - if not is_subset: - largest_int_lists.append(list(s1)) - - # Initialize a set to keep track of seen tuples - seen_tuples = set() - - # Initialize a list to store unique integer lists - unique_int_lists = [] - - # Iterate through the input list - for int_list in largest_int_lists: - # Convert the list to a tuple for hashing - int_tuple = tuple(sorted(int_list)) - - # Check if the tuple is not in the set of seen tuples - if int_tuple not in seen_tuples: - # Add the tuple to the set of seen tuples - seen_tuples.add(int_tuple) - - # Add the original list to the list of unique integer lists - unique_int_lists.append(int_list) - - return unique_int_lists - - -def detect_shared_vertex(edgeV_cad, edge_mask_cad, edgeV_bbox): - """ - Find the shared vertices - """ - edge_id_offset = 2 * np.concatenate([np.array([0]),np.cumsum((edge_mask_cad==False).sum(1))])[:-1] - valid = True - - # Detect shared-vertex on seperate face loop - used_vertex = [] - face_sep_merges = [] - for face_idx, (face_edges, face_edges_mask, bbox_edges) in enumerate(zip(edgeV_cad, edge_mask_cad, edgeV_bbox)): - face_edges = face_edges[~face_edges_mask] - face_edges = face_edges.reshape(len(face_edges),2,3) - face_start_id = edge_id_offset[face_idx] - - # connect end points by closest distance (edge bbox) - merged_vertex_id = edge2loop(bbox_edges) - if len(merged_vertex_id) == len(face_edges): - merged_vertex_id = face_start_id + merged_vertex_id - face_sep_merges.append(merged_vertex_id) - used_vertex.append(bbox_edges*3) - print('[PASS]') - continue - - # connect end points by closest distance (vertex pos) - merged_vertex_id = edge2loop(face_edges) - if len(merged_vertex_id) == len(face_edges): - merged_vertex_id = face_start_id + merged_vertex_id - face_sep_merges.append(merged_vertex_id) - used_vertex.append(face_edges) - print('[PASS]') - continue - - print('[FAILED]') - valid = False - break - - # Invalid - if not valid: - assert False +# 导入OpenCASCADE相关库 +from OCC.Core.STEPControl import STEPControl_Reader # STEP文件读取器 +from OCC.Core.TopExp import TopExp_Explorer, topexp # 拓扑结构遍历 +from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX # 拓扑类型定义 +from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射 +from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算 +from OCC.Core.Bnd import Bnd_Box # 包围盒 +from OCC.Core.TopoDS import TopoDS_Shape, topods, TopoDS_Vertex # 拓扑数据结构 + +import numpy as np +from scipy.spatial import cKDTree +import trimesh - # Detect shared-vertex across faces - total_pnts = np.vstack(used_vertex) - total_pnts = total_pnts.reshape(len(total_pnts),2,3) - total_pnts_flatten = total_pnts.reshape(-1,3) - - total_ids = [] - for face_idx, face_merge in enumerate(face_sep_merges): - # non-self merge centers - nonself_face_idx = list(set(np.arange(len(face_sep_merges))) - set([face_idx])) - nonself_face_merges = [face_sep_merges[x] for x in nonself_face_idx] - nonself_face_merges = np.vstack(nonself_face_merges) - nonself_merged_centers = total_pnts_flatten[nonself_face_merges].mean(1) - - # connect end points by closest distance - across_merge_id = [] - for merge_id in face_merge: - merged_center = total_pnts_flatten[merge_id].mean(0) - distance = np.linalg.norm(nonself_merged_centers - merged_center, axis=1) - nonself_match_id = nonself_face_merges[np.argsort(distance)[0]] - joint_merge_id = list(nonself_match_id) + list(merge_id) - across_merge_id.append(joint_merge_id) - total_ids += across_merge_id - - # Merge T-junctions - while (True): - no_merge = True - final_merge_id = [] +from brep2sdf.utils.logger import logger - # iteratelly merge until no changes happen - for i in range(len(total_ids)): - perform_merge = False +def load_step(step_path): + """Load STEP file and return solids""" + reader = STEPControl_Reader() + reader.ReadFile(step_path) + reader.TransferRoots() + return [reader.OneShape()] + +def get_bbox(shape, subshape): + """ + 计算形状的包围盒 + + 参数: + shape: 完整的CAD模型形状 + subshape: 需要计算包围盒的子形状(面或边) + + 返回: + 包围盒的六个参数 [xmin, ymin, zmin, xmax, ymax, zmax] + """ + bbox = Bnd_Box() + brepbndlib.Add(subshape, bbox) + xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get() + return np.array([xmin, ymin, zmin, xmax, ymax, zmax]) + + + + +def normalize(surfs, edges, corners): + """ + 将CAD模型归一化到单位立方体空间 + + 参数: + surfs: 面的点集列表 + edges: 边的点集列表 + corners: 顶点坐标数组 [num_edges, 2, 3] + + 返回: + surfs_wcs: 原始坐标系下的面点集 + edges_wcs: 原始坐标系下的边点集 + surfs_ncs: 归一化坐标系下的面点集 + edges_ncs: 归一化坐标系下的边点集 + corner_wcs: 归一化后的顶点坐标 [num_edges, 2, 3] + center: 使用的中心点坐标 [3,] + scale: 使用的缩放系数 (float) + """ + if len(corners) == 0: + return None, None, None, None, None, None, None + + # 计算包围盒和缩放因子 + corners_array = corners.reshape(-1, 3) # [num_edges*2, 3] + center = (corners_array.max(0) + corners_array.min(0)) / 2 # 计算中心点 + scale = 1.0 / (corners_array.max(0) - corners_array.min(0)).max() # 计算缩放系数 + + # 归一化面的点集 + surfs_wcs = [] # 原始世界坐标系下的面点集 + surfs_ncs = [] # 归一化坐标系下的面点集 + for surf in surfs: + surf_wcs = np.array(surf) + surf_ncs = (surf_wcs - center) * scale # 归一化变换 + surfs_wcs.append(surf_wcs) + surfs_ncs.append(surf_ncs) + + # 归一化边的点集 + edges_wcs = [] # 原始世界坐标系下的边点集 + edges_ncs = [] # 归一化坐标系下的边点集 + for edge in edges: + edge_wcs = np.array(edge) + edge_ncs = (edge_wcs - center) * scale # 归一化变换 + edges_wcs.append(edge_wcs) + edges_ncs.append(edge_ncs) + + # 归一化顶点坐标 - 保持[num_edges, 2, 3]的形状 + corner_wcs = (corners - center) * scale # 广播操作会保持原有维度 + + return (np.array(surfs_wcs, dtype=object), + np.array(edges_wcs, dtype=object), + np.array(surfs_ncs, dtype=object), + np.array(edges_ncs, dtype=object), + corner_wcs.astype(np.float32), + center.astype(np.float32), + scale + ) - for j in range(i+1,len(total_ids)): - # check if vertex can be further merged - max_num = max(len(total_ids[i]), len(total_ids[j])) - union = set(total_ids[i]).union(set(total_ids[j])) - common = set(total_ids[i]).intersection(set(total_ids[j])) - if len(union) > max_num and len(common)>0: - final_merge_id.append(list(union)) - perform_merge = True - no_merge = False +def get_adjacency_info(shape, faces, edges, vertices): + """ + 优化后的邻接关系计算函数,直接使用已收集的几何元素 + + 参数新增: + faces: 已收集的面列表 + edges: 已收集的边列表 + vertices: 已收集的顶点列表 + """ + logger.debug("Get adjacency infos...") + # 创建边-面映射关系 + edge_face_map = TopTools_IndexedDataMapOfShapeListOfShape() + topexp.MapShapesAndAncestors(shape, TopAbs_EDGE, TopAbs_FACE, edge_face_map) + + # 直接使用传入的几何元素列表 + num_faces = len(faces) + num_edges = len(edges) + num_vertices = len(vertices) + logger.debug(f"num_faces: {num_faces}, num_edges: {num_edges}, num_vertices: {num_vertices}") + + edgeFace_adj = np.zeros((num_edges, num_faces), dtype=np.int32) + faceEdge_adj = np.zeros((num_faces, num_edges), dtype=np.int32) + edgeCorner_adj = np.zeros((num_edges, 2), dtype=np.int32) + + # 填充边-面邻接矩阵 + for i, edge in enumerate(edges): + # 检查每个面是否与当前边相连 + for j, face in enumerate(faces): + edge_explorer = TopExp_Explorer(face, TopAbs_EDGE) + while edge_explorer.More(): + if edge.IsSame(edge_explorer.Current()): + edgeFace_adj[i, j] = 1 + faceEdge_adj[j, i] = 1 break - - if not perform_merge: - final_merge_id.append(total_ids[i]) # no-merge - - total_ids = final_merge_id - if no_merge: break - - # remove subsets - total_ids = keep_largelist(total_ids) - - # merge again base on absolute coordinate value, required for >3 T-junction - tobe_merged_centers = [total_pnts_flatten[x].mean(0) for x in total_ids] - tobe_centers = np.array(tobe_merged_centers) - distances = np.linalg.norm(tobe_centers[:, np.newaxis, :] - tobe_centers, axis=2) - close_points = distances < 0.1 - mask = np.tril(np.ones_like(close_points, dtype=bool), k=-1) - non_diagonal_indices = np.where(close_points & mask) - row_indices, column_indices = non_diagonal_indices - - # update the total_ids - total_ids_updated = [] - for row, col in zip(row_indices, column_indices): - total_ids_updated.append(total_ids[row] + total_ids[col]) - for index, ids in enumerate(total_ids): - if index not in list(row_indices) and index not in list(column_indices): - total_ids_updated.append(ids) - total_ids = total_ids_updated - - # merged vertices - unique_vertices = [] - for center_id in total_ids: - center_pnts = total_pnts_flatten[center_id].mean(0) / 3.0 - unique_vertices.append(center_pnts) - unique_vertices = np.vstack(unique_vertices) - - new_vertex_dict = {} - for new_id, old_ids in enumerate(total_ids): - new_vertex_dict[new_id] = old_ids - - return [unique_vertices, new_vertex_dict] - - -def detect_shared_edge(unique_vertices, new_vertex_dict, edge_z_cad, surf_z_cad, z_threshold, edge_mask_cad): - """ - Find the shared edges - """ - init_edges = edge_z_cad - - # re-assign edge start/end to unique vertices - new_ids = [] - for old_id in np.arange(2*len(init_edges)): - new_id = [] - for key, value in new_vertex_dict.items(): - # Check if the desired number is in the associated list - if old_id in value: - new_id.append(key) - assert len(new_id) == 1 # should only return one unique value - new_ids.append(new_id[0]) - - EdgeVertexAdj = np.array(new_ids).reshape(-1,2) - - # find edges assigned to the same start/end - similar_edges = [] - for i, s1 in enumerate(EdgeVertexAdj): - for j, s2 in enumerate(EdgeVertexAdj): - if i!=j and set(s1) == set(s2): # same start/end - z1 = init_edges[i] - z2 = init_edges[j] - z_diff = np.abs(z1-z2).mean() - if z_diff < z_threshold: # check z difference - similar_edges.append(sorted([i,j])) - # else: - # print('z latent beyond...') - similar_edges = np.unique(np.array(similar_edges),axis=0) - - # should reduce total edges by two - if not 2*len(similar_edges) == len(EdgeVertexAdj): - assert False, 'edge not reduced by 2' - - # unique edges - unique_edge_id = similar_edges[:,0] - EdgeVertexAdj = EdgeVertexAdj[unique_edge_id] - unique_edges = init_edges[unique_edge_id] - - # unique faces - unique_faces = surf_z_cad - FaceEdgeAdj = [] - ranges = np.concatenate([np.array([0]),np.cumsum((edge_mask_cad==False).sum(1))]) - for index in range(len(ranges)-1): - adj_ids = np.arange(ranges[index], ranges[index+1]) - new_ids = [] - for id in adj_ids: - new_id = np.where(similar_edges == id)[0] - assert len(new_id) == 1 - new_ids.append(new_id[0]) - FaceEdgeAdj.append(new_ids) - - print(f'Post-process: F-{len(unique_faces)} E-{len(unique_edges)} V-{len(unique_vertices)}') - - return [unique_faces, unique_edges, FaceEdgeAdj, EdgeVertexAdj] - - -class STModel(nn.Module): - def __init__(self, num_edge, num_surf): - super().__init__() - self.edge_t = nn.Parameter(torch.zeros((num_edge, 3))) - self.surf_st = nn.Parameter(torch.FloatTensor([1,0,0,0]).unsqueeze(0).repeat(num_surf,1)) - - -def get_bbox_minmax(point_cloud): - # Find the minimum and maximum coordinates along each axis - min_x = np.min(point_cloud[:, 0]) - max_x = np.max(point_cloud[:, 0]) - - min_y = np.min(point_cloud[:, 1]) - max_y = np.max(point_cloud[:, 1]) - - min_z = np.min(point_cloud[:, 2]) - max_z = np.max(point_cloud[:, 2]) - - # Create the 3D bounding box using the min and max values - min_point = np.array([min_x, min_y, min_z]) - max_point = np.array([max_x, max_y, max_z]) - return (min_point, max_point) - - -def joint_optimize(surf_ncs, edge_ncs, surfPos, unique_vertices, EdgeVertexAdj, FaceEdgeAdj, num_edge, num_surf): - """ - Jointly optimize the face/edge/vertex based on topology - """ - loss_func = ChamferDistance() - - model = STModel(num_edge, num_surf) - model = model.cuda().train() - optimizer = torch.optim.AdamW( - model.parameters(), - lr=1e-3, - betas=(0.95, 0.999), - weight_decay=1e-6, - eps=1e-08, - ) - - # Optimize edges (directly compute) - edge_ncs_se = edge_ncs[:,[0,-1]] - edge_vertex_se = unique_vertices[EdgeVertexAdj] - - edge_wcs = [] - print('Joint Optimization...') - for wcs, ncs_se, vertex_se in zip(edge_ncs, edge_ncs_se, edge_vertex_se): - # scale - scale_target = np.linalg.norm(vertex_se[0] - vertex_se[1]) - scale_ncs = np.linalg.norm(ncs_se[0] - ncs_se[1]) - edge_scale = scale_target / scale_ncs - - edge_updated = wcs*edge_scale - edge_se = ncs_se*edge_scale - - # offset - offset = (vertex_se - edge_se) - offset_rev = (vertex_se - edge_se[::-1]) - - # swap start / end if necessary - offset_error = np.abs(offset[0] - offset[1]).mean() - offset_rev_error =np.abs(offset_rev[0] - offset_rev[1]).mean() - if offset_rev_error < offset_error: - edge_updated = edge_updated[::-1] - offset = offset_rev - - edge_updated = edge_updated + offset.mean(0)[np.newaxis,np.newaxis,:] - edge_wcs.append(edge_updated) - - edge_wcs = np.vstack(edge_wcs) - - # Replace start/end points with corner, and backprop change along curve - for index in range(len(edge_wcs)): - start_vec = edge_vertex_se[index,0] - edge_wcs[index, 0] - end_vec = edge_vertex_se[index,1] - edge_wcs[index, -1] - weight = np.tile((np.arange(32)/31)[:,np.newaxis], (1,3)) - weighted_vec = np.tile(start_vec[np.newaxis,:],(32,1))*(1-weight) + np.tile(end_vec,(32,1))*weight - edge_wcs[index] += weighted_vec - - # Optimize surfaces - face_edges = [] - for adj in FaceEdgeAdj: - all_pnts = edge_wcs[adj] - face_edges.append(torch.FloatTensor(all_pnts).cuda()) - - # Initialize surface in wcs based on surface pos - surf_wcs_init = [] - bbox_threshold_min = [] - bbox_threshold_max = [] - for edges_perface, ncs, bbox in zip(face_edges, surf_ncs, surfPos): - surf_center, surf_scale = compute_bbox_center_and_size(bbox[0:3], bbox[3:]) - edges_perface_flat = edges_perface.reshape(-1, 3).detach().cpu().numpy() - min_point, max_point = get_bbox_minmax(edges_perface_flat) - edge_center, edge_scale = compute_bbox_center_and_size(min_point, max_point) - bbox_threshold_min.append(min_point) - bbox_threshold_max.append(max_point) - - # increase surface size if does not fully cover the wire bbox - if surf_scale < edge_scale: - surf_scale = 1.05*edge_scale - - wcs = ncs * (surf_scale/2) + surf_center - surf_wcs_init.append(wcs) - - surf_wcs_init = np.stack(surf_wcs_init) - - # optimize the surface offset - surf = torch.FloatTensor(surf_wcs_init).cuda() - for iters in range(200): - surf_scale = model.surf_st[:,0].reshape(-1,1,1,1) - surf_offset = model.surf_st[:,1:].reshape(-1,1,1,3) - surf_updated = surf + surf_offset + edge_explorer.Next() - surf_loss = 0 - for surf_pnt, edge_pnts in zip(surf_updated, face_edges): - surf_pnt = surf_pnt.reshape(-1,3) - edge_pnts = edge_pnts.reshape(-1,3).detach() - surf_loss += loss_func(surf_pnt.unsqueeze(0), edge_pnts.unsqueeze(0), bidirectional=False, reverse=True) - surf_loss /= len(surf_updated) - - optimizer.zero_grad() - (surf_loss).backward() - optimizer.step() - - # print(f'Iter {iters} surf:{surf_loss:.5f}') - - surf_wcs = surf_updated.detach().cpu().numpy() - - return (surf_wcs, edge_wcs) - - -def add_pcurves_to_edges(face): - edge_fixer = ShapeFix_Edge() - top_exp = TopologyExplorer(face) - for wire in top_exp.wires(): - wire_exp = WireExplorer(wire) - for edge in wire_exp.ordered_edges(): - edge_fixer.FixAddPCurve(edge, face, False, 0.001) - - -def fix_wires(face, debug=False): - top_exp = TopologyExplorer(face) - for wire in top_exp.wires(): - if debug: - wire_checker = ShapeAnalysis_Wire(wire, face, 0.01) - print(f"Check order 3d {wire_checker.CheckOrder()}") - print(f"Check 3d gaps {wire_checker.CheckGaps3d()}") - print(f"Check closed {wire_checker.CheckClosed()}") - print(f"Check connected {wire_checker.CheckConnected()}") - wire_fixer = ShapeFix_Wire(wire, face, 0.01) - - # wire_fixer.SetClosedWireMode(True) - # wire_fixer.SetFixConnectedMode(True) - # wire_fixer.SetFixSeamMode(True) - - assert wire_fixer.IsReady() - ok = wire_fixer.Perform() - # assert ok - - -def fix_face(face): - fixer = ShapeFix_Face(face) - fixer.SetPrecision(0.01) - fixer.SetMaxTolerance(0.1) - ok = fixer.Perform() - # assert ok - fixer.FixOrientation() - face = fixer.Face() - return face - - -def construct_brep(surf_wcs, edge_wcs, FaceEdgeAdj, EdgeVertexAdj): - """ - Fit parametric surfaces / curves and trim into B-rep - """ - print('Building the B-rep...') - # Fit surface bspline - recon_faces = [] - for points in surf_wcs: - num_u_points, num_v_points = 32, 32 - uv_points_array = TColgp_Array2OfPnt(1, num_u_points, 1, num_v_points) - for u_index in range(1,num_u_points+1): - for v_index in range(1,num_v_points+1): - pt = points[u_index-1, v_index-1] - point_3d = gp_Pnt(float(pt[0]), float(pt[1]), float(pt[2])) - uv_points_array.SetValue(u_index, v_index, point_3d) - approx_face = GeomAPI_PointsToBSplineSurface(uv_points_array, 3, 8, GeomAbs_C2, 5e-2).Surface() - recon_faces.append(approx_face) - - recon_edges = [] - for points in edge_wcs: - num_u_points = 32 - u_points_array = TColgp_Array1OfPnt(1, num_u_points) - for u_index in range(1,num_u_points+1): - pt = points[u_index-1] - point_2d = gp_Pnt(float(pt[0]), float(pt[1]), float(pt[2])) - u_points_array.SetValue(u_index, point_2d) - try: - approx_edge = GeomAPI_PointsToBSpline(u_points_array, 0, 8, GeomAbs_C2, 5e-3).Curve() - except Exception as e: - print('high precision failed, trying mid precision...') - try: - approx_edge = GeomAPI_PointsToBSpline(u_points_array, 0, 8, GeomAbs_C2, 8e-3).Curve() - except Exception as e: - print('mid precision failed, trying low precision...') - approx_edge = GeomAPI_PointsToBSpline(u_points_array, 0, 8, GeomAbs_C2, 5e-2).Curve() - recon_edges.append(approx_edge) - - # Create edges from the curve list - edge_list = [] - for curve in recon_edges: - edge = BRepBuilderAPI_MakeEdge(curve).Edge() - edge_list.append(edge) - - # Cut surface by wire - post_faces = [] - post_edges = [] - for idx,(surface, edge_incides) in enumerate(zip(recon_faces, FaceEdgeAdj)): - corner_indices = EdgeVertexAdj[edge_incides] + # 获取边的两个端点 + v1 = TopoDS_Vertex() + v2 = TopoDS_Vertex() + topexp.Vertices(edge, v1, v2) - # ordered loop - loops = [] - ordered = [0] - seen_corners = [corner_indices[0,0], corner_indices[0,1]] - next_index = corner_indices[0,1] - - while len(ordered) Tuple[torch.Tensor, ...]: +def batch_compute_normals(mesh, surf_wcs, normal_type='vertex', k_neighbors=3): """ - 处理B-rep数据的函数 - - Args: - data (dict): 包含B-rep数据的字典,结构如下: - { - 'surf_ncs': np.ndarray, # 面归一化点云 [num_faces, num_surf_sample_points, 3] - 'edge_ncs': np.ndarray, # 边归一化点云 [num_edges, num_edge_sample_points, 3] - 'corner_wcs': np.ndarray, # 顶点坐标 [num_edges, 2, 3] - 每条边的两个端点 - 'faceEdge_adj': np.ndarray, # 面-边邻接矩阵 [num_faces, num_edges] - 'surf_pos': np.ndarray, # 面位置(包围盒) [num_faces, 6] - 'edge_pos': np.ndarray, # 边位置(包围盒) [num_edges, 6] - } - max_face (int): 最大面数,用于填充, config.data.max_face - max_edge (int): 最大边数,用于填充, config.data.max_edge - bbox_scaled (float): 边界框缩放因子, config.data.bbox_scaled - aug (bool): 是否使用数据增强, config.data.aug - data_class (Optional[int]): 数据类别标签 + 为嵌套点云数据计算法向量,并保持嵌套格式。 - Returns: - Tuple[torch.Tensor, ...]: 包含以下张量的元组: - - edge_ncs: 边归一化特征 [max_face, max_edge, num_edge_sample_points, 3] - - edge_pos: 边位置 [max_face, max_edge, 6] - - edge_mask: 边掩码 [max_face, max_edge] - - surf_ncs: 面归一化特征 [max_face, num_surf_sample_points, 3] - - surf_pos: 面位置 [max_face, 6] - - vertex_pos: 顶点位置 [max_face, max_edge, 2, 3] - 每个面的每条边的两个端点 - - data_class: (可选) 类别标签 [1] + 参数: + mesh: trimesh.Trimesh 对象,包含顶点和法向量信息 + surf_wcs: np.ndarray(dtype=object) 形状为 (N,) 的数组,每个元素是形状为 (M, 3) 的 float32 数组 + normal_type: str 法向量类型,可选 'vertex' 或 'face' + k_neighbors: int 用于平滑的最近邻数量 - 数据处理流程: - 1. 数据增强(可选): - - 对几何元素进行随机旋转 - - 重新计算包围盒 - 2. 特征复制: - - 根据面-边邻接关系复制边和顶点特征 - - 保持顶点对的结构 [2, 3] - 3. 特征打乱: - - 随机打乱每个面的边顺序 - - 随机打乱面的顺序 - 4. 填充处理: - - 填充到最大边数 - - 填充到最大面数 - 5. 转换为张量 + 返回: + normals: np.ndarray(dtype=object) 形状为 (N,) 的数组,每个元素是形状为 (M, 3) 的 float32 数组 """ - # 获取配置 - config = get_default_config() - num_surf_points = config.model.num_surf_points # 16 - num_edge_points = config.model.num_edge_points # 4 - - # 解包数据 - #_, _, surf_ncs, edge_ncs, corner_wcs, _, _, faceEdge_adj, surf_pos, edge_pos, _ = data.values() - # 直接获取需要的数据 - surf_ncs = data['surf_ncs'] # (num_faces,) -> 每个元素形状 (N, 3) - edge_ncs = data['edge_ncs'] # (num_edges, num_edge_points, 3) - corner_wcs = data['corner_wcs'] # (num_edges, 2, 3) - faceEdge_adj = data['faceEdge_adj'] # (num_faces, num_edges) - edgeCorner_adj = data['edgeCorner_adj'] # (num_edges, 2) 每条边连接2个顶点 - surf_pos = data['surf_bbox_wcs'] # (num_faces, 6) - edge_pos = data['edge_bbox_wcs'] # (num_edges, 6) + # 预处理网格数据 + tree, normals_source = preprocess_mesh(mesh, normal_type=normal_type) - # 数据增强 - random_num = np.random.rand() - if random_num > 0.5 and aug: - # 获取边界框八个角点 - surfpos_corners = bbox_corners(surf_pos) # [num_faces, 8, 3] - edgepos_corners = bbox_corners(edge_pos) # [num_edges, 8, 3] - - # 随机旋转 - for axis in ['x', 'y', 'z']: - angle = random.choice([90, 180, 270]) - # 旋转所有几何元素,保持形状不变 - surfpos_corners = rotate_axis(surfpos_corners, angle, axis, normalized=True) - edgepos_corners = rotate_axis(edgepos_corners, angle, axis, normalized=True) - corner_wcs = rotate_axis(corner_wcs, angle, axis, normalized=True) # 直接旋转,保持形状 - - - # 重新计算边界框 - surf_pos = get_bbox(surfpos_corners) # [num_faces, 2, 3] - surf_pos = surf_pos.reshape(len(surf_pos), 6) # [num_faces, 6] - edge_pos = get_bbox(edgepos_corners) # [num_edges, 2, 3] - edge_pos = edge_pos.reshape(len(edge_pos), 6) # [num_edges, 6] + # 展平所有点云为一个二维数组 [P, 3],并记录分割索引 + lengths = [len(point_cloud) for point_cloud in surf_wcs] + query_points = np.concatenate(surf_wcs, axis=0).astype(np.float32) # 避免多次内存分配 - # 缩放值范围 - surf_pos = surf_pos * bbox_scaled # [num_faces, 6] - edge_pos = edge_pos * bbox_scaled # [num_edges, 6] - corner_wcs = corner_wcs * bbox_scaled # [num_edges, 2, 3] - - # 特征复制 - edge_pos_duplicated = [] # [num_edges_per_face, 6] - vertex_pos_duplicated = [] # [num_edges_per_face, 2, 3] - edge_ncs_duplicated = [] # [num_edges_per_face, num_edge_points, 3] + # 批量查询最近邻 + distances, indices = tree.query(query_points, k=k_neighbors) - for adj in faceEdge_adj: # [num_faces, num_edges] - edge_indices = np.where(adj)[0] # 获取当前面的边索引 - - # 复制边的特征 - edge_ncs_duplicated.append(edge_ncs[edge_indices]) # [num_edges_per_face, num_edge_points, 3] - edge_pos_duplicated.append(edge_pos[edge_indices]) # [num_edges_per_face, 6] - - # 直接获取对应边的顶点对 - vertex_pairs = corner_wcs[edge_indices] # [num_edges_per_face, 2, 3] - vertex_pos_duplicated.append(vertex_pairs) - - # 边特征打乱和填充 - edge_pos_new = [] # 最终形状: [num_faces, max_edge, 6] - edge_ncs_new = [] # 最终形状: [num_faces, max_edge, num_edge_points, 3] - vert_pos_new = [] # 最终形状: [num_faces, max_edge, 6] - edge_mask = [] # 最终形状: [num_faces, max_edge] - - for pos, ncs, vert in zip(edge_pos_duplicated, edge_ncs_duplicated, vertex_pos_duplicated): - # 生成随机排列 - num_edges = pos.shape[0] - random_indices = np.random.permutation(num_edges) - - # 同时打乱所有特征 - pos = pos[random_indices] # [num_edges_per_face, 6] - ncs = ncs[random_indices] # [num_edges_per_face, num_edge_points, 3] - vert = vert[random_indices] # [num_edges_per_face, 2, 3] - - # 填充到最大边数 - pos, mask = pad_zero(pos, max_edge, return_mask=True) # [max_edge, 6], [max_edge] - ncs = pad_zero(ncs, max_edge) # [max_edge, num_edge_points, 3] - vert = pad_zero(vert, max_edge) # [max_edge, 2, 3] - - edge_pos_new.append(pos) - edge_ncs_new.append(ncs) - edge_mask.append(mask) - vert_pos_new.append(vert) - - edge_pos = np.stack(edge_pos_new) # [num_faces, max_edge, 6] - edge_ncs = np.stack(edge_ncs_new) # [num_faces, max_edge, num_edge_points, 3] - edge_mask = np.stack(edge_mask) # [num_faces, max_edge] - vertex_pos = np.stack(vert_pos_new) # [num_faces, max_edge, 2, 3] - - # 处理edge_ncs (确保形状为 [num_faces, max_edge, num_edge_points, 3]) - edge_ncs_list = [] - for face_idx in range(len(edge_ncs)): - face_edges = [] - for edge_idx in range(len(edge_ncs[face_idx])): - edge_points = edge_ncs[face_idx][edge_idx] - # 确保每条边有num_edge_points个点 - if len(edge_points) > num_edge_points: - indices = np.random.choice(len(edge_points), num_edge_points, replace=False) - edge_points = edge_points[indices] - elif len(edge_points) < num_edge_points: - indices = np.random.choice(len(edge_points), num_edge_points-len(edge_points)) - edge_points = np.concatenate([edge_points, edge_points[indices]], axis=0) - face_edges.append(edge_points) - - # 填充到最大边数 - while len(face_edges) < max_edge: - face_edges.append(np.zeros((num_edge_points, 3), dtype=np.float32)) - - edge_ncs_list.append(np.stack(face_edges)) - - edge_ncs = np.stack(edge_ncs_list).astype(np.float32) # [num_faces, max_edge, num_edge_points, 3] - - # 处理surf_ncs (确保形状为 [num_faces, num_surf_points, 3]) - surf_ncs_list = [] - for points in surf_ncs: - if len(points) > num_surf_points: - indices = np.random.choice(len(points), num_surf_points, replace=False) - points = points[indices] - elif len(points) < num_surf_points: - indices = np.random.choice(len(points), num_surf_points-len(points)) - points = np.concatenate([points, points[indices]], axis=0) - surf_ncs_list.append(points) - - surf_ncs = np.stack(surf_ncs_list).astype(np.float32) # [num_faces, num_surf_points, 3] - - # 面特征打乱 - random_indices = np.random.permutation(surf_pos.shape[0]) - surf_pos = surf_pos[random_indices] # [num_faces, 6] - edge_pos = edge_pos[random_indices] # [num_faces, max_edge, 6] - surf_ncs = surf_ncs[random_indices] # [num_faces, num_surf_points, 3] - edge_ncs = edge_ncs[random_indices] # [num_faces, max_edge, num_edge_points, 3] - edge_mask = edge_mask[random_indices] # [num_faces, max_edge] - vertex_pos = vertex_pos[random_indices] # [num_faces, max_edge, 2, 3] - - # 填充到最大面数 - surf_pos, face_mask = pad_zero(surf_pos, max_face, return_mask=True) # [max_face, 6] - surf_ncs = pad_zero(surf_ncs, max_face) # [max_face, num_surf_points, 3] - edge_pos = pad_zero(edge_pos, max_face) # [max_face, max_edge, 6] - edge_ncs = pad_zero(edge_ncs, max_face) # [max_face, max_edge, num_edge_points, 3] - vertex_pos = pad_zero(vertex_pos, max_face) # [max_face, max_edge, 2, 3] - - # 扩展边掩码 - 使用face_mask来创建新的edge_mask - if len(edge_mask) > max_face: - edge_mask = edge_mask[:max_face] - else: - # 创建填充掩码 - padding = np.zeros((max_face-len(edge_mask), max_edge), dtype=bool) - edge_mask = np.concatenate([edge_mask, padding], axis=0) # [max_face, max_edge] - - # 转换为张量并返回 - if data_class is not None: - return ( - torch.FloatTensor(edge_ncs), # [max_face, max_edge, num_edge_points, 3] - torch.FloatTensor(edge_pos), # [max_face, max_edge, 6] - torch.BoolTensor(edge_mask), # [max_face, max_edge] - torch.FloatTensor(surf_ncs), # [max_face, num_surf_points, 3] - torch.FloatTensor(surf_pos), # [max_face, 6] - torch.FloatTensor(vertex_pos), # [max_face, max_edge, 2, 3] - torch.LongTensor([data_class+1]) # [1] - ) + # 处理k=1的特殊情况 + if k_neighbors == 1: + nearest_normals = normals_source[indices] else: - return ( - torch.FloatTensor(edge_ncs), # [max_face, max_edge, num_edge_points, 3] - torch.FloatTensor(edge_pos), # [max_face, max_edge, 6] - torch.BoolTensor(edge_mask), # [max_face, max_edge] - torch.FloatTensor(surf_ncs), # [max_face, num_surf_points, 3] - torch.FloatTensor(surf_pos), # [max_face, 6] - torch.FloatTensor(vertex_pos) # [max_face, max_edge, 2, 3] - ) \ No newline at end of file + # 加权平均(权重为距离倒数) + inv_distances = 1 / (distances + 1e-8) # 防止除以零 + sum_inv_distances = inv_distances.sum(axis=1, keepdims=True) + valid_mask = sum_inv_distances > 1e-6 + weights = np.divide(inv_distances, sum_inv_distances, out=np.zeros_like(inv_distances), where=valid_mask) + nearest_normals = np.einsum('ijk,ij->ik', normals_source[indices], weights) + + # 标准化结果 + norms = np.linalg.norm(nearest_normals, axis=1) + valid_mask = norms > 1e-6 + nearest_normals[valid_mask] /= norms[valid_mask, None] + + # 按原始嵌套结构分割法向量 + start = 0 + normals_output = np.empty(len(surf_wcs), dtype=object) + for i, length in enumerate(lengths): + end = start + length + normals_output[i] = nearest_normals[start:end] + start = end + + return normals_output \ No newline at end of file diff --git a/brep2sdf/networks/octree.py b/brep2sdf/networks/octree.py index 28ac6d7..b987db0 100644 --- a/brep2sdf/networks/octree.py +++ b/brep2sdf/networks/octree.py @@ -6,6 +6,10 @@ import torch.nn.functional as F import numpy as np from brep2sdf.utils.logger import logger +from brep2sdf.networks.patch_graph import PatchGraph + + + def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor: """判断两个轴对齐包围盒(AABB)是否相交 @@ -27,7 +31,7 @@ def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor: return torch.all((max1 >= min2) & (max2 >= min1)) class OctreeNode(nn.Module): - def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None): + def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None): super().__init__() # 静态张量存储节点信息 self.register_buffer('bbox', bbox) # 当前节点的边界框 @@ -38,6 +42,9 @@ class OctreeNode(nn.Module): self.register_buffer('face_indices', torch.from_numpy(face_indices).to(bbox.device)) # 面片索引张量 self.register_buffer('surf_bbox', surf_bbox) # 面片边界框 + # PatchGraph作为普通属性 + self.patch_graph = patch_graph # 不再使用register_buffer + self.max_depth = max_depth # 将param_key改为张量 self.register_buffer('param_key', torch.tensor(-1, dtype=torch.long)) @@ -51,7 +58,7 @@ class OctreeNode(nn.Module): k: 参数索引值 """ self.param_key.fill_(k) - + @torch.jit.export def build_static_tree(self) -> None: """构建静态八叉树结构""" @@ -72,7 +79,8 @@ class OctreeNode(nn.Module): node_idx, bbox, faces = queue.pop(0) self.node_bboxes[node_idx] = bbox - if faces.shape[0] <= 2 or current_idx >= self.max_depth: + # 判断 要不要继续分裂 + if not self._should_split_node(current_idx): self.is_leaf_mask[node_idx] = True continue @@ -103,7 +111,20 @@ class OctreeNode(nn.Module): # 将子节点加入队列 if intersecting_faces: queue.append((child_idx, child_bbox, torch.tensor(intersecting_faces, device=self.bbox.device))) - + + def _should_split_node(self, current_depth: int) -> bool: + """判断节点是否需要分裂""" + # 检查是否达到最大深度 + if current_depth >= self.max_depth: + return False + + # 检查是否为完全图 + is_clique = self.patch_graph.is_clique(self.face_indices) + if is_clique: + return False + + return True + @torch.jit.export def _generate_child_bboxes(self, min_coords: torch.Tensor, mid_coords: torch.Tensor, max_coords: torch.Tensor) -> torch.Tensor: """生成8个子节点的边界框""" @@ -208,6 +229,7 @@ class OctreeNode(nn.Module): 'is_leaf_mask': self.is_leaf_mask, 'face_indices': self.face_indices, 'surf_bbox': self.surf_bbox, + 'patch_graph': self.patch_graph, 'max_depth': self.max_depth, 'param_key': self.param_key, '_is_leaf': self._is_leaf @@ -223,6 +245,7 @@ class OctreeNode(nn.Module): self.is_leaf_mask = state['is_leaf_mask'] self.face_indices = state['face_indices'] self.surf_bbox = state['surf_bbox'] + self.patch_graph = state['patch_graph'] self.max_depth = state['max_depth'] self.param_key = state['param_key'] self._is_leaf = state['_is_leaf'] \ No newline at end of file diff --git a/brep2sdf/networks/patch_graph.py b/brep2sdf/networks/patch_graph.py new file mode 100644 index 0000000..3e78c2b --- /dev/null +++ b/brep2sdf/networks/patch_graph.py @@ -0,0 +1,182 @@ +from typing import Tuple, Optional +import torch +import torch.nn as nn +import numpy as np +from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE +from OCC.Core.TopExp import TopExp_Explorer +from OCC.Core.TopoDS import TopoDS_Edge, TopoDS_Face, topods_Edge, topods_Face +from OCC.Core.BRep import BRep_Tool +from OCC.Core.GeomLProp import GeomLProp_SLProps +from OCC.Core.BRepAdaptor import BRepAdaptor_Surface + +class PatchGraph(nn.Module): + def __init__(self, num_patches: int, device: torch.device = None): + super().__init__() + self.num_patches = num_patches + self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # 注册缓冲区 + self.register_buffer('edge_index', None) # 边的连接关系 (2, E) + self.register_buffer('edge_type', None) # 边的类型 (E,) 0:凹边 1:凸边 + self.register_buffer('patch_features', None) # 面片特征 (N, F) + + def set_edges(self, edge_index: torch.Tensor, edge_type: torch.Tensor) -> None: + """设置边的信息 + + 参数: + edge_index: 形状为 (2, E) 的张量,表示边的连接关系 + edge_type: 形状为 (E,) 的张量,0表示凹边,1表示凸边 + """ + if edge_index.shape[0] != 2: + raise ValueError(f"edge_index 必须是形状为 (2, E) 的张量,但得到 {edge_index.shape}") + if edge_index.shape[1] != edge_type.shape[0]: + raise ValueError("edge_index 和 edge_type 的边数量不匹配") + + self.edge_index = edge_index.to(self.device) + self.edge_type = edge_type.to(self.device) + + def get_subgraph(self, node_faces: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """获取子图的边和类型""" + if self.edge_index is None: + return None, None + + node_faces = node_faces.to(self.device) + mask = torch.isin(self.edge_index[0], node_faces) & torch.isin(self.edge_index[1], node_faces) + subgraph_edges = self.edge_index[:, mask] + subgraph_types = self.edge_type[mask] + + return subgraph_edges, subgraph_types + + @staticmethod + def from_preprocessed_data(surf_wcs: np.ndarray, edgeFace_adj: np.ndarray, edge_types: np.ndarray, device: torch.device = None) -> 'PatchGraph': + num_faces = len(surf_wcs) + graph = PatchGraph(num_faces, device) + + edge_pairs = [] + edge_types_list = [] + + for edge_idx in range(len(edgeFace_adj)): + connected_faces = np.where(edgeFace_adj[edge_idx] == 1)[0] + if len(connected_faces) == 2: + face1, face2 = connected_faces + edge_pairs.extend([[face1, face2], [face2, face1]]) + edge_type = edge_types[edge_idx] + edge_types_list.extend([edge_type, edge_type]) + + if edge_pairs: + edge_index = torch.tensor(edge_pairs, dtype=torch.long, device=graph.device).t() + edge_type = torch.tensor(edge_types_list, dtype=torch.long, device=graph.device) + graph.set_edges(edge_index, edge_type) + + return graph + + def set_features(self, features: torch.Tensor) -> None: + """设置面片特征 + + 参数: + features: 形状为 (N, F) 的张量,表示面片的特征向量 + """ + if features.shape[0] != self.num_patches: + raise ValueError(f"特征数量 {features.shape[0]} 与面片数量 {self.num_patches} 不匹配") + self.patch_features = features + + def is_clique(self, node_faces: torch.Tensor) -> bool: + """检查给定面片集合是否构成完全图 + + 参数: + node_faces: 要检查的面片索引集合 + + 返回: + bool: 是否为完全图 + """ + if self.edge_index is None: + return False + + # 获取子图的边 + mask = torch.isin(self.edge_index[0], node_faces) & torch.isin(self.edge_index[1], node_faces) + subgraph_edges = self.edge_index[:, mask] + + # 计算完全图应有的边数 + n = len(node_faces) + expected_edges = n * (n - 1) // 2 + + # 计算实际的边数(考虑无向图) + actual_edges = len(subgraph_edges[0]) // 2 + + return actual_edges == expected_edges + + def combine_sdf(self, sdf_values: torch.Tensor) -> torch.Tensor: + """根据邻接关系组合SDF值 + + 参数: + sdf_values: 形状为 (N,) 的张量,表示每个面片的SDF值 + + 返回: + torch.Tensor: 组合后的SDF值 + """ + if self.edge_index is None or self.edge_type is None: + raise RuntimeError("请先设置边的信息") + + # 获取所有相连面片对的SDF值 + sdf_i = sdf_values[self.edge_index[0]] # (E,) + sdf_j = sdf_values[self.edge_index[1]] # (E,) + + # 根据边的类型选择组合方式 + concave_mask = self.edge_type == 0 + convex_mask = self.edge_type == 1 + + # 初始化结果为第一个SDF值 + result = sdf_values[0].clone() + + # 凹边取最大值,凸边取最小值 + if torch.any(concave_mask): + result = torch.max(result, torch.max(torch.stack([sdf_i[concave_mask], + sdf_j[concave_mask]]))) + if torch.any(convex_mask): + result = torch.min(result, torch.min(torch.stack([sdf_i[convex_mask], + sdf_j[convex_mask]]))) + + return result + + @staticmethod + def from_preprocessed_data( + surf_wcs: np.ndarray, # 形状为(N,)的对象数组,每个元素是形状为(M, 3)的float32数组 + edgeFace_adj: np.ndarray, # 形状为(num_edges, num_faces)的int32数组 + edge_types: np.ndarray # 形状为(num_edges,)的int32数组 + ) -> 'PatchGraph': + """从预处理的数据直接构建面片邻接图 + + 参数: + surf_wcs: 世界坐标系下的曲面几何数据,形状为(N,)的对象数组,每个元素是形状为(M, 3)的float32数组 + edgeFace_adj: 边-面邻接矩阵,形状为(num_edges, num_faces)的int32数组,1表示边与面相邻 + edge_types: 边的类型数组,形状为(num_edges,)的int32数组,0表示凹边,1表示凸边 + + 返回: + PatchGraph: 初始化好的面片邻接图,包含: + - edge_index: 形状为(2, num_edges*2)的torch.long张量,表示双向边的连接关系 + - edge_type: 形状为(num_edges*2,)的torch.long张量,表示每条边的类型 + """ + num_faces = len(surf_wcs) + graph = PatchGraph(num_faces) + + # 构建边的索引和类型 + edge_pairs = [] + edge_types_list = [] + + # 遍历边-面邻接矩阵 + for edge_idx in range(len(edgeFace_adj)): + connected_faces = np.where(edgeFace_adj[edge_idx] == 1)[0] + if len(connected_faces) == 2: + face1, face2 = connected_faces + # 添加双向边 + edge_pairs.extend([[face1, face2], [face2, face1]]) + # 使用预计算的边类型 + edge_type = edge_types[edge_idx] + edge_types_list.extend([edge_type, edge_type]) # 双向边使用相同的类型 + + if edge_pairs: # 确保有边存在 + edge_index = torch.tensor(edge_pairs, dtype=torch.long).t() + edge_type = torch.tensor(edge_types_list, dtype=torch.long) + graph.set_edges(edge_index, edge_type) + + return graph diff --git a/brep2sdf/test.py b/brep2sdf/test.py index 00f7695..498d4b8 100644 --- a/brep2sdf/test.py +++ b/brep2sdf/test.py @@ -1,161 +1,4 @@ -import os import torch -import numpy as np -from torch.utils.data import DataLoader -from brep2sdf.data.data import BRepSDFDataset -from brep2sdf.networks.network import BRepToSDF -from brep2sdf.utils.logger import logger -from brep2sdf.config.default_config import get_default_config -import matplotlib.pyplot as plt -from tqdm import tqdm -class Tester: - def __init__(self, config, checkpoint_path): - self.config = config - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - # 初始化测试数据集 - self.test_dataset = BRepSDFDataset( - brep_dir=config.data.brep_dir, - sdf_dir=config.data.sdf_dir, - valid_data_dir=config.data.valid_data_dir, - split='test' - ) - - # 初始化数据加载器 - self.test_loader = DataLoader( - self.test_dataset, - batch_size=1, # 测试时使用batch_size=1 - shuffle=False, - num_workers=config.train.num_workers, - pin_memory=False - ) - - # 加载模型 - self.model = BRepToSDF(config).to(self.device) - self.load_checkpoint(checkpoint_path) - - # 创建结果保存目录 - self.result_dir = os.path.join(config.data.result_save_dir, 'test_results') - os.makedirs(self.result_dir, exist_ok=True) - - def load_checkpoint(self, checkpoint_path): - """加载检查点""" - if not os.path.exists(checkpoint_path): - raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}") - - checkpoint = torch.load(checkpoint_path, map_location=self.device) - self.model.load_state_dict(checkpoint['model_state_dict']) - logger.info(f"Loaded checkpoint from {checkpoint_path}") - - def compute_metrics(self, pred_sdf, gt_sdf): - """计算评估指标""" - mse = torch.mean((pred_sdf - gt_sdf) ** 2).item() - mae = torch.mean(torch.abs(pred_sdf - gt_sdf)).item() - max_error = torch.max(torch.abs(pred_sdf - gt_sdf)).item() - - return { - 'mse': mse, - 'mae': mae, - 'max_error': max_error - } - - def visualize_results(self, pred_sdf, gt_sdf, points, save_path): - """可视化预测结果""" - fig = plt.figure(figsize=(15, 5)) - - # 绘制预测SDF - ax1 = fig.add_subplot(131, projection='3d') - scatter = ax1.scatter(points[:, 0], points[:, 1], points[:, 2], - c=pred_sdf.squeeze().cpu(), cmap='coolwarm') - ax1.set_title('Predicted SDF') - plt.colorbar(scatter) - - # 绘制真实SDF - ax2 = fig.add_subplot(132, projection='3d') - scatter = ax2.scatter(points[:, 0], points[:, 1], points[:, 2], - c=gt_sdf.squeeze().cpu(), cmap='coolwarm') - ax2.set_title('Ground Truth SDF') - plt.colorbar(scatter) - - # 绘制误差图 - ax3 = fig.add_subplot(133, projection='3d') - error = torch.abs(pred_sdf - gt_sdf) - scatter = ax3.scatter(points[:, 0], points[:, 1], points[:, 2], - c=error.squeeze().cpu(), cmap='Reds') - ax3.set_title('Absolute Error') - plt.colorbar(scatter) - - plt.tight_layout() - plt.savefig(save_path) - plt.close() - - def test(self): - """执行测试""" - self.model.eval() - total_metrics = {'mse': 0, 'mae': 0, 'max_error': 0} - - logger.info("Starting testing...") - - with torch.no_grad(): - for idx, batch in enumerate(tqdm(self.test_loader)): - # 获取数据并移动到设备 - surf_ncs = batch['surf_ncs'].to(self.device) - edge_ncs = batch['edge_ncs'].to(self.device) - surf_pos = batch['surf_pos'].to(self.device) - edge_pos = batch['edge_pos'].to(self.device) - vertex_pos = batch['vertex_pos'].to(self.device) - edge_mask = batch['edge_mask'].to(self.device) - points = batch['points'].to(self.device) - gt_sdf = batch['sdf'].to(self.device) - - # 前向传播 - pred_sdf = self.model( - surf_ncs=surf_ncs, edge_ncs=edge_ncs, - surf_pos=surf_pos, edge_pos=edge_pos, - vertex_pos=vertex_pos, edge_mask=edge_mask, - query_points=points - ) - - # 计算指标 - metrics = self.compute_metrics(pred_sdf, gt_sdf) - for k, v in metrics.items(): - total_metrics[k] += v - - # 可视化结果 - if idx % self.config.test.vis_freq == 0: - save_path = os.path.join(self.result_dir, f'result_{idx}.png') - self.visualize_results(pred_sdf, gt_sdf, points[0].cpu(), save_path) - - # 计算平均指标 - num_samples = len(self.test_loader) - avg_metrics = {k: v / num_samples for k, v in total_metrics.items()} - - # 保存测试结果 - logger.info("Test Results:") - for k, v in avg_metrics.items(): - logger.info(f"{k}: {v:.6f}") - - # 保存指标到文件 - with open(os.path.join(self.result_dir, 'test_metrics.txt'), 'w') as f: - for k, v in avg_metrics.items(): - f.write(f"{k}: {v:.6f}\n") - - return avg_metrics - -def main(): - # 获取配置 - config = get_default_config() - - # 设置检查点路径 - checkpoint_path = os.path.join( - config.data.model_save_dir, - config.data.best_model_name.format(model_name=config.data.model_name) - ) - - # 初始化测试器并执行测试 - tester = Tester(config, checkpoint_path) - metrics = tester.test() - -if __name__ == '__main__': - main() \ No newline at end of file +model = torch.jit.load("/home/wch/brep2sdf/data/output_data/00000054.pt") +print(model) \ No newline at end of file diff --git a/brep2sdf/train.py b/brep2sdf/train.py index f835b17..8a2f18a 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -13,6 +13,7 @@ from brep2sdf.data.pre_process_by_mesh import process_single_step from brep2sdf.networks.network import Net from brep2sdf.networks.octree import OctreeNode from brep2sdf.networks.loss import LossManager +from brep2sdf.networks.patch_graph import PatchGraph from brep2sdf.utils.logger import logger @@ -31,10 +32,17 @@ parser.add_argument( help='只采样零表面点 SDF 训练' ) parser.add_argument( - '--force-reprocess', + '--force-reprocess','-f', action='store_true', # 默认为 False,如果用户指定该参数,则为 True help='强制重新进行数据预处理,忽略缓存或已有结果' ) +parser.add_argument( + '--resume-checkpoint-path', + type=str, + default=None, + help='从指定的checkpoint文件继续训练' +) + args = parser.parse_args() @@ -86,8 +94,13 @@ class Trainer: #logger.info( self.brep_data ) #self.sdf_data = load_sdf_file(sdf_path=self.config.data.sdf_path, num_query_points=self.config.data.num_query_points).to(self.device) + # 构建面片邻接图 + graph = PatchGraph.from_preprocessed_data( + surf_wcs=self.data['surf_wcs'], + edgeFace_adj=self.data['edgeFace_adj'], + edge_types=self.data['edge_types'] + ) # 初始化网络 - surf_bbox=torch.tensor( self.data['surf_bbox_ncs'], dtype=torch.float32, @@ -95,7 +108,7 @@ class Trainer: ) - self.build_tree(surf_bbox=surf_bbox, max_depth=4) + self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=4) self.model = Net( @@ -115,12 +128,13 @@ class Trainer: logger.info(f"初始化完成,正在处理模型 {self.model_name}") - def build_tree(self,surf_bbox, max_depth=6): + def build_tree(self,surf_bbox, graph, max_depth=6): num_faces = surf_bbox.shape[0] bbox = self._calculate_global_bbox(surf_bbox) self.root = OctreeNode( bbox=bbox, face_indices=np.arange(num_faces), # 初始包含所有面 + patch_graph=graph, max_depth=max_depth, surf_bbox=surf_bbox ) @@ -302,8 +316,13 @@ class Trainer: best_val_loss = float('inf') logger.info("Starting training...") start_time = time.time() + + start_epoch = 1 + if args.resume_checkpoint_path: + start_epoch = self._load_checkpoint(args.resume_checkpoint_path) + logger.info(f"Loaded model from {args.resume_checkpoint_path}") - for epoch in range(1, self.config.train.num_epochs + 1): + for epoch in range(start_epoch, self.config.train.num_epochs + 1): # 训练一个epoch train_loss = self.train_epoch(epoch) @@ -329,7 +348,8 @@ class Trainer: # 训练完成 total_time = time.time() - start_time - self._tracing_model() + self._tracing_model_by_script() + #self._tracing_model() logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s') logger.info(f'Best validation loss: {best_val_loss:.6f}') #self.test_load() @@ -349,8 +369,8 @@ class Trainer: self.model.eval() # 确保模型中的所有逻辑都兼容 TorchScript scripted_model = torch.jit.script(self.model) - optimized_model = optimize_for_mobile(scripted_model) - torch.jit.save(optimized_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt") + #optimized_model = optimize_for_mobile(scripted_model) + torch.jit.save(scripted_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt") def _tracing_model(self): """保存模型""" @@ -375,11 +395,6 @@ class Trainer: except Exception as e: logger.error(f"模型验证失败:{e}") - def _load_checkpoint(self, checkpoint_path): - """从检查点恢复训练状态""" - model = torch.load(checkpoint_path) - return model - def _save_checkpoint(self, epoch: int, train_loss: float): """保存训练检查点""" checkpoint_dir = os.path.join( @@ -387,18 +402,26 @@ class Trainer: self.model_name ) os.makedirs(checkpoint_dir, exist_ok=True) - checkpoint_path = os.path.join(checkpoint_dir,f"epoch_{epoch:03d}.pth") - ''' + checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch:03d}.pth") + + # 只保存状态字典 torch.save({ 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'loss': train_loss, - 'config': self.config }, checkpoint_path) - ''' - torch.save(self.model,checkpoint_path) + def _load_checkpoint(self, checkpoint_path): + """从检查点恢复训练状态""" + try: + checkpoint = torch.load(checkpoint_path) + self.model.load_state_dict(checkpoint['model_state_dict']) + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + return checkpoint['epoch'] + 1 + except Exception as e: + logger.error(f"加载checkpoint失败: {str(e)}") + raise def main(): # 这里需要初始化配置 config = get_default_config()