Browse Source

支持normal

final
mckay 4 months ago
parent
commit
cc88dfb798
  1. 621
      brep2sdf/data/pre_process_by_mesh.py
  2. 254
      brep2sdf/networks/loss.py
  3. 11
      brep2sdf/networks/network.py
  4. 97
      brep2sdf/train.py

621
brep2sdf/data/pre_process_by_mesh.py

@ -0,0 +1,621 @@
"""
CAD模型处理脚本
功能将STEP格式的CAD模型转换为结构化数据包括
- 几何信息顶点的坐标数据
- 拓扑信息--顶点的邻接关系
- 空间信息包围盒数据
"""
import os
import pickle # 用于数据序列化
import argparse # 命令行参数解析
import numpy as np
from tqdm import tqdm # 进度条显示
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError # 并行处理
import logging
from datetime import datetime
from scipy.spatial import cKDTree
from brep2sdf.utils.logger import logger
import tempfile
import trimesh
# 导入OpenCASCADE相关库
from OCC.Core.STEPControl import STEPControl_Reader # STEP文件读取器
from OCC.Core.TopExp import TopExp_Explorer, topexp # 拓扑结构遍历
from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX # 拓扑类型定义
from OCC.Core.BRep import BRep_Tool # B-rep工具
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh # 网格剖分
from OCC.Core.TopLoc import TopLoc_Location # 位置变换
from OCC.Core.IFSelect import IFSelect_RetDone,IFSelect_RetError, IFSelect_RetFail, IFSelect_RetVoid # 操作状态码
from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射
from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算
from OCC.Core.Bnd import Bnd_Box # 包围盒
from OCC.Core.TopoDS import TopoDS_Shape, topods, TopoDS_Vertex # 拓扑数据结构
from OCC.Core.StlAPI import StlAPI_Writer
# 导入配置
from brep2sdf.config.default_config import get_default_config
config = get_default_config()
# 设置最大面数阈值,用于加速处理
MAX_FACE = config.data.max_face
def normalize(surfs, edges, corners):
"""
将CAD模型归一化到单位立方体空间
参数:
surfs: 面的点集列表
edges: 边的点集列表
corners: 顶点坐标数组 [num_edges, 2, 3]
返回:
surfs_wcs: 原始坐标系下的面点集
edges_wcs: 原始坐标系下的边点集
surfs_ncs: 归一化坐标系下的面点集
edges_ncs: 归一化坐标系下的边点集
corner_wcs: 归一化后的顶点坐标 [num_edges, 2, 3]
center: 使用的中心点坐标 [3,]
scale: 使用的缩放系数 (float)
"""
if len(corners) == 0:
return None, None, None, None, None, None, None
# 计算包围盒和缩放因子
corners_array = corners.reshape(-1, 3) # [num_edges*2, 3]
center = (corners_array.max(0) + corners_array.min(0)) / 2 # 计算中心点
scale = 1.0 / (corners_array.max(0) - corners_array.min(0)).max() # 计算缩放系数
# 归一化面的点集
surfs_wcs = [] # 原始世界坐标系下的面点集
surfs_ncs = [] # 归一化坐标系下的面点集
for surf in surfs:
surf_wcs = np.array(surf)
surf_ncs = (surf_wcs - center) * scale # 归一化变换
surfs_wcs.append(surf_wcs)
surfs_ncs.append(surf_ncs)
# 归一化边的点集
edges_wcs = [] # 原始世界坐标系下的边点集
edges_ncs = [] # 归一化坐标系下的边点集
for edge in edges:
edge_wcs = np.array(edge)
edge_ncs = (edge_wcs - center) * scale # 归一化变换
edges_wcs.append(edge_wcs)
edges_ncs.append(edge_ncs)
# 归一化顶点坐标 - 保持[num_edges, 2, 3]的形状
corner_wcs = (corners - center) * scale # 广播操作会保持原有维度
return (np.array(surfs_wcs, dtype=object),
np.array(edges_wcs, dtype=object),
np.array(surfs_ncs, dtype=object),
np.array(edges_ncs, dtype=object),
corner_wcs.astype(np.float32),
center.astype(np.float32),
scale
)
def get_adjacency_info(shape, faces, edges, vertices):
"""
优化后的邻接关系计算函数直接使用已收集的几何元素
参数新增:
faces: 已收集的面列表
edges: 已收集的边列表
vertices: 已收集的顶点列表
"""
logger.debug("Get adjacency infos...")
# 创建边-面映射关系
edge_face_map = TopTools_IndexedDataMapOfShapeListOfShape()
topexp.MapShapesAndAncestors(shape, TopAbs_EDGE, TopAbs_FACE, edge_face_map)
# 直接使用传入的几何元素列表
num_faces = len(faces)
num_edges = len(edges)
num_vertices = len(vertices)
logger.debug(f"num_faces: {num_faces}, num_edges: {num_edges}, num_vertices: {num_vertices}")
edgeFace_adj = np.zeros((num_edges, num_faces), dtype=np.int32)
faceEdge_adj = np.zeros((num_faces, num_edges), dtype=np.int32)
edgeCorner_adj = np.zeros((num_edges, 2), dtype=np.int32)
# 填充边-面邻接矩阵
for i, edge in enumerate(edges):
# 检查每个面是否与当前边相连
for j, face in enumerate(faces):
edge_explorer = TopExp_Explorer(face, TopAbs_EDGE)
while edge_explorer.More():
if edge.IsSame(edge_explorer.Current()):
edgeFace_adj[i, j] = 1
faceEdge_adj[j, i] = 1
break
edge_explorer.Next()
# 获取边的两个端点
v1 = TopoDS_Vertex()
v2 = TopoDS_Vertex()
topexp.Vertices(edge, v1, v2)
# 记录边的端点索引
if not v1.IsNull() and not v2.IsNull():
v1_vertex = topods.Vertex(v1)
v2_vertex = topods.Vertex(v2)
for k, vertex in enumerate(vertices):
if v1_vertex.IsSame(vertex):
edgeCorner_adj[i, 0] = k
if v2_vertex.IsSame(vertex):
edgeCorner_adj[i, 1] = k
return edgeFace_adj, faceEdge_adj, edgeCorner_adj
def get_bbox(shape, subshape):
"""
计算形状的包围盒
参数:
shape: 完整的CAD模型形状
subshape: 需要计算包围盒的子形状面或边
返回:
包围盒的六个参数 [xmin, ymin, zmin, xmax, ymax, zmax]
"""
bbox = Bnd_Box()
brepbndlib.Add(subshape, bbox)
xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get()
return np.array([xmin, ymin, zmin, xmax, ymax, zmax])
def parse_solid(step_path,sample_normal_vector=False):
"""
解析STEP文件中的CAD模型数据
返回:
dict: 包含以下键值对的字典:
# 几何数据
'surf_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示面的点云坐标
'edge_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示边的采样点坐标
'surf_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示归一化后的面点云
'edge_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示归一化后的边采样点
'corner_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 2, 3)的数组,表示每条边的两个端点坐标
'corner_unique': np.ndarray(dtype=float32) # 形状为(num_vertices, 3)的数组,表示所有顶点的唯一坐标,num_vertices <= num_edges * 2
# 拓扑关系
'edgeFace_adj': np.ndarray(dtype=int32) # 形状为(num_edges, num_faces)的数组,表示边-面邻接关系
'edgeCorner_adj': np.ndarray(dtype=int32) # 形状为(num_edges, 2)的数组,表示边-顶点邻接关系
'faceEdge_adj': np.ndarray(dtype=int32) # 形状为(num_faces, num_edges)的数组,表示面-边邻接关系
# 包围盒数据
'surf_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_faces, 6)的数组,表示每个面的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
'edge_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 6)的数组,表示每条边的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
"""
# Load STEP file
reader = STEPControl_Reader()
status = reader.ReadFile(step_path)
if status != IFSelect_RetDone:
if status == IFSelect_RetError:
print("Error: An error occurred while reading the file.")
elif status == IFSelect_RetFail:
print("Error: Failed to read the file.")
elif status == IFSelect_RetVoid:
print("Error: No data was read from the file.")
else:
print(f"Unexpected status code: {status}")
raise Exception(f"Failed to read STEP file. {status}")
reader.TransferRoots()
shape = reader.OneShape()
# Create mesh
mesh = BRepMesh_IncrementalMesh(shape, 0.01)
mesh.Perform()
# Initialize explorers
face_explorer = TopExp_Explorer(shape, TopAbs_FACE)
edge_explorer = TopExp_Explorer(shape, TopAbs_EDGE)
vertex_explorer = TopExp_Explorer(shape, TopAbs_VERTEX)
face_pnts = []
edge_pnts = []
corner_pnts = []
surf_bbox_wcs = []
edge_bbox_wcs = []
faces, edges, vertices = [], [], []
# Extract face points
logger.debug("Extract face points...")
while face_explorer.More():
face = topods.Face(face_explorer.Current())
faces.append(face)
loc = TopLoc_Location()
triangulation = BRep_Tool.Triangulation(face, loc)
if triangulation is not None:
points = []
for i in range(1, triangulation.NbNodes() + 1):
node = triangulation.Node(i)
pnt = node.Transformed(loc.Transformation())
points.append([pnt.X(), pnt.Y(), pnt.Z()])
if points:
points = np.array(points, dtype=np.float32)
if len(points.shape) == 2 and points.shape[1] == 3:
# 确保每个面至少有一些点
if len(points) < 3: # 如果点数太少,跳过这个面
continue
face_pnts.append(points)
surf_bbox_wcs.append(get_bbox(shape, face))
face_explorer.Next()
face_count = len(faces)
if face_count > MAX_FACE:
logger.error(f"step has {face_count} faces, which exceeds MAX_FACE {MAX_FACE}")
return None
# Extract edge points
logger.debug("Extract edge points...")
num_samples = config.model.num_edge_points # 使用配置中的边采样点数
while edge_explorer.More():
edge = topods.Edge(edge_explorer.Current())
edges.append(edge)
logger.debug(len(edges))
curve_info = BRep_Tool.Curve(edge)
if curve_info is None:
continue # 跳过无效边
try:
if len(curve_info) == 3:
curve, first, last = curve_info
elif len(curve_info) == 2:
curve = None # 跳过判断
else:
raise ValueError(f"Unexpected curve info: {curve_info}")
except Exception as e:
logger.error(f"Failed to process edge {edge}: {str(e)}")
curve = None
if curve is not None:
points = []
for i in range(num_samples):
param = first + (last - first) * float(i) / (num_samples - 1)
pnt = curve.Value(param)
points.append([pnt.X(), pnt.Y(), pnt.Z()])
if points:
points = np.array(points, dtype=np.float32)
if len(points.shape) == 2 and points.shape[1] == 3:
edge_pnts.append(points) # 现在points是(num_edge_points, 3)形状
edge_bbox_wcs.append(get_bbox(shape, edge))
edge_explorer.Next()
# Extract vertex points
logger.debug("Extract vertex points...")
while vertex_explorer.More():
vertex = topods.Vertex(vertex_explorer.Current())
vertices.append(vertex)
pnt = BRep_Tool.Pnt(vertex)
corner_pnts.append([pnt.X(), pnt.Y(), pnt.Z()])
vertex_explorer.Next()
# 获取邻接信息
edgeFace_adj, faceEdge_adj, edgeCorner_adj = get_adjacency_info(
shape,
faces=faces, # 传入已收集的面列表
edges=edges, # 传入已收集的边列表
vertices=vertices # 传入已收集的顶点列表
)
logger.debug("complete.")
# 转换为numpy数组时确保类型正确
face_pnts = [np.array(points, dtype=np.float32) for points in face_pnts]
edge_pnts = [np.array(points, dtype=np.float32) for points in edge_pnts]
# 转换为对象数组
face_pnts = np.array(face_pnts, dtype=object)
edge_pnts = np.array(edge_pnts, dtype=object)
corner_pnts = np.array(corner_pnts, dtype=np.float32)
# 重组顶点数据为每条边两个端点的形式
corner_pairs = []
for edge_idx in range(len(edge_pnts)):
v1_idx, v2_idx = edgeCorner_adj[edge_idx]
v1_pos = corner_pnts[v1_idx]
v2_pos = corner_pnts[v2_idx]
# 按坐标排序确保顺序一致
if (v1_pos > v2_pos).any():
v1_pos, v2_pos = v2_pos, v1_pos
corner_pairs.append(np.stack([v1_pos, v2_pos]))
corner_pairs = np.stack(corner_pairs).astype(np.float32) # [num_edges, 2, 3]
# 确保所有数组都有正确的类型
surf_bbox_wcs = np.array(surf_bbox_wcs, dtype=np.float32)
edge_bbox_wcs = np.array(edge_bbox_wcs, dtype=np.float32)
# Normalize the CAD model
surfs_wcs, edges_wcs, surfs_ncs, edges_ncs, corner_wcs,center, scale = normalize(
face_pnts, edge_pnts, corner_pairs)
# 计算归一化后的包围盒
surf_bbox_ncs = np.empty_like(surf_bbox_wcs)
edge_bbox_ncs = np.empty_like(edge_bbox_wcs)
# 转换曲面包围盒到归一化坐标系
surf_bbox_ncs[:, :3] = (surf_bbox_wcs[:, :3] - center) * scale # 最小点
surf_bbox_ncs[:, 3:] = (surf_bbox_wcs[:, 3:] - center) * scale # 最大点
# 转换边包围盒到归一化坐标系
edge_bbox_ncs[:, :3] = (edge_bbox_wcs[:, :3] - center) * scale # 最小点
edge_bbox_ncs[:, 3:] = (edge_bbox_wcs[:, 3:] - center) * scale # 最大点
# 验证归一化后的数据
if any(x is None for x in [surfs_wcs, edges_wcs, surfs_ncs, edges_ncs, corner_wcs]):
logger.error(f"Normalization failed for {step_path}")
return None
# 创建结果字典并确保所有数组都有正确的类型
data = {
'surf_wcs': np.array(surfs_wcs, dtype=object), # 保持对象数组
'edge_wcs': np.array(edges_wcs, dtype=object), # 保持对象数组
'surf_ncs': np.array(surfs_ncs, dtype=object), # 保持对象数组
'edge_ncs': np.array(edges_ncs, dtype=object), # 保持对象数组
'corner_wcs': corner_wcs.astype(np.float32), # [num_edges, 2, 3]
'edgeFace_adj': edgeFace_adj.astype(np.int32),
'edgeCorner_adj': edgeCorner_adj.astype(np.int32),
'faceEdge_adj': faceEdge_adj.astype(np.int32),
'surf_bbox_wcs': surf_bbox_wcs.astype(np.float32),
'edge_bbox_wcs': edge_bbox_wcs.astype(np.float32),
'surf_bbox_ncs': surf_bbox_ncs.astype(np.float32), # 归一化坐标系 [num_faces, 6]
'edge_bbox_ncs': edge_bbox_ncs.astype(np.float32), # 归一化坐标系 [num_edges, 6]
'corner_unique': np.unique(corner_wcs.reshape(-1, 3), axis=0).astype(np.float32), # 先展平再去重
'normalization_params': {
'center': center.astype(np.float32), # 归一化中心点 [3,]
'scale': float(scale), # 归一化缩放系数
}
}
if sample_normal_vector:
# 从 mesh 读 法向量
mesh.Perform()
# 导出为STL临时文件
stl_writer = StlAPI_Writer()
stl_writer.SetASCIIMode(False)
with tempfile.NamedTemporaryFile(suffix='.stl') as tmp:
stl_writer.Write(shape, tmp.name)
trimesh_mesh = trimesh.load(tmp.name)
data['surf_pnt_normals']= batch_compute_normals(trimesh_mesh,surfs_wcs)
return data
def load_step(step_path):
"""Load STEP file and return solids"""
reader = STEPControl_Reader()
reader.ReadFile(step_path)
reader.TransferRoots()
return [reader.OneShape()]
def preprocess_mesh(mesh, normal_type='vertex'):
"""
预处理网格数据生成 KDTree 和法向量源
参数
mesh: trimesh.Trimesh 对象包含顶点和法向量信息
normal_type: str 法向量类型可选 'vertex' 'face'
返回
tree: cKDTree 用于加速最近邻查询
normals_source: np.ndarray 包含顶点法向量或面法向量
"""
if normal_type == 'vertex':
tree = cKDTree(mesh.vertices)
normals_source = mesh.vertex_normals
elif normal_type == 'face':
# 计算每个面的中心点
face_centers = np.mean(mesh.vertices[mesh.faces], axis=1)
tree = cKDTree(face_centers)
normals_source = mesh.face_normals
else:
raise ValueError(f"Unsupported normal type: {normal_type}")
return tree, normals_source
def batch_compute_normals(mesh, surf_wcs, normal_type='vertex', k_neighbors=3):
"""
为嵌套点云数据计算法向量并保持嵌套格式
参数
mesh: trimesh.Trimesh 对象包含顶点和法向量信息
surf_wcs: np.ndarray(dtype=object) 形状为 (N,) 的数组每个元素是形状为 (M, 3) float32 数组
normal_type: str 法向量类型可选 'vertex' 'face'
k_neighbors: int 用于平滑的最近邻数量
返回
normals: np.ndarray(dtype=object) 形状为 (N,) 的数组每个元素是形状为 (M, 3) float32 数组
"""
# 预处理网格数据
tree, normals_source = preprocess_mesh(mesh, normal_type=normal_type)
# 展平所有点云为一个二维数组 [P, 3],并记录分割索引
lengths = [len(point_cloud) for point_cloud in surf_wcs]
query_points = np.concatenate(surf_wcs, axis=0).astype(np.float32) # 避免多次内存分配
# 批量查询最近邻
distances, indices = tree.query(query_points, k=k_neighbors)
# 处理k=1的特殊情况
if k_neighbors == 1:
nearest_normals = normals_source[indices]
else:
# 加权平均(权重为距离倒数)
inv_distances = 1 / (distances + 1e-8) # 防止除以零
sum_inv_distances = inv_distances.sum(axis=1, keepdims=True)
valid_mask = sum_inv_distances > 1e-6
weights = np.divide(inv_distances, sum_inv_distances, out=np.zeros_like(inv_distances), where=valid_mask)
nearest_normals = np.einsum('ijk,ij->ik', normals_source[indices], weights)
# 标准化结果
norms = np.linalg.norm(nearest_normals, axis=1)
valid_mask = norms > 1e-6
nearest_normals[valid_mask] /= norms[valid_mask, None]
# 按原始嵌套结构分割法向量
start = 0
normals_output = np.empty(len(surf_wcs), dtype=object)
for i, length in enumerate(lengths):
end = start + length
normals_output[i] = nearest_normals[start:end]
start = end
return normals_output
def check_data_format(data, step_file):
"""检查数据格式是否正确"""
required_keys = [
'surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs', 'corner_wcs',
'edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj',
'surf_bbox_wcs', 'edge_bbox_wcs', 'corner_unique'
]
# 检查所有必需的键是否存在
for key in required_keys:
if key not in data:
return False, f"Missing key: {key}"
# 检查几何数据
geometry_arrays = ['surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs']
for key in geometry_arrays:
if not isinstance(data[key], np.ndarray):
return False, f"{key} should be a numpy array"
# 允许对象数组
if data[key].dtype != object:
return False, f"{key} should be a numpy array with dtype=object"
# 检查其他数组
float32_arrays = ['corner_wcs', 'surf_bbox_wcs', 'edge_bbox_wcs', 'corner_unique']
for key in float32_arrays:
if not isinstance(data[key], np.ndarray):
return False, f"{key} should be a numpy array"
if data[key].dtype != np.float32:
return False, f"{key} should be a numpy array with dtype=float32"
int32_arrays = ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj']
for key in int32_arrays:
if not isinstance(data[key], np.ndarray):
return False, f"{key} should be a numpy array"
if data[key].dtype != np.int32:
return False, f"{key} should be a numpy array with dtype=int32"
return True, ""
def process_single_step(step_path:str, output_path:str=None, sample_normal_vector=False, timeout:int=300) -> dict:
"""处理单个STEP文件, 从 brep 2 pkl
return data = {
'surf_wcs': np.array(surfs_wcs, dtype=object), # 世界坐标系下的曲面几何数据(对象数组)
'edge_wcs': np.array(edges_wcs, dtype=object), # 世界坐标系下的边几何数据(对象数组)
'surf_ncs': np.array(surfs_ncs, dtype=object), # 归一化坐标系下的曲面几何数据(对象数组) 面归一化点云 [num_faces, num_surf_sample_points, 3]
'edge_ncs': np.array(edges_ncs, dtype=object), # 归一化坐标系下的边几何数据(对象数组) 边归一化点云 [num_edges, num_edge_sample_points, 3]
'corner_wcs': corner_wcs.astype(np.float32), # 世界坐标系下的角点数据 [num_edges, 2, 3]
'edgeFace_adj': edgeFace_adj.astype(np.int32), # 边-面的邻接关系矩阵
'edgeCorner_adj': edgeCorner_adj.astype(np.int32),# 边-角点的邻接关系矩阵
'faceEdge_adj': faceEdge_adj.astype(np.int32), # 面-边的邻接关系矩阵
'surf_bbox_wcs': surf_bbox_wcs.astype(np.float32),# 曲面在世界坐标系下的包围盒
'edge_bbox_wcs': edge_bbox_wcs.astype(np.float32),# 边在世界坐标系下的包围盒
'corner_unique': np.unique(corner_wcs.reshape(-1, 3), axis=0).astype(np.float32) # 去重后的唯一角点坐标
}"""
try:
logger.info("数据预处理……")
if not os.path.exists(step_path):
logger.error(f"STEP文件不存在: {step_path}")
return None
if not step_path.lower().endswith('.step') and not step_path.lower().endswith('.stp'):
logger.error(f"文件格式不支持,必须是.step或.stp文件: {step_path}")
return None
# 解析STEP文件
data = parse_solid(step_path, sample_normal_vector)
if data is None:
logger.error(f"Failed to parse STEP file: {step_path}")
return None
# 检查数据格式
is_valid, msg = check_data_format(data, step_path)
if not is_valid:
logger.error(f"Data format check failed for {step_path}: {msg}")
return None
# 保存结果
if output_path:
try:
logger.info(f"Saving results to: {output_path}")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'wb') as f:
pickle.dump(data, f)
logger.info("数据预处理完成")
logger.info(f"Results saved successfully: {output_path}")
return data
except Exception as e:
logger.error(f'Failed to save {output_path}: {str(e)}')
return None
logger.info("数据预处理完成")
return data
except Exception as e:
logger.error(f'Error processing {step_path}: {str(e)}')
return None
def test(step_file_path, output_path=None):
"""
测试函数转换单个STEP文件并保存结果
"""
try:
logger.info(f"Processing STEP file: {step_file_path}")
# 解析STEP文件
data = parse_solid(step_file_path)
if data is None:
logger.error(f"Failed to parse STEP file: {step_file_path}")
return None
# 检查数据格式
is_valid, msg = check_data_format(data, step_file_path)
if not is_valid:
logger.error(f"Data format check failed for {step_file_path}: {msg}")
return None
# 打印统计信息
logger.info("\nStatistics:")
logger.info(f"Number of surfaces: {len(data['surf_wcs'])}")
logger.info(f"Number of edges: {len(data['edge_wcs'])}")
logger.info(f"Number of corners: {len(data['corner_wcs'])}") # 修正为corner_wcs
# 保存结果
if output_path:
try:
logger.info(f"Saving results to: {output_path}")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'wb') as f:
pickle.dump(data, f)
logger.info(f"Results saved successfully: {output_path}")
except Exception as e:
logger.error(f"Failed to save {output_path}: {str(e)}")
return None
return data
except Exception as e:
logger.error(f"Error processing {step_file_path}: {str(e)}")
return None
if __name__ == '__main__':
# main()
test("/home/wch/brep2sdf/data/step/00000000/00000000_290a9120f9f249a7a05cfe9c_step_000.step","/home/wch/brep2sdf/test_data/pkl/train/00000031xx.pkl")
#test("/home/wch/brep2sdf/00000031_ad34a3f60c4a4caa99646600_step_011.step", "/home/wch/brep2sdf/test_data/pkl/train/00000031.pkl")
#test("/mnt/mynewdisk/dataset/furniture/step/furniture_dataset_step/train/bathtub_0004.step", "/home/wch/brep2sdf/test_data/pkl/train/0004.pkl")
#reader = STEPControl_Reader()

