Browse Source

可以直接从 step 开始处理的版本

final
mckay 2 months ago
parent
commit
2e8dc3b0fe
  1. 2
      brep2sdf/data/data.py
  2. 540
      brep2sdf/data/pre_process.py
  3. 23
      brep2sdf/networks/encoder.py
  4. 14
      brep2sdf/networks/network.py
  5. 84
      brep2sdf/networks/octree.py
  6. 27
      brep2sdf/scripts/process_brep.py
  7. 6
      brep2sdf/scripts/process_furniture.py
  8. 105
      brep2sdf/train.py

2
brep2sdf/data/data.py

@ -251,6 +251,8 @@ class BRepSDFDataset(Dataset):
return num_faces, num_edges return num_faces, num_edges
def load_brep_file(brep_path): def load_brep_file(brep_path):
with open(brep_path, 'rb') as f: with open(brep_path, 'rb') as f:
brep_raw = pickle.load(f) brep_raw = pickle.load(f)

540
brep2sdf/data/pre_process.py

@ -0,0 +1,540 @@
"""
CAD模型处理脚本
功能将STEP格式的CAD模型转换为结构化数据包括
- 几何信息顶点的坐标数据
- 拓扑信息--顶点的邻接关系
- 空间信息包围盒数据
"""
import os
import pickle # 用于数据序列化
import argparse # 命令行参数解析
import numpy as np
from tqdm import tqdm # 进度条显示
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError # 并行处理
import logging
from datetime import datetime
from brep2sdf.utils.logger import logger
# 导入OpenCASCADE相关库
from OCC.Core.STEPControl import STEPControl_Reader # STEP文件读取器
from OCC.Core.TopExp import TopExp_Explorer, topexp # 拓扑结构遍历
from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX # 拓扑类型定义
from OCC.Core.BRep import BRep_Tool # B-rep工具
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh # 网格剖分
from OCC.Core.TopLoc import TopLoc_Location # 位置变换
from OCC.Core.IFSelect import IFSelect_RetDone,IFSelect_RetError, IFSelect_RetFail, IFSelect_RetVoid # 操作状态码
from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射
from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算
from OCC.Core.Bnd import Bnd_Box # 包围盒
from OCC.Core.TopoDS import TopoDS_Shape, topods, TopoDS_Vertex # 拓扑数据结构
# 导入配置
from brep2sdf.config.default_config import get_default_config
config = get_default_config()
# 设置最大面数阈值,用于加速处理
MAX_FACE = config.data.max_face
def normalize(surfs, edges, corners):
"""
将CAD模型归一化到单位立方体空间
参数:
surfs: 面的点集列表
edges: 边的点集列表
corners: 顶点坐标数组 [num_edges, 2, 3]
返回:
surfs_wcs: 原始坐标系下的面点集
edges_wcs: 原始坐标系下的边点集
surfs_ncs: 归一化坐标系下的面点集
edges_ncs: 归一化坐标系下的边点集
corner_wcs: 归一化后的顶点坐标 [num_edges, 2, 3]
center: 使用的中心点坐标 [3,]
scale: 使用的缩放系数 (float)
"""
if len(corners) == 0:
return None, None, None, None, None, None, None
# 计算包围盒和缩放因子
corners_array = corners.reshape(-1, 3) # [num_edges*2, 3]
center = (corners_array.max(0) + corners_array.min(0)) / 2 # 计算中心点
scale = 1.0 / (corners_array.max(0) - corners_array.min(0)).max() # 计算缩放系数
# 归一化面的点集
surfs_wcs = [] # 原始世界坐标系下的面点集
surfs_ncs = [] # 归一化坐标系下的面点集
for surf in surfs:
surf_wcs = np.array(surf)
surf_ncs = (surf_wcs - center) * scale # 归一化变换
surfs_wcs.append(surf_wcs)
surfs_ncs.append(surf_ncs)
# 归一化边的点集
edges_wcs = [] # 原始世界坐标系下的边点集
edges_ncs = [] # 归一化坐标系下的边点集
for edge in edges:
edge_wcs = np.array(edge)
edge_ncs = (edge_wcs - center) * scale # 归一化变换
edges_wcs.append(edge_wcs)
edges_ncs.append(edge_ncs)
# 归一化顶点坐标 - 保持[num_edges, 2, 3]的形状
corner_wcs = (corners - center) * scale # 广播操作会保持原有维度
return (np.array(surfs_wcs, dtype=object),
np.array(edges_wcs, dtype=object),
np.array(surfs_ncs, dtype=object),
np.array(edges_ncs, dtype=object),
corner_wcs.astype(np.float32),
center.astype(np.float32),
scale
)
def get_adjacency_info(shape):
"""
获取CAD模型中面顶点之间的邻接关系
参数:
shape: CAD模型的形状对象
返回:
edgeFace_adj: -面邻接矩阵 (num_edges × num_faces)
faceEdge_adj: -边邻接矩阵 (num_faces × num_edges)
edgeCorner_adj: -顶点邻接矩阵 (num_edges × 2)
"""
# 创建边-面映射关系
edge_face_map = TopTools_IndexedDataMapOfShapeListOfShape()
topexp.MapShapesAndAncestors(shape, TopAbs_EDGE, TopAbs_FACE, edge_face_map)
# 获取所有几何元素
faces = [] # 存储所有面
edges = [] # 存储所有边
vertices = [] # 存储所有顶点
# 创建拓扑结构探索器
face_explorer = TopExp_Explorer(shape, TopAbs_FACE)
edge_explorer = TopExp_Explorer(shape, TopAbs_EDGE)
vertex_explorer = TopExp_Explorer(shape, TopAbs_VERTEX)
# 收集所有几何元素
while face_explorer.More():
faces.append(topods.Face(face_explorer.Current()))
face_explorer.Next()
while edge_explorer.More():
edges.append(topods.Edge(edge_explorer.Current()))
edge_explorer.Next()
while vertex_explorer.More():
vertices.append(topods.Vertex(vertex_explorer.Current()))
vertex_explorer.Next()
# 创建邻接矩阵
num_faces = len(faces)
num_edges = len(edges)
num_vertices = len(vertices)
edgeFace_adj = np.zeros((num_edges, num_faces), dtype=np.int32)
faceEdge_adj = np.zeros((num_faces, num_edges), dtype=np.int32)
edgeCorner_adj = np.zeros((num_edges, 2), dtype=np.int32)
# 填充边-面邻接矩阵
for i, edge in enumerate(edges):
# 检查每个面是否与当前边相连
for j, face in enumerate(faces):
edge_explorer = TopExp_Explorer(face, TopAbs_EDGE)
while edge_explorer.More():
if edge.IsSame(edge_explorer.Current()):
edgeFace_adj[i, j] = 1
faceEdge_adj[j, i] = 1
break
edge_explorer.Next()
# 获取边的两个端点
v1 = TopoDS_Vertex()
v2 = TopoDS_Vertex()
topexp.Vertices(edge, v1, v2)
# 记录边的端点索引
if not v1.IsNull() and not v2.IsNull():
v1_vertex = topods.Vertex(v1)
v2_vertex = topods.Vertex(v2)
for k, vertex in enumerate(vertices):
if v1_vertex.IsSame(vertex):
edgeCorner_adj[i, 0] = k
if v2_vertex.IsSame(vertex):
edgeCorner_adj[i, 1] = k
return edgeFace_adj, faceEdge_adj, edgeCorner_adj
def get_bbox(shape, subshape):
"""
计算形状的包围盒
参数:
shape: 完整的CAD模型形状
subshape: 需要计算包围盒的子形状面或边
返回:
包围盒的六个参数 [xmin, ymin, zmin, xmax, ymax, zmax]
"""
bbox = Bnd_Box()
brepbndlib.Add(subshape, bbox)
xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get()
return np.array([xmin, ymin, zmin, xmax, ymax, zmax])
def parse_solid(step_path):
"""
解析STEP文件中的CAD模型数据
返回:
dict: 包含以下键值对的字典:
# 几何数据
'surf_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示面的点云坐标
'edge_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示边的采样点坐标
'surf_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示归一化后的面点云
'edge_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示归一化后的边采样点
'corner_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 2, 3)的数组,表示每条边的两个端点坐标
'corner_unique': np.ndarray(dtype=float32) # 形状为(num_vertices, 3)的数组,表示所有顶点的唯一坐标,num_vertices <= num_edges * 2
# 拓扑关系
'edgeFace_adj': np.ndarray(dtype=int32) # 形状为(num_edges, num_faces)的数组,表示边-面邻接关系
'edgeCorner_adj': np.ndarray(dtype=int32) # 形状为(num_edges, 2)的数组,表示边-顶点邻接关系
'faceEdge_adj': np.ndarray(dtype=int32) # 形状为(num_faces, num_edges)的数组,表示面-边邻接关系
# 包围盒数据
'surf_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_faces, 6)的数组,表示每个面的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
'edge_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 6)的数组,表示每条边的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
"""
# Load STEP file
reader = STEPControl_Reader()
status = reader.ReadFile(step_path)
if status != IFSelect_RetDone:
if status == IFSelect_RetError:
print("Error: An error occurred while reading the file.")
elif status == IFSelect_RetFail:
print("Error: Failed to read the file.")
elif status == IFSelect_RetVoid:
print("Error: No data was read from the file.")
else:
print(f"Unexpected status code: {status}")
raise Exception(f"Failed to read STEP file. {status}")
reader.TransferRoots()
shape = reader.OneShape()
# Create mesh
mesh = BRepMesh_IncrementalMesh(shape, 0.01)
mesh.Perform()
# Initialize explorers
face_explorer = TopExp_Explorer(shape, TopAbs_FACE)
edge_explorer = TopExp_Explorer(shape, TopAbs_EDGE)
vertex_explorer = TopExp_Explorer(shape, TopAbs_VERTEX)
face_pnts = []
edge_pnts = []
corner_pnts = []
surf_bbox_wcs = []
edge_bbox_wcs = []
# Extract face points
while face_explorer.More():
face = topods.Face(face_explorer.Current())
loc = TopLoc_Location()
triangulation = BRep_Tool.Triangulation(face, loc)
if triangulation is not None:
points = []
for i in range(1, triangulation.NbNodes() + 1):
node = triangulation.Node(i)
pnt = node.Transformed(loc.Transformation())
points.append([pnt.X(), pnt.Y(), pnt.Z()])
if points:
points = np.array(points, dtype=np.float32)
if len(points.shape) == 2 and points.shape[1] == 3:
# 确保每个面至少有一些点
if len(points) < 3: # 如果点数太少,跳过这个面
continue
face_pnts.append(points)
surf_bbox_wcs.append(get_bbox(shape, face))
face_explorer.Next()
# Extract edge points
num_samples = config.model.num_edge_points # 使用配置中的边采样点数
while edge_explorer.More():
edge = topods.Edge(edge_explorer.Current())
curve_info = BRep_Tool.Curve(edge)
if curve_info is None:
continue # 跳过无效边
try:
if len(curve_info) == 3:
curve, first, last = curve_info
elif len(curve_info) == 2:
continue
curve, location = curve_info
logger.info(curve)
first, last = BRep_Tool.Range(edge) # 显式获取参数范围
else:
raise ValueError(f"Unexpected curve info: {curve_info}")
except Exception as e:
logger.error(f"Failed to process edge {edge}: {str(e)}")
continue
if curve is not None:
points = []
for i in range(num_samples):
param = first + (last - first) * float(i) / (num_samples - 1)
pnt = curve.Value(param)
points.append([pnt.X(), pnt.Y(), pnt.Z()])
if points:
points = np.array(points, dtype=np.float32)
if len(points.shape) == 2 and points.shape[1] == 3:
edge_pnts.append(points) # 现在points是(num_edge_points, 3)形状
edge_bbox_wcs.append(get_bbox(shape, edge))
edge_explorer.Next()
# Extract vertex points
while vertex_explorer.More():
vertex = topods.Vertex(vertex_explorer.Current())
pnt = BRep_Tool.Pnt(vertex)
corner_pnts.append([pnt.X(), pnt.Y(), pnt.Z()])
vertex_explorer.Next()
# 获取邻接信息
edgeFace_adj, faceEdge_adj, edgeCorner_adj = get_adjacency_info(shape)
# 转换为numpy数组时确保类型正确
face_pnts = [np.array(points, dtype=np.float32) for points in face_pnts]
edge_pnts = [np.array(points, dtype=np.float32) for points in edge_pnts]
# 转换为对象数组
face_pnts = np.array(face_pnts, dtype=object)
edge_pnts = np.array(edge_pnts, dtype=object)
corner_pnts = np.array(corner_pnts, dtype=np.float32)
# 重组顶点数据为每条边两个端点的形式
corner_pairs = []
for edge_idx in range(len(edge_pnts)):
v1_idx, v2_idx = edgeCorner_adj[edge_idx]
v1_pos = corner_pnts[v1_idx]
v2_pos = corner_pnts[v2_idx]
# 按坐标排序确保顺序一致
if (v1_pos > v2_pos).any():
v1_pos, v2_pos = v2_pos, v1_pos
corner_pairs.append(np.stack([v1_pos, v2_pos]))
corner_pairs = np.stack(corner_pairs).astype(np.float32) # [num_edges, 2, 3]
# 确保所有数组都有正确的类型
surf_bbox_wcs = np.array(surf_bbox_wcs, dtype=np.float32)
edge_bbox_wcs = np.array(edge_bbox_wcs, dtype=np.float32)
# Normalize the CAD model
surfs_wcs, edges_wcs, surfs_ncs, edges_ncs, corner_wcs,center, scale = normalize(
face_pnts, edge_pnts, corner_pairs)
# 计算归一化后的包围盒
surf_bbox_ncs = np.empty_like(surf_bbox_wcs)
edge_bbox_ncs = np.empty_like(edge_bbox_wcs)
# 转换曲面包围盒到归一化坐标系
surf_bbox_ncs[:, :3] = (surf_bbox_wcs[:, :3] - center) * scale # 最小点
surf_bbox_ncs[:, 3:] = (surf_bbox_wcs[:, 3:] - center) * scale # 最大点
# 转换边包围盒到归一化坐标系
edge_bbox_ncs[:, :3] = (edge_bbox_wcs[:, :3] - center) * scale # 最小点
edge_bbox_ncs[:, 3:] = (edge_bbox_wcs[:, 3:] - center) * scale # 最大点
# 验证归一化后的数据
if any(x is None for x in [surfs_wcs, edges_wcs, surfs_ncs, edges_ncs, corner_wcs]):
logger.error(f"Normalization failed for {step_path}")
return None
# 创建结果字典并确保所有数组都有正确的类型
data = {
'surf_wcs': np.array(surfs_wcs, dtype=object), # 保持对象数组
'edge_wcs': np.array(edges_wcs, dtype=object), # 保持对象数组
'surf_ncs': np.array(surfs_ncs, dtype=object), # 保持对象数组
'edge_ncs': np.array(edges_ncs, dtype=object), # 保持对象数组
'corner_wcs': corner_wcs.astype(np.float32), # [num_edges, 2, 3]
'edgeFace_adj': edgeFace_adj.astype(np.int32),
'edgeCorner_adj': edgeCorner_adj.astype(np.int32),
'faceEdge_adj': faceEdge_adj.astype(np.int32),
'surf_bbox_wcs': surf_bbox_wcs.astype(np.float32),
'edge_bbox_wcs': edge_bbox_wcs.astype(np.float32),
'surf_bbox_ncs': surf_bbox_ncs.astype(np.float32), # 归一化坐标系 [num_faces, 6]
'edge_bbox_ncs': edge_bbox_ncs.astype(np.float32), # 归一化坐标系 [num_edges, 6]
'corner_unique': np.unique(corner_wcs.reshape(-1, 3), axis=0).astype(np.float32), # 先展平再去重
'normalization_params': {
'center': center.astype(np.float32), # 归一化中心点 [3,]
'scale': float(scale), # 归一化缩放系数
}
}
return data
def load_step(step_path):
"""Load STEP file and return solids"""
reader = STEPControl_Reader()
reader.ReadFile(step_path)
reader.TransferRoots()
return [reader.OneShape()]
def check_data_format(data, step_file):
"""检查数据格式是否正确"""
required_keys = [
'surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs', 'corner_wcs',
'edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj',
'surf_bbox_wcs', 'edge_bbox_wcs', 'corner_unique'
]
# 检查所有必需的键是否存在
for key in required_keys:
if key not in data:
return False, f"Missing key: {key}"
# 检查几何数据
geometry_arrays = ['surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs']
for key in geometry_arrays:
if not isinstance(data[key], np.ndarray):
return False, f"{key} should be a numpy array"
# 允许对象数组
if data[key].dtype != object:
return False, f"{key} should be a numpy array with dtype=object"
# 检查其他数组
float32_arrays = ['corner_wcs', 'surf_bbox_wcs', 'edge_bbox_wcs', 'corner_unique']
for key in float32_arrays:
if not isinstance(data[key], np.ndarray):
return False, f"{key} should be a numpy array"
if data[key].dtype != np.float32:
return False, f"{key} should be a numpy array with dtype=float32"
int32_arrays = ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj']
for key in int32_arrays:
if not isinstance(data[key], np.ndarray):
return False, f"{key} should be a numpy array"
if data[key].dtype != np.int32:
return False, f"{key} should be a numpy array with dtype=int32"
return True, ""
def process_single_step(step_path:str, output_path:str=None, timeout:int=300) -> dict:
"""处理单个STEP文件, 从 brep 2 pkl
return data = {
'surf_wcs': np.array(surfs_wcs, dtype=object), # 世界坐标系下的曲面几何数据(对象数组)
'edge_wcs': np.array(edges_wcs, dtype=object), # 世界坐标系下的边几何数据(对象数组)
'surf_ncs': np.array(surfs_ncs, dtype=object), # 归一化坐标系下的曲面几何数据(对象数组) 面归一化点云 [num_faces, num_surf_sample_points, 3]
'edge_ncs': np.array(edges_ncs, dtype=object), # 归一化坐标系下的边几何数据(对象数组) 边归一化点云 [num_edges, num_edge_sample_points, 3]
'corner_wcs': corner_wcs.astype(np.float32), # 世界坐标系下的角点数据 [num_edges, 2, 3]
'edgeFace_adj': edgeFace_adj.astype(np.int32), # 边-面的邻接关系矩阵
'edgeCorner_adj': edgeCorner_adj.astype(np.int32),# 边-角点的邻接关系矩阵
'faceEdge_adj': faceEdge_adj.astype(np.int32), # 面-边的邻接关系矩阵
'surf_bbox_wcs': surf_bbox_wcs.astype(np.float32),# 曲面在世界坐标系下的包围盒
'edge_bbox_wcs': edge_bbox_wcs.astype(np.float32),# 边在世界坐标系下的包围盒
'corner_unique': np.unique(corner_wcs.reshape(-1, 3), axis=0).astype(np.float32) # 去重后的唯一角点坐标
}"""
try:
logger.info("数据预处理……")
if not os.path.exists(step_path):
logger.error(f"STEP文件不存在: {step_path}")
return None
if not step_path.lower().endswith('.step') and not step_path.lower().endswith('.stp'):
logger.error(f"文件格式不支持,必须是.step或.stp文件: {step_path}")
return None
# 解析STEP文件
data = parse_solid(step_path)
if data is None:
logger.error(f"Failed to parse STEP file: {step_path}")
return None
# 检查数据格式
is_valid, msg = check_data_format(data, step_path)
if not is_valid:
logger.error(f"Data format check failed for {step_path}: {msg}")
return None
# 保存结果
if output_path:
try:
logger.info(f"Saving results to: {output_path}")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'wb') as f:
pickle.dump(data, f)
logger.info("数据预处理完成")
logger.info(f"Results saved successfully: {output_path}")
return data
except Exception as e:
logger.error(f'Failed to save {output_path}: {str(e)}')
return None
logger.info("数据预处理完成")
return data
except Exception as e:
logger.error(f'Error processing {step_path}: {str(e)}')
return None
def test(step_file_path, output_path=None):
"""
测试函数转换单个STEP文件并保存结果
"""
try:
logger.info(f"Processing STEP file: {step_file_path}")
# 解析STEP文件
data = parse_solid(step_file_path)
if data is None:
logger.error(f"Failed to parse STEP file: {step_file_path}")
return None
# 检查数据格式
is_valid, msg = check_data_format(data, step_file_path)
if not is_valid:
logger.error(f"Data format check failed for {step_file_path}: {msg}")
return None
# 打印统计信息
logger.info("\nStatistics:")
logger.info(f"Number of surfaces: {len(data['surf_wcs'])}")
logger.info(f"Number of edges: {len(data['edge_wcs'])}")
logger.info(f"Number of corners: {len(data['corner_wcs'])}") # 修正为corner_wcs
# 保存结果
if output_path:
try:
logger.info(f"Saving results to: {output_path}")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'wb') as f:
pickle.dump(data, f)
logger.info(f"Results saved successfully: {output_path}")
except Exception as e:
logger.error(f"Failed to save {output_path}: {str(e)}")
return None
return data
except Exception as e:
logger.error(f"Error processing {step_file_path}: {str(e)}")
return None
if __name__ == '__main__':
# main()
test("/home/wch/brep2sdf/data/step/00000000/00000000_290a9120f9f249a7a05cfe9c_step_000.step","/home/wch/brep2sdf/test_data/pkl/train/00000031xx.pkl")
#test("/home/wch/brep2sdf/00000031_ad34a3f60c4a4caa99646600_step_011.step", "/home/wch/brep2sdf/test_data/pkl/train/00000031.pkl")
#test("/mnt/mynewdisk/dataset/furniture/step/furniture_dataset_step/train/bathtub_0004.step", "/home/wch/brep2sdf/test_data/pkl/train/0004.pkl")
#reader = STEPControl_Reader()

