|
|
@ -1,59 +1,107 @@ |
|
|
|
""" |
|
|
|
CAD模型处理脚本 |
|
|
|
功能:将STEP格式的CAD模型转换为结构化数据,包括: |
|
|
|
- 几何信息:面、边、顶点的坐标数据 |
|
|
|
- 拓扑信息:面-边-顶点的邻接关系 |
|
|
|
- 空间信息:包围盒数据 |
|
|
|
""" |
|
|
|
|
|
|
|
import os |
|
|
|
import pickle |
|
|
|
import argparse |
|
|
|
import pickle # 用于数据序列化 |
|
|
|
import argparse # 命令行参数解析 |
|
|
|
import numpy as np |
|
|
|
from tqdm import tqdm |
|
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError |
|
|
|
from tqdm import tqdm # 进度条显示 |
|
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError # 并行处理 |
|
|
|
import logging |
|
|
|
from datetime import datetime |
|
|
|
|
|
|
|
# 创建logs目录 |
|
|
|
os.makedirs('logs', exist_ok=True) |
|
|
|
|
|
|
|
# 设置日志记录器 |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
|
|
# 创建格式化器 |
|
|
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
|
|
from OCC.Core.STEPControl import STEPControl_Reader |
|
|
|
from OCC.Core.TopExp import TopExp_Explorer, topexp |
|
|
|
from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX |
|
|
|
from OCC.Core.BRep import BRep_Tool |
|
|
|
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh |
|
|
|
from OCC.Core.TopLoc import TopLoc_Location |
|
|
|
from OCC.Core.IFSelect import IFSelect_RetDone |
|
|
|
from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape |
|
|
|
from OCC.Core.BRepBndLib import brepbndlib |
|
|
|
from OCC.Core.Bnd import Bnd_Box |
|
|
|
from OCC.Core.TopoDS import TopoDS_Shape |
|
|
|
from OCC.Core.TopoDS import topods |
|
|
|
from OCC.Core.TopoDS import TopoDS_Vertex |
|
|
|
# 创建文件处理器 |
|
|
|
current_time = datetime.now().strftime('%Y%m%d_%H%M%S') |
|
|
|
log_file = f'logs/process_brep_{current_time}.log' |
|
|
|
file_handler = logging.FileHandler(log_file, encoding='utf-8') |
|
|
|
file_handler.setLevel(logging.INFO) |
|
|
|
file_handler.setFormatter(formatter) |
|
|
|
|
|
|
|
# To speed up processing, define maximum threshold |
|
|
|
# 添加文件处理器到日志记录器 |
|
|
|
logger.addHandler(file_handler) |
|
|
|
|
|
|
|
# 记录脚本开始执行 |
|
|
|
logger.info("="*50) |
|
|
|
logger.info("Script started") |
|
|
|
logger.info(f"Log file: {log_file}") |
|
|
|
logger.info("="*50) |
|
|
|
|
|
|
|
# 导入OpenCASCADE相关库 |
|
|
|
from OCC.Core.STEPControl import STEPControl_Reader # STEP文件读取器 |
|
|
|
from OCC.Core.TopExp import TopExp_Explorer, topexp # 拓扑结构遍历 |
|
|
|
from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX # 拓扑类型定义 |
|
|
|
from OCC.Core.BRep import BRep_Tool # B-rep工具 |
|
|
|
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh # 网格剖分 |
|
|
|
from OCC.Core.TopLoc import TopLoc_Location # 位置变换 |
|
|
|
from OCC.Core.IFSelect import IFSelect_RetDone # 操作状态码 |
|
|
|
from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射 |
|
|
|
from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算 |
|
|
|
from OCC.Core.Bnd import Bnd_Box # 包围盒 |
|
|
|
from OCC.Core.TopoDS import TopoDS_Shape, topods, TopoDS_Vertex # 拓扑数据结构 |
|
|
|
|
|
|
|
# 设置最大面数阈值,用于加速处理 |
|
|
|
MAX_FACE = 70 |
|
|
|
|
|
|
|
def normalize(surfs, edges, corners): |
|
|
|
"""Normalize the CAD model to unit cube""" |
|
|
|
""" |
|
|
|
将CAD模型归一化到单位立方体空间 |
|
|
|
|
|
|
|
参数: |
|
|
|
surfs: 面的点集列表 |
|
|
|
edges: 边的点集列表 |
|
|
|
corners: 顶点坐标列表 |
|
|
|
|
|
|
|
返回: |
|
|
|
surfs_wcs: 原始坐标系下的面点集 |
|
|
|
edges_wcs: 原始坐标系下的边点集 |
|
|
|
surfs_ncs: 归一化坐标系下的面点集 |
|
|
|
edges_ncs: 归一化坐标系下的边点集 |
|
|
|
corner_wcs: 归一化后的顶点坐标 |
|
|
|
""" |
|
|
|
if len(corners) == 0: |
|
|
|
return None, None, None, None, None |
|
|
|
|
|
|
|
# Get bounding box |
|
|
|
# 计算包围盒和缩放因子 |
|
|
|
corners_array = np.array(corners) |
|
|
|
center = (corners_array.max(0) + corners_array.min(0)) / 2 |
|
|
|
scale = 1.0 / (corners_array.max(0) - corners_array.min(0)).max() |
|
|
|
center = (corners_array.max(0) + corners_array.min(0)) / 2 # 计算中心点 |
|
|
|
scale = 1.0 / (corners_array.max(0) - corners_array.min(0)).max() # 计算缩放系数 |
|
|
|
|
|
|
|
# Normalize surfaces |
|
|
|
surfs_wcs = [] |
|
|
|
surfs_ncs = [] |
|
|
|
# 归一化面的点集 |
|
|
|
surfs_wcs = [] # 原始世界坐标系下的面点集 |
|
|
|
surfs_ncs = [] # 归一化坐标系下的面点集 |
|
|
|
for surf in surfs: |
|
|
|
surf_wcs = np.array(surf) # 确保是numpy数组 |
|
|
|
surf_ncs = (surf_wcs - center) * scale |
|
|
|
surf_wcs = np.array(surf) |
|
|
|
surf_ncs = (surf_wcs - center) * scale # 归一化变换 |
|
|
|
surfs_wcs.append(surf_wcs) |
|
|
|
surfs_ncs.append(surf_ncs) |
|
|
|
|
|
|
|
# Normalize edges |
|
|
|
edges_wcs = [] |
|
|
|
edges_ncs = [] |
|
|
|
# 归一化边的点集 |
|
|
|
edges_wcs = [] # 原始世界坐标系下的边点集 |
|
|
|
edges_ncs = [] # 归一化坐标系下的边点集 |
|
|
|
for edge in edges: |
|
|
|
edge_wcs = np.array(edge) # 确保是numpy数组 |
|
|
|
edge_ncs = (edge_wcs - center) * scale |
|
|
|
edge_wcs = np.array(edge) |
|
|
|
edge_ncs = (edge_wcs - center) * scale # 归一化变换 |
|
|
|
edges_wcs.append(edge_wcs) |
|
|
|
edges_ncs.append(edge_ncs) |
|
|
|
|
|
|
|
# Normalize corners |
|
|
|
# 归一化顶点坐标 |
|
|
|
corner_wcs = (corners_array - center) * scale |
|
|
|
|
|
|
|
# 返回时保持列表形式 |
|
|
|
return (np.array(surfs_wcs, dtype=object), |
|
|
|
np.array(edges_wcs, dtype=object), |
|
|
|
np.array(surfs_ncs, dtype=object), |
|
|
@ -61,21 +109,32 @@ def normalize(surfs, edges, corners): |
|
|
|
corner_wcs) |
|
|
|
|
|
|
|
def get_adjacency_info(shape): |
|
|
|
"""获取形状的邻接信息""" |
|
|
|
# 创建数据映射 |
|
|
|
""" |
|
|
|
获取CAD模型中面、边、顶点之间的邻接关系 |
|
|
|
|
|
|
|
参数: |
|
|
|
shape: CAD模型的形状对象 |
|
|
|
|
|
|
|
返回: |
|
|
|
edgeFace_adj: 边-面邻接矩阵 (num_edges × num_faces) |
|
|
|
faceEdge_adj: 面-边邻接矩阵 (num_faces × num_edges) |
|
|
|
edgeCorner_adj: 边-顶点邻接矩阵 (num_edges × 2) |
|
|
|
""" |
|
|
|
# 创建边-面映射关系 |
|
|
|
edge_face_map = TopTools_IndexedDataMapOfShapeListOfShape() |
|
|
|
topexp.MapShapesAndAncestors(shape, TopAbs_EDGE, TopAbs_FACE, edge_face_map) |
|
|
|
|
|
|
|
# 获取所有面和边 |
|
|
|
faces = [] |
|
|
|
edges = [] |
|
|
|
vertices = [] |
|
|
|
# 获取所有几何元素 |
|
|
|
faces = [] # 存储所有面 |
|
|
|
edges = [] # 存储所有边 |
|
|
|
vertices = [] # 存储所有顶点 |
|
|
|
|
|
|
|
# 创建拓扑结构探索器 |
|
|
|
face_explorer = TopExp_Explorer(shape, TopAbs_FACE) |
|
|
|
edge_explorer = TopExp_Explorer(shape, TopAbs_EDGE) |
|
|
|
vertex_explorer = TopExp_Explorer(shape, TopAbs_VERTEX) |
|
|
|
|
|
|
|
# 收集所有面、边和顶点 |
|
|
|
# 收集所有几何元素 |
|
|
|
while face_explorer.More(): |
|
|
|
faces.append(topods.Face(face_explorer.Current())) |
|
|
|
face_explorer.Next() |
|
|
@ -99,9 +158,8 @@ def get_adjacency_info(shape): |
|
|
|
|
|
|
|
# 填充边-面邻接矩阵 |
|
|
|
for i, edge in enumerate(edges): |
|
|
|
# 获取与边相连的面 |
|
|
|
# 检查每个面是否与当前边相连 |
|
|
|
for j, face in enumerate(faces): |
|
|
|
# 使用 explorer 检查边是否属于面 |
|
|
|
edge_explorer = TopExp_Explorer(face, TopAbs_EDGE) |
|
|
|
while edge_explorer.More(): |
|
|
|
if edge.IsSame(edge_explorer.Current()): |
|
|
@ -110,16 +168,16 @@ def get_adjacency_info(shape): |
|
|
|
break |
|
|
|
edge_explorer.Next() |
|
|
|
|
|
|
|
# 获取边的顶点 |
|
|
|
# 获取边的两个端点 |
|
|
|
v1 = TopoDS_Vertex() |
|
|
|
v2 = TopoDS_Vertex() |
|
|
|
topexp.Vertices(edge, v1, v2) |
|
|
|
|
|
|
|
# 记录边的端点索引 |
|
|
|
if not v1.IsNull() and not v2.IsNull(): |
|
|
|
v1_vertex = topods.Vertex(v1) |
|
|
|
v2_vertex = topods.Vertex(v2) |
|
|
|
|
|
|
|
# 找到顶点的索引 |
|
|
|
for k, vertex in enumerate(vertices): |
|
|
|
if v1_vertex.IsSame(vertex): |
|
|
|
edgeCorner_adj[i, 0] = k |
|
|
@ -129,7 +187,16 @@ def get_adjacency_info(shape): |
|
|
|
return edgeFace_adj, faceEdge_adj, edgeCorner_adj |
|
|
|
|
|
|
|
def get_bbox(shape, subshape): |
|
|
|
"""计算包围盒""" |
|
|
|
""" |
|
|
|
计算形状的包围盒 |
|
|
|
|
|
|
|
参数: |
|
|
|
shape: 完整的CAD模型形状 |
|
|
|
subshape: 需要计算包围盒的子形状(面或边) |
|
|
|
|
|
|
|
返回: |
|
|
|
包围盒的六个参数 [xmin, ymin, zmin, xmax, ymax, zmax] |
|
|
|
""" |
|
|
|
bbox = Bnd_Box() |
|
|
|
brepbndlib.Add(subshape, bbox) |
|
|
|
xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get() |
|
|
@ -247,45 +314,32 @@ def load_step(step_path): |
|
|
|
reader.TransferRoots() |
|
|
|
return [reader.OneShape()] |
|
|
|
|
|
|
|
def process(step_path, timeout=300): |
|
|
|
def process_single_step( |
|
|
|
step_path:str, |
|
|
|
output_path:str=None, |
|
|
|
timeout:int=300 |
|
|
|
) -> dict: |
|
|
|
"""Process single STEP file""" |
|
|
|
try: |
|
|
|
# Check single solid |
|
|
|
cad_solid = load_step(step_path) |
|
|
|
if len(cad_solid)!=1: |
|
|
|
print('Skipping multi solids...') |
|
|
|
return 0 |
|
|
|
|
|
|
|
# Start data parsing |
|
|
|
# 解析STEP文件 |
|
|
|
data = parse_solid(step_path) |
|
|
|
if data is None: |
|
|
|
print ('Exceeding threshold...') |
|
|
|
return 0 |
|
|
|
|
|
|
|
# Save the parsed result |
|
|
|
if 'furniture' in step_path: |
|
|
|
data_uid = step_path.split('/')[-2] + '_' + step_path.split('/')[-1] |
|
|
|
sub_folder = step_path.split('/')[-3] |
|
|
|
else: |
|
|
|
data_uid = step_path.split('/')[-2] |
|
|
|
sub_folder = data_uid[:4] |
|
|
|
|
|
|
|
if data_uid.endswith('.step'): |
|
|
|
data_uid = data_uid[:-5] |
|
|
|
|
|
|
|
data['uid'] = data_uid |
|
|
|
save_folder = os.path.join(OUTPUT, sub_folder) |
|
|
|
if not os.path.exists(save_folder): |
|
|
|
os.makedirs(save_folder) |
|
|
|
|
|
|
|
save_path = os.path.join(save_folder, data['uid']+'.pkl') |
|
|
|
with open(save_path, "wb") as tf: |
|
|
|
pickle.dump(data, tf) |
|
|
|
|
|
|
|
return 1 |
|
|
|
|
|
|
|
if data is None: |
|
|
|
logger.error("Failed to parse STEP file") |
|
|
|
return None |
|
|
|
# 保存结果 |
|
|
|
if output_path: |
|
|
|
try: |
|
|
|
logger.info(f"Saving results to: {output_path}") |
|
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True) |
|
|
|
with open(output_path, 'wb') as f: |
|
|
|
pickle.dump(data, f) |
|
|
|
logger.info("Results saved successfully") |
|
|
|
except Exception as e: |
|
|
|
logger.error(f'Not saving due to error: {str(e)}') |
|
|
|
|
|
|
|
return data |
|
|
|
except Exception as e: |
|
|
|
print('not saving due to error...', str(e)) |
|
|
|
logger.error(f'Not saving due to error: {str(e)}') |
|
|
|
return 0 |
|
|
|
|
|
|
|
def test(step_file_path, output_path=None): |
|
|
@ -293,120 +347,145 @@ def test(step_file_path, output_path=None): |
|
|
|
测试函数:转换单个STEP文件并保存结果 |
|
|
|
""" |
|
|
|
try: |
|
|
|
print(f"Processing STEP file: {step_file_path}") |
|
|
|
logger.info(f"Processing STEP file: {step_file_path}") |
|
|
|
|
|
|
|
# 解析STEP文件 |
|
|
|
data = parse_solid(step_file_path) |
|
|
|
if data is None: |
|
|
|
print("Failed to parse STEP file") |
|
|
|
logger.error("Failed to parse STEP file") |
|
|
|
return None |
|
|
|
|
|
|
|
# 打印统计信息 |
|
|
|
print("\nStatistics:") |
|
|
|
print(f"Number of surfaces: {len(data['surf_wcs'])}") |
|
|
|
print(f"Number of edges: {len(data['edge_wcs'])}") |
|
|
|
print(f"Number of corners: {len(data['corner_unique'])}") |
|
|
|
logger.info("\nStatistics:") |
|
|
|
logger.info(f"Number of surfaces: {len(data['surf_wcs'])}") |
|
|
|
logger.info(f"Number of edges: {len(data['edge_wcs'])}") |
|
|
|
logger.info(f"Number of corners: {len(data['corner_unique'])}") |
|
|
|
|
|
|
|
# 保存结果 |
|
|
|
if output_path: |
|
|
|
print(f"\nSaving results to: {output_path}") |
|
|
|
logger.info(f"Saving results to: {output_path}") |
|
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True) |
|
|
|
with open(output_path, 'wb') as f: |
|
|
|
pickle.dump(data, f) |
|
|
|
print("Results saved successfully") |
|
|
|
logger.info("Results saved successfully") |
|
|
|
|
|
|
|
return data |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
print(f"Error processing STEP file: {str(e)}") |
|
|
|
logger.error(f"Error processing STEP file: {str(e)}") |
|
|
|
return None |
|
|
|
|
|
|
|
def load_furniture_step(data_path): |
|
|
|
"""Load furniture STEP files""" |
|
|
|
step_dirs = [] |
|
|
|
def process_furniture_step(data_path): |
|
|
|
""" |
|
|
|
处理家具数据集的STEP文件 |
|
|
|
|
|
|
|
参数: |
|
|
|
data_path: 数据集路径 |
|
|
|
|
|
|
|
返回: |
|
|
|
包含训练、验证和测试集的STEP文件路径字典 |
|
|
|
{ |
|
|
|
'train': [step_file_path1, step_file_path2, ...], |
|
|
|
'val': [step_file_path1, step_file_path2, ...], |
|
|
|
'test': [step_file_path1, step_file_path2, ...] |
|
|
|
} |
|
|
|
""" |
|
|
|
|
|
|
|
step_dirs = {} |
|
|
|
for split in ['train', 'val', 'test']: |
|
|
|
tmp_step_dirs = [] |
|
|
|
split_path = os.path.join(data_path, split) |
|
|
|
if os.path.exists(split_path): |
|
|
|
for f in os.listdir(split_path): |
|
|
|
if f.endswith('.step'): |
|
|
|
step_dirs.append(os.path.join(split_path, f)) |
|
|
|
tmp_step_dirs.append(f) |
|
|
|
step_dirs[split] = tmp_step_dirs |
|
|
|
return step_dirs |
|
|
|
|
|
|
|
def load_abc_step(data_path, is_deepcad=False): |
|
|
|
"""Load ABC/DeepCAD STEP files""" |
|
|
|
step_dirs = [] |
|
|
|
for f in sorted(os.listdir(data_path)): |
|
|
|
if os.path.isdir(os.path.join(data_path, f)): |
|
|
|
if is_deepcad: |
|
|
|
step_path = os.path.join(data_path, f, f+'.step') |
|
|
|
else: |
|
|
|
step_path = os.path.join(data_path, f, 'shape.step') |
|
|
|
if os.path.exists(step_path): |
|
|
|
step_dirs.append(step_path) |
|
|
|
return step_dirs |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
""" |
|
|
|
主函数:处多个STEP文件 |
|
|
|
主函数:处理多个STEP文件 |
|
|
|
""" |
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("--input", type=str, help="Data folder path", required=True) |
|
|
|
parser.add_argument("--option", type=str, choices=['abc', 'deepcad', 'furniture'], default='abc', |
|
|
|
help="Choose between dataset option [abc/deepcad/furniture] (default: abc)") |
|
|
|
parser.add_argument("--interval", type=int, help="Data range index, only required for abc/deepcad") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
global OUTPUT |
|
|
|
if args.option == 'deepcad': |
|
|
|
OUTPUT = 'deepcad_parsed' |
|
|
|
elif args.option == 'abc': |
|
|
|
OUTPUT = 'abc_parsed' |
|
|
|
else: |
|
|
|
OUTPUT = 'furniture_parsed' |
|
|
|
# 定义路径常量 |
|
|
|
INPUT = '/mnt/disk2/dataset/furniture/step/furniture_dataset_step/' |
|
|
|
OUTPUT = 'test_data/pkl/' |
|
|
|
RESULT = 'test_data/result/' # 用于存储成功/失败文件记录 |
|
|
|
|
|
|
|
# Load all STEP files |
|
|
|
if args.option == 'furniture': |
|
|
|
step_dirs = load_furniture_step(args.input) |
|
|
|
else: |
|
|
|
step_dirs = load_abc_step(args.input, args.option=='deepcad') |
|
|
|
step_dirs = step_dirs[args.interval*10000 : (args.interval+1)*10000] |
|
|
|
|
|
|
|
# Process B-reps in parallel |
|
|
|
valid = 0 |
|
|
|
with ProcessPoolExecutor(max_workers=os.cpu_count() // 2) as executor: |
|
|
|
futures = {} |
|
|
|
for step_folder in step_dirs: |
|
|
|
future = executor.submit(process, step_folder, timeout=300) |
|
|
|
futures[future] = step_folder |
|
|
|
|
|
|
|
for future in tqdm(as_completed(futures), total=len(step_dirs)): |
|
|
|
try: |
|
|
|
status = future.result(timeout=300) |
|
|
|
valid += status |
|
|
|
except TimeoutError: |
|
|
|
print(f"Timeout occurred while processing {futures[future]}") |
|
|
|
except Exception as e: |
|
|
|
print(f"An error occurred while processing {futures[future]}: {e}") |
|
|
|
|
|
|
|
print(f'Done... Data Converted Ratio {100.0*valid/len(step_dirs)}%') |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
import sys |
|
|
|
# 确保输出目录存在 |
|
|
|
os.makedirs(OUTPUT, exist_ok=True) |
|
|
|
os.makedirs(RESULT, exist_ok=True) |
|
|
|
|
|
|
|
# 获取所有STEP文件 |
|
|
|
step_dirs_dict = process_furniture_step(INPUT) |
|
|
|
total_processed = 0 |
|
|
|
total_success = 0 |
|
|
|
|
|
|
|
if len(sys.argv) > 1 and sys.argv[1] == '--test': |
|
|
|
# 测试模式 |
|
|
|
if len(sys.argv) < 3: |
|
|
|
print("Usage: python process_brep.py --test <step_file_path> [output_path]") |
|
|
|
sys.exit(1) |
|
|
|
# 按数据集分割处理文件 |
|
|
|
for split in ['train', 'val', 'test']: |
|
|
|
current_step_dirs = step_dirs_dict[split] |
|
|
|
if not current_step_dirs: |
|
|
|
logger.warning(f"No files found in {split} split") |
|
|
|
continue |
|
|
|
|
|
|
|
step_file = sys.argv[2] |
|
|
|
output_file = sys.argv[3] if len(sys.argv) > 3 else None |
|
|
|
# 确保分割目录存在 |
|
|
|
split_output_dir = os.path.join(OUTPUT, split) |
|
|
|
split_result_dir = os.path.join(RESULT, split) |
|
|
|
os.makedirs(split_output_dir, exist_ok=True) |
|
|
|
os.makedirs(split_result_dir, exist_ok=True) |
|
|
|
|
|
|
|
print("Running in test mode...") |
|
|
|
result = test(step_file, output_file) |
|
|
|
success_files = [] # 存储成功处理的文件名 |
|
|
|
failed_files = [] # 存储失败的文件名及原因 |
|
|
|
|
|
|
|
if result is not None: |
|
|
|
print("\nTest completed successfully!") |
|
|
|
# 并行处理文件 |
|
|
|
with ProcessPoolExecutor(max_workers=os.cpu_count() // 2) as executor: |
|
|
|
futures = {} |
|
|
|
for step_file in current_step_dirs: |
|
|
|
input_path = os.path.join(INPUT, split, step_file) |
|
|
|
output_path = os.path.join(split_output_dir, step_file.replace('.step', '.pkl')) |
|
|
|
future = executor.submit(process_single_step, input_path, output_path, timeout=300) |
|
|
|
futures[future] = step_file |
|
|
|
|
|
|
|
# 处理结果 |
|
|
|
for future in tqdm(as_completed(futures), total=len(current_step_dirs), |
|
|
|
desc=f"Processing {split} set"): |
|
|
|
try: |
|
|
|
status = future.result(timeout=300) |
|
|
|
if status is not None: |
|
|
|
success_files.append(futures[future]) |
|
|
|
total_success += 1 |
|
|
|
except TimeoutError: |
|
|
|
logger.error(f"Timeout occurred while processing {futures[future]}") |
|
|
|
failed_files.append((futures[future], "Timeout")) |
|
|
|
except Exception as e: |
|
|
|
logger.error(f"Error processing {futures[future]}: {str(e)}") |
|
|
|
failed_files.append((futures[future], str(e))) |
|
|
|
finally: |
|
|
|
total_processed += 1 |
|
|
|
|
|
|
|
# 保存成功文件列表 |
|
|
|
success_file_path = os.path.join(split_result_dir, 'success.txt') |
|
|
|
with open(success_file_path, 'w', encoding='utf-8') as f: |
|
|
|
f.write('\n'.join(success_files)) |
|
|
|
logger.info(f"Saved {len(success_files)} successful files to {success_file_path}") |
|
|
|
|
|
|
|
# 保存失败文件列表(包含错误信息) |
|
|
|
failed_file_path = os.path.join(split_result_dir, 'failed.txt') |
|
|
|
with open(failed_file_path, 'w', encoding='utf-8') as f: |
|
|
|
for file, error in failed_files: |
|
|
|
f.write(f"{file}: {error}\n") |
|
|
|
logger.info(f"Saved {len(failed_files)} failed files to {failed_file_path}") |
|
|
|
|
|
|
|
# 打印最终统计信息 |
|
|
|
if total_processed > 0: |
|
|
|
success_rate = (total_success / total_processed) * 100 |
|
|
|
logger.info(f"Processing completed:") |
|
|
|
logger.info(f"Total files processed: {total_processed}") |
|
|
|
logger.info(f"Successfully processed: {total_success}") |
|
|
|
logger.info(f"Success rate: {success_rate:.2f}%") |
|
|
|
else: |
|
|
|
# 正常批处理模式 |
|
|
|
main() |
|
|
|
logger.warning("No files were processed") |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
main() |
|
|
|