Browse Source

加入邻接图作为八叉树subdivide依据

final
mckay 12 months ago
parent
commit
2bb0864190
  1. 1
      brep2sdf/data/data.py
  2. 295
      brep2sdf/data/pre_process_by_mesh.py
  3. 133
      brep2sdf/data/sampler.py
  4. 1402
      brep2sdf/data/utils.py
  5. 27
      brep2sdf/networks/octree.py
  6. 182
      brep2sdf/networks/patch_graph.py
  7. 161
      brep2sdf/test.py
  8. 57
      brep2sdf/train.py

1
brep2sdf/data/data.py

@ -4,7 +4,6 @@ from torch.utils.data import Dataset
import numpy as np import numpy as np
import pickle import pickle
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
from brep2sdf.data.utils import process_brep_data
from brep2sdf.config.default_config import get_default_config from brep2sdf.config.default_config import get_default_config

295
brep2sdf/data/pre_process_by_mesh.py

@ -15,7 +15,6 @@ from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
import logging import logging
from datetime import datetime from datetime import datetime
from scipy.spatial import cKDTree from scipy.spatial import cKDTree
from brep2sdf.utils.logger import logger
import tempfile import tempfile
import trimesh import trimesh
from trimesh.proximity import ProximityQuery from trimesh.proximity import ProximityQuery
@ -36,6 +35,8 @@ from OCC.Core.StlAPI import StlAPI_Writer
from brep2sdf.data.sampler import sample_sdf_points_and_normals from brep2sdf.data.sampler import sample_sdf_points_and_normals
from brep2sdf.data.data import check_data_format from brep2sdf.data.data import check_data_format
from brep2sdf.data.utils import get_bbox, normalize, get_adjacency_info,load_step, preprocess_mesh,batch_compute_normals
from brep2sdf.utils.logger import logger
# 导入配置 # 导入配置
from brep2sdf.config.default_config import get_default_config from brep2sdf.config.default_config import get_default_config
config = get_default_config() config = get_default_config()
@ -43,133 +44,6 @@ config = get_default_config()
# 设置最大面数阈值,用于加速处理 # 设置最大面数阈值,用于加速处理
MAX_FACE = config.data.max_face MAX_FACE = config.data.max_face
def normalize(surfs, edges, corners):
"""
将CAD模型归一化到单位立方体空间
参数:
surfs: 面的点集列表
edges: 边的点集列表
corners: 顶点坐标数组 [num_edges, 2, 3]
返回:
surfs_wcs: 原始坐标系下的面点集
edges_wcs: 原始坐标系下的边点集
surfs_ncs: 归一化坐标系下的面点集
edges_ncs: 归一化坐标系下的边点集
corner_wcs: 归一化后的顶点坐标 [num_edges, 2, 3]
center: 使用的中心点坐标 [3,]
scale: 使用的缩放系数 (float)
"""
if len(corners) == 0:
return None, None, None, None, None, None, None
# 计算包围盒和缩放因子
corners_array = corners.reshape(-1, 3) # [num_edges*2, 3]
center = (corners_array.max(0) + corners_array.min(0)) / 2 # 计算中心点
scale = 1.0 / (corners_array.max(0) - corners_array.min(0)).max() # 计算缩放系数
# 归一化面的点集
surfs_wcs = [] # 原始世界坐标系下的面点集
surfs_ncs = [] # 归一化坐标系下的面点集
for surf in surfs:
surf_wcs = np.array(surf)
surf_ncs = (surf_wcs - center) * scale # 归一化变换
surfs_wcs.append(surf_wcs)
surfs_ncs.append(surf_ncs)
# 归一化边的点集
edges_wcs = [] # 原始世界坐标系下的边点集
edges_ncs = [] # 归一化坐标系下的边点集
for edge in edges:
edge_wcs = np.array(edge)
edge_ncs = (edge_wcs - center) * scale # 归一化变换
edges_wcs.append(edge_wcs)
edges_ncs.append(edge_ncs)
# 归一化顶点坐标 - 保持[num_edges, 2, 3]的形状
corner_wcs = (corners - center) * scale # 广播操作会保持原有维度
return (np.array(surfs_wcs, dtype=object),
np.array(edges_wcs, dtype=object),
np.array(surfs_ncs, dtype=object),
np.array(edges_ncs, dtype=object),
corner_wcs.astype(np.float32),
center.astype(np.float32),
scale
)
def get_adjacency_info(shape, faces, edges, vertices):
"""
优化后的邻接关系计算函数直接使用已收集的几何元素
参数新增:
faces: 已收集的面列表
edges: 已收集的边列表
vertices: 已收集的顶点列表
"""
logger.debug("Get adjacency infos...")
# 创建边-面映射关系
edge_face_map = TopTools_IndexedDataMapOfShapeListOfShape()
topexp.MapShapesAndAncestors(shape, TopAbs_EDGE, TopAbs_FACE, edge_face_map)
# 直接使用传入的几何元素列表
num_faces = len(faces)
num_edges = len(edges)
num_vertices = len(vertices)
logger.debug(f"num_faces: {num_faces}, num_edges: {num_edges}, num_vertices: {num_vertices}")
edgeFace_adj = np.zeros((num_edges, num_faces), dtype=np.int32)
faceEdge_adj = np.zeros((num_faces, num_edges), dtype=np.int32)
edgeCorner_adj = np.zeros((num_edges, 2), dtype=np.int32)
# 填充边-面邻接矩阵
for i, edge in enumerate(edges):
# 检查每个面是否与当前边相连
for j, face in enumerate(faces):
edge_explorer = TopExp_Explorer(face, TopAbs_EDGE)
while edge_explorer.More():
if edge.IsSame(edge_explorer.Current()):
edgeFace_adj[i, j] = 1
faceEdge_adj[j, i] = 1
break
edge_explorer.Next()
# 获取边的两个端点
v1 = TopoDS_Vertex()
v2 = TopoDS_Vertex()
topexp.Vertices(edge, v1, v2)
# 记录边的端点索引
if not v1.IsNull() and not v2.IsNull():
v1_vertex = topods.Vertex(v1)
v2_vertex = topods.Vertex(v2)
for k, vertex in enumerate(vertices):
if v1_vertex.IsSame(vertex):
edgeCorner_adj[i, 0] = k
if v2_vertex.IsSame(vertex):
edgeCorner_adj[i, 1] = k
return edgeFace_adj, faceEdge_adj, edgeCorner_adj
def get_bbox(shape, subshape):
"""
计算形状的包围盒
参数:
shape: 完整的CAD模型形状
subshape: 需要计算包围盒的子形状面或边
返回:
包围盒的六个参数 [xmin, ymin, zmin, xmax, ymax, zmax]
"""
bbox = Bnd_Box()
brepbndlib.Add(subshape, bbox)
xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get()
return np.array([xmin, ymin, zmin, xmax, ymax, zmax])
def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False): def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False):
""" """
@ -360,6 +234,73 @@ def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False):
if any(x is None for x in [surfs_wcs, edges_wcs, surfs_ncs, edges_ncs, corner_wcs]): if any(x is None for x in [surfs_wcs, edges_wcs, surfs_ncs, edges_ncs, corner_wcs]):
logger.error(f"Normalization failed for {step_path}") logger.error(f"Normalization failed for {step_path}")
return None return None
# --- 计算边的类型 ---
logger.debug("计算边的类型...")
edge_types = [] # 0:凹边 1:凸边
# 对每条边进行处理
for edge_idx in range(len(edges)):
# 获取与当前边相邻的面
adjacent_faces = np.where(edgeFace_adj[edge_idx] == 1)[0]
# 如果边只有一个相邻面或没有相邻面,默认为凹边
if len(adjacent_faces) < 2:
edge_types.append(0)
continue
# 获取两个相邻面
face1, face2 = faces[adjacent_faces[0]], faces[adjacent_faces[1]]
# 使用BRep_Tool获取面的几何信息
from OCC.Core.BRepAdaptor import BRepAdaptor_Surface
from OCC.Core.gp import gp_Pnt, gp_Vec
# 获取第一个面的法向量
surf1 = BRepAdaptor_Surface(face1)
u1 = (surf1.FirstUParameter() + surf1.LastUParameter()) / 2
v1 = (surf1.FirstVParameter() + surf1.LastVParameter()) / 2
pnt1 = gp_Pnt()
du1 = gp_Vec()
dv1 = gp_Vec()
surf1.D1(u1, v1, pnt1, du1, dv1)
normal1 = du1.Crossed(dv1)
normal1.Normalize()
normal1_np = np.array([normal1.X(), normal1.Y(), normal1.Z()])
# 获取第二个面的法向量
surf2 = BRepAdaptor_Surface(face2)
u2 = (surf2.FirstUParameter() + surf2.LastUParameter()) / 2
v2 = (surf2.FirstVParameter() + surf2.LastVParameter()) / 2
pnt2 = gp_Pnt()
du2 = gp_Vec()
dv2 = gp_Vec()
surf2.D1(u2, v2, pnt2, du2, dv2)
normal2 = du2.Crossed(dv2)
normal2.Normalize()
normal2_np = np.array([normal2.X(), normal2.Y(), normal2.Z()])
# 获取边的方向向量
edge = edges[edge_idx]
curve_info = BRep_Tool.Curve(edge)
if curve_info is None or len(curve_info) < 3:
edge_types.append(0)
continue
curve, first, last = curve_info
# 计算边的方向向量
start_point = np.array([curve.Value(first).X(), curve.Value(first).Y(), curve.Value(first).Z()])
end_point = np.array([curve.Value(last).X(), curve.Value(last).Y(), curve.Value(last).Z()])
edge_vector = end_point - start_point
edge_vector = edge_vector / np.linalg.norm(edge_vector)
# 使用混合积判断凹凸性
# 如果混合积为正,说明是凸边;为负,说明是凹边
mixed_product = np.dot(np.cross(normal1_np, normal2_np), edge_vector)
# 根据混合积的符号确定边的类型
edge_types.append(1 if mixed_product > 0 else 0)
edge_types = np.array(edge_types, dtype=np.int32)
# 创建结果字典并确保所有数组都有正确的类型 # 创建结果字典并确保所有数组都有正确的类型
data = { data = {
@ -368,9 +309,10 @@ def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False):
'surf_ncs': np.array(surfs_ncs, dtype=object), # 保持对象数组 'surf_ncs': np.array(surfs_ncs, dtype=object), # 保持对象数组
'edge_ncs': np.array(edges_ncs, dtype=object), # 保持对象数组 'edge_ncs': np.array(edges_ncs, dtype=object), # 保持对象数组
'corner_wcs': corner_wcs.astype(np.float32), # [num_edges, 2, 3] 'corner_wcs': corner_wcs.astype(np.float32), # [num_edges, 2, 3]
'edgeFace_adj': edgeFace_adj.astype(np.int32), 'edgeFace_adj': edgeFace_adj.astype(np.int32), # [num_edges, num_faces], 1 表示边与面相邻
'edgeCorner_adj': edgeCorner_adj.astype(np.int32), 'edgeCorner_adj': edgeCorner_adj.astype(np.int32),
'faceEdge_adj': faceEdge_adj.astype(np.int32), 'faceEdge_adj': faceEdge_adj.astype(np.int32),
'edge_types': np.array(edge_types, dtype=np.int32), # [num_edges]
'surf_bbox_wcs': surf_bbox_wcs.astype(np.float32), 'surf_bbox_wcs': surf_bbox_wcs.astype(np.float32),
'edge_bbox_wcs': edge_bbox_wcs.astype(np.float32), 'edge_bbox_wcs': edge_bbox_wcs.astype(np.float32),
'surf_bbox_ncs': surf_bbox_ncs.astype(np.float32), # 归一化坐标系 [num_faces, 6] 'surf_bbox_ncs': surf_bbox_ncs.astype(np.float32), # 归一化坐标系 [num_faces, 6]
@ -450,94 +392,6 @@ def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False):
logger.warning("请求了 SDF 点采样,但 Trimesh 加载失败。") logger.warning("请求了 SDF 点采样,但 Trimesh 加载失败。")
return data return data
def load_step(step_path):
"""Load STEP file and return solids"""
reader = STEPControl_Reader()
reader.ReadFile(step_path)
reader.TransferRoots()
return [reader.OneShape()]
def preprocess_mesh(mesh, normal_type='vertex'):
"""
预处理网格数据生成 KDTree 和法向量源
参数
mesh: trimesh.Trimesh 对象包含顶点和法向量信息
normal_type: str 法向量类型可选 'vertex' 'face'
返回
tree: cKDTree 用于加速最近邻查询
normals_source: np.ndarray 包含顶点法向量或面法向量
"""
if normal_type == 'vertex':
tree = cKDTree(mesh.vertices)
normals_source = mesh.vertex_normals
elif normal_type == 'face':
# 计算每个面的中心点
face_centers = np.mean(mesh.vertices[mesh.faces], axis=1)
tree = cKDTree(face_centers)
normals_source = mesh.face_normals
else:
raise ValueError(f"Unsupported normal type: {normal_type}")
return tree, normals_source
def batch_compute_normals(mesh, surf_wcs, normal_type='vertex', k_neighbors=3):
"""
为嵌套点云数据计算法向量并保持嵌套格式
参数
mesh: trimesh.Trimesh 对象包含顶点和法向量信息
surf_wcs: np.ndarray(dtype=object) 形状为 (N,) 的数组每个元素是形状为 (M, 3) float32 数组
normal_type: str 法向量类型可选 'vertex' 'face'
k_neighbors: int 用于平滑的最近邻数量
返回
normals: np.ndarray(dtype=object) 形状为 (N,) 的数组每个元素是形状为 (M, 3) float32 数组
"""
# 预处理网格数据
tree, normals_source = preprocess_mesh(mesh, normal_type=normal_type)
# 展平所有点云为一个二维数组 [P, 3],并记录分割索引
lengths = [len(point_cloud) for point_cloud in surf_wcs]
query_points = np.concatenate(surf_wcs, axis=0).astype(np.float32) # 避免多次内存分配
# 批量查询最近邻
distances, indices = tree.query(query_points, k=k_neighbors)
# 处理k=1的特殊情况
if k_neighbors == 1:
nearest_normals = normals_source[indices]
else:
# 加权平均(权重为距离倒数)
inv_distances = 1 / (distances + 1e-8) # 防止除以零
sum_inv_distances = inv_distances.sum(axis=1, keepdims=True)
valid_mask = sum_inv_distances > 1e-6
weights = np.divide(inv_distances, sum_inv_distances, out=np.zeros_like(inv_distances), where=valid_mask)
nearest_normals = np.einsum('ijk,ij->ik', normals_source[indices], weights)
# 标准化结果
norms = np.linalg.norm(nearest_normals, axis=1)
valid_mask = norms > 1e-6
nearest_normals[valid_mask] /= norms[valid_mask, None]
# 按原始嵌套结构分割法向量
start = 0
normals_output = np.empty(len(surf_wcs), dtype=object)
for i, length in enumerate(lengths):
end = start + length
normals_output[i] = nearest_normals[start:end]
start = end
return normals_output
def process_single_step(step_path:str, output_path:str=None, sample_normal_vector=False, sample_sdf_points=False, timeout:int=300) -> dict: def process_single_step(step_path:str, output_path:str=None, sample_normal_vector=False, sample_sdf_points=False, timeout:int=300) -> dict:
"""处理单个STEP文件, 从 brep 2 pkl """处理单个STEP文件, 从 brep 2 pkl
return data = { return data = {
@ -549,9 +403,16 @@ def process_single_step(step_path:str, output_path:str=None, sample_normal_vecto
'edgeFace_adj': edgeFace_adj.astype(np.int32), # 边-面的邻接关系矩阵 'edgeFace_adj': edgeFace_adj.astype(np.int32), # 边-面的邻接关系矩阵
'edgeCorner_adj': edgeCorner_adj.astype(np.int32),# 边-角点的邻接关系矩阵 'edgeCorner_adj': edgeCorner_adj.astype(np.int32),# 边-角点的邻接关系矩阵
'faceEdge_adj': faceEdge_adj.astype(np.int32), # 面-边的邻接关系矩阵 'faceEdge_adj': faceEdge_adj.astype(np.int32), # 面-边的邻接关系矩阵
'edge_types': np.array(edge_types, dtype=np.int32)# [num_edges]
'surf_bbox_wcs': surf_bbox_wcs.astype(np.float32),# 曲面在世界坐标系下的包围盒 'surf_bbox_wcs': surf_bbox_wcs.astype(np.float32),# 曲面在世界坐标系下的包围盒
'edge_bbox_wcs': edge_bbox_wcs.astype(np.float32),# 边在世界坐标系下的包围盒 'edge_bbox_wcs': edge_bbox_wcs.astype(np.float32),# 边在世界坐标系下的包围盒
'corner_unique': np.unique(corner_wcs.reshape(-1, 3), axis=0).astype(np.float32) # 去重后的唯一角点坐标 'corner_unique': np.unique(corner_wcs.reshape(-1, 3), axis=0).astype(np.float32) # 去重后的唯一角点坐标
'normalization_params': { # 归一化参数
'center': center.astype(np.float32), # 归一化中心点 [3,]
'scale': float(scale), # 归一化缩放系数
},
'surf_pnt_normals': np.array(dtype=object), # 表面点的法线数据 [num_faces, num_surf_sample_points, 3],仅当 sample_normal_vector=True
'sampled_points_normals_sdf': np.array(dtype=float32), # 采样点的位置、法线和SDF值 [num_samples, 7],仅当 sample_sdf_points=True
}""" }"""
try: try:
logger.info("数据预处理……") logger.info("数据预处理……")

133
brep2sdf/data/sampler.py

@ -36,100 +36,14 @@ from OCC.Core.StlAPI import StlAPI_Writer
# 导入配置 # 导入配置
from brep2sdf.config.default_config import get_default_config from brep2sdf.config.default_config import get_default_config
from brep2sdf.data.utils import get_bbox, normalize, get_adjacency_info,load_step, preprocess_mesh,batch_compute_normals
config = get_default_config() config = get_default_config()
# 设置最大面数阈值,用于加速处理 # 设置最大面数阈值,用于加速处理
MAX_FACE = config.data.max_face MAX_FACE = config.data.max_face
def _sample_uniform_points(num_points: int) -> np.ndarray:
"""在 [-0.5, 0.5] 范围内均匀采样点
参数:
num_points: 要采样的点数
返回:
np.ndarray: 形状为 (num_points, 3) 的采样点数组
"""
return np.random.uniform(-0.5, 0.5, (num_points, 3))
def _sample_near_surface_points(
mesh: trimesh.Trimesh,
num_points: int,
std_dev: float
) -> np.ndarray:
"""在网格表面附近采样点
参数:
mesh: 输入的trimesh网格
num_points: 要采样的点数
std_dev: 沿法线方向的扰动标准差
返回:
np.ndarray: 形状为 (num_points, 3) 的采样点数组
"""
if mesh.faces.shape[0] == 0:
logger.warning("网格没有面,无法执行近表面采样。将替换为均匀采样点。")
return _sample_uniform_points(num_points)
try:
near_points_on_surface = mesh.sample(num_points)
proximity_query_near = ProximityQuery(mesh)
closest_points_near, _, face_indices_near = proximity_query_near.on_surface(near_points_on_surface)
if np.any(face_indices_near >= len(mesh.face_normals)):
raise IndexError("Face index out of bounds during near-surface normal lookup")
normals_near = mesh.face_normals[face_indices_near]
perturbations = np.random.randn(num_points, 1) * std_dev
near_points = near_points_on_surface + normals_near * perturbations
return np.clip(near_points, -0.5, 0.5)
except Exception as e:
logger.warning(f"近表面采样失败: {e}。将替换为均匀采样点。")
return _sample_uniform_points(num_points)
def sample_points(
trimesh_mesh_ncs: trimesh.Trimesh,
num_uniform_samples: int,
num_near_surface_samples: int,
sdf_sampling_std_dev: float
) -> np.ndarray | None:
"""组合均匀采样和近表面采样的点
参数:
trimesh_mesh_ncs: 归一化的trimesh网格
num_uniform_samples: 均匀采样点数
num_near_surface_samples: 近表面采样点数
sdf_sampling_std_dev: 近表面采样的标准差
返回:
np.ndarray | None: 合并后的采样点数组失败时返回None
"""
sampled_points_list = []
# 均匀采样
if num_uniform_samples > 0:
uniform_points = _sample_uniform_points(num_uniform_samples)
sampled_points_list.append(uniform_points)
# 近表面采样
if num_near_surface_samples > 0:
near_points = _sample_near_surface_points(
trimesh_mesh_ncs,
num_near_surface_samples,
sdf_sampling_std_dev
)
sampled_points_list.append(near_points)
# 合并采样点
if not sampled_points_list:
logger.warning("没有采样到任何点。")
return None
return np.vstack(sampled_points_list).astype(np.float32)
# 在原始的sample_sdf_points_and_normals函数中使用新的采样函数
def sample_sdf_points_and_normals( def sample_sdf_points_and_normals(
trimesh_mesh_ncs: trimesh.Trimesh, trimesh_mesh_ncs: trimesh.Trimesh,
surf_bbox_ncs: np.ndarray, surf_bbox_ncs: np.ndarray,
@ -169,12 +83,43 @@ def sample_sdf_points_and_normals(
logger.debug(f"固定 SDF 采样点数: {num_sdf_samples} (均匀: {num_uniform_samples}, 近表面: {num_near_surface_samples})") logger.debug(f"固定 SDF 采样点数: {num_sdf_samples} (均匀: {num_uniform_samples}, 近表面: {num_near_surface_samples})")
# --- 执行采样 --- # --- 执行采样 ---
sampled_points_ncs = sample_points( sampled_points_list = []
trimesh_mesh_ncs,
num_uniform_samples, # 均匀采样 (在 [-0.5, 0.5] 范围内)
num_near_surface_samples, if num_uniform_samples > 0:
sdf_sampling_std_dev uniform_points = np.random.uniform(-0.5, 0.5, (num_uniform_samples, 3))
) sampled_points_list.append(uniform_points)
# 近表面采样
if num_near_surface_samples > 0:
if trimesh_mesh_ncs.faces.shape[0] > 0:
try:
near_points_on_surface = trimesh_mesh_ncs.sample(num_near_surface_samples)
proximity_query_near = ProximityQuery(trimesh_mesh_ncs)
closest_points_near, distances_near, face_indices_near = proximity_query_near.on_surface(near_points_on_surface)
if np.any(face_indices_near >= len(trimesh_mesh_ncs.face_normals)):
raise IndexError("Face index out of bounds during near-surface normal lookup")
normals_near = trimesh_mesh_ncs.face_normals[face_indices_near]
perturbations = np.random.randn(num_near_surface_samples, 1) * sdf_sampling_std_dev
near_points = near_points_on_surface + normals_near * perturbations
# 确保近表面点也在 [-0.5, 0.5] 范围内
near_points = np.clip(near_points, -0.5, 0.5)
sampled_points_list.append(near_points)
except Exception as e:
logger.warning(f"近表面采样失败: {e}。将替换为均匀采样点。")
fallback_uniform = np.random.uniform(-0.5, 0.5, (num_near_surface_samples, 3))
sampled_points_list.append(fallback_uniform)
else:
logger.warning("网格没有面,无法执行近表面采样。将替换为均匀采样点。")
fallback_uniform = np.random.uniform(-0.5, 0.5, (num_near_surface_samples, 3))
sampled_points_list.append(fallback_uniform)
# --- 合并采样点 ---
if not sampled_points_list:
logger.warning("没有为SDF采样到任何点。")
return None
sampled_points_ncs = np.vstack(sampled_points_list).astype(np.float32)
try: try:
proximity_query = ProximityQuery(trimesh_mesh_ncs) proximity_query = ProximityQuery(trimesh_mesh_ncs)

1402
brep2sdf/data/utils.py

File diff suppressed because it is too large

27
brep2sdf/networks/octree.py

@ -6,6 +6,10 @@ import torch.nn.functional as F
import numpy as np import numpy as np
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
from brep2sdf.networks.patch_graph import PatchGraph
def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor: def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor:
"""判断两个轴对齐包围盒(AABB)是否相交 """判断两个轴对齐包围盒(AABB)是否相交
@ -27,7 +31,7 @@ def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor:
return torch.all((max1 >= min2) & (max2 >= min1)) return torch.all((max1 >= min2) & (max2 >= min1))
class OctreeNode(nn.Module): class OctreeNode(nn.Module):
def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None): def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None):
super().__init__() super().__init__()
# 静态张量存储节点信息 # 静态张量存储节点信息
self.register_buffer('bbox', bbox) # 当前节点的边界框 self.register_buffer('bbox', bbox) # 当前节点的边界框
@ -38,6 +42,9 @@ class OctreeNode(nn.Module):
self.register_buffer('face_indices', torch.from_numpy(face_indices).to(bbox.device)) # 面片索引张量 self.register_buffer('face_indices', torch.from_numpy(face_indices).to(bbox.device)) # 面片索引张量
self.register_buffer('surf_bbox', surf_bbox) # 面片边界框 self.register_buffer('surf_bbox', surf_bbox) # 面片边界框
# PatchGraph作为普通属性
self.patch_graph = patch_graph # 不再使用register_buffer
self.max_depth = max_depth self.max_depth = max_depth
# 将param_key改为张量 # 将param_key改为张量
self.register_buffer('param_key', torch.tensor(-1, dtype=torch.long)) self.register_buffer('param_key', torch.tensor(-1, dtype=torch.long))
@ -72,7 +79,8 @@ class OctreeNode(nn.Module):
node_idx, bbox, faces = queue.pop(0) node_idx, bbox, faces = queue.pop(0)
self.node_bboxes[node_idx] = bbox self.node_bboxes[node_idx] = bbox
if faces.shape[0] <= 2 or current_idx >= self.max_depth: # 判断 要不要继续分裂
if not self._should_split_node(current_idx):
self.is_leaf_mask[node_idx] = True self.is_leaf_mask[node_idx] = True
continue continue
@ -104,6 +112,19 @@ class OctreeNode(nn.Module):
if intersecting_faces: if intersecting_faces:
queue.append((child_idx, child_bbox, torch.tensor(intersecting_faces, device=self.bbox.device))) queue.append((child_idx, child_bbox, torch.tensor(intersecting_faces, device=self.bbox.device)))
def _should_split_node(self, current_depth: int) -> bool:
"""判断节点是否需要分裂"""
# 检查是否达到最大深度
if current_depth >= self.max_depth:
return False
# 检查是否为完全图
is_clique = self.patch_graph.is_clique(self.face_indices)
if is_clique:
return False
return True
@torch.jit.export @torch.jit.export
def _generate_child_bboxes(self, min_coords: torch.Tensor, mid_coords: torch.Tensor, max_coords: torch.Tensor) -> torch.Tensor: def _generate_child_bboxes(self, min_coords: torch.Tensor, mid_coords: torch.Tensor, max_coords: torch.Tensor) -> torch.Tensor:
"""生成8个子节点的边界框""" """生成8个子节点的边界框"""
@ -208,6 +229,7 @@ class OctreeNode(nn.Module):
'is_leaf_mask': self.is_leaf_mask, 'is_leaf_mask': self.is_leaf_mask,
'face_indices': self.face_indices, 'face_indices': self.face_indices,
'surf_bbox': self.surf_bbox, 'surf_bbox': self.surf_bbox,
'patch_graph': self.patch_graph,
'max_depth': self.max_depth, 'max_depth': self.max_depth,
'param_key': self.param_key, 'param_key': self.param_key,
'_is_leaf': self._is_leaf '_is_leaf': self._is_leaf
@ -223,6 +245,7 @@ class OctreeNode(nn.Module):
self.is_leaf_mask = state['is_leaf_mask'] self.is_leaf_mask = state['is_leaf_mask']
self.face_indices = state['face_indices'] self.face_indices = state['face_indices']
self.surf_bbox = state['surf_bbox'] self.surf_bbox = state['surf_bbox']
self.patch_graph = state['patch_graph']
self.max_depth = state['max_depth'] self.max_depth = state['max_depth']
self.param_key = state['param_key'] self.param_key = state['param_key']
self._is_leaf = state['_is_leaf'] self._is_leaf = state['_is_leaf']

182
brep2sdf/networks/patch_graph.py

@ -0,0 +1,182 @@
from typing import Tuple, Optional
import torch
import torch.nn as nn
import numpy as np
from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE
from OCC.Core.TopExp import TopExp_Explorer
from OCC.Core.TopoDS import TopoDS_Edge, TopoDS_Face, topods_Edge, topods_Face
from OCC.Core.BRep import BRep_Tool
from OCC.Core.GeomLProp import GeomLProp_SLProps
from OCC.Core.BRepAdaptor import BRepAdaptor_Surface
class PatchGraph(nn.Module):
def __init__(self, num_patches: int, device: torch.device = None):
super().__init__()
self.num_patches = num_patches
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 注册缓冲区
self.register_buffer('edge_index', None) # 边的连接关系 (2, E)
self.register_buffer('edge_type', None) # 边的类型 (E,) 0:凹边 1:凸边
self.register_buffer('patch_features', None) # 面片特征 (N, F)
def set_edges(self, edge_index: torch.Tensor, edge_type: torch.Tensor) -> None:
"""设置边的信息
参数:
edge_index: 形状为 (2, E) 的张量表示边的连接关系
edge_type: 形状为 (E,) 的张量0表示凹边1表示凸边
"""
if edge_index.shape[0] != 2:
raise ValueError(f"edge_index 必须是形状为 (2, E) 的张量,但得到 {edge_index.shape}")
if edge_index.shape[1] != edge_type.shape[0]:
raise ValueError("edge_index 和 edge_type 的边数量不匹配")
self.edge_index = edge_index.to(self.device)
self.edge_type = edge_type.to(self.device)
def get_subgraph(self, node_faces: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""获取子图的边和类型"""
if self.edge_index is None:
return None, None
node_faces = node_faces.to(self.device)
mask = torch.isin(self.edge_index[0], node_faces) & torch.isin(self.edge_index[1], node_faces)
subgraph_edges = self.edge_index[:, mask]
subgraph_types = self.edge_type[mask]
return subgraph_edges, subgraph_types
@staticmethod
def from_preprocessed_data(surf_wcs: np.ndarray, edgeFace_adj: np.ndarray, edge_types: np.ndarray, device: torch.device = None) -> 'PatchGraph':
num_faces = len(surf_wcs)
graph = PatchGraph(num_faces, device)
edge_pairs = []
edge_types_list = []
for edge_idx in range(len(edgeFace_adj)):
connected_faces = np.where(edgeFace_adj[edge_idx] == 1)[0]
if len(connected_faces) == 2:
face1, face2 = connected_faces
edge_pairs.extend([[face1, face2], [face2, face1]])
edge_type = edge_types[edge_idx]
edge_types_list.extend([edge_type, edge_type])
if edge_pairs:
edge_index = torch.tensor(edge_pairs, dtype=torch.long, device=graph.device).t()
edge_type = torch.tensor(edge_types_list, dtype=torch.long, device=graph.device)
graph.set_edges(edge_index, edge_type)
return graph
def set_features(self, features: torch.Tensor) -> None:
"""设置面片特征
参数:
features: 形状为 (N, F) 的张量表示面片的特征向量
"""
if features.shape[0] != self.num_patches:
raise ValueError(f"特征数量 {features.shape[0]} 与面片数量 {self.num_patches} 不匹配")
self.patch_features = features
def is_clique(self, node_faces: torch.Tensor) -> bool:
"""检查给定面片集合是否构成完全图
参数:
node_faces: 要检查的面片索引集合
返回:
bool: 是否为完全图
"""
if self.edge_index is None:
return False
# 获取子图的边
mask = torch.isin(self.edge_index[0], node_faces) & torch.isin(self.edge_index[1], node_faces)
subgraph_edges = self.edge_index[:, mask]
# 计算完全图应有的边数
n = len(node_faces)
expected_edges = n * (n - 1) // 2
# 计算实际的边数(考虑无向图)
actual_edges = len(subgraph_edges[0]) // 2
return actual_edges == expected_edges
def combine_sdf(self, sdf_values: torch.Tensor) -> torch.Tensor:
"""根据邻接关系组合SDF值
参数:
sdf_values: 形状为 (N,) 的张量表示每个面片的SDF值
返回:
torch.Tensor: 组合后的SDF值
"""
if self.edge_index is None or self.edge_type is None:
raise RuntimeError("请先设置边的信息")
# 获取所有相连面片对的SDF值
sdf_i = sdf_values[self.edge_index[0]] # (E,)
sdf_j = sdf_values[self.edge_index[1]] # (E,)
# 根据边的类型选择组合方式
concave_mask = self.edge_type == 0
convex_mask = self.edge_type == 1
# 初始化结果为第一个SDF值
result = sdf_values[0].clone()
# 凹边取最大值,凸边取最小值
if torch.any(concave_mask):
result = torch.max(result, torch.max(torch.stack([sdf_i[concave_mask],
sdf_j[concave_mask]])))
if torch.any(convex_mask):
result = torch.min(result, torch.min(torch.stack([sdf_i[convex_mask],
sdf_j[convex_mask]])))
return result
@staticmethod
def from_preprocessed_data(
surf_wcs: np.ndarray, # 形状为(N,)的对象数组,每个元素是形状为(M, 3)的float32数组
edgeFace_adj: np.ndarray, # 形状为(num_edges, num_faces)的int32数组
edge_types: np.ndarray # 形状为(num_edges,)的int32数组
) -> 'PatchGraph':
"""从预处理的数据直接构建面片邻接图
参数:
surf_wcs: 世界坐标系下的曲面几何数据形状为(N,)的对象数组每个元素是形状为(M, 3)的float32数组
edgeFace_adj: -面邻接矩阵形状为(num_edges, num_faces)的int32数组1表示边与面相邻
edge_types: 边的类型数组形状为(num_edges,)的int32数组0表示凹边1表示凸边
返回:
PatchGraph: 初始化好的面片邻接图包含
- edge_index: 形状为(2, num_edges*2)的torch.long张量表示双向边的连接关系
- edge_type: 形状为(num_edges*2,)的torch.long张量表示每条边的类型
"""
num_faces = len(surf_wcs)
graph = PatchGraph(num_faces)
# 构建边的索引和类型
edge_pairs = []
edge_types_list = []
# 遍历边-面邻接矩阵
for edge_idx in range(len(edgeFace_adj)):
connected_faces = np.where(edgeFace_adj[edge_idx] == 1)[0]
if len(connected_faces) == 2:
face1, face2 = connected_faces
# 添加双向边
edge_pairs.extend([[face1, face2], [face2, face1]])
# 使用预计算的边类型
edge_type = edge_types[edge_idx]
edge_types_list.extend([edge_type, edge_type]) # 双向边使用相同的类型
if edge_pairs: # 确保有边存在
edge_index = torch.tensor(edge_pairs, dtype=torch.long).t()
edge_type = torch.tensor(edge_types_list, dtype=torch.long)
graph.set_edges(edge_index, edge_type)
return graph

161
brep2sdf/test.py

@ -1,161 +1,4 @@
import os
import torch import torch
import numpy as np
from torch.utils.data import DataLoader
from brep2sdf.data.data import BRepSDFDataset
from brep2sdf.networks.network import BRepToSDF
from brep2sdf.utils.logger import logger
from brep2sdf.config.default_config import get_default_config
import matplotlib.pyplot as plt
from tqdm import tqdm
class Tester: model = torch.jit.load("/home/wch/brep2sdf/data/output_data/00000054.pt")
def __init__(self, config, checkpoint_path): print(model)
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 初始化测试数据集
self.test_dataset = BRepSDFDataset(
brep_dir=config.data.brep_dir,
sdf_dir=config.data.sdf_dir,
valid_data_dir=config.data.valid_data_dir,
split='test'
)
# 初始化数据加载器
self.test_loader = DataLoader(
self.test_dataset,
batch_size=1, # 测试时使用batch_size=1
shuffle=False,
num_workers=config.train.num_workers,
pin_memory=False
)
# 加载模型
self.model = BRepToSDF(config).to(self.device)
self.load_checkpoint(checkpoint_path)
# 创建结果保存目录
self.result_dir = os.path.join(config.data.result_save_dir, 'test_results')
os.makedirs(self.result_dir, exist_ok=True)
def load_checkpoint(self, checkpoint_path):
"""加载检查点"""
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
logger.info(f"Loaded checkpoint from {checkpoint_path}")
def compute_metrics(self, pred_sdf, gt_sdf):
"""计算评估指标"""
mse = torch.mean((pred_sdf - gt_sdf) ** 2).item()
mae = torch.mean(torch.abs(pred_sdf - gt_sdf)).item()
max_error = torch.max(torch.abs(pred_sdf - gt_sdf)).item()
return {
'mse': mse,
'mae': mae,
'max_error': max_error
}
def visualize_results(self, pred_sdf, gt_sdf, points, save_path):
"""可视化预测结果"""
fig = plt.figure(figsize=(15, 5))
# 绘制预测SDF
ax1 = fig.add_subplot(131, projection='3d')
scatter = ax1.scatter(points[:, 0], points[:, 1], points[:, 2],
c=pred_sdf.squeeze().cpu(), cmap='coolwarm')
ax1.set_title('Predicted SDF')
plt.colorbar(scatter)
# 绘制真实SDF
ax2 = fig.add_subplot(132, projection='3d')
scatter = ax2.scatter(points[:, 0], points[:, 1], points[:, 2],
c=gt_sdf.squeeze().cpu(), cmap='coolwarm')
ax2.set_title('Ground Truth SDF')
plt.colorbar(scatter)
# 绘制误差图
ax3 = fig.add_subplot(133, projection='3d')
error = torch.abs(pred_sdf - gt_sdf)
scatter = ax3.scatter(points[:, 0], points[:, 1], points[:, 2],
c=error.squeeze().cpu(), cmap='Reds')
ax3.set_title('Absolute Error')
plt.colorbar(scatter)
plt.tight_layout()
plt.savefig(save_path)
plt.close()
def test(self):
"""执行测试"""
self.model.eval()
total_metrics = {'mse': 0, 'mae': 0, 'max_error': 0}
logger.info("Starting testing...")
with torch.no_grad():
for idx, batch in enumerate(tqdm(self.test_loader)):
# 获取数据并移动到设备
surf_ncs = batch['surf_ncs'].to(self.device)
edge_ncs = batch['edge_ncs'].to(self.device)
surf_pos = batch['surf_pos'].to(self.device)
edge_pos = batch['edge_pos'].to(self.device)
vertex_pos = batch['vertex_pos'].to(self.device)
edge_mask = batch['edge_mask'].to(self.device)
points = batch['points'].to(self.device)
gt_sdf = batch['sdf'].to(self.device)
# 前向传播
pred_sdf = self.model(
surf_ncs=surf_ncs, edge_ncs=edge_ncs,
surf_pos=surf_pos, edge_pos=edge_pos,
vertex_pos=vertex_pos, edge_mask=edge_mask,
query_points=points
)
# 计算指标
metrics = self.compute_metrics(pred_sdf, gt_sdf)
for k, v in metrics.items():
total_metrics[k] += v
# 可视化结果
if idx % self.config.test.vis_freq == 0:
save_path = os.path.join(self.result_dir, f'result_{idx}.png')
self.visualize_results(pred_sdf, gt_sdf, points[0].cpu(), save_path)
# 计算平均指标
num_samples = len(self.test_loader)
avg_metrics = {k: v / num_samples for k, v in total_metrics.items()}
# 保存测试结果
logger.info("Test Results:")
for k, v in avg_metrics.items():
logger.info(f"{k}: {v:.6f}")
# 保存指标到文件
with open(os.path.join(self.result_dir, 'test_metrics.txt'), 'w') as f:
for k, v in avg_metrics.items():
f.write(f"{k}: {v:.6f}\n")
return avg_metrics
def main():
# 获取配置
config = get_default_config()
# 设置检查点路径
checkpoint_path = os.path.join(
config.data.model_save_dir,
config.data.best_model_name.format(model_name=config.data.model_name)
)
# 初始化测试器并执行测试
tester = Tester(config, checkpoint_path)
metrics = tester.test()
if __name__ == '__main__':
main()

57
brep2sdf/train.py

@ -13,6 +13,7 @@ from brep2sdf.data.pre_process_by_mesh import process_single_step
from brep2sdf.networks.network import Net from brep2sdf.networks.network import Net
from brep2sdf.networks.octree import OctreeNode from brep2sdf.networks.octree import OctreeNode
from brep2sdf.networks.loss import LossManager from brep2sdf.networks.loss import LossManager
from brep2sdf.networks.patch_graph import PatchGraph
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
@ -31,10 +32,17 @@ parser.add_argument(
help='只采样零表面点 SDF 训练' help='只采样零表面点 SDF 训练'
) )
parser.add_argument( parser.add_argument(
'--force-reprocess', '--force-reprocess','-f',
action='store_true', # 默认为 False,如果用户指定该参数,则为 True action='store_true', # 默认为 False,如果用户指定该参数,则为 True
help='强制重新进行数据预处理,忽略缓存或已有结果' help='强制重新进行数据预处理,忽略缓存或已有结果'
) )
parser.add_argument(
'--resume-checkpoint-path',
type=str,
default=None,
help='从指定的checkpoint文件继续训练'
)
args = parser.parse_args() args = parser.parse_args()
@ -86,8 +94,13 @@ class Trainer:
#logger.info( self.brep_data ) #logger.info( self.brep_data )
#self.sdf_data = load_sdf_file(sdf_path=self.config.data.sdf_path, num_query_points=self.config.data.num_query_points).to(self.device) #self.sdf_data = load_sdf_file(sdf_path=self.config.data.sdf_path, num_query_points=self.config.data.num_query_points).to(self.device)
# 构建面片邻接图
graph = PatchGraph.from_preprocessed_data(
surf_wcs=self.data['surf_wcs'],
edgeFace_adj=self.data['edgeFace_adj'],
edge_types=self.data['edge_types']
)
# 初始化网络 # 初始化网络
surf_bbox=torch.tensor( surf_bbox=torch.tensor(
self.data['surf_bbox_ncs'], self.data['surf_bbox_ncs'],
dtype=torch.float32, dtype=torch.float32,
@ -95,7 +108,7 @@ class Trainer:
) )
self.build_tree(surf_bbox=surf_bbox, max_depth=4) self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=4)
self.model = Net( self.model = Net(
@ -115,12 +128,13 @@ class Trainer:
logger.info(f"初始化完成,正在处理模型 {self.model_name}") logger.info(f"初始化完成,正在处理模型 {self.model_name}")
def build_tree(self,surf_bbox, max_depth=6): def build_tree(self,surf_bbox, graph, max_depth=6):
num_faces = surf_bbox.shape[0] num_faces = surf_bbox.shape[0]
bbox = self._calculate_global_bbox(surf_bbox) bbox = self._calculate_global_bbox(surf_bbox)
self.root = OctreeNode( self.root = OctreeNode(
bbox=bbox, bbox=bbox,
face_indices=np.arange(num_faces), # 初始包含所有面 face_indices=np.arange(num_faces), # 初始包含所有面
patch_graph=graph,
max_depth=max_depth, max_depth=max_depth,
surf_bbox=surf_bbox surf_bbox=surf_bbox
) )
@ -303,7 +317,12 @@ class Trainer:
logger.info("Starting training...") logger.info("Starting training...")
start_time = time.time() start_time = time.time()
for epoch in range(1, self.config.train.num_epochs + 1): start_epoch = 1
if args.resume_checkpoint_path:
start_epoch = self._load_checkpoint(args.resume_checkpoint_path)
logger.info(f"Loaded model from {args.resume_checkpoint_path}")
for epoch in range(start_epoch, self.config.train.num_epochs + 1):
# 训练一个epoch # 训练一个epoch
train_loss = self.train_epoch(epoch) train_loss = self.train_epoch(epoch)
@ -329,7 +348,8 @@ class Trainer:
# 训练完成 # 训练完成
total_time = time.time() - start_time total_time = time.time() - start_time
self._tracing_model() self._tracing_model_by_script()
#self._tracing_model()
logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s') logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s')
logger.info(f'Best validation loss: {best_val_loss:.6f}') logger.info(f'Best validation loss: {best_val_loss:.6f}')
#self.test_load() #self.test_load()
@ -349,8 +369,8 @@ class Trainer:
self.model.eval() self.model.eval()
# 确保模型中的所有逻辑都兼容 TorchScript # 确保模型中的所有逻辑都兼容 TorchScript
scripted_model = torch.jit.script(self.model) scripted_model = torch.jit.script(self.model)
optimized_model = optimize_for_mobile(scripted_model) #optimized_model = optimize_for_mobile(scripted_model)
torch.jit.save(optimized_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt") torch.jit.save(scripted_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt")
def _tracing_model(self): def _tracing_model(self):
"""保存模型""" """保存模型"""
@ -375,11 +395,6 @@ class Trainer:
except Exception as e: except Exception as e:
logger.error(f"模型验证失败:{e}") logger.error(f"模型验证失败:{e}")
def _load_checkpoint(self, checkpoint_path):
"""从检查点恢复训练状态"""
model = torch.load(checkpoint_path)
return model
def _save_checkpoint(self, epoch: int, train_loss: float): def _save_checkpoint(self, epoch: int, train_loss: float):
"""保存训练检查点""" """保存训练检查点"""
checkpoint_dir = os.path.join( checkpoint_dir = os.path.join(
@ -388,17 +403,25 @@ class Trainer:
) )
os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch:03d}.pth") checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch:03d}.pth")
'''
# 只保存状态字典
torch.save({ torch.save({
'epoch': epoch, 'epoch': epoch,
'model_state_dict': self.model.state_dict(), 'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(),
'loss': train_loss, 'loss': train_loss,
'config': self.config
}, checkpoint_path) }, checkpoint_path)
'''
torch.save(self.model,checkpoint_path)
def _load_checkpoint(self, checkpoint_path):
"""从检查点恢复训练状态"""
try:
checkpoint = torch.load(checkpoint_path)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch'] + 1
except Exception as e:
logger.error(f"加载checkpoint失败: {str(e)}")
raise
def main(): def main():
# 这里需要初始化配置 # 这里需要初始化配置
config = get_default_config() config = get_default_config()

Loading…
Cancel
Save