23
brep2sdf/networks/encoder.py

@ -10,28 +10,33 @@ from brep2sdf.utils.logger import logger
import numpy as np import numpy as np
class Encoder: class Encoder:
def __init__(self, surf_bbox_wcs: torch.Tensor, origin_bbox_wcs: torch.Tensor, max_depth: int, feature_dim:int = 64): def __init__(self, surf_bbox: torch.Tensor, origin_bbox: torch.Tensor, max_depth: int, feature_dim:int = 64):
""" """
初始化表面八叉树管理器 初始化表面八叉树管理器
参数: 参数:
surf_bbox_wcs: 表面包围盒的世界坐标形状为 (num_edges, 6), dtype=float32 surf_bbox: 表面包围盒的世界坐标形状为 (num_edges, 6), dtype=float32
origin_bbox_wcs: 原点包围盒的世界坐标形状为 (6), dtype=float32 origin_bbox: 原点包围盒的世界坐标形状为 (6), dtype=float32
max_depth: 八叉树的最大深度 max_depth: 八叉树的最大深度
""" """
self.max_depth = max_depth self.max_depth = max_depth
# 初始化根节点,包含所有面索引。NOTE: 实际上需要判断 aabb 包含问题,但是这里默认 origin_bbox_wcs 由这些 face 计算,所以不再重复判断 # 初始化根节点,包含所有面索引。NOTE: 实际上需要判断 aabb 包含问题,但是这里默认 origin_bbox 由这些 face 计算,所以不再重复判断
num_faces = surf_bbox_wcs.shape[0] num_faces = surf_bbox.shape[0]
print(f"surf_bbox_wcs: {surf_bbox_wcs.shape}")
print(f"origin_bbox_wcs: {origin_bbox_wcs.shape}") #print(f"surf_bbox: {surf_bbox.shape}")
#print(f"origin_bbox: {origin_bbox.shape}")
self.root = OctreeNode( self.root = OctreeNode(
bbox=origin_bbox_wcs, bbox=origin_bbox,
face_indices=np.arange(num_faces), # 初始包含所有面 face_indices=np.arange(num_faces), # 初始包含所有面
max_depth=self.max_depth, max_depth=self.max_depth,
feature_dim=feature_dim, feature_dim=feature_dim,
surf_bbox_wcs=surf_bbox_wcs surf_bbox=surf_bbox
) )
#print(surf_bbox)
logger.info("starting octree conduction")
self.root.conduct_tree() self.root.conduct_tree()
logger.info("complete octree conduction")
#self.root.print_tree(0)
def get_feature_vector(self, query_point): def get_feature_vector(self, query_point):
return self.root.get_feature_vector(query_point) return self.root.get_feature_vector(query_point)