254
brep2sdf/networks/loss.py

@ -1,6 +1,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from .network import gradient
from brep2sdf.config.default_config import get_default_config from brep2sdf.config.default_config import get_default_config
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
@ -66,35 +66,229 @@ class Brep2SDFLoss(nn.Module):
return grad_loss return grad_loss
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1):
"""SDF损失函数"""
# 确保points需要梯度
if not points.requires_grad:
points = points.detach().requires_grad_(True)
# L1损失
l1_loss = F.l1_loss(pred_sdf, gt_sdf) class LossManager:
def __init__(self, ablation, **condition_kwargs):
self.weights = {
"manifold": 1,
"feature_manifold": 1, # 原文里面和manifold的权重是一样的
"normals": 1,
"eikonal": 1,
"offsurface": 1,
"consistency": 1,
"correction": 1,
}
self.condition_kwargs = condition_kwargs
self.ablation = ablation # 消融实验用
def _get_condition_kwargs(self, key):
"""
获取条件参数, 期望
ab: 损失类型 overall, patch, off, cons, cc, cor,
siren: 是否使用SIREN
epoch: 当前epoch
baseline: 是否为baseline
"""
if key in self.condition_kwargs:
return self.condition_kwargs[key]
else:
raise ValueError(f"Key {key} not found in condition_kwargs")
def pre_process(self, mnfld_pnts, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last):
"""
预处理
"""
mnfld_pred_h = mnfld_pred_all[:,0] # 提取流形预测结果
nonmnfld_pred_h = nonmnfld_pred_all[:,0] # 提取非流形预测结果
mnfld_grad = gradient(mnfld_pnts, mnfld_pred_h) # 计算流形点的梯度
all_fi = torch.zeros([n_batchsize, 1], device = 'cuda') # 初始化所有流形预测值
for i in range(n_branch - 1):
all_fi[i * n_patch_batch : (i + 1) * n_patch_batch, 0] = mnfld_pred_all[i * n_patch_batch : (i + 1) * n_patch_batch, i + 1] # 填充流形预测值
# last patch
all_fi[(n_branch - 1) * n_patch_batch:, 0] = mnfld_pred_all[(n_branch - 1) * n_patch_batch:, n_branch] # 填充最后一个分支的流形预测值
return mnfld_pred_h, nonmnfld_pred_h, mnfld_grad, all_fi
def position_loss(self, pred_sdfs: torch.Tensor, gt_sdfs: torch.Tensor) -> torch.Tensor:
"""
计算流型损失的逻辑
:param pred_sdfs: 预测的SDF值形状为 (N, 1)
:param gt_sdfs: 真实的SDF值形状为 (N, 1)
:return: 计算得到的流型损失标量
"""
# 计算预测值与真实值的差
diff = pred_sdfs - gt_sdfs
# 计算平方误差
squared_diff = torch.pow(diff, 2)
# 计算均值
manifold_loss = torch.mean(squared_diff)
return manifold_loss
try: def normals_loss(self, normals: torch.Tensor, mnfld_pnts: torch.Tensor, pred_sdfs) -> torch.Tensor:
# 梯度约束损失 """
grad = torch.autograd.grad( 计算法线损失
pred_sdf.sum(),
points, :param normals: 法线
create_graph=True, :param mnfld_pnts: 流型点
retain_graph=True, :param all_fi: 所有流型预测值
allow_unused=True :param patch_sup: 是否支持补丁
)[0] :return: 计算得到的法线损失
"""
# NOTE 源代码 这里还有复杂逻辑
# 计算分支梯度
branch_grad = gradient(mnfld_pnts, pred_sdfs) # 计算分支梯度
# 计算法线损失
normals_loss = (((branch_grad - normals).abs()).norm(2, dim=1)).mean() # 计算法线损失
return normals_loss # 返回加权后的法线损失
def eikonal_loss(self, nonmnfld_pnts, nonmnfld_pred):
"""
计算Eikonal损失
"""
grad_loss_h = torch.zeros(1).cuda() # 初始化 Eikonal 损失
single_nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred) # 计算非流形点的梯度
eikonal_loss = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean() # 计算 Eikonal 损失
return eikonal_loss
def offsurface_loss(self, nonmnfld_pnts, nonmnfld_pred):
"""
Eo
惩罚远离表面但是预测值接近0的点
"""
offsurface_loss = torch.zeros(1).cuda()
if not self.ablation == 'off':
offsurface_loss = torch.exp(-100.0 * torch.abs(nonmnfld_pred)).mean() # 计算离表面损失
return offsurface_loss
def consistency_loss(self, mnfld_pnts, mnfld_pred, all_fi):
"""
惩罚流形点预测值和非流形点预测值不一致的点
"""
mnfld_consistency_loss = torch.zeros(1).cuda()
if not (self.ablation == 'cons' or self.ablation == 'cc'):
mnfld_consistency_loss = (mnfld_pred - all_fi[:,0]).abs().mean() # 计算流形一致性损失
return mnfld_consistency_loss
def correction_loss(self, mnfld_pnts, mnfld_pred, all_fi, th_closeness = 1e-5, a_correction = 100):
"""
修正损失
"""
correction_loss = torch.zeros(1).cuda() # 初始化修正损失
if not (self.ablation == 'cor' or self.ablation == 'cc'):
mismatch_id = torch.abs(mnfld_pred - all_fi[:,0]) > th_closeness # 计算不匹配的 ID
if mismatch_id.sum() != 0: # 如果存在不匹配
correction_loss = (a_correction * torch.abs(mnfld_pred - all_fi[:,0])[mismatch_id]).mean() # 计算修正损失
return correction_loss
def compute_loss(self, points,
normals,
gt_sdfs,
pred_sdfs):
"""
计算流型损失的逻辑
:param outputs: 模型的输出
:return: 计算得到的流型损失值
"""
# 计算流形损失
manifold_loss = self.position_loss(pred_sdfs,gt_sdfs)
if grad is not None: # 计算法线损失
grad_constraint = F.mse_loss( normals_loss = self.normals_loss(normals, points, pred_sdfs)
torch.norm(grad, dim=-1),
torch.ones_like(pred_sdf.squeeze(-1)) # 汇总损失
) loss_details = {
else: "manifold": self.weights["manifold"] * manifold_loss,
grad_constraint = torch.tensor(0.0, device=pred_sdf.device) "normals": self.weights["normals"] * normals_loss,
}
# 计算总损失
total_loss = sum(loss_details.values())
return total_loss, loss_details
def _compute_loss(self, mnfld_pnts, normals, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last):
"""
计算流型损失的逻辑
:param outputs: 模型的输出
:return: 计算得到的流型损失值
"""
mnfld_pred, nonmnfld_pred, mnfld_grad, all_fi = self.pre_process(mnfld_pnts, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last)
manifold_loss = torch.zeros(1).cuda()
# 计算流型损失(这里使用均方误差作为示例)
if not self.ablation == 'overall':
manifold_loss = (mnfld_pred.abs()).mean() # 计算流型损失
'''
if args.feature_sample: # 如果启用了特征采样
feature_indices = torch.randperm(args.all_feature_sample)[:args.num_feature_sample].cuda() # 随机选择特征点
feature_pnts = self.feature_data[feature_indices] # 获取特征点数据
feature_mask_pair = self.feature_data_mask_pair[feature_indices] # 获取特征掩码对
feature_pred_all = self.network(feature_pnts) # 进行前向传播,计算特征点的预测值
feature_pred = feature_pred_all[:,0] # 提取特征预测结果
feature_mnfld_loss = feature_pred.abs().mean() # 计算特征流形损失
loss = loss + weight_mnfld_h * feature_mnfld_loss # 将特征流形损失加权到总损失中
except Exception as e: # patch loss:
logger.warning(f"Gradient computation failed: {str(e)}") feature_id_left = [list(range(args.num_feature_sample)), feature_mask_pair[:,0].tolist()] # 获取左侧特征 ID
grad_constraint = torch.tensor(0.0, device=pred_sdf.device) feature_id_right = [list(range(args.num_feature_sample)), feature_mask_pair[:,1].tolist()] # 获取右侧特征 ID
feature_fis_left = feature_pred_all[feature_id_left] # 获取左侧特征预测值
return l1_loss + grad_weight * grad_constraint feature_fis_right = feature_pred_all[feature_id_right] # 获取右侧特征预测值
feature_loss_patch = feature_fis_left.abs().mean() + feature_fis_right.abs().mean() # 计算补丁损失
loss += feature_loss_patch # 将补丁损失加权到总损失中
# consistency loss:
feature_loss_cons = (feature_fis_left - feature_pred).abs().mean() + (feature_fis_right - feature_pred).abs().mean() # 计算一致性损失
'''
manifold_loss_patch = torch.zeros(1).cuda()
if self.ablation == 'patch':
manifold_loss_patch = all_fi[:,0].abs().mean()
# 计算法线损失
normals_loss = self.normals_loss(normals, mnfld_pnts, all_fi, patch_sup=True)
# 计算Eikonal损失
eikonal_loss = self.eikonal_loss(nonmnfld_pnts, nonmnfld_pred_all)
# 计算离表面损失
offsurface_loss = self.offsurface_loss(nonmnfld_pnts, nonmnfld_pred_all)
# 计算一致性损失
consistency_loss = self.consistency_loss(mnfld_pnts, mnfld_pred, all_fi)
# 计算修正损失
correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi)
loss_details = {
"manifold": self.weights["manifold"] * manifold_loss,
"manifold_patch": manifold_loss_patch,
"normals": self.weights["normals"] * normals_loss,
"eikonal": self.weights["eikonal"] * eikonal_loss,
"offsurface": self.weights["offsurface"] * offsurface_loss,
"consistency": self.weights["consistency"] * consistency_loss,
"correction": self.weights["correction"] * correction_loss,
}
# 计算总损失
total_loss = sum(loss_details.values())
return total_loss, loss_details

11
brep2sdf/networks/network.py

@ -48,6 +48,7 @@ class GridNet:
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.autograd import grad
from .encoder import Encoder from .encoder import Encoder
from .decoder import Decoder from .decoder import Decoder
@ -90,3 +91,13 @@ class Net(nn.Module):
return output return output
def gradient(inputs, outputs):
d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
points_grad = grad(
outputs=outputs,
inputs=inputs,
grad_outputs=d_points,
create_graph=True,
retain_graph=True,
only_inputs=True)[0][:, -3:]
return points_grad

97
brep2sdf/train.py

@ -8,26 +8,64 @@ import argparse
from brep2sdf.config.default_config import get_default_config from brep2sdf.config.default_config import get_default_config
from brep2sdf.data.data import load_brep_file,load_sdf_file from brep2sdf.data.data import load_brep_file,load_sdf_file
from brep2sdf.data.pre_process import process_single_step from brep2sdf.data.pre_process_by_mesh import process_single_step
from brep2sdf.networks.network import Net from brep2sdf.networks.network import Net
from brep2sdf.networks.octree import OctreeNode from brep2sdf.networks.octree import OctreeNode
from brep2sdf.networks.loss import LossManager
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
def prepare_sdf_data(surf_data, max_points=100000, device='cuda'):
# 配置命令行参数
parser = argparse.ArgumentParser(description='STEP文件批量处理工具')
parser.add_argument('-i', '--input', required=True,
help='待处理 brep (.step) 路径')
parser.add_argument(
'--use-normal',
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
help='强制采样点有法向量'
)
parser.add_argument(
'--force-reprocess',
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
help='强制重新进行数据预处理,忽略缓存或已有结果'
)
args = parser.parse_args()
def prepare_sdf_data(surf_data, normals=None, max_points=100000, device='cuda'):
total_points = sum(len(s) for s in surf_data) total_points = sum(len(s) for s in surf_data)
# 降采样逻辑(修复版) # 降采样逻辑(修复版)
if total_points > max_points: if total_points > max_points:
# 先随机打乱所有点 # 生成索引
all_points = np.concatenate(surf_data) indices = []
np.random.shuffle(all_points) for i, points in enumerate(surf_data):
# 直接取前max_points个点 indices.extend([(i, j) for j in range(len(points))])
sampled_points = all_points[:max_points]
sdf_array = np.zeros((max_points, 4), dtype=np.float32) # 随机打乱索引
sdf_array[:, :3] = sampled_points np.random.shuffle(indices)
# 选择前max_points个索引
selected_indices = indices[:max_points]
if not normals is None:
# 根据索引构建sdf_array
sdf_array = np.zeros((max_points, 4), dtype=np.float32)
for idx, (i, j) in enumerate(selected_indices):
sdf_array[idx, :3] = surf_data[i][j]
else:
sdf_array = np.zeros((max_points, 7), dtype=np.float32)
for idx, (i, j) in enumerate(selected_indices):
sdf_array[idx, :3] = surf_data[i][j]
sdf_array[idx, 3:6] = normals[i][j]
else: else:
sdf_array = np.zeros((total_points, 4), dtype=np.float32) if not normals is None:
sdf_array[:, :3] = np.concatenate(surf_data) sdf_array = np.zeros((total_points, 4), dtype=np.float32)
sdf_array[:, :3] = np.concatenate(surf_data)
sdf_array = np.zeros((max_points, 7), dtype=np.float32)
else:
for idx, (i, j) in enumerate(selected_indices):
sdf_array[idx, :3] = surf_data[i][j]
sdf_array[idx, 3:6] = normals[i][j]
return torch.tensor(sdf_array, dtype=torch.float32, device=device) return torch.tensor(sdf_array, dtype=torch.float32, device=device)
@ -40,15 +78,16 @@ class Trainer:
self.model_name = os.path.basename(input_step).split('_')[0] self.model_name = os.path.basename(input_step).split('_')[0]
self.base_name = self.model_name + ".xyz" self.base_name = self.model_name + ".xyz"
data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name) data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name)
if os.path.exists(data_path): if os.path.exists(data_path) and not args.force_reprocess:
self.data = load_brep_file(data_path) self.data = load_brep_file(data_path)
else: else:
self.data = process_single_step(step_path=input_step, output_path=data_path) self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal)
# 将曲面点云列表转换为 (N*M, 4) 数组 # 将曲面点云列表转换为 (N*M, 4) 数组
surfs = self.data["surf_ncs"] surfs = self.data["surf_ncs"]
self.sdf_data = prepare_sdf_data( self.sdf_data = prepare_sdf_data(
surfs, surfs,
normals = self.data["surf_pnt_normals"],
max_points=4096, max_points=4096,
device=self.device device=self.device
) )
@ -81,6 +120,8 @@ class Trainer:
weight_decay=config.train.weight_decay weight_decay=config.train.weight_decay
) )
self.loss_manager = LossManager(ablation="none")
def build_tree(self,surf_bbox, max_depth=6): def build_tree(self,surf_bbox, max_depth=6):
num_faces = surf_bbox.shape[0] num_faces = surf_bbox.shape[0]
@ -131,14 +172,27 @@ class Trainer:
# 获取数据并移动到设备 # 获取数据并移动到设备
points = self.sdf_data[:,0:3] points = self.sdf_data[:,0:3]
points.requires_grad_(True) points.requires_grad_(True)
gt_sdf = self.sdf_data[:,3] if args.use_normal:
normals = self.sdf_data[:,3:6]
gt_sdf = self.sdf_data[:,6]
else:
gt_sdf = self.sdf_data[:,3]
# 前向传播 # 前向传播
self.optimizer.zero_grad() self.optimizer.zero_grad()
pred_sdf = self.model(points) pred_sdf = self.model(points)
# 计算损失 # 计算损失
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) if args.use_normal:
loss,loss_details = self.loss_manager.compute_loss(
points,
normals,
gt_sdf,
pred_sdf
) # 计算损失
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
# 反向传播和优化 # 反向传播和优化
loss.backward() loss.backward()
@ -173,7 +227,7 @@ class Trainer:
best_val_loss = float('inf') best_val_loss = float('inf')
logger.info("Starting training...") logger.info("Starting training...")
start_time = time.time() start_time = time.time()
"""
for epoch in range(1, self.config.train.num_epochs + 1): for epoch in range(1, self.config.train.num_epochs + 1):
# 训练一个epoch # 训练一个epoch
train_loss = self.train_epoch(epoch) train_loss = self.train_epoch(epoch)
@ -202,8 +256,8 @@ class Trainer:
logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s') logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s')
logger.info(f'Best validation loss: {best_val_loss:.6f}') logger.info(f'Best validation loss: {best_val_loss:.6f}')
self._tracing_model() self._tracing_model()
"""
self.test_load() #self.test_load()
def test_load(self): def test_load(self):
model = self._load_checkpoint("/home/wch/brep2sdf/brep2sdf/00000054.pt") model = self._load_checkpoint("/home/wch/brep2sdf/brep2sdf/00000054.pt")
@ -250,14 +304,7 @@ class Trainer:
def main(): def main():
# 这里需要初始化配置 # 这里需要初始化配置
# 配置命令行参数
parser = argparse.ArgumentParser(description='STEP文件批量处理工具')
parser.add_argument('-i', '--input', required=True,
help='待处理 brep (.step) 路径')
args = parser.parse_args()
config = get_default_config() config = get_default_config()
# 初始化训练器并开始训练 # 初始化训练器并开始训练
trainer = Trainer(config, input_step=args.input) trainer = Trainer(config, input_step=args.input)
trainer.train() trainer.train()

Loading…
Cancel
Save