Browse Source

支持normal

final
mckay 4 months ago
parent
commit
cc88dfb798
  1. 621
      brep2sdf/data/pre_process_by_mesh.py
  2. 254
      brep2sdf/networks/loss.py
  3. 11
      brep2sdf/networks/network.py
  4. 97
      brep2sdf/train.py

621
brep2sdf/data/pre_process_by_mesh.py

@ -0,0 +1,621 @@
"""
CAD模型处理脚本
功能将STEP格式的CAD模型转换为结构化数据包括
- 几何信息顶点的坐标数据
- 拓扑信息--顶点的邻接关系
- 空间信息包围盒数据
"""
import os
import pickle # 用于数据序列化
import argparse # 命令行参数解析
import numpy as np
from tqdm import tqdm # 进度条显示
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError # 并行处理
import logging
from datetime import datetime
from scipy.spatial import cKDTree
from brep2sdf.utils.logger import logger
import tempfile
import trimesh
# 导入OpenCASCADE相关库
from OCC.Core.STEPControl import STEPControl_Reader # STEP文件读取器
from OCC.Core.TopExp import TopExp_Explorer, topexp # 拓扑结构遍历
from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX # 拓扑类型定义
from OCC.Core.BRep import BRep_Tool # B-rep工具
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh # 网格剖分
from OCC.Core.TopLoc import TopLoc_Location # 位置变换
from OCC.Core.IFSelect import IFSelect_RetDone,IFSelect_RetError, IFSelect_RetFail, IFSelect_RetVoid # 操作状态码
from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射
from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算
from OCC.Core.Bnd import Bnd_Box # 包围盒
from OCC.Core.TopoDS import TopoDS_Shape, topods, TopoDS_Vertex # 拓扑数据结构
from OCC.Core.StlAPI import StlAPI_Writer
# 导入配置
from brep2sdf.config.default_config import get_default_config
config = get_default_config()
# 设置最大面数阈值,用于加速处理
MAX_FACE = config.data.max_face
def normalize(surfs, edges, corners):
"""
将CAD模型归一化到单位立方体空间
参数:
surfs: 面的点集列表
edges: 边的点集列表
corners: 顶点坐标数组 [num_edges, 2, 3]
返回:
surfs_wcs: 原始坐标系下的面点集
edges_wcs: 原始坐标系下的边点集
surfs_ncs: 归一化坐标系下的面点集
edges_ncs: 归一化坐标系下的边点集
corner_wcs: 归一化后的顶点坐标 [num_edges, 2, 3]
center: 使用的中心点坐标 [3,]
scale: 使用的缩放系数 (float)
"""
if len(corners) == 0:
return None, None, None, None, None, None, None
# 计算包围盒和缩放因子
corners_array = corners.reshape(-1, 3) # [num_edges*2, 3]
center = (corners_array.max(0) + corners_array.min(0)) / 2 # 计算中心点
scale = 1.0 / (corners_array.max(0) - corners_array.min(0)).max() # 计算缩放系数
# 归一化面的点集
surfs_wcs = [] # 原始世界坐标系下的面点集
surfs_ncs = [] # 归一化坐标系下的面点集
for surf in surfs:
surf_wcs = np.array(surf)
surf_ncs = (surf_wcs - center) * scale # 归一化变换
surfs_wcs.append(surf_wcs)
surfs_ncs.append(surf_ncs)
# 归一化边的点集
edges_wcs = [] # 原始世界坐标系下的边点集
edges_ncs = [] # 归一化坐标系下的边点集
for edge in edges:
edge_wcs = np.array(edge)
edge_ncs = (edge_wcs - center) * scale # 归一化变换
edges_wcs.append(edge_wcs)
edges_ncs.append(edge_ncs)
# 归一化顶点坐标 - 保持[num_edges, 2, 3]的形状
corner_wcs = (corners - center) * scale # 广播操作会保持原有维度
return (np.array(surfs_wcs, dtype=object),
np.array(edges_wcs, dtype=object),
np.array(surfs_ncs, dtype=object),
np.array(edges_ncs, dtype=object),
corner_wcs.astype(np.float32),
center.astype(np.float32),
scale
)
def get_adjacency_info(shape, faces, edges, vertices):
"""
优化后的邻接关系计算函数直接使用已收集的几何元素
参数新增:
faces: 已收集的面列表
edges: 已收集的边列表
vertices: 已收集的顶点列表
"""
logger.debug("Get adjacency infos...")
# 创建边-面映射关系
edge_face_map = TopTools_IndexedDataMapOfShapeListOfShape()
topexp.MapShapesAndAncestors(shape, TopAbs_EDGE, TopAbs_FACE, edge_face_map)
# 直接使用传入的几何元素列表
num_faces = len(faces)
num_edges = len(edges)
num_vertices = len(vertices)
logger.debug(f"num_faces: {num_faces}, num_edges: {num_edges}, num_vertices: {num_vertices}")
edgeFace_adj = np.zeros((num_edges, num_faces), dtype=np.int32)
faceEdge_adj = np.zeros((num_faces, num_edges), dtype=np.int32)
edgeCorner_adj = np.zeros((num_edges, 2), dtype=np.int32)
# 填充边-面邻接矩阵
for i, edge in enumerate(edges):
# 检查每个面是否与当前边相连
for j, face in enumerate(faces):
edge_explorer = TopExp_Explorer(face, TopAbs_EDGE)
while edge_explorer.More():
if edge.IsSame(edge_explorer.Current()):
edgeFace_adj[i, j] = 1
faceEdge_adj[j, i] = 1
break
edge_explorer.Next()
# 获取边的两个端点
v1 = TopoDS_Vertex()
v2 = TopoDS_Vertex()
topexp.Vertices(edge, v1, v2)
# 记录边的端点索引
if not v1.IsNull() and not v2.IsNull():
v1_vertex = topods.Vertex(v1)
v2_vertex = topods.Vertex(v2)
for k, vertex in enumerate(vertices):
if v1_vertex.IsSame(vertex):
edgeCorner_adj[i, 0] = k
if v2_vertex.IsSame(vertex):
edgeCorner_adj[i, 1] = k
return edgeFace_adj, faceEdge_adj, edgeCorner_adj
def get_bbox(shape, subshape):
"""
计算形状的包围盒
参数:
shape: 完整的CAD模型形状
subshape: 需要计算包围盒的子形状面或边
返回:
包围盒的六个参数 [xmin, ymin, zmin, xmax, ymax, zmax]
"""
bbox = Bnd_Box()
brepbndlib.Add(subshape, bbox)
xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get()
return np.array([xmin, ymin, zmin, xmax, ymax, zmax])
def parse_solid(step_path,sample_normal_vector=False):
"""
解析STEP文件中的CAD模型数据
返回:
dict: 包含以下键值对的字典:
# 几何数据
'surf_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示面的点云坐标
'edge_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示边的采样点坐标
'surf_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示归一化后的面点云
'edge_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示归一化后的边采样点
'corner_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 2, 3)的数组,表示每条边的两个端点坐标
'corner_unique': np.ndarray(dtype=float32) # 形状为(num_vertices, 3)的数组,表示所有顶点的唯一坐标,num_vertices <= num_edges * 2
# 拓扑关系
'edgeFace_adj': np.ndarray(dtype=int32) # 形状为(num_edges, num_faces)的数组,表示边-面邻接关系
'edgeCorner_adj': np.ndarray(dtype=int32) # 形状为(num_edges, 2)的数组,表示边-顶点邻接关系
'faceEdge_adj': np.ndarray(dtype=int32) # 形状为(num_faces, num_edges)的数组,表示面-边邻接关系
# 包围盒数据
'surf_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_faces, 6)的数组,表示每个面的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
'edge_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 6)的数组,表示每条边的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
"""
# Load STEP file
reader = STEPControl_Reader()
status = reader.ReadFile(step_path)
if status != IFSelect_RetDone:
if status == IFSelect_RetError:
print("Error: An error occurred while reading the file.")
elif status == IFSelect_RetFail:
print("Error: Failed to read the file.")
elif status == IFSelect_RetVoid:
print("Error: No data was read from the file.")
else:
print(f"Unexpected status code: {status}")
raise Exception(f"Failed to read STEP file. {status}")
reader.TransferRoots()
shape = reader.OneShape()
# Create mesh
mesh = BRepMesh_IncrementalMesh(shape, 0.01)
mesh.Perform()
# Initialize explorers
face_explorer = TopExp_Explorer(shape, TopAbs_FACE)
edge_explorer = TopExp_Explorer(shape, TopAbs_EDGE)
vertex_explorer = TopExp_Explorer(shape, TopAbs_VERTEX)
face_pnts = []
edge_pnts = []
corner_pnts = []
surf_bbox_wcs = []
edge_bbox_wcs = []
faces, edges, vertices = [], [], []
# Extract face points
logger.debug("Extract face points...")
while face_explorer.More():
face = topods.Face(face_explorer.Current())
faces.append(face)
loc = TopLoc_Location()
triangulation = BRep_Tool.Triangulation(face, loc)
if triangulation is not None:
points = []
for i in range(1, triangulation.NbNodes() + 1):
node = triangulation.Node(i)
pnt = node.Transformed(loc.Transformation())
points.append([pnt.X(), pnt.Y(), pnt.Z()])
if points:
points = np.array(points, dtype=np.float32)
if len(points.shape) == 2 and points.shape[1] == 3:
# 确保每个面至少有一些点
if len(points) < 3: # 如果点数太少,跳过这个面
continue
face_pnts.append(points)
surf_bbox_wcs.append(get_bbox(shape, face))
face_explorer.Next()
face_count = len(faces)
if face_count > MAX_FACE:
logger.error(f"step has {face_count} faces, which exceeds MAX_FACE {MAX_FACE}")
return None
# Extract edge points
logger.debug("Extract edge points...")
num_samples = config.model.num_edge_points # 使用配置中的边采样点数
while edge_explorer.More():
edge = topods.Edge(edge_explorer.Current())
edges.append(edge)
logger.debug(len(edges))
curve_info = BRep_Tool.Curve(edge)
if curve_info is None:
continue # 跳过无效边
try:
if len(curve_info) == 3:
curve, first, last = curve_info
elif len(curve_info) == 2:
curve = None # 跳过判断
else:
raise ValueError(f"Unexpected curve info: {curve_info}")
except Exception as e:
logger.error(f"Failed to process edge {edge}: {str(e)}")
curve = None
if curve is not None:
points = []
for i in range(num_samples):
param = first + (last - first) * float(i) / (num_samples - 1)
pnt = curve.Value(param)
points.append([pnt.X(), pnt.Y(), pnt.Z()])
if points:
points = np.array(points, dtype=np.float32)
if len(points.shape) == 2 and points.shape[1] == 3:
edge_pnts.append(points) # 现在points是(num_edge_points, 3)形状
edge_bbox_wcs.append(get_bbox(shape, edge))
edge_explorer.Next()
# Extract vertex points
logger.debug("Extract vertex points...")
while vertex_explorer.More():
vertex = topods.Vertex(vertex_explorer.Current())
vertices.append(vertex)
pnt = BRep_Tool.Pnt(vertex)
corner_pnts.append([pnt.X(), pnt.Y(), pnt.Z()])
vertex_explorer.Next()
# 获取邻接信息
edgeFace_adj, faceEdge_adj, edgeCorner_adj = get_adjacency_info(
shape,
faces=faces, # 传入已收集的面列表
edges=edges, # 传入已收集的边列表
vertices=vertices # 传入已收集的顶点列表
)
logger.debug("complete.")
# 转换为numpy数组时确保类型正确
face_pnts = [np.array(points, dtype=np.float32) for points in face_pnts]
edge_pnts = [np.array(points, dtype=np.float32) for points in edge_pnts]
# 转换为对象数组
face_pnts = np.array(face_pnts, dtype=object)
edge_pnts = np.array(edge_pnts, dtype=object)
corner_pnts = np.array(corner_pnts, dtype=np.float32)
# 重组顶点数据为每条边两个端点的形式
corner_pairs = []
for edge_idx in range(len(edge_pnts)):
v1_idx, v2_idx = edgeCorner_adj[edge_idx]
v1_pos = corner_pnts[v1_idx]
v2_pos = corner_pnts[v2_idx]
# 按坐标排序确保顺序一致
if (v1_pos > v2_pos).any():
v1_pos, v2_pos = v2_pos, v1_pos
corner_pairs.append(np.stack([v1_pos, v2_pos]))
corner_pairs = np.stack(corner_pairs).astype(np.float32) # [num_edges, 2, 3]
# 确保所有数组都有正确的类型
surf_bbox_wcs = np.array(surf_bbox_wcs, dtype=np.float32)
edge_bbox_wcs = np.array(edge_bbox_wcs, dtype=np.float32)
# Normalize the CAD model
surfs_wcs, edges_wcs, surfs_ncs, edges_ncs, corner_wcs,center, scale = normalize(
face_pnts, edge_pnts, corner_pairs)
# 计算归一化后的包围盒
surf_bbox_ncs = np.empty_like(surf_bbox_wcs)
edge_bbox_ncs = np.empty_like(edge_bbox_wcs)
# 转换曲面包围盒到归一化坐标系
surf_bbox_ncs[:, :3] = (surf_bbox_wcs[:, :3] - center) * scale # 最小点
surf_bbox_ncs[:, 3:] = (surf_bbox_wcs[:, 3:] - center) * scale # 最大点
# 转换边包围盒到归一化坐标系
edge_bbox_ncs[:, :3] = (edge_bbox_wcs[:, :3] - center) * scale # 最小点
edge_bbox_ncs[:, 3:] = (edge_bbox_wcs[:, 3:] - center) * scale # 最大点
# 验证归一化后的数据
if any(x is None for x in [surfs_wcs, edges_wcs, surfs_ncs, edges_ncs, corner_wcs]):
logger.error(f"Normalization failed for {step_path}")
return None
# 创建结果字典并确保所有数组都有正确的类型
data = {
'surf_wcs': np.array(surfs_wcs, dtype=object), # 保持对象数组
'edge_wcs': np.array(edges_wcs, dtype=object), # 保持对象数组
'surf_ncs': np.array(surfs_ncs, dtype=object), # 保持对象数组
'edge_ncs': np.array(edges_ncs, dtype=object), # 保持对象数组
'corner_wcs': corner_wcs.astype(np.float32), # [num_edges, 2, 3]
'edgeFace_adj': edgeFace_adj.astype(np.int32),
'edgeCorner_adj': edgeCorner_adj.astype(np.int32),
'faceEdge_adj': faceEdge_adj.astype(np.int32),
'surf_bbox_wcs': surf_bbox_wcs.astype(np.float32),
'edge_bbox_wcs': edge_bbox_wcs.astype(np.float32),
'surf_bbox_ncs': surf_bbox_ncs.astype(np.float32), # 归一化坐标系 [num_faces, 6]
'edge_bbox_ncs': edge_bbox_ncs.astype(np.float32), # 归一化坐标系 [num_edges, 6]
'corner_unique': np.unique(corner_wcs.reshape(-1, 3), axis=0).astype(np.float32), # 先展平再去重
'normalization_params': {
'center': center.astype(np.float32), # 归一化中心点 [3,]
'scale': float(scale), # 归一化缩放系数
}
}
if sample_normal_vector:
# 从 mesh 读 法向量
mesh.Perform()
# 导出为STL临时文件
stl_writer = StlAPI_Writer()
stl_writer.SetASCIIMode(False)
with tempfile.NamedTemporaryFile(suffix='.stl') as tmp:
stl_writer.Write(shape, tmp.name)
trimesh_mesh = trimesh.load(tmp.name)
data['surf_pnt_normals']= batch_compute_normals(trimesh_mesh,surfs_wcs)
return data
def load_step(step_path):
"""Load STEP file and return solids"""
reader = STEPControl_Reader()
reader.ReadFile(step_path)
reader.TransferRoots()
return [reader.OneShape()]
def preprocess_mesh(mesh, normal_type='vertex'):
"""
预处理网格数据生成 KDTree 和法向量源
参数
mesh: trimesh.Trimesh 对象包含顶点和法向量信息
normal_type: str 法向量类型可选 'vertex' 'face'
返回
tree: cKDTree 用于加速最近邻查询
normals_source: np.ndarray 包含顶点法向量或面法向量
"""
if normal_type == 'vertex':
tree = cKDTree(mesh.vertices)
normals_source = mesh.vertex_normals
elif normal_type == 'face':
# 计算每个面的中心点
face_centers = np.mean(mesh.vertices[mesh.faces], axis=1)
tree = cKDTree(face_centers)
normals_source = mesh.face_normals
else:
raise ValueError(f"Unsupported normal type: {normal_type}")
return tree, normals_source
def batch_compute_normals(mesh, surf_wcs, normal_type='vertex', k_neighbors=3):
"""
为嵌套点云数据计算法向量并保持嵌套格式
参数
mesh: trimesh.Trimesh 对象包含顶点和法向量信息
surf_wcs: np.ndarray(dtype=object) 形状为 (N,) 的数组每个元素是形状为 (M, 3) float32 数组
normal_type: str 法向量类型可选 'vertex' 'face'
k_neighbors: int 用于平滑的最近邻数量
返回
normals: np.ndarray(dtype=object) 形状为 (N,) 的数组每个元素是形状为 (M, 3) float32 数组
"""
# 预处理网格数据
tree, normals_source = preprocess_mesh(mesh, normal_type=normal_type)
# 展平所有点云为一个二维数组 [P, 3],并记录分割索引
lengths = [len(point_cloud) for point_cloud in surf_wcs]
query_points = np.concatenate(surf_wcs, axis=0).astype(np.float32) # 避免多次内存分配
# 批量查询最近邻
distances, indices = tree.query(query_points, k=k_neighbors)
# 处理k=1的特殊情况
if k_neighbors == 1:
nearest_normals = normals_source[indices]
else:
# 加权平均(权重为距离倒数)
inv_distances = 1 / (distances + 1e-8) # 防止除以零
sum_inv_distances = inv_distances.sum(axis=1, keepdims=True)
valid_mask = sum_inv_distances > 1e-6
weights = np.divide(inv_distances, sum_inv_distances, out=np.zeros_like(inv_distances), where=valid_mask)
nearest_normals = np.einsum('ijk,ij->ik', normals_source[indices], weights)
# 标准化结果
norms = np.linalg.norm(nearest_normals, axis=1)
valid_mask = norms > 1e-6
nearest_normals[valid_mask] /= norms[valid_mask, None]
# 按原始嵌套结构分割法向量
start = 0
normals_output = np.empty(len(surf_wcs), dtype=object)
for i, length in enumerate(lengths):
end = start + length
normals_output[i] = nearest_normals[start:end]
start = end
return normals_output
def check_data_format(data, step_file):
"""检查数据格式是否正确"""
required_keys = [
'surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs', 'corner_wcs',
'edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj',
'surf_bbox_wcs', 'edge_bbox_wcs', 'corner_unique'
]
# 检查所有必需的键是否存在
for key in required_keys:
if key not in data:
return False, f"Missing key: {key}"
# 检查几何数据
geometry_arrays = ['surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs']
for key in geometry_arrays:
if not isinstance(data[key], np.ndarray):
return False, f"{key} should be a numpy array"
# 允许对象数组
if data[key].dtype != object:
return False, f"{key} should be a numpy array with dtype=object"
# 检查其他数组
float32_arrays = ['corner_wcs', 'surf_bbox_wcs', 'edge_bbox_wcs', 'corner_unique']
for key in float32_arrays:
if not isinstance(data[key], np.ndarray):
return False, f"{key} should be a numpy array"
if data[key].dtype != np.float32:
return False, f"{key} should be a numpy array with dtype=float32"
int32_arrays = ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj']
for key in int32_arrays:
if not isinstance(data[key], np.ndarray):
return False, f"{key} should be a numpy array"
if data[key].dtype != np.int32:
return False, f"{key} should be a numpy array with dtype=int32"
return True, ""
def process_single_step(step_path:str, output_path:str=None, sample_normal_vector=False, timeout:int=300) -> dict:
"""处理单个STEP文件, 从 brep 2 pkl
return data = {
'surf_wcs': np.array(surfs_wcs, dtype=object), # 世界坐标系下的曲面几何数据(对象数组)
'edge_wcs': np.array(edges_wcs, dtype=object), # 世界坐标系下的边几何数据(对象数组)
'surf_ncs': np.array(surfs_ncs, dtype=object), # 归一化坐标系下的曲面几何数据(对象数组) 面归一化点云 [num_faces, num_surf_sample_points, 3]
'edge_ncs': np.array(edges_ncs, dtype=object), # 归一化坐标系下的边几何数据(对象数组) 边归一化点云 [num_edges, num_edge_sample_points, 3]
'corner_wcs': corner_wcs.astype(np.float32), # 世界坐标系下的角点数据 [num_edges, 2, 3]
'edgeFace_adj': edgeFace_adj.astype(np.int32), # 边-面的邻接关系矩阵
'edgeCorner_adj': edgeCorner_adj.astype(np.int32),# 边-角点的邻接关系矩阵
'faceEdge_adj': faceEdge_adj.astype(np.int32), # 面-边的邻接关系矩阵
'surf_bbox_wcs': surf_bbox_wcs.astype(np.float32),# 曲面在世界坐标系下的包围盒
'edge_bbox_wcs': edge_bbox_wcs.astype(np.float32),# 边在世界坐标系下的包围盒
'corner_unique': np.unique(corner_wcs.reshape(-1, 3), axis=0).astype(np.float32) # 去重后的唯一角点坐标
}"""
try:
logger.info("数据预处理……")
if not os.path.exists(step_path):
logger.error(f"STEP文件不存在: {step_path}")
return None
if not step_path.lower().endswith('.step') and not step_path.lower().endswith('.stp'):
logger.error(f"文件格式不支持,必须是.step或.stp文件: {step_path}")
return None
# 解析STEP文件
data = parse_solid(step_path, sample_normal_vector)
if data is None:
logger.error(f"Failed to parse STEP file: {step_path}")
return None
# 检查数据格式
is_valid, msg = check_data_format(data, step_path)
if not is_valid:
logger.error(f"Data format check failed for {step_path}: {msg}")
return None
# 保存结果
if output_path:
try:
logger.info(f"Saving results to: {output_path}")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'wb') as f:
pickle.dump(data, f)
logger.info("数据预处理完成")
logger.info(f"Results saved successfully: {output_path}")
return data
except Exception as e:
logger.error(f'Failed to save {output_path}: {str(e)}')
return None
logger.info("数据预处理完成")
return data
except Exception as e:
logger.error(f'Error processing {step_path}: {str(e)}')
return None
def test(step_file_path, output_path=None):
"""
测试函数转换单个STEP文件并保存结果
"""
try:
logger.info(f"Processing STEP file: {step_file_path}")
# 解析STEP文件
data = parse_solid(step_file_path)
if data is None:
logger.error(f"Failed to parse STEP file: {step_file_path}")
return None
# 检查数据格式
is_valid, msg = check_data_format(data, step_file_path)
if not is_valid:
logger.error(f"Data format check failed for {step_file_path}: {msg}")
return None
# 打印统计信息
logger.info("\nStatistics:")
logger.info(f"Number of surfaces: {len(data['surf_wcs'])}")
logger.info(f"Number of edges: {len(data['edge_wcs'])}")
logger.info(f"Number of corners: {len(data['corner_wcs'])}") # 修正为corner_wcs
# 保存结果
if output_path:
try:
logger.info(f"Saving results to: {output_path}")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'wb') as f:
pickle.dump(data, f)
logger.info(f"Results saved successfully: {output_path}")
except Exception as e:
logger.error(f"Failed to save {output_path}: {str(e)}")
return None
return data
except Exception as e:
logger.error(f"Error processing {step_file_path}: {str(e)}")
return None
if __name__ == '__main__':
# main()
test("/home/wch/brep2sdf/data/step/00000000/00000000_290a9120f9f249a7a05cfe9c_step_000.step","/home/wch/brep2sdf/test_data/pkl/train/00000031xx.pkl")
#test("/home/wch/brep2sdf/00000031_ad34a3f60c4a4caa99646600_step_011.step", "/home/wch/brep2sdf/test_data/pkl/train/00000031.pkl")
#test("/mnt/mynewdisk/dataset/furniture/step/furniture_dataset_step/train/bathtub_0004.step", "/home/wch/brep2sdf/test_data/pkl/train/0004.pkl")
#reader = STEPControl_Reader()