14
brep2sdf/networks/network.py

@ -7,7 +7,7 @@ class GridNet:
def __init__(self, def __init__(self,
surf_wcs, edge_wcs, surf_ncs, edge_ncs, corner_wcs, corner_unique, surf_wcs, edge_wcs, surf_ncs, edge_ncs, corner_wcs, corner_unique,
edgeFace_adj, edgeCorner_adj, faceEdge_adj, edgeFace_adj, edgeCorner_adj, faceEdge_adj,
surf_bbox_wcs, edge_bbox_wcs): surf_bbox, edge_bbox_wcs):
""" """
初始化 GridNet 初始化 GridNet
@ -26,7 +26,7 @@ class GridNet:
'faceEdge_adj': np.ndarray(dtype=int32) # 形状为(num_faces, num_edges)的数组,表示面-边邻接关系 'faceEdge_adj': np.ndarray(dtype=int32) # 形状为(num_faces, num_edges)的数组,表示面-边邻接关系
# 包围盒数据 # 包围盒数据
'surf_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_faces, 6)的数组,表示每个面的包围盒[xmin,ymin,zmin,xmax,ymax,zmax] 'surf_bbox': np.ndarray(dtype=float32) # 形状为(num_faces, 6)的数组,表示每个面的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
'edge_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 6)的数组,表示每条边的包围盒[xmin,ymin,zmin,xmax,ymax,zmax] 'edge_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 6)的数组,表示每条边的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
""" """
self.surf_wcs = surf_wcs self.surf_wcs = surf_wcs
@ -38,7 +38,7 @@ class GridNet:
self.edgeFace_adj = edgeFace_adj self.edgeFace_adj = edgeFace_adj
self.edgeCorner_adj = edgeCorner_adj self.edgeCorner_adj = edgeCorner_adj
self.faceEdge_adj = faceEdge_adj self.faceEdge_adj = faceEdge_adj
self.surf_bbox_wcs = surf_bbox_wcs self.surf_bbox = surf_bbox
self.edge_bbox_wcs = edge_bbox_wcs self.edge_bbox_wcs = edge_bbox_wcs
# net # net
@ -53,8 +53,8 @@ from .decoder import Decoder
class Net(nn.Module): class Net(nn.Module):
def __init__(self, def __init__(self,
surf_bbox_wcs, surf_bbox,
origin_bbox_wcs, origin_bbox,
max_depth=4, max_depth=4,
feature_dim=64, feature_dim=64,
decoder_input_dim=64, decoder_input_dim=64,
@ -68,8 +68,8 @@ class Net(nn.Module):
# 初始化 Encoder # 初始化 Encoder
self.encoder = Encoder( self.encoder = Encoder(
surf_bbox_wcs=surf_bbox_wcs, # 使用传入的bbox作为表面包围盒 surf_bbox=surf_bbox, # 使用传入的bbox作为表面包围盒
origin_bbox_wcs=origin_bbox_wcs, # 使用相同的bbox作为原点包围盒 origin_bbox=origin_bbox, # 使用相同的bbox作为原点包围盒
max_depth=max_depth, max_depth=max_depth,
feature_dim=feature_dim feature_dim=feature_dim
) )

84
brep2sdf/networks/octree.py

@ -7,18 +7,32 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from brep2sdf.utils.logger import logger
def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> bool: def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> bool:
"""判断两个包围盒是否相交""" """判断两个轴对齐包围盒(AABB)是否相交
return not (bbox1[3] < bbox2[0] or bbox1[0] > bbox2[3] or
bbox1[4] < bbox2[1] or bbox1[1] > bbox2[4] or 参数:
bbox1[5] < bbox2[2] or bbox1[2] > bbox2[5]) bbox1: 形状为 (6,) 的张量格式 [min_x, min_y, min_z, max_x, max_y, max_z]
bbox2: 同bbox1格式
返回:
bool: 两包围盒是否相交(包括刚好接触的情况)
"""
assert bbox1.shape == (6,) and bbox2.shape == (6,), "输入必须是形状为(6,)的张量"
# 提取min和max坐标
min1, max1 = bbox1[:3], bbox1[3:]
min2, max2 = bbox2[:3], bbox2[3:]
# 向量化比较
return torch.all((max1 >= min2) & (max2 >= min1))
class OctreeNode: class OctreeNode:
feature_dim=None feature_dim=None
device=None device=None
surf_bbox_wcs = None surf_bbox = None
def __init__(self, bbox: torch.Tensor,face_indices: np.ndarray, max_depth: int = 5, feature_dim:int = None, surf_bbox_wcs = None): def __init__(self, bbox: torch.Tensor,face_indices: np.ndarray, max_depth: int = 5, feature_dim:int = None, surf_bbox:torch.Tensor = None):
self.bbox = bbox # 节点的边界框 self.bbox = bbox # 节点的边界框
self.max_depth: int = max_depth # 最大深度,当这个为0时,表示已经到达最大深度,不可再分子节点 self.max_depth: int = max_depth # 最大深度,当这个为0时,表示已经到达最大深度,不可再分子节点
self.children: List['OctreeNode'] = [] # 子节点列表 self.children: List['OctreeNode'] = [] # 子节点列表
@ -29,8 +43,16 @@ class OctreeNode:
if feature_dim is not None: if feature_dim is not None:
OctreeNode.feature_dim = feature_dim OctreeNode.feature_dim = feature_dim
if surf_bbox_wcs is not None: if surf_bbox is not None:
OctreeNode.surf_bbox_wcs = surf_bbox_wcs # NOTE: 只在根节点时创建 if not isinstance(surf_bbox, torch.Tensor):
raise TypeError(
f"surf_bbox 必须是 torch.Tensor 类型,但得到 {type(surf_bbox)}"
)
if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6:
raise ValueError(
f"surf_bbox 应为二维张量且形状为 (N,6),但得到 {surf_bbox.shape}"
)
OctreeNode.surf_bbox = surf_bbox # NOTE: 只在根节点时创建
OctreeNode.device = bbox.device OctreeNode.device = bbox.device
def is_leaf(self): def is_leaf(self):
@ -72,10 +94,10 @@ class OctreeNode:
# 找到与子包围盒相交的面 # 找到与子包围盒相交的面
intersecting_faces = [] intersecting_faces = []
for face_idx in self.face_indices: for face_idx in self.face_indices:
face_bbox = OctreeNode.surf_bbox_wcs[face_idx] face_bbox = OctreeNode.surf_bbox[face_idx]
if bbox_intersect(bbox, face_bbox): if bbox_intersect(bbox, face_bbox):
intersecting_faces.append(face_idx) intersecting_faces.append(face_idx)
#print(f"{bbox}: {intersecting_faces}")
if intersecting_faces: if intersecting_faces:
child_node = OctreeNode( child_node = OctreeNode(
bbox=bbox, bbox=bbox,
@ -95,6 +117,8 @@ class OctreeNode:
""" """
#print(query_point) #print(query_point)
x, y, z = query_point x, y, z = query_point
#logger.info(f"query_point: {query_point}")
#logger.info(f"box: {self.bbox}")
min_x, min_y, min_z, max_x, max_y, max_z = self.bbox min_x, min_y, min_z, max_x, max_y, max_z = self.bbox
# 计算中间点 # 计算中间点
@ -109,7 +133,7 @@ class OctreeNode:
index += 2 index += 2
if z >= mid_z: # 修正变量名 if z >= mid_z: # 修正变量名
index += 4 index += 4
#logger.info(f"index: {index}")
return index return index
def get_feature_vector(self, query_point:torch.Tensor): def get_feature_vector(self, query_point:torch.Tensor):
@ -125,7 +149,18 @@ class OctreeNode:
return self.trilinear_interpolation(query_point) return self.trilinear_interpolation(query_point)
else: else:
index = self.get_child_index(query_point) index = self.get_child_index(query_point)
try:
if index < 0 or index >= len(self.children):
raise IndexError(
f"Child index {index} out of range (0-{len(self.children)-1}) "
f"for query point {query_point.cpu().numpy().tolist()}. "
f"Node bbox: {self.bbox.cpu().numpy().tolist()}"
f"dept info: {self.max_depth}"
)
return self.children[index].get_feature_vector(query_point) return self.children[index].get_feature_vector(query_point)
except IndexError as e:
logger.error(str(e))
raise e
def trilinear_interpolation(self, query_point:torch.Tensor) -> torch.Tensor: def trilinear_interpolation(self, query_point:torch.Tensor) -> torch.Tensor:
""" """
@ -167,3 +202,30 @@ class OctreeNode:
c1 = c01 * (1 - y) + c11 * y c1 = c01 * (1 - y) + c11 * y
return c0 * (1 - z) + c1 * z return c0 * (1 - z) + c1 * z
def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None:
"""
递归打印八叉树结构
参数:
depth: 当前深度 (内部使用)
max_print_depth: 最大打印深度 (None表示打印全部)
"""
if max_print_depth is not None and depth > max_print_depth:
return
# 打印当前节点信息
indent = " " * depth
node_type = "Leaf" if self._is_leaf else "Internal"
print(f"{indent}L{depth} [{node_type}] BBox: {self.bbox.cpu().numpy().tolist()}")
# 打印面片信息(如果有)
if self.face_indices is not None:
print(f"{indent} Face indices: {self.face_indices.tolist()}")
print(f"{indent} len children: {len(self.children)}")
# 递归打印子节点
for i, child in enumerate(self.children):
print(f"{indent} Child {i}:")
child.print_tree(depth + 1, max_print_depth)

27
brep2sdf/scripts/process_brep.py

@ -24,7 +24,7 @@ from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX # 拓扑类
from OCC.Core.BRep import BRep_Tool # B-rep工具 from OCC.Core.BRep import BRep_Tool # B-rep工具
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh # 网格剖分 from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh # 网格剖分
from OCC.Core.TopLoc import TopLoc_Location # 位置变换 from OCC.Core.TopLoc import TopLoc_Location # 位置变换
from OCC.Core.IFSelect import IFSelect_RetDone # 操作状态码 from OCC.Core.IFSelect import IFSelect_RetDone,IFSelect_RetError, IFSelect_RetFail, IFSelect_RetVoid # 操作状态码
from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射 from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射
from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算 from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算
from OCC.Core.Bnd import Bnd_Box # 包围盒 from OCC.Core.Bnd import Bnd_Box # 包围盒
@ -211,7 +211,15 @@ def parse_solid(step_path):
reader = STEPControl_Reader() reader = STEPControl_Reader()
status = reader.ReadFile(step_path) status = reader.ReadFile(step_path)
if status != IFSelect_RetDone: if status != IFSelect_RetDone:
raise Exception("Failed to read STEP file") if status == IFSelect_RetError:
print("Error: An error occurred while reading the file.")
elif status == IFSelect_RetFail:
print("Error: Failed to read the file.")
elif status == IFSelect_RetVoid:
print("Error: No data was read from the file.")
else:
print(f"Unexpected status code: {status}")
raise Exception(f"Failed to read STEP file. {status}")
reader.TransferRoots() reader.TransferRoots()
shape = reader.OneShape() shape = reader.OneShape()
@ -385,8 +393,14 @@ def check_data_format(data, step_file):
return True, "" return True, ""
def process_single_step(step_path:str, output_path:str=None, timeout:int=300) -> dict: def process_single_step(step_path:str, output_path:str=None, timeout:int=300) -> dict:
"""处理单个STEP文件""" """处理单个STEP文件, 从 brep 2 pkl"""
try: try:
if not os.path.exists(step_path):
logger.error(f"STEP文件不存在: {step_path}")
return None
if not step_path.lower().endswith('.step') and not step_path.lower().endswith('.stp'):
logger.error(f"文件格式不支持,必须是.step或.stp文件: {step_path}")
return None
# 解析STEP文件 # 解析STEP文件
data = parse_solid(step_path) data = parse_solid(step_path)
if data is None: if data is None:
@ -596,5 +610,8 @@ def main():
logger.warning("No files were processed") logger.warning("No files were processed")
if __name__ == '__main__': if __name__ == '__main__':
main() # main()
#test("/mnt/disk2/dataset/furniture/step/furniture_dataset_step/train/bathtub_0004.step", "/home/wch/brep2sdf/test_data/pkl/train/bathtub_0004.pkl") test("/home/wch/brep2sdf/data/step/00000000/00000000_290a9120f9f249a7a05cfe9c_step_000.step","/home/wch/brep2sdf/test_data/pkl/train/00000031xx.pkl")
#test("/home/wch/brep2sdf/00000031_ad34a3f60c4a4caa99646600_step_011.step", "/home/wch/brep2sdf/test_data/pkl/train/00000031.pkl")
#test("/mnt/mynewdisk/dataset/furniture/step/furniture_dataset_step/train/bathtub_0004.step", "/home/wch/brep2sdf/test_data/pkl/train/0004.pkl")
#reader = STEPControl_Reader()

6
brep2sdf/scripts/process_furniture.py

@ -301,10 +301,12 @@ def main():
success_rate = (valid_conversions / total_files) * 100 # 这个变量在日志中被使用但未定义 success_rate = (valid_conversions / total_files) * 100 # 这个变量在日志中被使用但未定义
logger.info(f"处理完成: {set_name} 集合, 成功率: {success_rate:.2f}% = {valid_conversions}/{total_files}") logger.info(f"处理完成: {set_name} 集合, 成功率: {success_rate:.2f}% = {valid_conversions}/{total_files}")
def test(step_file: str, set_name:str):
process(step_file, set_name)
if __name__ == "__main__": if __name__ == "__main__":
main() # main()
test("/home/wch/brep2sdf/00000031.step","train")

105
brep2sdf/train.py

@ -2,26 +2,70 @@ import torch
import torch.optim as optim import torch.optim as optim
import time import time
import os import os
import numpy as np
import argparse
from brep2sdf.config.default_config import get_default_config from brep2sdf.config.default_config import get_default_config
from brep2sdf.data.data import load_brep_file,load_sdf_file from brep2sdf.data.data import load_brep_file,load_sdf_file
from brep2sdf.data.pre_process import process_single_step
from brep2sdf.networks.network import Net from brep2sdf.networks.network import Net
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
def prepare_sdf_data(surf_data, max_points=100000, device='cuda'):
total_points = sum(len(s) for s in surf_data)
# 降采样逻辑(修复版)
if total_points > max_points:
# 先随机打乱所有点
all_points = np.concatenate(surf_data)
np.random.shuffle(all_points)
# 直接取前max_points个点
sampled_points = all_points[:max_points]
sdf_array = np.zeros((max_points, 4), dtype=np.float32)
sdf_array[:, :3] = sampled_points
else:
sdf_array = np.zeros((total_points, 4), dtype=np.float32)
sdf_array[:, :3] = np.concatenate(surf_data)
return torch.tensor(sdf_array, dtype=torch.float32, device=device)
class Trainer: class Trainer:
def __init__(self, config): def __init__(self, config, input_step):
self.config = config self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model_name = os.path.basename(input_step).replace(".step", "")
self.base_name = self.model_name + ".xyz"
data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name)
if os.path.exists(data_path):
self.data = load_brep_file(data_path)
else:
self.data = process_single_step(step_path=input_step, output_path=data_path)
# 将曲面点云列表转换为 (N*M, 4) 数组
surfs = self.data["surf_ncs"]
self.sdf_data = prepare_sdf_data(
surfs,
max_points=4096,
device=self.device
)
# 初始化数据集 # 初始化数据集
self.brep_data = load_brep_file(self.config.data.pkl_path) #self.brep_data = load_brep_file(self.config.data.pkl_path)
self.sdf_data = load_sdf_file(sdf_path=self.config.data.sdf_path, num_query_points=self.config.data.num_query_points).to(self.device) #logger.info( self.brep_data )
#self.sdf_data = load_sdf_file(sdf_path=self.config.data.sdf_path, num_query_points=self.config.data.num_query_points).to(self.device)
# 初始化网络 # 初始化网络
bbox = self._calculate_global_bbox()
surf_bbox=torch.tensor(
self.data['surf_bbox_ncs'],
dtype=torch.float32,
device=self.device
)
bbox = self._calculate_global_bbox(surf_bbox)
self.model = Net( self.model = Net(
surf_bbox_wcs=self.brep_data['surf_bbox_wcs'], surf_bbox=surf_bbox,
origin_bbox_wcs=bbox, origin_bbox=bbox,
feature_dim=64 feature_dim=64
).to(self.device) ).to(self.device)
@ -34,28 +78,30 @@ class Trainer:
def _calculate_global_bbox(self) -> torch.Tensor: def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor:
""" """
计算整个数据集的全局边界框 计算整个数据集的全局边界框综合考虑表面包围盒和采样点
参数:
surf_bbox: 形状为 (num_edges, 6) 的Tensor表示每条边的包围盒
[xmin, ymin, zmin, xmax, ymax, zmax]
返回: 返回:
bbox_tensor: 形状为(6,)的Tensor格式为[x_min, y_min, z_min, x_max, y_max, z_max] 形状为 (6,) 的Tensor格式为 [x_min, y_min, z_min, x_max, y_max, z_max]
""" """
# 获取所有点的坐标 # 验证输入
points = self.sdf_data[:, 0:3] # 假设sdf_data的前三列是点的坐标 if not isinstance(surf_bbox, torch.Tensor):
raise TypeError(f"surf_bbox 必须是 torch.Tensor,但得到 {type(surf_bbox)}")
if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6:
raise ValueError(f"surf_bbox 形状应为 (num_edges, 6),但得到 {surf_bbox.shape}")
# 计算最小点和最大点 # 计算表面包围盒的全局范围
min_point = torch.min(points, dim=0).values global_min = surf_bbox[:, :3].min(dim=0).values
max_point = torch.max(points, dim=0).values global_max = surf_bbox[:, 3:].max(dim=0).values
# 确保在正确设备上
min_point = min_point.to(self.device)
max_point = max_point.to(self.device)
# 将最小点和最大点合并成一个(6,)的Tensor # 返回合并后的边界框
bbox_tensor = torch.cat([min_point, max_point], dim=0) return torch.cat([global_min, global_max])
#print(f"bbox_tensor shape: {bbox_tensor.shape}")
return bbox_tensor
def train_epoch(self, epoch: int) -> float: def train_epoch(self, epoch: int) -> float:
self.model.train() self.model.train()
@ -154,13 +200,12 @@ class Trainer:
def _save_checkpoint(self, epoch: int, train_loss: float): def _save_checkpoint(self, epoch: int, train_loss: float):
"""保存训练检查点""" """保存训练检查点"""
checkpoint_path = os.path.join( checkpoint_dir = os.path.join(
self.config.train.checkpoint_dir, self.config.train.checkpoint_dir,
self.config.train.checkpoint_format.format( self.model_name
model_name=self.config.train.model_name,
epoch=epoch
)
) )
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir,f"epoch_{epoch:03d}.pth")
torch.save({ torch.save({
'epoch': epoch, 'epoch': epoch,
'model_state_dict': self.model.state_dict(), 'model_state_dict': self.model.state_dict(),
@ -171,10 +216,16 @@ class Trainer:
def main(): def main():
# 这里需要初始化配置 # 这里需要初始化配置
# 配置命令行参数
parser = argparse.ArgumentParser(description='STEP文件批量处理工具')
parser.add_argument('-i', '--input', required=True,
help='待处理 brep (.step) 路径')
args = parser.parse_args()
config = get_default_config() config = get_default_config()
# 初始化训练器并开始训练 # 初始化训练器并开始训练
trainer = Trainer(config) trainer = Trainer(config, input_step=args.input)
trainer.train() trainer.train()
if __name__ == '__main__': if __name__ == '__main__':

Loading…
Cancel
Save