import os import pickle import argparse import numpy as np from tqdm import tqdm from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError from OCC.Core.STEPControl import STEPControl_Reader from OCC.Core.TopExp import TopExp_Explorer, topexp from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX from OCC.Core.BRep import BRep_Tool from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh from OCC.Core.TopLoc import TopLoc_Location from OCC.Core.IFSelect import IFSelect_RetDone from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape from OCC.Core.BRepBndLib import brepbndlib from OCC.Core.Bnd import Bnd_Box from OCC.Core.TopoDS import TopoDS_Shape from OCC.Core.TopoDS import topods from OCC.Core.TopoDS import TopoDS_Vertex # To speed up processing, define maximum threshold MAX_FACE = 70 def normalize(surfs, edges, corners): """Normalize the CAD model to unit cube""" if len(corners) == 0: return None, None, None, None, None # Get bounding box corners_array = np.array(corners) center = (corners_array.max(0) + corners_array.min(0)) / 2 scale = 1.0 / (corners_array.max(0) - corners_array.min(0)).max() # Normalize surfaces surfs_wcs = [] surfs_ncs = [] for surf in surfs: surf_wcs = np.array(surf) # 确保是numpy数组 surf_ncs = (surf_wcs - center) * scale surfs_wcs.append(surf_wcs) surfs_ncs.append(surf_ncs) # Normalize edges edges_wcs = [] edges_ncs = [] for edge in edges: edge_wcs = np.array(edge) # 确保是numpy数组 edge_ncs = (edge_wcs - center) * scale edges_wcs.append(edge_wcs) edges_ncs.append(edge_ncs) # Normalize corners corner_wcs = (corners_array - center) * scale # 返回时保持列表形式 return (np.array(surfs_wcs, dtype=object), np.array(edges_wcs, dtype=object), np.array(surfs_ncs, dtype=object), np.array(edges_ncs, dtype=object), corner_wcs) def get_adjacency_info(shape): """获取形状的邻接信息""" # 创建数据映射 edge_face_map = TopTools_IndexedDataMapOfShapeListOfShape() topexp.MapShapesAndAncestors(shape, TopAbs_EDGE, TopAbs_FACE, edge_face_map) # 获取所有面和边 faces = [] edges = [] vertices = [] face_explorer = TopExp_Explorer(shape, TopAbs_FACE) edge_explorer = TopExp_Explorer(shape, TopAbs_EDGE) vertex_explorer = TopExp_Explorer(shape, TopAbs_VERTEX) # 收集所有面、边和顶点 while face_explorer.More(): faces.append(topods.Face(face_explorer.Current())) face_explorer.Next() while edge_explorer.More(): edges.append(topods.Edge(edge_explorer.Current())) edge_explorer.Next() while vertex_explorer.More(): vertices.append(topods.Vertex(vertex_explorer.Current())) vertex_explorer.Next() # 创建邻接矩阵 num_faces = len(faces) num_edges = len(edges) num_vertices = len(vertices) edgeFace_adj = np.zeros((num_edges, num_faces), dtype=np.int32) faceEdge_adj = np.zeros((num_faces, num_edges), dtype=np.int32) edgeCorner_adj = np.zeros((num_edges, 2), dtype=np.int32) # 填充边-面邻接矩阵 for i, edge in enumerate(edges): # 获取与边相连的面 for j, face in enumerate(faces): # 使用 explorer 检查边是否属于面 edge_explorer = TopExp_Explorer(face, TopAbs_EDGE) while edge_explorer.More(): if edge.IsSame(edge_explorer.Current()): edgeFace_adj[i, j] = 1 faceEdge_adj[j, i] = 1 break edge_explorer.Next() # 获取边的顶点 v1 = TopoDS_Vertex() v2 = TopoDS_Vertex() topexp.Vertices(edge, v1, v2) if not v1.IsNull() and not v2.IsNull(): v1_vertex = topods.Vertex(v1) v2_vertex = topods.Vertex(v2) # 找到顶点的索引 for k, vertex in enumerate(vertices): if v1_vertex.IsSame(vertex): edgeCorner_adj[i, 0] = k if v2_vertex.IsSame(vertex): edgeCorner_adj[i, 1] = k return edgeFace_adj, faceEdge_adj, edgeCorner_adj def get_bbox(shape, subshape): """计算包围盒""" bbox = Bnd_Box() brepbndlib.Add(subshape, bbox) xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get() return np.array([xmin, ymin, zmin, xmax, ymax, zmax]) def parse_solid(step_path): """Parse the surface, curve, face, edge, vertex in a CAD solid using OCC.""" # Load STEP file reader = STEPControl_Reader() status = reader.ReadFile(step_path) if status != IFSelect_RetDone: raise Exception("Failed to read STEP file") reader.TransferRoots() shape = reader.OneShape() # Create mesh mesh = BRepMesh_IncrementalMesh(shape, 0.01) mesh.Perform() # Initialize explorers face_explorer = TopExp_Explorer(shape, TopAbs_FACE) edge_explorer = TopExp_Explorer(shape, TopAbs_EDGE) vertex_explorer = TopExp_Explorer(shape, TopAbs_VERTEX) face_pnts = [] edge_pnts = [] corner_pnts = [] surf_bbox_wcs = [] edge_bbox_wcs = [] # Extract face points while face_explorer.More(): face = topods.Face(face_explorer.Current()) loc = TopLoc_Location() triangulation = BRep_Tool.Triangulation(face, loc) if triangulation is not None: points = [] for i in range(1, triangulation.NbNodes() + 1): node = triangulation.Node(i) pnt = node.Transformed(loc.Transformation()) points.append([pnt.X(), pnt.Y(), pnt.Z()]) if points: points = np.array(points, dtype=np.float32) if len(points.shape) == 2 and points.shape[1] == 3: face_pnts.append(points) surf_bbox_wcs.append(get_bbox(shape, face)) face_explorer.Next() # Extract edge points num_samples = 100 while edge_explorer.More(): edge = topods.Edge(edge_explorer.Current()) curve, first, last = BRep_Tool.Curve(edge) if curve is not None: points = [] for i in range(num_samples): param = first + (last - first) * float(i) / (num_samples - 1) pnt = curve.Value(param) points.append([pnt.X(), pnt.Y(), pnt.Z()]) if points: points = np.array(points, dtype=np.float32) if len(points.shape) == 2 and points.shape[1] == 3: edge_pnts.append(points) edge_bbox_wcs.append(get_bbox(shape, edge)) edge_explorer.Next() # Extract vertex points while vertex_explorer.More(): vertex = topods.Vertex(vertex_explorer.Current()) pnt = BRep_Tool.Pnt(vertex) corner_pnts.append([pnt.X(), pnt.Y(), pnt.Z()]) vertex_explorer.Next() # 获取邻接信息 edgeFace_adj, faceEdge_adj, edgeCorner_adj = get_adjacency_info(shape) # 转换为numpy数组,但保持列表形式 face_pnts = list(face_pnts) # 确保是列表 edge_pnts = list(edge_pnts) # 确保是列表 corner_pnts = np.array(corner_pnts, dtype=np.float32) surf_bbox_wcs = np.array(surf_bbox_wcs, dtype=np.float32) edge_bbox_wcs = np.array(edge_bbox_wcs, dtype=np.float32) # Normalize the CAD model surfs_wcs, edges_wcs, surfs_ncs, edges_ncs, corner_wcs = normalize(face_pnts, edge_pnts, corner_pnts) # Create result dictionary data = { 'surf_wcs': surfs_wcs, 'edge_wcs': edges_wcs, 'surf_ncs': surfs_ncs, 'edge_ncs': edges_ncs, 'corner_wcs': corner_wcs.astype(np.float32), 'edgeFace_adj': edgeFace_adj, 'edgeCorner_adj': edgeCorner_adj, 'faceEdge_adj': faceEdge_adj, 'surf_bbox_wcs': surf_bbox_wcs, 'edge_bbox_wcs': edge_bbox_wcs, 'corner_unique': np.unique(corner_wcs, axis=0).astype(np.float32) } return data def load_step(step_path): """Load STEP file and return solids""" reader = STEPControl_Reader() reader.ReadFile(step_path) reader.TransferRoots() return [reader.OneShape()] def process(step_path, timeout=300): """Process single STEP file""" try: # Check single solid cad_solid = load_step(step_path) if len(cad_solid)!=1: print('Skipping multi solids...') return 0 # Start data parsing data = parse_solid(step_path) if data is None: print ('Exceeding threshold...') return 0 # Save the parsed result if 'furniture' in step_path: data_uid = step_path.split('/')[-2] + '_' + step_path.split('/')[-1] sub_folder = step_path.split('/')[-3] else: data_uid = step_path.split('/')[-2] sub_folder = data_uid[:4] if data_uid.endswith('.step'): data_uid = data_uid[:-5] data['uid'] = data_uid save_folder = os.path.join(OUTPUT, sub_folder) if not os.path.exists(save_folder): os.makedirs(save_folder) save_path = os.path.join(save_folder, data['uid']+'.pkl') with open(save_path, "wb") as tf: pickle.dump(data, tf) return 1 except Exception as e: print('not saving due to error...', str(e)) return 0 def test(step_file_path, output_path=None): """ 测试函数:转换单个STEP文件并保存结果 """ try: print(f"Processing STEP file: {step_file_path}") # 解析STEP文件 data = parse_solid(step_file_path) if data is None: print("Failed to parse STEP file") return None # 打印统计信息 print("\nStatistics:") print(f"Number of surfaces: {len(data['surf_wcs'])}") print(f"Number of edges: {len(data['edge_wcs'])}") print(f"Number of corners: {len(data['corner_unique'])}") # 保存结果 if output_path: print(f"\nSaving results to: {output_path}") os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, 'wb') as f: pickle.dump(data, f) print("Results saved successfully") return data except Exception as e: print(f"Error processing STEP file: {str(e)}") return None def load_furniture_step(data_path): """Load furniture STEP files""" step_dirs = [] for split in ['train', 'val', 'test']: split_path = os.path.join(data_path, split) if os.path.exists(split_path): for f in os.listdir(split_path): if f.endswith('.step'): step_dirs.append(os.path.join(split_path, f)) return step_dirs def load_abc_step(data_path, is_deepcad=False): """Load ABC/DeepCAD STEP files""" step_dirs = [] for f in sorted(os.listdir(data_path)): if os.path.isdir(os.path.join(data_path, f)): if is_deepcad: step_path = os.path.join(data_path, f, f+'.step') else: step_path = os.path.join(data_path, f, 'shape.step') if os.path.exists(step_path): step_dirs.append(step_path) return step_dirs def main(): """ 主函数:处多个STEP文件 """ parser = argparse.ArgumentParser() parser.add_argument("--input", type=str, help="Data folder path", required=True) parser.add_argument("--option", type=str, choices=['abc', 'deepcad', 'furniture'], default='abc', help="Choose between dataset option [abc/deepcad/furniture] (default: abc)") parser.add_argument("--interval", type=int, help="Data range index, only required for abc/deepcad") args = parser.parse_args() global OUTPUT if args.option == 'deepcad': OUTPUT = 'deepcad_parsed' elif args.option == 'abc': OUTPUT = 'abc_parsed' else: OUTPUT = 'furniture_parsed' # Load all STEP files if args.option == 'furniture': step_dirs = load_furniture_step(args.input) else: step_dirs = load_abc_step(args.input, args.option=='deepcad') step_dirs = step_dirs[args.interval*10000 : (args.interval+1)*10000] # Process B-reps in parallel valid = 0 with ProcessPoolExecutor(max_workers=os.cpu_count() // 2) as executor: futures = {} for step_folder in step_dirs: future = executor.submit(process, step_folder, timeout=300) futures[future] = step_folder for future in tqdm(as_completed(futures), total=len(step_dirs)): try: status = future.result(timeout=300) valid += status except TimeoutError: print(f"Timeout occurred while processing {futures[future]}") except Exception as e: print(f"An error occurred while processing {futures[future]}: {e}") print(f'Done... Data Converted Ratio {100.0*valid/len(step_dirs)}%') if __name__ == '__main__': import sys if len(sys.argv) > 1 and sys.argv[1] == '--test': # 测试模式 if len(sys.argv) < 3: print("Usage: python process_brep.py --test [output_path]") sys.exit(1) step_file = sys.argv[2] output_file = sys.argv[3] if len(sys.argv) > 3 else None print("Running in test mode...") result = test(step_file, output_file) if result is not None: print("\nTest completed successfully!") else: # 正常批处理模式 main()