254
brep2sdf/networks/loss.py

@ -1,6 +1,6 @@
import torch
import torch.nn as nn
from .network import gradient
from brep2sdf.config.default_config import get_default_config
from brep2sdf.utils.logger import logger
@ -66,35 +66,229 @@ class Brep2SDFLoss(nn.Module):
return grad_loss
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1):
"""SDF损失函数"""
# 确保points需要梯度
if not points.requires_grad:
points = points.detach().requires_grad_(True)
# L1损失
l1_loss = F.l1_loss(pred_sdf, gt_sdf)
class LossManager:
def __init__(self, ablation, **condition_kwargs):
self.weights = {
"manifold": 1,
"feature_manifold": 1, # 原文里面和manifold的权重是一样的
"normals": 1,
"eikonal": 1,
"offsurface": 1,
"consistency": 1,
"correction": 1,
}
self.condition_kwargs = condition_kwargs
self.ablation = ablation # 消融实验用
def _get_condition_kwargs(self, key):
"""
获取条件参数, 期望
ab: 损失类型 overall, patch, off, cons, cc, cor,
siren: 是否使用SIREN
epoch: 当前epoch
baseline: 是否为baseline
"""
if key in self.condition_kwargs:
return self.condition_kwargs[key]
else:
raise ValueError(f"Key {key} not found in condition_kwargs")
def pre_process(self, mnfld_pnts, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last):
"""
预处理
"""
mnfld_pred_h = mnfld_pred_all[:,0] # 提取流形预测结果
nonmnfld_pred_h = nonmnfld_pred_all[:,0] # 提取非流形预测结果
mnfld_grad = gradient(mnfld_pnts, mnfld_pred_h) # 计算流形点的梯度
all_fi = torch.zeros([n_batchsize, 1], device = 'cuda') # 初始化所有流形预测值
for i in range(n_branch - 1):
all_fi[i * n_patch_batch : (i + 1) * n_patch_batch, 0] = mnfld_pred_all[i * n_patch_batch : (i + 1) * n_patch_batch, i + 1] # 填充流形预测值
# last patch
all_fi[(n_branch - 1) * n_patch_batch:, 0] = mnfld_pred_all[(n_branch - 1) * n_patch_batch:, n_branch] # 填充最后一个分支的流形预测值
return mnfld_pred_h, nonmnfld_pred_h, mnfld_grad, all_fi
def position_loss(self, pred_sdfs: torch.Tensor, gt_sdfs: torch.Tensor) -> torch.Tensor:
"""
计算流型损失的逻辑
:param pred_sdfs: 预测的SDF值形状为 (N, 1)
:param gt_sdfs: 真实的SDF值形状为 (N, 1)
:return: 计算得到的流型损失标量
"""
# 计算预测值与真实值的差
diff = pred_sdfs - gt_sdfs
# 计算平方误差
squared_diff = torch.pow(diff, 2)
# 计算均值
manifold_loss = torch.mean(squared_diff)
return manifold_loss
try:
# 梯度约束损失
grad = torch.autograd.grad(
pred_sdf.sum(),
points,
create_graph=True,
retain_graph=True,
allow_unused=True
)[0]
def normals_loss(self, normals: torch.Tensor, mnfld_pnts: torch.Tensor, pred_sdfs) -> torch.Tensor:
"""
计算法线损失
:param normals: 法线
:param mnfld_pnts: 流型点
:param all_fi: 所有流型预测值
:param patch_sup: 是否支持补丁
:return: 计算得到的法线损失
"""
# NOTE 源代码 这里还有复杂逻辑
# 计算分支梯度
branch_grad = gradient(mnfld_pnts, pred_sdfs) # 计算分支梯度
# 计算法线损失
normals_loss = (((branch_grad - normals).abs()).norm(2, dim=1)).mean() # 计算法线损失
return normals_loss # 返回加权后的法线损失
def eikonal_loss(self, nonmnfld_pnts, nonmnfld_pred):
"""
计算Eikonal损失
"""
grad_loss_h = torch.zeros(1).cuda() # 初始化 Eikonal 损失
single_nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred) # 计算非流形点的梯度
eikonal_loss = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean() # 计算 Eikonal 损失
return eikonal_loss
def offsurface_loss(self, nonmnfld_pnts, nonmnfld_pred):
"""
Eo
惩罚远离表面但是预测值接近0的点
"""
offsurface_loss = torch.zeros(1).cuda()
if not self.ablation == 'off':
offsurface_loss = torch.exp(-100.0 * torch.abs(nonmnfld_pred)).mean() # 计算离表面损失
return offsurface_loss
def consistency_loss(self, mnfld_pnts, mnfld_pred, all_fi):
"""
惩罚流形点预测值和非流形点预测值不一致的点
"""
mnfld_consistency_loss = torch.zeros(1).cuda()
if not (self.ablation == 'cons' or self.ablation == 'cc'):
mnfld_consistency_loss = (mnfld_pred - all_fi[:,0]).abs().mean() # 计算流形一致性损失
return mnfld_consistency_loss
def correction_loss(self, mnfld_pnts, mnfld_pred, all_fi, th_closeness = 1e-5, a_correction = 100):
"""
修正损失
"""
correction_loss = torch.zeros(1).cuda() # 初始化修正损失
if not (self.ablation == 'cor' or self.ablation == 'cc'):
mismatch_id = torch.abs(mnfld_pred - all_fi[:,0]) > th_closeness # 计算不匹配的 ID
if mismatch_id.sum() != 0: # 如果存在不匹配
correction_loss = (a_correction * torch.abs(mnfld_pred - all_fi[:,0])[mismatch_id]).mean() # 计算修正损失
return correction_loss
def compute_loss(self, points,
normals,
gt_sdfs,
pred_sdfs):
"""
计算流型损失的逻辑
:param outputs: 模型的输出
:return: 计算得到的流型损失值
"""
# 计算流形损失
manifold_loss = self.position_loss(pred_sdfs,gt_sdfs)
if grad is not None:
grad_constraint = F.mse_loss(
torch.norm(grad, dim=-1),
torch.ones_like(pred_sdf.squeeze(-1))
)
else:
grad_constraint = torch.tensor(0.0, device=pred_sdf.device)
# 计算法线损失
normals_loss = self.normals_loss(normals, points, pred_sdfs)
# 汇总损失
loss_details = {
"manifold": self.weights["manifold"] * manifold_loss,
"normals": self.weights["normals"] * normals_loss,
}
# 计算总损失
total_loss = sum(loss_details.values())
return total_loss, loss_details
def _compute_loss(self, mnfld_pnts, normals, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last):
"""
计算流型损失的逻辑
:param outputs: 模型的输出
:return: 计算得到的流型损失值
"""
mnfld_pred, nonmnfld_pred, mnfld_grad, all_fi = self.pre_process(mnfld_pnts, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last)
manifold_loss = torch.zeros(1).cuda()
# 计算流型损失(这里使用均方误差作为示例)
if not self.ablation == 'overall':
manifold_loss = (mnfld_pred.abs()).mean() # 计算流型损失
'''
if args.feature_sample: # 如果启用了特征采样
feature_indices = torch.randperm(args.all_feature_sample)[:args.num_feature_sample].cuda() # 随机选择特征点
feature_pnts = self.feature_data[feature_indices] # 获取特征点数据
feature_mask_pair = self.feature_data_mask_pair[feature_indices] # 获取特征掩码对
feature_pred_all = self.network(feature_pnts) # 进行前向传播,计算特征点的预测值
feature_pred = feature_pred_all[:,0] # 提取特征预测结果
feature_mnfld_loss = feature_pred.abs().mean() # 计算特征流形损失
loss = loss + weight_mnfld_h * feature_mnfld_loss # 将特征流形损失加权到总损失中
except Exception as e:
logger.warning(f"Gradient computation failed: {str(e)}")
grad_constraint = torch.tensor(0.0, device=pred_sdf.device)
return l1_loss + grad_weight * grad_constraint
# patch loss:
feature_id_left = [list(range(args.num_feature_sample)), feature_mask_pair[:,0].tolist()] # 获取左侧特征 ID
feature_id_right = [list(range(args.num_feature_sample)), feature_mask_pair[:,1].tolist()] # 获取右侧特征 ID
feature_fis_left = feature_pred_all[feature_id_left] # 获取左侧特征预测值
feature_fis_right = feature_pred_all[feature_id_right] # 获取右侧特征预测值
feature_loss_patch = feature_fis_left.abs().mean() + feature_fis_right.abs().mean() # 计算补丁损失
loss += feature_loss_patch # 将补丁损失加权到总损失中
# consistency loss:
feature_loss_cons = (feature_fis_left - feature_pred).abs().mean() + (feature_fis_right - feature_pred).abs().mean() # 计算一致性损失
'''
manifold_loss_patch = torch.zeros(1).cuda()
if self.ablation == 'patch':
manifold_loss_patch = all_fi[:,0].abs().mean()
# 计算法线损失
normals_loss = self.normals_loss(normals, mnfld_pnts, all_fi, patch_sup=True)
# 计算Eikonal损失
eikonal_loss = self.eikonal_loss(nonmnfld_pnts, nonmnfld_pred_all)
# 计算离表面损失
offsurface_loss = self.offsurface_loss(nonmnfld_pnts, nonmnfld_pred_all)
# 计算一致性损失
consistency_loss = self.consistency_loss(mnfld_pnts, mnfld_pred, all_fi)
# 计算修正损失
correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi)
loss_details = {
"manifold": self.weights["manifold"] * manifold_loss,
"manifold_patch": manifold_loss_patch,
"normals": self.weights["normals"] * normals_loss,
"eikonal": self.weights["eikonal"] * eikonal_loss,
"offsurface": self.weights["offsurface"] * offsurface_loss,
"consistency": self.weights["consistency"] * consistency_loss,
"correction": self.weights["correction"] * correction_loss,
}
# 计算总损失
total_loss = sum(loss_details.values())
return total_loss, loss_details

