diff --git a/scripts/process_brep.py b/scripts/process_brep.py index 4fd92de..244488f 100644 --- a/scripts/process_brep.py +++ b/scripts/process_brep.py @@ -1,59 +1,107 @@ +""" +CAD模型处理脚本 +功能:将STEP格式的CAD模型转换为结构化数据,包括: +- 几何信息:面、边、顶点的坐标数据 +- 拓扑信息:面-边-顶点的邻接关系 +- 空间信息:包围盒数据 +""" + import os -import pickle -import argparse +import pickle # 用于数据序列化 +import argparse # 命令行参数解析 import numpy as np -from tqdm import tqdm -from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError +from tqdm import tqdm # 进度条显示 +from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError # 并行处理 +import logging +from datetime import datetime + +# 创建logs目录 +os.makedirs('logs', exist_ok=True) + +# 设置日志记录器 +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +# 创建格式化器 +formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') -from OCC.Core.STEPControl import STEPControl_Reader -from OCC.Core.TopExp import TopExp_Explorer, topexp -from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX -from OCC.Core.BRep import BRep_Tool -from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh -from OCC.Core.TopLoc import TopLoc_Location -from OCC.Core.IFSelect import IFSelect_RetDone -from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape -from OCC.Core.BRepBndLib import brepbndlib -from OCC.Core.Bnd import Bnd_Box -from OCC.Core.TopoDS import TopoDS_Shape -from OCC.Core.TopoDS import topods -from OCC.Core.TopoDS import TopoDS_Vertex +# 创建文件处理器 +current_time = datetime.now().strftime('%Y%m%d_%H%M%S') +log_file = f'logs/process_brep_{current_time}.log' +file_handler = logging.FileHandler(log_file, encoding='utf-8') +file_handler.setLevel(logging.INFO) +file_handler.setFormatter(formatter) -# To speed up processing, define maximum threshold +# 添加文件处理器到日志记录器 +logger.addHandler(file_handler) + +# 记录脚本开始执行 +logger.info("="*50) +logger.info("Script started") +logger.info(f"Log file: {log_file}") +logger.info("="*50) + +# 导入OpenCASCADE相关库 +from OCC.Core.STEPControl import STEPControl_Reader # STEP文件读取器 +from OCC.Core.TopExp import TopExp_Explorer, topexp # 拓扑结构遍历 +from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX # 拓扑类型定义 +from OCC.Core.BRep import BRep_Tool # B-rep工具 +from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh # 网格剖分 +from OCC.Core.TopLoc import TopLoc_Location # 位置变换 +from OCC.Core.IFSelect import IFSelect_RetDone # 操作状态码 +from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射 +from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算 +from OCC.Core.Bnd import Bnd_Box # 包围盒 +from OCC.Core.TopoDS import TopoDS_Shape, topods, TopoDS_Vertex # 拓扑数据结构 + +# 设置最大面数阈值,用于加速处理 MAX_FACE = 70 def normalize(surfs, edges, corners): - """Normalize the CAD model to unit cube""" + """ + 将CAD模型归一化到单位立方体空间 + + 参数: + surfs: 面的点集列表 + edges: 边的点集列表 + corners: 顶点坐标列表 + + 返回: + surfs_wcs: 原始坐标系下的面点集 + edges_wcs: 原始坐标系下的边点集 + surfs_ncs: 归一化坐标系下的面点集 + edges_ncs: 归一化坐标系下的边点集 + corner_wcs: 归一化后的顶点坐标 + """ if len(corners) == 0: return None, None, None, None, None - # Get bounding box + # 计算包围盒和缩放因子 corners_array = np.array(corners) - center = (corners_array.max(0) + corners_array.min(0)) / 2 - scale = 1.0 / (corners_array.max(0) - corners_array.min(0)).max() + center = (corners_array.max(0) + corners_array.min(0)) / 2 # 计算中心点 + scale = 1.0 / (corners_array.max(0) - corners_array.min(0)).max() # 计算缩放系数 - # Normalize surfaces - surfs_wcs = [] - surfs_ncs = [] + # 归一化面的点集 + surfs_wcs = [] # 原始世界坐标系下的面点集 + surfs_ncs = [] # 归一化坐标系下的面点集 for surf in surfs: - surf_wcs = np.array(surf) # 确保是numpy数组 - surf_ncs = (surf_wcs - center) * scale + surf_wcs = np.array(surf) + surf_ncs = (surf_wcs - center) * scale # 归一化变换 surfs_wcs.append(surf_wcs) surfs_ncs.append(surf_ncs) - # Normalize edges - edges_wcs = [] - edges_ncs = [] + # 归一化边的点集 + edges_wcs = [] # 原始世界坐标系下的边点集 + edges_ncs = [] # 归一化坐标系下的边点集 for edge in edges: - edge_wcs = np.array(edge) # 确保是numpy数组 - edge_ncs = (edge_wcs - center) * scale + edge_wcs = np.array(edge) + edge_ncs = (edge_wcs - center) * scale # 归一化变换 edges_wcs.append(edge_wcs) edges_ncs.append(edge_ncs) - # Normalize corners + # 归一化顶点坐标 corner_wcs = (corners_array - center) * scale - # 返回时保持列表形式 return (np.array(surfs_wcs, dtype=object), np.array(edges_wcs, dtype=object), np.array(surfs_ncs, dtype=object), @@ -61,21 +109,32 @@ def normalize(surfs, edges, corners): corner_wcs) def get_adjacency_info(shape): - """获取形状的邻接信息""" - # 创建数据映射 + """ + 获取CAD模型中面、边、顶点之间的邻接关系 + + 参数: + shape: CAD模型的形状对象 + + 返回: + edgeFace_adj: 边-面邻接矩阵 (num_edges × num_faces) + faceEdge_adj: 面-边邻接矩阵 (num_faces × num_edges) + edgeCorner_adj: 边-顶点邻接矩阵 (num_edges × 2) + """ + # 创建边-面映射关系 edge_face_map = TopTools_IndexedDataMapOfShapeListOfShape() topexp.MapShapesAndAncestors(shape, TopAbs_EDGE, TopAbs_FACE, edge_face_map) - # 获取所有面和边 - faces = [] - edges = [] - vertices = [] + # 获取所有几何元素 + faces = [] # 存储所有面 + edges = [] # 存储所有边 + vertices = [] # 存储所有顶点 + # 创建拓扑结构探索器 face_explorer = TopExp_Explorer(shape, TopAbs_FACE) edge_explorer = TopExp_Explorer(shape, TopAbs_EDGE) vertex_explorer = TopExp_Explorer(shape, TopAbs_VERTEX) - # 收集所有面、边和顶点 + # 收集所有几何元素 while face_explorer.More(): faces.append(topods.Face(face_explorer.Current())) face_explorer.Next() @@ -99,9 +158,8 @@ def get_adjacency_info(shape): # 填充边-面邻接矩阵 for i, edge in enumerate(edges): - # 获取与边相连的面 + # 检查每个面是否与当前边相连 for j, face in enumerate(faces): - # 使用 explorer 检查边是否属于面 edge_explorer = TopExp_Explorer(face, TopAbs_EDGE) while edge_explorer.More(): if edge.IsSame(edge_explorer.Current()): @@ -110,16 +168,16 @@ def get_adjacency_info(shape): break edge_explorer.Next() - # 获取边的顶点 + # 获取边的两个端点 v1 = TopoDS_Vertex() v2 = TopoDS_Vertex() topexp.Vertices(edge, v1, v2) + # 记录边的端点索引 if not v1.IsNull() and not v2.IsNull(): v1_vertex = topods.Vertex(v1) v2_vertex = topods.Vertex(v2) - # 找到顶点的索引 for k, vertex in enumerate(vertices): if v1_vertex.IsSame(vertex): edgeCorner_adj[i, 0] = k @@ -129,7 +187,16 @@ def get_adjacency_info(shape): return edgeFace_adj, faceEdge_adj, edgeCorner_adj def get_bbox(shape, subshape): - """计算包围盒""" + """ + 计算形状的包围盒 + + 参数: + shape: 完整的CAD模型形状 + subshape: 需要计算包围盒的子形状(面或边) + + 返回: + 包围盒的六个参数 [xmin, ymin, zmin, xmax, ymax, zmax] + """ bbox = Bnd_Box() brepbndlib.Add(subshape, bbox) xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get() @@ -247,45 +314,32 @@ def load_step(step_path): reader.TransferRoots() return [reader.OneShape()] -def process(step_path, timeout=300): +def process_single_step( + step_path:str, + output_path:str=None, + timeout:int=300 +) -> dict: """Process single STEP file""" try: - # Check single solid - cad_solid = load_step(step_path) - if len(cad_solid)!=1: - print('Skipping multi solids...') - return 0 - - # Start data parsing + # 解析STEP文件 data = parse_solid(step_path) - if data is None: - print ('Exceeding threshold...') - return 0 - - # Save the parsed result - if 'furniture' in step_path: - data_uid = step_path.split('/')[-2] + '_' + step_path.split('/')[-1] - sub_folder = step_path.split('/')[-3] - else: - data_uid = step_path.split('/')[-2] - sub_folder = data_uid[:4] - - if data_uid.endswith('.step'): - data_uid = data_uid[:-5] - - data['uid'] = data_uid - save_folder = os.path.join(OUTPUT, sub_folder) - if not os.path.exists(save_folder): - os.makedirs(save_folder) - - save_path = os.path.join(save_folder, data['uid']+'.pkl') - with open(save_path, "wb") as tf: - pickle.dump(data, tf) - - return 1 - + if data is None: + logger.error("Failed to parse STEP file") + return None + # 保存结果 + if output_path: + try: + logger.info(f"Saving results to: {output_path}") + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'wb') as f: + pickle.dump(data, f) + logger.info("Results saved successfully") + except Exception as e: + logger.error(f'Not saving due to error: {str(e)}') + + return data except Exception as e: - print('not saving due to error...', str(e)) + logger.error(f'Not saving due to error: {str(e)}') return 0 def test(step_file_path, output_path=None): @@ -293,120 +347,145 @@ def test(step_file_path, output_path=None): 测试函数:转换单个STEP文件并保存结果 """ try: - print(f"Processing STEP file: {step_file_path}") + logger.info(f"Processing STEP file: {step_file_path}") # 解析STEP文件 data = parse_solid(step_file_path) if data is None: - print("Failed to parse STEP file") + logger.error("Failed to parse STEP file") return None # 打印统计信息 - print("\nStatistics:") - print(f"Number of surfaces: {len(data['surf_wcs'])}") - print(f"Number of edges: {len(data['edge_wcs'])}") - print(f"Number of corners: {len(data['corner_unique'])}") + logger.info("\nStatistics:") + logger.info(f"Number of surfaces: {len(data['surf_wcs'])}") + logger.info(f"Number of edges: {len(data['edge_wcs'])}") + logger.info(f"Number of corners: {len(data['corner_unique'])}") # 保存结果 if output_path: - print(f"\nSaving results to: {output_path}") + logger.info(f"Saving results to: {output_path}") os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, 'wb') as f: pickle.dump(data, f) - print("Results saved successfully") + logger.info("Results saved successfully") return data except Exception as e: - print(f"Error processing STEP file: {str(e)}") + logger.error(f"Error processing STEP file: {str(e)}") return None -def load_furniture_step(data_path): - """Load furniture STEP files""" - step_dirs = [] +def process_furniture_step(data_path): + """ + 处理家具数据集的STEP文件 + + 参数: + data_path: 数据集路径 + + 返回: + 包含训练、验证和测试集的STEP文件路径字典 + { + 'train': [step_file_path1, step_file_path2, ...], + 'val': [step_file_path1, step_file_path2, ...], + 'test': [step_file_path1, step_file_path2, ...] + } + """ + + step_dirs = {} for split in ['train', 'val', 'test']: + tmp_step_dirs = [] split_path = os.path.join(data_path, split) if os.path.exists(split_path): for f in os.listdir(split_path): if f.endswith('.step'): - step_dirs.append(os.path.join(split_path, f)) + tmp_step_dirs.append(f) + step_dirs[split] = tmp_step_dirs return step_dirs -def load_abc_step(data_path, is_deepcad=False): - """Load ABC/DeepCAD STEP files""" - step_dirs = [] - for f in sorted(os.listdir(data_path)): - if os.path.isdir(os.path.join(data_path, f)): - if is_deepcad: - step_path = os.path.join(data_path, f, f+'.step') - else: - step_path = os.path.join(data_path, f, 'shape.step') - if os.path.exists(step_path): - step_dirs.append(step_path) - return step_dirs + def main(): """ - 主函数:处多个STEP文件 + 主函数:处理多个STEP文件 """ - parser = argparse.ArgumentParser() - parser.add_argument("--input", type=str, help="Data folder path", required=True) - parser.add_argument("--option", type=str, choices=['abc', 'deepcad', 'furniture'], default='abc', - help="Choose between dataset option [abc/deepcad/furniture] (default: abc)") - parser.add_argument("--interval", type=int, help="Data range index, only required for abc/deepcad") - args = parser.parse_args() - - global OUTPUT - if args.option == 'deepcad': - OUTPUT = 'deepcad_parsed' - elif args.option == 'abc': - OUTPUT = 'abc_parsed' - else: - OUTPUT = 'furniture_parsed' + # 定义路径常量 + INPUT = '/mnt/disk2/dataset/furniture/step/furniture_dataset_step/' + OUTPUT = 'test_data/pkl/' + RESULT = 'test_data/result/' # 用于存储成功/失败文件记录 - # Load all STEP files - if args.option == 'furniture': - step_dirs = load_furniture_step(args.input) - else: - step_dirs = load_abc_step(args.input, args.option=='deepcad') - step_dirs = step_dirs[args.interval*10000 : (args.interval+1)*10000] - - # Process B-reps in parallel - valid = 0 - with ProcessPoolExecutor(max_workers=os.cpu_count() // 2) as executor: - futures = {} - for step_folder in step_dirs: - future = executor.submit(process, step_folder, timeout=300) - futures[future] = step_folder - - for future in tqdm(as_completed(futures), total=len(step_dirs)): - try: - status = future.result(timeout=300) - valid += status - except TimeoutError: - print(f"Timeout occurred while processing {futures[future]}") - except Exception as e: - print(f"An error occurred while processing {futures[future]}: {e}") - - print(f'Done... Data Converted Ratio {100.0*valid/len(step_dirs)}%') - -if __name__ == '__main__': - import sys + # 确保输出目录存在 + os.makedirs(OUTPUT, exist_ok=True) + os.makedirs(RESULT, exist_ok=True) + + # 获取所有STEP文件 + step_dirs_dict = process_furniture_step(INPUT) + total_processed = 0 + total_success = 0 - if len(sys.argv) > 1 and sys.argv[1] == '--test': - # 测试模式 - if len(sys.argv) < 3: - print("Usage: python process_brep.py --test [output_path]") - sys.exit(1) + # 按数据集分割处理文件 + for split in ['train', 'val', 'test']: + current_step_dirs = step_dirs_dict[split] + if not current_step_dirs: + logger.warning(f"No files found in {split} split") + continue - step_file = sys.argv[2] - output_file = sys.argv[3] if len(sys.argv) > 3 else None + # 确保分割目录存在 + split_output_dir = os.path.join(OUTPUT, split) + split_result_dir = os.path.join(RESULT, split) + os.makedirs(split_output_dir, exist_ok=True) + os.makedirs(split_result_dir, exist_ok=True) - print("Running in test mode...") - result = test(step_file, output_file) + success_files = [] # 存储成功处理的文件名 + failed_files = [] # 存储失败的文件名及原因 - if result is not None: - print("\nTest completed successfully!") + # 并行处理文件 + with ProcessPoolExecutor(max_workers=os.cpu_count() // 2) as executor: + futures = {} + for step_file in current_step_dirs: + input_path = os.path.join(INPUT, split, step_file) + output_path = os.path.join(split_output_dir, step_file.replace('.step', '.pkl')) + future = executor.submit(process_single_step, input_path, output_path, timeout=300) + futures[future] = step_file + + # 处理结果 + for future in tqdm(as_completed(futures), total=len(current_step_dirs), + desc=f"Processing {split} set"): + try: + status = future.result(timeout=300) + if status is not None: + success_files.append(futures[future]) + total_success += 1 + except TimeoutError: + logger.error(f"Timeout occurred while processing {futures[future]}") + failed_files.append((futures[future], "Timeout")) + except Exception as e: + logger.error(f"Error processing {futures[future]}: {str(e)}") + failed_files.append((futures[future], str(e))) + finally: + total_processed += 1 + + # 保存成功文件列表 + success_file_path = os.path.join(split_result_dir, 'success.txt') + with open(success_file_path, 'w', encoding='utf-8') as f: + f.write('\n'.join(success_files)) + logger.info(f"Saved {len(success_files)} successful files to {success_file_path}") + + # 保存失败文件列表(包含错误信息) + failed_file_path = os.path.join(split_result_dir, 'failed.txt') + with open(failed_file_path, 'w', encoding='utf-8') as f: + for file, error in failed_files: + f.write(f"{file}: {error}\n") + logger.info(f"Saved {len(failed_files)} failed files to {failed_file_path}") + + # 打印最终统计信息 + if total_processed > 0: + success_rate = (total_success / total_processed) * 100 + logger.info(f"Processing completed:") + logger.info(f"Total files processed: {total_processed}") + logger.info(f"Successfully processed: {total_success}") + logger.info(f"Success rate: {success_rate:.2f}%") else: - # 正常批处理模式 - main() \ No newline at end of file + logger.warning("No files were processed") + +if __name__ == '__main__': + main()