You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
412 lines
14 KiB
412 lines
14 KiB
import os
|
|
import pickle
|
|
import argparse
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
|
|
|
|
from OCC.Core.STEPControl import STEPControl_Reader
|
|
from OCC.Core.TopExp import TopExp_Explorer, topexp
|
|
from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX
|
|
from OCC.Core.BRep import BRep_Tool
|
|
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh
|
|
from OCC.Core.TopLoc import TopLoc_Location
|
|
from OCC.Core.IFSelect import IFSelect_RetDone
|
|
from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape
|
|
from OCC.Core.BRepBndLib import brepbndlib
|
|
from OCC.Core.Bnd import Bnd_Box
|
|
from OCC.Core.TopoDS import TopoDS_Shape
|
|
from OCC.Core.TopoDS import topods
|
|
from OCC.Core.TopoDS import TopoDS_Vertex
|
|
|
|
# To speed up processing, define maximum threshold
|
|
MAX_FACE = 70
|
|
|
|
def normalize(surfs, edges, corners):
|
|
"""Normalize the CAD model to unit cube"""
|
|
if len(corners) == 0:
|
|
return None, None, None, None, None
|
|
|
|
# Get bounding box
|
|
corners_array = np.array(corners)
|
|
center = (corners_array.max(0) + corners_array.min(0)) / 2
|
|
scale = 1.0 / (corners_array.max(0) - corners_array.min(0)).max()
|
|
|
|
# Normalize surfaces
|
|
surfs_wcs = []
|
|
surfs_ncs = []
|
|
for surf in surfs:
|
|
surf_wcs = np.array(surf) # 确保是numpy数组
|
|
surf_ncs = (surf_wcs - center) * scale
|
|
surfs_wcs.append(surf_wcs)
|
|
surfs_ncs.append(surf_ncs)
|
|
|
|
# Normalize edges
|
|
edges_wcs = []
|
|
edges_ncs = []
|
|
for edge in edges:
|
|
edge_wcs = np.array(edge) # 确保是numpy数组
|
|
edge_ncs = (edge_wcs - center) * scale
|
|
edges_wcs.append(edge_wcs)
|
|
edges_ncs.append(edge_ncs)
|
|
|
|
# Normalize corners
|
|
corner_wcs = (corners_array - center) * scale
|
|
|
|
# 返回时保持列表形式
|
|
return (np.array(surfs_wcs, dtype=object),
|
|
np.array(edges_wcs, dtype=object),
|
|
np.array(surfs_ncs, dtype=object),
|
|
np.array(edges_ncs, dtype=object),
|
|
corner_wcs)
|
|
|
|
def get_adjacency_info(shape):
|
|
"""获取形状的邻接信息"""
|
|
# 创建数据映射
|
|
edge_face_map = TopTools_IndexedDataMapOfShapeListOfShape()
|
|
topexp.MapShapesAndAncestors(shape, TopAbs_EDGE, TopAbs_FACE, edge_face_map)
|
|
|
|
# 获取所有面和边
|
|
faces = []
|
|
edges = []
|
|
vertices = []
|
|
|
|
face_explorer = TopExp_Explorer(shape, TopAbs_FACE)
|
|
edge_explorer = TopExp_Explorer(shape, TopAbs_EDGE)
|
|
vertex_explorer = TopExp_Explorer(shape, TopAbs_VERTEX)
|
|
|
|
# 收集所有面、边和顶点
|
|
while face_explorer.More():
|
|
faces.append(topods.Face(face_explorer.Current()))
|
|
face_explorer.Next()
|
|
|
|
while edge_explorer.More():
|
|
edges.append(topods.Edge(edge_explorer.Current()))
|
|
edge_explorer.Next()
|
|
|
|
while vertex_explorer.More():
|
|
vertices.append(topods.Vertex(vertex_explorer.Current()))
|
|
vertex_explorer.Next()
|
|
|
|
# 创建邻接矩阵
|
|
num_faces = len(faces)
|
|
num_edges = len(edges)
|
|
num_vertices = len(vertices)
|
|
|
|
edgeFace_adj = np.zeros((num_edges, num_faces), dtype=np.int32)
|
|
faceEdge_adj = np.zeros((num_faces, num_edges), dtype=np.int32)
|
|
edgeCorner_adj = np.zeros((num_edges, 2), dtype=np.int32)
|
|
|
|
# 填充边-面邻接矩阵
|
|
for i, edge in enumerate(edges):
|
|
# 获取与边相连的面
|
|
for j, face in enumerate(faces):
|
|
# 使用 explorer 检查边是否属于面
|
|
edge_explorer = TopExp_Explorer(face, TopAbs_EDGE)
|
|
while edge_explorer.More():
|
|
if edge.IsSame(edge_explorer.Current()):
|
|
edgeFace_adj[i, j] = 1
|
|
faceEdge_adj[j, i] = 1
|
|
break
|
|
edge_explorer.Next()
|
|
|
|
# 获取边的顶点
|
|
v1 = TopoDS_Vertex()
|
|
v2 = TopoDS_Vertex()
|
|
topexp.Vertices(edge, v1, v2)
|
|
|
|
if not v1.IsNull() and not v2.IsNull():
|
|
v1_vertex = topods.Vertex(v1)
|
|
v2_vertex = topods.Vertex(v2)
|
|
|
|
# 找到顶点的索引
|
|
for k, vertex in enumerate(vertices):
|
|
if v1_vertex.IsSame(vertex):
|
|
edgeCorner_adj[i, 0] = k
|
|
if v2_vertex.IsSame(vertex):
|
|
edgeCorner_adj[i, 1] = k
|
|
|
|
return edgeFace_adj, faceEdge_adj, edgeCorner_adj
|
|
|
|
def get_bbox(shape, subshape):
|
|
"""计算包围盒"""
|
|
bbox = Bnd_Box()
|
|
brepbndlib.Add(subshape, bbox)
|
|
xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get()
|
|
return np.array([xmin, ymin, zmin, xmax, ymax, zmax])
|
|
|
|
def parse_solid(step_path):
|
|
"""Parse the surface, curve, face, edge, vertex in a CAD solid using OCC."""
|
|
# Load STEP file
|
|
reader = STEPControl_Reader()
|
|
status = reader.ReadFile(step_path)
|
|
if status != IFSelect_RetDone:
|
|
raise Exception("Failed to read STEP file")
|
|
|
|
reader.TransferRoots()
|
|
shape = reader.OneShape()
|
|
|
|
# Create mesh
|
|
mesh = BRepMesh_IncrementalMesh(shape, 0.01)
|
|
mesh.Perform()
|
|
|
|
# Initialize explorers
|
|
face_explorer = TopExp_Explorer(shape, TopAbs_FACE)
|
|
edge_explorer = TopExp_Explorer(shape, TopAbs_EDGE)
|
|
vertex_explorer = TopExp_Explorer(shape, TopAbs_VERTEX)
|
|
|
|
face_pnts = []
|
|
edge_pnts = []
|
|
corner_pnts = []
|
|
surf_bbox_wcs = []
|
|
edge_bbox_wcs = []
|
|
|
|
# Extract face points
|
|
while face_explorer.More():
|
|
face = topods.Face(face_explorer.Current())
|
|
loc = TopLoc_Location()
|
|
triangulation = BRep_Tool.Triangulation(face, loc)
|
|
|
|
if triangulation is not None:
|
|
points = []
|
|
for i in range(1, triangulation.NbNodes() + 1):
|
|
node = triangulation.Node(i)
|
|
pnt = node.Transformed(loc.Transformation())
|
|
points.append([pnt.X(), pnt.Y(), pnt.Z()])
|
|
|
|
if points:
|
|
points = np.array(points, dtype=np.float32)
|
|
if len(points.shape) == 2 and points.shape[1] == 3:
|
|
face_pnts.append(points)
|
|
surf_bbox_wcs.append(get_bbox(shape, face))
|
|
|
|
face_explorer.Next()
|
|
|
|
# Extract edge points
|
|
num_samples = 100
|
|
while edge_explorer.More():
|
|
edge = topods.Edge(edge_explorer.Current())
|
|
curve, first, last = BRep_Tool.Curve(edge)
|
|
|
|
if curve is not None:
|
|
points = []
|
|
for i in range(num_samples):
|
|
param = first + (last - first) * float(i) / (num_samples - 1)
|
|
pnt = curve.Value(param)
|
|
points.append([pnt.X(), pnt.Y(), pnt.Z()])
|
|
|
|
if points:
|
|
points = np.array(points, dtype=np.float32)
|
|
if len(points.shape) == 2 and points.shape[1] == 3:
|
|
edge_pnts.append(points)
|
|
edge_bbox_wcs.append(get_bbox(shape, edge))
|
|
|
|
edge_explorer.Next()
|
|
|
|
# Extract vertex points
|
|
while vertex_explorer.More():
|
|
vertex = topods.Vertex(vertex_explorer.Current())
|
|
pnt = BRep_Tool.Pnt(vertex)
|
|
corner_pnts.append([pnt.X(), pnt.Y(), pnt.Z()])
|
|
vertex_explorer.Next()
|
|
|
|
# 获取邻接信息
|
|
edgeFace_adj, faceEdge_adj, edgeCorner_adj = get_adjacency_info(shape)
|
|
|
|
# 转换为numpy数组,但保持列表形式
|
|
face_pnts = list(face_pnts) # 确保是列表
|
|
edge_pnts = list(edge_pnts) # 确保是列表
|
|
corner_pnts = np.array(corner_pnts, dtype=np.float32)
|
|
surf_bbox_wcs = np.array(surf_bbox_wcs, dtype=np.float32)
|
|
edge_bbox_wcs = np.array(edge_bbox_wcs, dtype=np.float32)
|
|
|
|
# Normalize the CAD model
|
|
surfs_wcs, edges_wcs, surfs_ncs, edges_ncs, corner_wcs = normalize(face_pnts, edge_pnts, corner_pnts)
|
|
|
|
# Create result dictionary
|
|
data = {
|
|
'surf_wcs': surfs_wcs,
|
|
'edge_wcs': edges_wcs,
|
|
'surf_ncs': surfs_ncs,
|
|
'edge_ncs': edges_ncs,
|
|
'corner_wcs': corner_wcs.astype(np.float32),
|
|
'edgeFace_adj': edgeFace_adj,
|
|
'edgeCorner_adj': edgeCorner_adj,
|
|
'faceEdge_adj': faceEdge_adj,
|
|
'surf_bbox_wcs': surf_bbox_wcs,
|
|
'edge_bbox_wcs': edge_bbox_wcs,
|
|
'corner_unique': np.unique(corner_wcs, axis=0).astype(np.float32)
|
|
}
|
|
|
|
return data
|
|
|
|
def load_step(step_path):
|
|
"""Load STEP file and return solids"""
|
|
reader = STEPControl_Reader()
|
|
reader.ReadFile(step_path)
|
|
reader.TransferRoots()
|
|
return [reader.OneShape()]
|
|
|
|
def process(step_path, timeout=300):
|
|
"""Process single STEP file"""
|
|
try:
|
|
# Check single solid
|
|
cad_solid = load_step(step_path)
|
|
if len(cad_solid)!=1:
|
|
print('Skipping multi solids...')
|
|
return 0
|
|
|
|
# Start data parsing
|
|
data = parse_solid(step_path)
|
|
if data is None:
|
|
print ('Exceeding threshold...')
|
|
return 0
|
|
|
|
# Save the parsed result
|
|
if 'furniture' in step_path:
|
|
data_uid = step_path.split('/')[-2] + '_' + step_path.split('/')[-1]
|
|
sub_folder = step_path.split('/')[-3]
|
|
else:
|
|
data_uid = step_path.split('/')[-2]
|
|
sub_folder = data_uid[:4]
|
|
|
|
if data_uid.endswith('.step'):
|
|
data_uid = data_uid[:-5]
|
|
|
|
data['uid'] = data_uid
|
|
save_folder = os.path.join(OUTPUT, sub_folder)
|
|
if not os.path.exists(save_folder):
|
|
os.makedirs(save_folder)
|
|
|
|
save_path = os.path.join(save_folder, data['uid']+'.pkl')
|
|
with open(save_path, "wb") as tf:
|
|
pickle.dump(data, tf)
|
|
|
|
return 1
|
|
|
|
except Exception as e:
|
|
print('not saving due to error...', str(e))
|
|
return 0
|
|
|
|
def test(step_file_path, output_path=None):
|
|
"""
|
|
测试函数:转换单个STEP文件并保存结果
|
|
"""
|
|
try:
|
|
print(f"Processing STEP file: {step_file_path}")
|
|
|
|
# 解析STEP文件
|
|
data = parse_solid(step_file_path)
|
|
if data is None:
|
|
print("Failed to parse STEP file")
|
|
return None
|
|
|
|
# 打印统计信息
|
|
print("\nStatistics:")
|
|
print(f"Number of surfaces: {len(data['surf_wcs'])}")
|
|
print(f"Number of edges: {len(data['edge_wcs'])}")
|
|
print(f"Number of corners: {len(data['corner_unique'])}")
|
|
|
|
# 保存结果
|
|
if output_path:
|
|
print(f"\nSaving results to: {output_path}")
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
with open(output_path, 'wb') as f:
|
|
pickle.dump(data, f)
|
|
print("Results saved successfully")
|
|
|
|
return data
|
|
|
|
except Exception as e:
|
|
print(f"Error processing STEP file: {str(e)}")
|
|
return None
|
|
|
|
def load_furniture_step(data_path):
|
|
"""Load furniture STEP files"""
|
|
step_dirs = []
|
|
for split in ['train', 'val', 'test']:
|
|
split_path = os.path.join(data_path, split)
|
|
if os.path.exists(split_path):
|
|
for f in os.listdir(split_path):
|
|
if f.endswith('.step'):
|
|
step_dirs.append(os.path.join(split_path, f))
|
|
return step_dirs
|
|
|
|
def load_abc_step(data_path, is_deepcad=False):
|
|
"""Load ABC/DeepCAD STEP files"""
|
|
step_dirs = []
|
|
for f in sorted(os.listdir(data_path)):
|
|
if os.path.isdir(os.path.join(data_path, f)):
|
|
if is_deepcad:
|
|
step_path = os.path.join(data_path, f, f+'.step')
|
|
else:
|
|
step_path = os.path.join(data_path, f, 'shape.step')
|
|
if os.path.exists(step_path):
|
|
step_dirs.append(step_path)
|
|
return step_dirs
|
|
|
|
def main():
|
|
"""
|
|
主函数:处多个STEP文件
|
|
"""
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--input", type=str, help="Data folder path", required=True)
|
|
parser.add_argument("--option", type=str, choices=['abc', 'deepcad', 'furniture'], default='abc',
|
|
help="Choose between dataset option [abc/deepcad/furniture] (default: abc)")
|
|
parser.add_argument("--interval", type=int, help="Data range index, only required for abc/deepcad")
|
|
args = parser.parse_args()
|
|
|
|
global OUTPUT
|
|
if args.option == 'deepcad':
|
|
OUTPUT = 'deepcad_parsed'
|
|
elif args.option == 'abc':
|
|
OUTPUT = 'abc_parsed'
|
|
else:
|
|
OUTPUT = 'furniture_parsed'
|
|
|
|
# Load all STEP files
|
|
if args.option == 'furniture':
|
|
step_dirs = load_furniture_step(args.input)
|
|
else:
|
|
step_dirs = load_abc_step(args.input, args.option=='deepcad')
|
|
step_dirs = step_dirs[args.interval*10000 : (args.interval+1)*10000]
|
|
|
|
# Process B-reps in parallel
|
|
valid = 0
|
|
with ProcessPoolExecutor(max_workers=os.cpu_count() // 2) as executor:
|
|
futures = {}
|
|
for step_folder in step_dirs:
|
|
future = executor.submit(process, step_folder, timeout=300)
|
|
futures[future] = step_folder
|
|
|
|
for future in tqdm(as_completed(futures), total=len(step_dirs)):
|
|
try:
|
|
status = future.result(timeout=300)
|
|
valid += status
|
|
except TimeoutError:
|
|
print(f"Timeout occurred while processing {futures[future]}")
|
|
except Exception as e:
|
|
print(f"An error occurred while processing {futures[future]}: {e}")
|
|
|
|
print(f'Done... Data Converted Ratio {100.0*valid/len(step_dirs)}%')
|
|
|
|
if __name__ == '__main__':
|
|
import sys
|
|
|
|
if len(sys.argv) > 1 and sys.argv[1] == '--test':
|
|
# 测试模式
|
|
if len(sys.argv) < 3:
|
|
print("Usage: python process_brep.py --test <step_file_path> [output_path]")
|
|
sys.exit(1)
|
|
|
|
step_file = sys.argv[2]
|
|
output_file = sys.argv[3] if len(sys.argv) > 3 else None
|
|
|
|
print("Running in test mode...")
|
|
result = test(step_file, output_file)
|
|
|
|
if result is not None:
|
|
print("\nTest completed successfully!")
|
|
else:
|
|
# 正常批处理模式
|
|
main()
|