11
brep2sdf/networks/network.py

@ -48,6 +48,7 @@ class GridNet:
import torch
import torch.nn as nn
from torch.autograd import grad
from .encoder import Encoder
from .decoder import Decoder
@ -90,3 +91,13 @@ class Net(nn.Module):
return output
def gradient(inputs, outputs):
d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
points_grad = grad(
outputs=outputs,
inputs=inputs,
grad_outputs=d_points,
create_graph=True,
retain_graph=True,
only_inputs=True)[0][:, -3:]
return points_grad

97
brep2sdf/train.py

@ -8,26 +8,64 @@ import argparse
from brep2sdf.config.default_config import get_default_config
from brep2sdf.data.data import load_brep_file,load_sdf_file
from brep2sdf.data.pre_process import process_single_step
from brep2sdf.data.pre_process_by_mesh import process_single_step
from brep2sdf.networks.network import Net
from brep2sdf.networks.octree import OctreeNode
from brep2sdf.networks.loss import LossManager
from brep2sdf.utils.logger import logger
def prepare_sdf_data(surf_data, max_points=100000, device='cuda'):
# 配置命令行参数
parser = argparse.ArgumentParser(description='STEP文件批量处理工具')
parser.add_argument('-i', '--input', required=True,
help='待处理 brep (.step) 路径')
parser.add_argument(
'--use-normal',
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
help='强制采样点有法向量'
)
parser.add_argument(
'--force-reprocess',
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
help='强制重新进行数据预处理,忽略缓存或已有结果'
)
args = parser.parse_args()
def prepare_sdf_data(surf_data, normals=None, max_points=100000, device='cuda'):
total_points = sum(len(s) for s in surf_data)
# 降采样逻辑(修复版)
if total_points > max_points:
# 先随机打乱所有点
all_points = np.concatenate(surf_data)
np.random.shuffle(all_points)
# 直接取前max_points个点
sampled_points = all_points[:max_points]
sdf_array = np.zeros((max_points, 4), dtype=np.float32)
sdf_array[:, :3] = sampled_points
# 生成索引
indices = []
for i, points in enumerate(surf_data):
indices.extend([(i, j) for j in range(len(points))])
# 随机打乱索引
np.random.shuffle(indices)
# 选择前max_points个索引
selected_indices = indices[:max_points]
if not normals is None:
# 根据索引构建sdf_array
sdf_array = np.zeros((max_points, 4), dtype=np.float32)
for idx, (i, j) in enumerate(selected_indices):
sdf_array[idx, :3] = surf_data[i][j]
else:
sdf_array = np.zeros((max_points, 7), dtype=np.float32)
for idx, (i, j) in enumerate(selected_indices):
sdf_array[idx, :3] = surf_data[i][j]
sdf_array[idx, 3:6] = normals[i][j]
else:
sdf_array = np.zeros((total_points, 4), dtype=np.float32)
sdf_array[:, :3] = np.concatenate(surf_data)
if not normals is None:
sdf_array = np.zeros((total_points, 4), dtype=np.float32)
sdf_array[:, :3] = np.concatenate(surf_data)
sdf_array = np.zeros((max_points, 7), dtype=np.float32)
else:
for idx, (i, j) in enumerate(selected_indices):
sdf_array[idx, :3] = surf_data[i][j]
sdf_array[idx, 3:6] = normals[i][j]
return torch.tensor(sdf_array, dtype=torch.float32, device=device)
@ -40,15 +78,16 @@ class Trainer:
self.model_name = os.path.basename(input_step).split('_')[0]
self.base_name = self.model_name + ".xyz"
data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name)
if os.path.exists(data_path):
if os.path.exists(data_path) and not args.force_reprocess:
self.data = load_brep_file(data_path)
else:
self.data = process_single_step(step_path=input_step, output_path=data_path)
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal)
# 将曲面点云列表转换为 (N*M, 4) 数组
surfs = self.data["surf_ncs"]
self.sdf_data = prepare_sdf_data(
surfs,
normals = self.data["surf_pnt_normals"],
max_points=4096,
device=self.device
)
@ -81,6 +120,8 @@ class Trainer:
weight_decay=config.train.weight_decay
)
self.loss_manager = LossManager(ablation="none")
def build_tree(self,surf_bbox, max_depth=6):
num_faces = surf_bbox.shape[0]
@ -131,14 +172,27 @@ class Trainer:
# 获取数据并移动到设备
points = self.sdf_data[:,0:3]
points.requires_grad_(True)
gt_sdf = self.sdf_data[:,3]
if args.use_normal:
normals = self.sdf_data[:,3:6]
gt_sdf = self.sdf_data[:,6]
else:
gt_sdf = self.sdf_data[:,3]
# 前向传播
self.optimizer.zero_grad()
pred_sdf = self.model(points)
# 计算损失
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
if args.use_normal:
loss,loss_details = self.loss_manager.compute_loss(
points,
normals,
gt_sdf,
pred_sdf
) # 计算损失
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
# 反向传播和优化
loss.backward()
@ -173,7 +227,7 @@ class Trainer:
best_val_loss = float('inf')
logger.info("Starting training...")
start_time = time.time()
"""
for epoch in range(1, self.config.train.num_epochs + 1):
# 训练一个epoch
train_loss = self.train_epoch(epoch)
@ -202,8 +256,8 @@ class Trainer:
logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s')
logger.info(f'Best validation loss: {best_val_loss:.6f}')
self._tracing_model()
"""
self.test_load()
#self.test_load()
def test_load(self):
model = self._load_checkpoint("/home/wch/brep2sdf/brep2sdf/00000054.pt")
@ -250,14 +304,7 @@ class Trainer:
def main():
# 这里需要初始化配置
# 配置命令行参数
parser = argparse.ArgumentParser(description='STEP文件批量处理工具')
parser.add_argument('-i', '--input', required=True,
help='待处理 brep (.step) 路径')
args = parser.parse_args()
config = get_default_config()
# 初始化训练器并开始训练
trainer = Trainer(config, input_step=args.input)
trainer.train()

Loading…
Cancel
Save