You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
615 lines
24 KiB
615 lines
24 KiB
"""
|
|
CAD模型处理脚本
|
|
功能:将STEP格式的CAD模型转换为结构化数据,包括:
|
|
- 几何信息:面、边、顶点的坐标数据
|
|
- 拓扑信息:面-边-顶点的邻接关系
|
|
- 空间信息:包围盒数据
|
|
"""
|
|
|
|
import os
|
|
import pickle # 用于数据序列化
|
|
import argparse # 命令行参数解析
|
|
import numpy as np
|
|
from tqdm import tqdm # 进度条显示
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError # 并行处理
|
|
import logging
|
|
from datetime import datetime
|
|
from brep2sdf.utils.logger import logger
|
|
|
|
|
|
# 导入OpenCASCADE相关库
|
|
from OCC.Core.STEPControl import STEPControl_Reader # STEP文件读取器
|
|
from OCC.Core.TopExp import TopExp_Explorer, topexp # 拓扑结构遍历
|
|
from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX # 拓扑类型定义
|
|
from OCC.Core.BRep import BRep_Tool # B-rep工具
|
|
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh # 网格剖分
|
|
from OCC.Core.TopLoc import TopLoc_Location # 位置变换
|
|
from OCC.Core.IFSelect import IFSelect_RetDone # 操作状态码
|
|
from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射
|
|
from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算
|
|
from OCC.Core.Bnd import Bnd_Box # 包围盒
|
|
from OCC.Core.TopoDS import TopoDS_Shape, topods, TopoDS_Vertex # 拓扑数据结构
|
|
|
|
# 导入配置
|
|
from brep2sdf.config.default_config import get_default_config
|
|
config = get_default_config()
|
|
|
|
# 设置最大面数阈值,用于加速处理
|
|
MAX_FACE = config.data.max_face
|
|
|
|
def normalize(surfs, edges, corners):
|
|
"""
|
|
将CAD模型归一化到单位立方体空间
|
|
|
|
参数:
|
|
surfs: 面的点集列表
|
|
edges: 边的点集列表
|
|
corners: 顶点坐标数组 [num_edges, 2, 3]
|
|
|
|
返回:
|
|
surfs_wcs: 原始坐标系下的面点集
|
|
edges_wcs: 原始坐标系下的边点集
|
|
surfs_ncs: 归一化坐标系下的面点集
|
|
edges_ncs: 归一化坐标系下的边点集
|
|
corner_wcs: 归一化后的顶点坐标 [num_edges, 2, 3]
|
|
"""
|
|
if len(corners) == 0:
|
|
return None, None, None, None, None
|
|
|
|
# 计算包围盒和缩放因子
|
|
corners_array = corners.reshape(-1, 3) # [num_edges*2, 3]
|
|
center = (corners_array.max(0) + corners_array.min(0)) / 2 # 计算中心点
|
|
scale = 1.0 / (corners_array.max(0) - corners_array.min(0)).max() # 计算缩放系数
|
|
|
|
# 归一化面的点集
|
|
surfs_wcs = [] # 原始世界坐标系下的面点集
|
|
surfs_ncs = [] # 归一化坐标系下的面点集
|
|
for surf in surfs:
|
|
surf_wcs = np.array(surf)
|
|
surf_ncs = (surf_wcs - center) * scale # 归一化变换
|
|
surfs_wcs.append(surf_wcs)
|
|
surfs_ncs.append(surf_ncs)
|
|
|
|
# 归一化边的点集
|
|
edges_wcs = [] # 原始世界坐标系下的边点集
|
|
edges_ncs = [] # 归一化坐标系下的边点集
|
|
for edge in edges:
|
|
edge_wcs = np.array(edge)
|
|
edge_ncs = (edge_wcs - center) * scale # 归一化变换
|
|
edges_wcs.append(edge_wcs)
|
|
edges_ncs.append(edge_ncs)
|
|
|
|
# 归一化顶点坐标 - 保持[num_edges, 2, 3]的形状
|
|
corner_wcs = (corners - center) * scale # 广播操作会保持原有维度
|
|
|
|
return (np.array(surfs_wcs, dtype=object),
|
|
np.array(edges_wcs, dtype=object),
|
|
np.array(surfs_ncs, dtype=object),
|
|
np.array(edges_ncs, dtype=object),
|
|
corner_wcs)
|
|
|
|
def get_adjacency_info(shape):
|
|
"""
|
|
获取CAD模型中面、边、顶点之间的邻接关系
|
|
|
|
参数:
|
|
shape: CAD模型的形状对象
|
|
|
|
返回:
|
|
edgeFace_adj: 边-面邻接矩阵 (num_edges × num_faces)
|
|
faceEdge_adj: 面-边邻接矩阵 (num_faces × num_edges)
|
|
edgeCorner_adj: 边-顶点邻接矩阵 (num_edges × 2)
|
|
"""
|
|
# 创建边-面映射关系
|
|
edge_face_map = TopTools_IndexedDataMapOfShapeListOfShape()
|
|
topexp.MapShapesAndAncestors(shape, TopAbs_EDGE, TopAbs_FACE, edge_face_map)
|
|
|
|
# 获取所有几何元素
|
|
faces = [] # 存储所有面
|
|
edges = [] # 存储所有边
|
|
vertices = [] # 存储所有顶点
|
|
|
|
# 创建拓扑结构探索器
|
|
face_explorer = TopExp_Explorer(shape, TopAbs_FACE)
|
|
edge_explorer = TopExp_Explorer(shape, TopAbs_EDGE)
|
|
vertex_explorer = TopExp_Explorer(shape, TopAbs_VERTEX)
|
|
|
|
# 收集所有几何元素
|
|
while face_explorer.More():
|
|
faces.append(topods.Face(face_explorer.Current()))
|
|
face_explorer.Next()
|
|
|
|
while edge_explorer.More():
|
|
edges.append(topods.Edge(edge_explorer.Current()))
|
|
edge_explorer.Next()
|
|
|
|
while vertex_explorer.More():
|
|
vertices.append(topods.Vertex(vertex_explorer.Current()))
|
|
vertex_explorer.Next()
|
|
|
|
# 创建邻接矩阵
|
|
num_faces = len(faces)
|
|
num_edges = len(edges)
|
|
num_vertices = len(vertices)
|
|
|
|
edgeFace_adj = np.zeros((num_edges, num_faces), dtype=np.int32)
|
|
faceEdge_adj = np.zeros((num_faces, num_edges), dtype=np.int32)
|
|
edgeCorner_adj = np.zeros((num_edges, 2), dtype=np.int32)
|
|
|
|
# 填充边-面邻接矩阵
|
|
for i, edge in enumerate(edges):
|
|
# 检查每个面是否与当前边相连
|
|
for j, face in enumerate(faces):
|
|
edge_explorer = TopExp_Explorer(face, TopAbs_EDGE)
|
|
while edge_explorer.More():
|
|
if edge.IsSame(edge_explorer.Current()):
|
|
edgeFace_adj[i, j] = 1
|
|
faceEdge_adj[j, i] = 1
|
|
break
|
|
edge_explorer.Next()
|
|
|
|
# 获取边的两个端点
|
|
v1 = TopoDS_Vertex()
|
|
v2 = TopoDS_Vertex()
|
|
topexp.Vertices(edge, v1, v2)
|
|
|
|
# 记录边的端点索引
|
|
if not v1.IsNull() and not v2.IsNull():
|
|
v1_vertex = topods.Vertex(v1)
|
|
v2_vertex = topods.Vertex(v2)
|
|
|
|
for k, vertex in enumerate(vertices):
|
|
if v1_vertex.IsSame(vertex):
|
|
edgeCorner_adj[i, 0] = k
|
|
if v2_vertex.IsSame(vertex):
|
|
edgeCorner_adj[i, 1] = k
|
|
|
|
return edgeFace_adj, faceEdge_adj, edgeCorner_adj
|
|
|
|
def get_bbox(shape, subshape):
|
|
"""
|
|
计算形状的包围盒
|
|
|
|
参数:
|
|
shape: 完整的CAD模型形状
|
|
subshape: 需要计算包围盒的子形状(面或边)
|
|
|
|
返回:
|
|
包围盒的六个参数 [xmin, ymin, zmin, xmax, ymax, zmax]
|
|
"""
|
|
bbox = Bnd_Box()
|
|
brepbndlib.Add(subshape, bbox)
|
|
xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get()
|
|
return np.array([xmin, ymin, zmin, xmax, ymax, zmax])
|
|
|
|
|
|
|
|
def parse_solid(step_path):
|
|
"""
|
|
解析STEP文件中的CAD模型数据
|
|
|
|
返回:
|
|
dict: 包含以下键值对的字典:
|
|
# 几何数据
|
|
'surf_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示面的点云坐标
|
|
'edge_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示边的采样点坐标
|
|
'surf_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示归一化后的面点云
|
|
'edge_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示归一化后的边采样点
|
|
'corner_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 2, 3)的数组,表示每条边的两个端点坐标
|
|
'corner_unique': np.ndarray(dtype=float32) # 形状为(num_vertices, 3)的数组,表示所有顶点的唯一坐标,num_vertices <= num_edges * 2
|
|
|
|
# 拓扑关系
|
|
'edgeFace_adj': np.ndarray(dtype=int32) # 形状为(num_edges, num_faces)的数组,表示边-面邻接关系
|
|
'edgeCorner_adj': np.ndarray(dtype=int32) # 形状为(num_edges, 2)的数组,表示边-顶点邻接关系
|
|
'faceEdge_adj': np.ndarray(dtype=int32) # 形状为(num_faces, num_edges)的数组,表示面-边邻接关系
|
|
|
|
# 包围盒数据
|
|
'surf_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_faces, 6)的数组,表示每个面的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
|
|
'edge_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 6)的数组,表示每条边的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
|
|
"""
|
|
# Load STEP file
|
|
reader = STEPControl_Reader()
|
|
status = reader.ReadFile(step_path)
|
|
if status != IFSelect_RetDone:
|
|
raise Exception("Failed to read STEP file")
|
|
|
|
reader.TransferRoots()
|
|
shape = reader.OneShape()
|
|
|
|
# Create mesh
|
|
mesh = BRepMesh_IncrementalMesh(shape, 0.01)
|
|
mesh.Perform()
|
|
|
|
# Initialize explorers
|
|
face_explorer = TopExp_Explorer(shape, TopAbs_FACE)
|
|
edge_explorer = TopExp_Explorer(shape, TopAbs_EDGE)
|
|
vertex_explorer = TopExp_Explorer(shape, TopAbs_VERTEX)
|
|
|
|
face_pnts = []
|
|
edge_pnts = []
|
|
corner_pnts = []
|
|
surf_bbox_wcs = []
|
|
edge_bbox_wcs = []
|
|
|
|
# Extract face points
|
|
while face_explorer.More():
|
|
face = topods.Face(face_explorer.Current())
|
|
loc = TopLoc_Location()
|
|
triangulation = BRep_Tool.Triangulation(face, loc)
|
|
|
|
if triangulation is not None:
|
|
points = []
|
|
for i in range(1, triangulation.NbNodes() + 1):
|
|
node = triangulation.Node(i)
|
|
pnt = node.Transformed(loc.Transformation())
|
|
points.append([pnt.X(), pnt.Y(), pnt.Z()])
|
|
|
|
if points:
|
|
points = np.array(points, dtype=np.float32)
|
|
if len(points.shape) == 2 and points.shape[1] == 3:
|
|
# 确保每个面至少有一些点
|
|
if len(points) < 3: # 如果点数太少,跳过这个面
|
|
continue
|
|
face_pnts.append(points)
|
|
surf_bbox_wcs.append(get_bbox(shape, face))
|
|
|
|
face_explorer.Next()
|
|
|
|
# Extract edge points
|
|
num_samples = config.model.num_edge_points # 使用配置中的边采样点数
|
|
while edge_explorer.More():
|
|
edge = topods.Edge(edge_explorer.Current())
|
|
curve, first, last = BRep_Tool.Curve(edge)
|
|
|
|
if curve is not None:
|
|
points = []
|
|
for i in range(num_samples):
|
|
param = first + (last - first) * float(i) / (num_samples - 1)
|
|
pnt = curve.Value(param)
|
|
points.append([pnt.X(), pnt.Y(), pnt.Z()])
|
|
|
|
if points:
|
|
points = np.array(points, dtype=np.float32)
|
|
if len(points.shape) == 2 and points.shape[1] == 3:
|
|
edge_pnts.append(points) # 现在points是(num_edge_points, 3)形状
|
|
edge_bbox_wcs.append(get_bbox(shape, edge))
|
|
|
|
edge_explorer.Next()
|
|
|
|
# Extract vertex points
|
|
while vertex_explorer.More():
|
|
vertex = topods.Vertex(vertex_explorer.Current())
|
|
pnt = BRep_Tool.Pnt(vertex)
|
|
corner_pnts.append([pnt.X(), pnt.Y(), pnt.Z()])
|
|
vertex_explorer.Next()
|
|
|
|
# 获取邻接信息
|
|
edgeFace_adj, faceEdge_adj, edgeCorner_adj = get_adjacency_info(shape)
|
|
|
|
# 转换为numpy数组时确保类型正确
|
|
face_pnts = [np.array(points, dtype=np.float32) for points in face_pnts]
|
|
edge_pnts = [np.array(points, dtype=np.float32) for points in edge_pnts]
|
|
|
|
# 转换为对象数组
|
|
face_pnts = np.array(face_pnts, dtype=object)
|
|
edge_pnts = np.array(edge_pnts, dtype=object)
|
|
corner_pnts = np.array(corner_pnts, dtype=np.float32)
|
|
|
|
# 重组顶点数据为每条边两个端点的形式
|
|
corner_pairs = []
|
|
for edge_idx in range(len(edge_pnts)):
|
|
v1_idx, v2_idx = edgeCorner_adj[edge_idx]
|
|
v1_pos = corner_pnts[v1_idx]
|
|
v2_pos = corner_pnts[v2_idx]
|
|
# 按坐标排序确保顺序一致
|
|
if (v1_pos > v2_pos).any():
|
|
v1_pos, v2_pos = v2_pos, v1_pos
|
|
corner_pairs.append(np.stack([v1_pos, v2_pos]))
|
|
|
|
corner_pairs = np.stack(corner_pairs).astype(np.float32) # [num_edges, 2, 3]
|
|
|
|
# 确保所有数组都有正确的类型
|
|
surf_bbox_wcs = np.array(surf_bbox_wcs, dtype=np.float32)
|
|
edge_bbox_wcs = np.array(edge_bbox_wcs, dtype=np.float32)
|
|
|
|
# Normalize the CAD model
|
|
surfs_wcs, edges_wcs, surfs_ncs, edges_ncs, corner_wcs = normalize(
|
|
face_pnts, edge_pnts, corner_pairs)
|
|
|
|
# 验证归一化后的数据
|
|
if any(x is None for x in [surfs_wcs, edges_wcs, surfs_ncs, edges_ncs, corner_wcs]):
|
|
logger.error(f"Normalization failed for {step_path}")
|
|
return None
|
|
|
|
# 创建结果字典并确保所有数组都有正确的类型
|
|
data = {
|
|
'surf_wcs': np.array(surfs_wcs, dtype=object), # 保持对象数组
|
|
'edge_wcs': np.array(edges_wcs, dtype=object), # 保持对象数组
|
|
'surf_ncs': np.array(surfs_ncs, dtype=object), # 保持对象数组
|
|
'edge_ncs': np.array(edges_ncs, dtype=object), # 保持对象数组
|
|
'corner_wcs': corner_wcs.astype(np.float32), # [num_edges, 2, 3]
|
|
'edgeFace_adj': edgeFace_adj.astype(np.int32),
|
|
'edgeCorner_adj': edgeCorner_adj.astype(np.int32),
|
|
'faceEdge_adj': faceEdge_adj.astype(np.int32),
|
|
'surf_bbox_wcs': surf_bbox_wcs.astype(np.float32),
|
|
'edge_bbox_wcs': edge_bbox_wcs.astype(np.float32),
|
|
'corner_unique': np.unique(corner_wcs.reshape(-1, 3), axis=0).astype(np.float32) # 先展平再去重
|
|
}
|
|
|
|
return data
|
|
|
|
def load_step(step_path):
|
|
"""Load STEP file and return solids"""
|
|
reader = STEPControl_Reader()
|
|
reader.ReadFile(step_path)
|
|
reader.TransferRoots()
|
|
return [reader.OneShape()]
|
|
|
|
def check_data_format(data, step_file):
|
|
"""检查数据格式是否正确"""
|
|
required_keys = [
|
|
'surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs', 'corner_wcs',
|
|
'edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj',
|
|
'surf_bbox_wcs', 'edge_bbox_wcs', 'corner_unique'
|
|
]
|
|
|
|
# 检查所有必需的键是否存在
|
|
for key in required_keys:
|
|
if key not in data:
|
|
return False, f"Missing key: {key}"
|
|
|
|
# 检查几何数据
|
|
geometry_arrays = ['surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs']
|
|
for key in geometry_arrays:
|
|
if not isinstance(data[key], np.ndarray):
|
|
return False, f"{key} should be a numpy array"
|
|
# 允许对象数组
|
|
if data[key].dtype != object:
|
|
return False, f"{key} should be a numpy array with dtype=object"
|
|
|
|
# 检查其他数组
|
|
float32_arrays = ['corner_wcs', 'surf_bbox_wcs', 'edge_bbox_wcs', 'corner_unique']
|
|
for key in float32_arrays:
|
|
if not isinstance(data[key], np.ndarray):
|
|
return False, f"{key} should be a numpy array"
|
|
if data[key].dtype != np.float32:
|
|
return False, f"{key} should be a numpy array with dtype=float32"
|
|
|
|
int32_arrays = ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj']
|
|
for key in int32_arrays:
|
|
if not isinstance(data[key], np.ndarray):
|
|
return False, f"{key} should be a numpy array"
|
|
if data[key].dtype != np.int32:
|
|
return False, f"{key} should be a numpy array with dtype=int32"
|
|
|
|
return True, ""
|
|
|
|
def process_single_step(step_path:str, output_path:str=None, timeout:int=300) -> dict:
|
|
"""处理单个STEP文件"""
|
|
try:
|
|
# 解析STEP文件
|
|
data = parse_solid(step_path)
|
|
if data is None:
|
|
logger.error(f"Failed to parse STEP file: {step_path}")
|
|
return None
|
|
|
|
# 检查数据格式
|
|
is_valid, msg = check_data_format(data, step_path)
|
|
if not is_valid:
|
|
logger.error(f"Data format check failed for {step_path}: {msg}")
|
|
return None
|
|
|
|
# 保存结果
|
|
if output_path:
|
|
try:
|
|
logger.info(f"Saving results to: {output_path}")
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
with open(output_path, 'wb') as f:
|
|
pickle.dump(data, f)
|
|
logger.info(f"Results saved successfully: {output_path}")
|
|
return data
|
|
except Exception as e:
|
|
logger.error(f'Failed to save {output_path}: {str(e)}')
|
|
return None
|
|
|
|
return data
|
|
|
|
except Exception as e:
|
|
logger.error(f'Error processing {step_path}: {str(e)}')
|
|
return None
|
|
|
|
def test(step_file_path, output_path=None):
|
|
"""
|
|
测试函数:转换单个STEP文件并保存结果
|
|
"""
|
|
try:
|
|
logger.info(f"Processing STEP file: {step_file_path}")
|
|
|
|
# 解析STEP文件
|
|
data = parse_solid(step_file_path)
|
|
if data is None:
|
|
logger.error(f"Failed to parse STEP file: {step_file_path}")
|
|
return None
|
|
|
|
# 检查数据格式
|
|
is_valid, msg = check_data_format(data, step_file_path)
|
|
if not is_valid:
|
|
logger.error(f"Data format check failed for {step_file_path}: {msg}")
|
|
return None
|
|
|
|
# 打印统计信息
|
|
logger.info("\nStatistics:")
|
|
logger.info(f"Number of surfaces: {len(data['surf_wcs'])}")
|
|
logger.info(f"Number of edges: {len(data['edge_wcs'])}")
|
|
logger.info(f"Number of corners: {len(data['corner_wcs'])}") # 修正为corner_wcs
|
|
|
|
# 保存结果
|
|
if output_path:
|
|
try:
|
|
logger.info(f"Saving results to: {output_path}")
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
with open(output_path, 'wb') as f:
|
|
pickle.dump(data, f)
|
|
logger.info(f"Results saved successfully: {output_path}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to save {output_path}: {str(e)}")
|
|
return None
|
|
|
|
return data
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing {step_file_path}: {str(e)}")
|
|
return None
|
|
|
|
def process_furniture_step(data_path):
|
|
"""
|
|
处理家具数据集的STEP文件
|
|
|
|
参数:
|
|
data_path: 数据集路径
|
|
|
|
返回:
|
|
包含训练、验证和测试集的STEP文件路径字典
|
|
{
|
|
'train': [step_file_path1, step_file_path2, ...],
|
|
'val': [step_file_path1, step_file_path2, ...],
|
|
'test': [step_file_path1, step_file_path2, ...]
|
|
}
|
|
"""
|
|
|
|
step_dirs = {}
|
|
for split in ['train', 'val', 'test']:
|
|
tmp_step_dirs = []
|
|
split_path = os.path.join(data_path, split)
|
|
if os.path.exists(split_path):
|
|
for f in os.listdir(split_path):
|
|
if f.endswith('.step'):
|
|
tmp_step_dirs.append(f)
|
|
step_dirs[split] = tmp_step_dirs
|
|
return step_dirs
|
|
|
|
|
|
def main():
|
|
"""主函数:处理多个STEP文件"""
|
|
parser = argparse.ArgumentParser(description='STEP文件批量处理工具')
|
|
parser.add_argument('-i', '--input_dir', required=True,
|
|
help='输入目录路径,包含STEP文件的文件夹')
|
|
parser.add_argument('-o', '--output_dir', required=True,
|
|
help='输出目录路径,用于保存OBJ文件')
|
|
parser.add_argument('-d', '--deflection', type=float, default=0.01,
|
|
help='网格精度参数 (默认: 0.01)')
|
|
parser.add_argument('-f', '--force', action='store_true',
|
|
help='覆盖已存在的输出文件')
|
|
parser.add_argument('-v', '--verbose', action='store_true',
|
|
help='显示详细处理信息')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# 创建输出目录
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
# 获取所有STEP文件
|
|
step_files = glob.glob(os.path.join(args.input_dir, "**/*.step"), recursive=True)
|
|
|
|
# 清理输出目录
|
|
def clean_directory(directory):
|
|
if os.path.exists(directory):
|
|
logger.info(f"Cleaning directory: {directory}")
|
|
for root, dirs, files in os.walk(directory, topdown=False):
|
|
for name in files:
|
|
os.remove(os.path.join(root, name))
|
|
for name in dirs:
|
|
os.rmdir(os.path.join(root, name))
|
|
logger.info(f"Directory cleaned: {directory}")
|
|
|
|
# 清理之前的输出
|
|
clean_directory(OUTPUT)
|
|
clean_directory(RESULT)
|
|
|
|
# 确保输出目录存在
|
|
os.makedirs(OUTPUT, exist_ok=True)
|
|
os.makedirs(RESULT, exist_ok=True)
|
|
|
|
# 获取所有STEP文件
|
|
step_dirs_dict = process_furniture_step(INPUT)
|
|
total_processed = 0
|
|
total_success = 0
|
|
|
|
# 记录开始时间
|
|
start_time = datetime.now()
|
|
|
|
# 按数据集分割处理文件
|
|
for split in ['train', 'val', 'test']:
|
|
current_step_dirs = step_dirs_dict[split]
|
|
if not current_step_dirs:
|
|
logger.warning(f"No files found in {split} split")
|
|
continue
|
|
|
|
# 确保输出目录存在
|
|
split_output_dir = os.path.join(OUTPUT, split)
|
|
os.makedirs(split_output_dir, exist_ok=True)
|
|
|
|
success_files = [] # 只存储基础文件名(不含扩展名)
|
|
failed_files = [] # 只存储基础文件名(不含扩展名)
|
|
|
|
# 并行处理文件
|
|
with ProcessPoolExecutor(max_workers=os.cpu_count() // 2) as executor:
|
|
futures = {}
|
|
for step_file in current_step_dirs:
|
|
input_path = os.path.join(INPUT, split, step_file)
|
|
output_path = os.path.join(split_output_dir, step_file.replace('.step', '.pkl'))
|
|
future = executor.submit(process_single_step, input_path, output_path, timeout=300)
|
|
futures[future] = step_file
|
|
|
|
# 处理结果
|
|
for future in tqdm(as_completed(futures), total=len(current_step_dirs),
|
|
desc=f"Processing {split} set"):
|
|
step_file = futures[future]
|
|
base_name = step_file.replace('.step', '') # 获取不含扩展名的文件名
|
|
try:
|
|
status = future.result(timeout=300)
|
|
if status is not None:
|
|
success_files.append(base_name)
|
|
total_success += 1
|
|
else:
|
|
failed_files.append(base_name)
|
|
except (TimeoutError, Exception):
|
|
failed_files.append(base_name)
|
|
finally:
|
|
total_processed += 1
|
|
|
|
# 保存处理结果
|
|
os.makedirs(RESULT, exist_ok=True)
|
|
|
|
# 保存成功文件列表 (只保存文件名)
|
|
success_path = os.path.join(RESULT, f'{split}_success.txt')
|
|
with open(success_path, 'w') as f:
|
|
f.write('\n'.join(success_files))
|
|
|
|
# 保存失败文件列表 (只保存文件名)
|
|
failed_path = os.path.join(RESULT, f'{split}_failed.txt')
|
|
with open(failed_path, 'w') as f:
|
|
f.write('\n'.join(failed_files))
|
|
|
|
logger.info(f"{split} set - Success: {len(success_files)}, Failed: {len(failed_files)}")
|
|
|
|
# 打印最终统计信息
|
|
end_time = datetime.now()
|
|
duration = end_time - start_time
|
|
|
|
if total_processed > 0:
|
|
success_rate = (total_success / total_processed) * 100
|
|
logger.info("\nProcessing Summary:")
|
|
logger.info(f"Start time: {start_time}")
|
|
logger.info(f"End time: {end_time}")
|
|
logger.info(f"Duration: {duration}")
|
|
logger.info(f"Total files processed: {total_processed}")
|
|
logger.info(f"Successfully processed: {total_success}")
|
|
logger.info(f"Failed: {total_processed - total_success}")
|
|
logger.info(f"Success rate: {success_rate:.2f}%")
|
|
else:
|
|
logger.warning("No files were processed")
|
|
|
|
if __name__ == '__main__':
|
|
#main()
|
|
#test("/home/wch/brep2sdf/00000031_ad34a3f60c4a4caa99646600_step_011.step", "/home/wch/brep2sdf/test_data/pkl/train/00000031.pkl")
|
|
|