diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py index 96e71bf..c41db00 100644 --- a/brep2sdf/config/default_config.py +++ b/brep2sdf/config/default_config.py @@ -27,8 +27,8 @@ class ModelConfig: @dataclass class DataConfig: """数据相关配置""" - max_face: int = 16 - max_edge: int = 64 + max_face: int = 80 + max_edge: int = 16 num_query_points: int = 32*32*32 # 限制查询点数量,sdf 采样点数 本来是 128*128*128 ,在data load时随机采样 bbox_scaled: float = 1.0 @@ -47,7 +47,7 @@ class TrainConfig: # 基本训练参数 batch_size: int = 8 num_workers: int = 4 - num_epochs: int = 500 + num_epochs: int = 20 learning_rate: float = 0.01 min_lr: float = 1e-5 weight_decay: float = 0.01 @@ -62,7 +62,7 @@ class TrainConfig: warmup_epochs: int = 5 # 保存和验证 - save_freq: int = 10 # 每多少个epoch保存一次 + save_freq: int = 20 # 每多少个epoch保存一次 val_freq: int = 1 # 每多少个epoch验证一次 # 保存路径 diff --git a/brep2sdf/data/pre_process.py b/brep2sdf/data/pre_process.py index 903a2ae..8a65b46 100644 --- a/brep2sdf/data/pre_process.py +++ b/brep2sdf/data/pre_process.py @@ -93,49 +93,25 @@ def normalize(surfs, edges, corners): scale ) -def get_adjacency_info(shape): +def get_adjacency_info(shape, faces, edges, vertices): """ - 获取CAD模型中面、边、顶点之间的邻接关系 + 优化后的邻接关系计算函数,直接使用已收集的几何元素 - 参数: - shape: CAD模型的形状对象 - - 返回: - edgeFace_adj: 边-面邻接矩阵 (num_edges × num_faces) - faceEdge_adj: 面-边邻接矩阵 (num_faces × num_edges) - edgeCorner_adj: 边-顶点邻接矩阵 (num_edges × 2) + 参数新增: + faces: 已收集的面列表 + edges: 已收集的边列表 + vertices: 已收集的顶点列表 """ + logger.debug("Get adjacency infos...") # 创建边-面映射关系 edge_face_map = TopTools_IndexedDataMapOfShapeListOfShape() topexp.MapShapesAndAncestors(shape, TopAbs_EDGE, TopAbs_FACE, edge_face_map) - # 获取所有几何元素 - faces = [] # 存储所有面 - edges = [] # 存储所有边 - vertices = [] # 存储所有顶点 - - # 创建拓扑结构探索器 - face_explorer = TopExp_Explorer(shape, TopAbs_FACE) - edge_explorer = TopExp_Explorer(shape, TopAbs_EDGE) - vertex_explorer = TopExp_Explorer(shape, TopAbs_VERTEX) - - # 收集所有几何元素 - while face_explorer.More(): - faces.append(topods.Face(face_explorer.Current())) - face_explorer.Next() - - while edge_explorer.More(): - edges.append(topods.Edge(edge_explorer.Current())) - edge_explorer.Next() - - while vertex_explorer.More(): - vertices.append(topods.Vertex(vertex_explorer.Current())) - vertex_explorer.Next() - - # 创建邻接矩阵 + # 直接使用传入的几何元素列表 num_faces = len(faces) num_edges = len(edges) num_vertices = len(vertices) + logger.debug(f"num_faces: {num_faces}, num_edges: {num_edges}, num_vertices: {num_vertices}") edgeFace_adj = np.zeros((num_edges, num_faces), dtype=np.int32) faceEdge_adj = np.zeros((num_faces, num_edges), dtype=np.int32) @@ -243,10 +219,14 @@ def parse_solid(step_path): corner_pnts = [] surf_bbox_wcs = [] edge_bbox_wcs = [] + + faces, edges, vertices = [], [], [] # Extract face points + logger.debug("Extract face points...") while face_explorer.More(): face = topods.Face(face_explorer.Current()) + faces.append(face) loc = TopLoc_Location() triangulation = BRep_Tool.Triangulation(face, loc) @@ -267,11 +247,18 @@ def parse_solid(step_path): surf_bbox_wcs.append(get_bbox(shape, face)) face_explorer.Next() + face_count = len(faces) + if face_count > MAX_FACE: + logger.error(f"step has {face_count} faces, which exceeds MAX_FACE {MAX_FACE}") + return None # Extract edge points + logger.debug("Extract edge points...") num_samples = config.model.num_edge_points # 使用配置中的边采样点数 while edge_explorer.More(): edge = topods.Edge(edge_explorer.Current()) + edges.append(edge) + logger.debug(len(edges)) curve_info = BRep_Tool.Curve(edge) if curve_info is None: continue # 跳过无效边 @@ -280,15 +267,12 @@ def parse_solid(step_path): if len(curve_info) == 3: curve, first, last = curve_info elif len(curve_info) == 2: - continue - curve, location = curve_info - logger.info(curve) - first, last = BRep_Tool.Range(edge) # 显式获取参数范围 + curve = None # 跳过判断 else: raise ValueError(f"Unexpected curve info: {curve_info}") except Exception as e: logger.error(f"Failed to process edge {edge}: {str(e)}") - continue + curve = None if curve is not None: points = [] @@ -306,14 +290,22 @@ def parse_solid(step_path): edge_explorer.Next() # Extract vertex points + logger.debug("Extract vertex points...") while vertex_explorer.More(): vertex = topods.Vertex(vertex_explorer.Current()) + vertices.append(vertex) pnt = BRep_Tool.Pnt(vertex) corner_pnts.append([pnt.X(), pnt.Y(), pnt.Z()]) vertex_explorer.Next() # 获取邻接信息 - edgeFace_adj, faceEdge_adj, edgeCorner_adj = get_adjacency_info(shape) + edgeFace_adj, faceEdge_adj, edgeCorner_adj = get_adjacency_info( + shape, + faces=faces, # 传入已收集的面列表 + edges=edges, # 传入已收集的边列表 + vertices=vertices # 传入已收集的顶点列表 + ) + logger.debug("complete.") # 转换为numpy数组时确保类型正确 face_pnts = [np.array(points, dtype=np.float32) for points in face_pnts]