Browse Source

加入邻接图作为八叉树subdivide依据

final
mckay 10 months ago
parent
commit
2bb0864190
  1. 1
      brep2sdf/data/data.py
  2. 295
      brep2sdf/data/pre_process_by_mesh.py
  3. 133
      brep2sdf/data/sampler.py
  4. 1400
      brep2sdf/data/utils.py
  5. 31
      brep2sdf/networks/octree.py
  6. 182
      brep2sdf/networks/patch_graph.py
  7. 161
      brep2sdf/test.py
  8. 59
      brep2sdf/train.py

1
brep2sdf/data/data.py

@ -4,7 +4,6 @@ from torch.utils.data import Dataset
import numpy as np
import pickle
from brep2sdf.utils.logger import logger
from brep2sdf.data.utils import process_brep_data
from brep2sdf.config.default_config import get_default_config

295
brep2sdf/data/pre_process_by_mesh.py

@ -15,7 +15,6 @@ from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
import logging
from datetime import datetime
from scipy.spatial import cKDTree
from brep2sdf.utils.logger import logger
import tempfile
import trimesh
from trimesh.proximity import ProximityQuery
@ -36,6 +35,8 @@ from OCC.Core.StlAPI import StlAPI_Writer
from brep2sdf.data.sampler import sample_sdf_points_and_normals
from brep2sdf.data.data import check_data_format
from brep2sdf.data.utils import get_bbox, normalize, get_adjacency_info,load_step, preprocess_mesh,batch_compute_normals
from brep2sdf.utils.logger import logger
# 导入配置
from brep2sdf.config.default_config import get_default_config
config = get_default_config()
@ -43,133 +44,6 @@ config = get_default_config()
# 设置最大面数阈值,用于加速处理
MAX_FACE = config.data.max_face
def normalize(surfs, edges, corners):
"""
将CAD模型归一化到单位立方体空间
参数:
surfs: 面的点集列表
edges: 边的点集列表
corners: 顶点坐标数组 [num_edges, 2, 3]
返回:
surfs_wcs: 原始坐标系下的面点集
edges_wcs: 原始坐标系下的边点集
surfs_ncs: 归一化坐标系下的面点集
edges_ncs: 归一化坐标系下的边点集
corner_wcs: 归一化后的顶点坐标 [num_edges, 2, 3]
center: 使用的中心点坐标 [3,]
scale: 使用的缩放系数 (float)
"""
if len(corners) == 0:
return None, None, None, None, None, None, None
# 计算包围盒和缩放因子
corners_array = corners.reshape(-1, 3) # [num_edges*2, 3]
center = (corners_array.max(0) + corners_array.min(0)) / 2 # 计算中心点
scale = 1.0 / (corners_array.max(0) - corners_array.min(0)).max() # 计算缩放系数
# 归一化面的点集
surfs_wcs = [] # 原始世界坐标系下的面点集
surfs_ncs = [] # 归一化坐标系下的面点集
for surf in surfs:
surf_wcs = np.array(surf)
surf_ncs = (surf_wcs - center) * scale # 归一化变换
surfs_wcs.append(surf_wcs)
surfs_ncs.append(surf_ncs)
# 归一化边的点集
edges_wcs = [] # 原始世界坐标系下的边点集
edges_ncs = [] # 归一化坐标系下的边点集
for edge in edges:
edge_wcs = np.array(edge)
edge_ncs = (edge_wcs - center) * scale # 归一化变换
edges_wcs.append(edge_wcs)
edges_ncs.append(edge_ncs)
# 归一化顶点坐标 - 保持[num_edges, 2, 3]的形状
corner_wcs = (corners - center) * scale # 广播操作会保持原有维度
return (np.array(surfs_wcs, dtype=object),
np.array(edges_wcs, dtype=object),
np.array(surfs_ncs, dtype=object),
np.array(edges_ncs, dtype=object),
corner_wcs.astype(np.float32),
center.astype(np.float32),
scale
)
def get_adjacency_info(shape, faces, edges, vertices):
"""
优化后的邻接关系计算函数直接使用已收集的几何元素
参数新增:
faces: 已收集的面列表
edges: 已收集的边列表
vertices: 已收集的顶点列表
"""
logger.debug("Get adjacency infos...")
# 创建边-面映射关系
edge_face_map = TopTools_IndexedDataMapOfShapeListOfShape()
topexp.MapShapesAndAncestors(shape, TopAbs_EDGE, TopAbs_FACE, edge_face_map)
# 直接使用传入的几何元素列表
num_faces = len(faces)
num_edges = len(edges)
num_vertices = len(vertices)
logger.debug(f"num_faces: {num_faces}, num_edges: {num_edges}, num_vertices: {num_vertices}")
edgeFace_adj = np.zeros((num_edges, num_faces), dtype=np.int32)
faceEdge_adj = np.zeros((num_faces, num_edges), dtype=np.int32)
edgeCorner_adj = np.zeros((num_edges, 2), dtype=np.int32)
# 填充边-面邻接矩阵
for i, edge in enumerate(edges):
# 检查每个面是否与当前边相连
for j, face in enumerate(faces):
edge_explorer = TopExp_Explorer(face, TopAbs_EDGE)
while edge_explorer.More():
if edge.IsSame(edge_explorer.Current()):
edgeFace_adj[i, j] = 1
faceEdge_adj[j, i] = 1
break
edge_explorer.Next()
# 获取边的两个端点
v1 = TopoDS_Vertex()
v2 = TopoDS_Vertex()
topexp.Vertices(edge, v1, v2)
# 记录边的端点索引
if not v1.IsNull() and not v2.IsNull():
v1_vertex = topods.Vertex(v1)
v2_vertex = topods.Vertex(v2)
for k, vertex in enumerate(vertices):
if v1_vertex.IsSame(vertex):
edgeCorner_adj[i, 0] = k
if v2_vertex.IsSame(vertex):
edgeCorner_adj[i, 1] = k
return edgeFace_adj, faceEdge_adj, edgeCorner_adj
def get_bbox(shape, subshape):
"""
计算形状的包围盒
参数:
shape: 完整的CAD模型形状
subshape: 需要计算包围盒的子形状面或边
返回:
包围盒的六个参数 [xmin, ymin, zmin, xmax, ymax, zmax]
"""
bbox = Bnd_Box()
brepbndlib.Add(subshape, bbox)
xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get()
return np.array([xmin, ymin, zmin, xmax, ymax, zmax])
def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False):
"""
@ -360,7 +234,74 @@ def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False):
if any(x is None for x in [surfs_wcs, edges_wcs, surfs_ncs, edges_ncs, corner_wcs]):
logger.error(f"Normalization failed for {step_path}")
return None
# --- 计算边的类型 ---
logger.debug("计算边的类型...")
edge_types = [] # 0:凹边 1:凸边
# 对每条边进行处理
for edge_idx in range(len(edges)):
# 获取与当前边相邻的面
adjacent_faces = np.where(edgeFace_adj[edge_idx] == 1)[0]
# 如果边只有一个相邻面或没有相邻面,默认为凹边
if len(adjacent_faces) < 2:
edge_types.append(0)
continue
# 获取两个相邻面
face1, face2 = faces[adjacent_faces[0]], faces[adjacent_faces[1]]
# 使用BRep_Tool获取面的几何信息
from OCC.Core.BRepAdaptor import BRepAdaptor_Surface
from OCC.Core.gp import gp_Pnt, gp_Vec
# 获取第一个面的法向量
surf1 = BRepAdaptor_Surface(face1)
u1 = (surf1.FirstUParameter() + surf1.LastUParameter()) / 2
v1 = (surf1.FirstVParameter() + surf1.LastVParameter()) / 2
pnt1 = gp_Pnt()
du1 = gp_Vec()
dv1 = gp_Vec()
surf1.D1(u1, v1, pnt1, du1, dv1)
normal1 = du1.Crossed(dv1)
normal1.Normalize()
normal1_np = np.array([normal1.X(), normal1.Y(), normal1.Z()])
# 获取第二个面的法向量
surf2 = BRepAdaptor_Surface(face2)
u2 = (surf2.FirstUParameter() + surf2.LastUParameter()) / 2
v2 = (surf2.FirstVParameter() + surf2.LastVParameter()) / 2
pnt2 = gp_Pnt()
du2 = gp_Vec()
dv2 = gp_Vec()
surf2.D1(u2, v2, pnt2, du2, dv2)
normal2 = du2.Crossed(dv2)
normal2.Normalize()
normal2_np = np.array([normal2.X(), normal2.Y(), normal2.Z()])
# 获取边的方向向量
edge = edges[edge_idx]
curve_info = BRep_Tool.Curve(edge)
if curve_info is None or len(curve_info) < 3:
edge_types.append(0)
continue
curve, first, last = curve_info
# 计算边的方向向量
start_point = np.array([curve.Value(first).X(), curve.Value(first).Y(), curve.Value(first).Z()])
end_point = np.array([curve.Value(last).X(), curve.Value(last).Y(), curve.Value(last).Z()])
edge_vector = end_point - start_point
edge_vector = edge_vector / np.linalg.norm(edge_vector)
# 使用混合积判断凹凸性
# 如果混合积为正,说明是凸边;为负,说明是凹边
mixed_product = np.dot(np.cross(normal1_np, normal2_np), edge_vector)
# 根据混合积的符号确定边的类型
edge_types.append(1 if mixed_product > 0 else 0)
edge_types = np.array(edge_types, dtype=np.int32)
# 创建结果字典并确保所有数组都有正确的类型
data = {
'surf_wcs': np.array(surfs_wcs, dtype=object), # 保持对象数组
@ -368,9 +309,10 @@ def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False):
'surf_ncs': np.array(surfs_ncs, dtype=object), # 保持对象数组
'edge_ncs': np.array(edges_ncs, dtype=object), # 保持对象数组
'corner_wcs': corner_wcs.astype(np.float32), # [num_edges, 2, 3]
'edgeFace_adj': edgeFace_adj.astype(np.int32),
'edgeFace_adj': edgeFace_adj.astype(np.int32), # [num_edges, num_faces], 1 表示边与面相邻
'edgeCorner_adj': edgeCorner_adj.astype(np.int32),
'faceEdge_adj': faceEdge_adj.astype(np.int32),
'edge_types': np.array(edge_types, dtype=np.int32), # [num_edges]
'surf_bbox_wcs': surf_bbox_wcs.astype(np.float32),
'edge_bbox_wcs': edge_bbox_wcs.astype(np.float32),
'surf_bbox_ncs': surf_bbox_ncs.astype(np.float32), # 归一化坐标系 [num_faces, 6]
@ -450,94 +392,6 @@ def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False):
logger.warning("请求了 SDF 点采样,但 Trimesh 加载失败。")
return data
def load_step(step_path):
"""Load STEP file and return solids"""
reader = STEPControl_Reader()
reader.ReadFile(step_path)
reader.TransferRoots()
return [reader.OneShape()]
def preprocess_mesh(mesh, normal_type='vertex'):
"""
预处理网格数据生成 KDTree 和法向量源
参数
mesh: trimesh.Trimesh 对象包含顶点和法向量信息
normal_type: str 法向量类型可选 'vertex' 'face'
返回
tree: cKDTree 用于加速最近邻查询
normals_source: np.ndarray 包含顶点法向量或面法向量
"""
if normal_type == 'vertex':
tree = cKDTree(mesh.vertices)
normals_source = mesh.vertex_normals
elif normal_type == 'face':
# 计算每个面的中心点
face_centers = np.mean(mesh.vertices[mesh.faces], axis=1)
tree = cKDTree(face_centers)
normals_source = mesh.face_normals
else:
raise ValueError(f"Unsupported normal type: {normal_type}")
return tree, normals_source
def batch_compute_normals(mesh, surf_wcs, normal_type='vertex', k_neighbors=3):
"""
为嵌套点云数据计算法向量并保持嵌套格式
参数
mesh: trimesh.Trimesh 对象包含顶点和法向量信息
surf_wcs: np.ndarray(dtype=object) 形状为 (N,) 的数组每个元素是形状为 (M, 3) float32 数组
normal_type: str 法向量类型可选 'vertex' 'face'
k_neighbors: int 用于平滑的最近邻数量
返回
normals: np.ndarray(dtype=object) 形状为 (N,) 的数组每个元素是形状为 (M, 3) float32 数组
"""
# 预处理网格数据
tree, normals_source = preprocess_mesh(mesh, normal_type=normal_type)
# 展平所有点云为一个二维数组 [P, 3],并记录分割索引
lengths = [len(point_cloud) for point_cloud in surf_wcs]
query_points = np.concatenate(surf_wcs, axis=0).astype(np.float32) # 避免多次内存分配
# 批量查询最近邻
distances, indices = tree.query(query_points, k=k_neighbors)
# 处理k=1的特殊情况
if k_neighbors == 1:
nearest_normals = normals_source[indices]
else:
# 加权平均(权重为距离倒数)
inv_distances = 1 / (distances + 1e-8) # 防止除以零
sum_inv_distances = inv_distances.sum(axis=1, keepdims=True)
valid_mask = sum_inv_distances > 1e-6
weights = np.divide(inv_distances, sum_inv_distances, out=np.zeros_like(inv_distances), where=valid_mask)
nearest_normals = np.einsum('ijk,ij->ik', normals_source[indices], weights)
# 标准化结果
norms = np.linalg.norm(nearest_normals, axis=1)
valid_mask = norms > 1e-6
nearest_normals[valid_mask] /= norms[valid_mask, None]
# 按原始嵌套结构分割法向量
start = 0
normals_output = np.empty(len(surf_wcs), dtype=object)
for i, length in enumerate(lengths):
end = start + length
normals_output[i] = nearest_normals[start:end]
start = end
return normals_output
def process_single_step(step_path:str, output_path:str=None, sample_normal_vector=False, sample_sdf_points=False, timeout:int=300) -> dict:
"""处理单个STEP文件, 从 brep 2 pkl
return data = {
@ -549,9 +403,16 @@ def process_single_step(step_path:str, output_path:str=None, sample_normal_vecto
'edgeFace_adj': edgeFace_adj.astype(np.int32), # 边-面的邻接关系矩阵
'edgeCorner_adj': edgeCorner_adj.astype(np.int32),# 边-角点的邻接关系矩阵
'faceEdge_adj': faceEdge_adj.astype(np.int32), # 面-边的邻接关系矩阵
'edge_types': np.array(edge_types, dtype=np.int32)# [num_edges]
'surf_bbox_wcs': surf_bbox_wcs.astype(np.float32),# 曲面在世界坐标系下的包围盒
'edge_bbox_wcs': edge_bbox_wcs.astype(np.float32),# 边在世界坐标系下的包围盒
'corner_unique': np.unique(corner_wcs.reshape(-1, 3), axis=0).astype(np.float32) # 去重后的唯一角点坐标
'normalization_params': { # 归一化参数
'center': center.astype(np.float32), # 归一化中心点 [3,]
'scale': float(scale), # 归一化缩放系数
},
'surf_pnt_normals': np.array(dtype=object), # 表面点的法线数据 [num_faces, num_surf_sample_points, 3],仅当 sample_normal_vector=True
'sampled_points_normals_sdf': np.array(dtype=float32), # 采样点的位置、法线和SDF值 [num_samples, 7],仅当 sample_sdf_points=True
}"""
try:
logger.info("数据预处理……")

133
brep2sdf/data/sampler.py

@ -36,100 +36,14 @@ from OCC.Core.StlAPI import StlAPI_Writer
# 导入配置
from brep2sdf.config.default_config import get_default_config
from brep2sdf.data.utils import get_bbox, normalize, get_adjacency_info,load_step, preprocess_mesh,batch_compute_normals
config = get_default_config()
# 设置最大面数阈值,用于加速处理
MAX_FACE = config.data.max_face
def _sample_uniform_points(num_points: int) -> np.ndarray:
"""在 [-0.5, 0.5] 范围内均匀采样点
参数:
num_points: 要采样的点数
返回:
np.ndarray: 形状为 (num_points, 3) 的采样点数组
"""
return np.random.uniform(-0.5, 0.5, (num_points, 3))
def _sample_near_surface_points(
mesh: trimesh.Trimesh,
num_points: int,
std_dev: float
) -> np.ndarray:
"""在网格表面附近采样点
参数:
mesh: 输入的trimesh网格
num_points: 要采样的点数
std_dev: 沿法线方向的扰动标准差
返回:
np.ndarray: 形状为 (num_points, 3) 的采样点数组
"""
if mesh.faces.shape[0] == 0:
logger.warning("网格没有面,无法执行近表面采样。将替换为均匀采样点。")
return _sample_uniform_points(num_points)
try:
near_points_on_surface = mesh.sample(num_points)
proximity_query_near = ProximityQuery(mesh)
closest_points_near, _, face_indices_near = proximity_query_near.on_surface(near_points_on_surface)
if np.any(face_indices_near >= len(mesh.face_normals)):
raise IndexError("Face index out of bounds during near-surface normal lookup")
normals_near = mesh.face_normals[face_indices_near]
perturbations = np.random.randn(num_points, 1) * std_dev
near_points = near_points_on_surface + normals_near * perturbations
return np.clip(near_points, -0.5, 0.5)
except Exception as e:
logger.warning(f"近表面采样失败: {e}。将替换为均匀采样点。")
return _sample_uniform_points(num_points)
def sample_points(
trimesh_mesh_ncs: trimesh.Trimesh,
num_uniform_samples: int,
num_near_surface_samples: int,
sdf_sampling_std_dev: float
) -> np.ndarray | None:
"""组合均匀采样和近表面采样的点
参数:
trimesh_mesh_ncs: 归一化的trimesh网格
num_uniform_samples: 均匀采样点数
num_near_surface_samples: 近表面采样点数
sdf_sampling_std_dev: 近表面采样的标准差
返回:
np.ndarray | None: 合并后的采样点数组失败时返回None
"""
sampled_points_list = []
# 均匀采样
if num_uniform_samples > 0:
uniform_points = _sample_uniform_points(num_uniform_samples)
sampled_points_list.append(uniform_points)
# 近表面采样
if num_near_surface_samples > 0:
near_points = _sample_near_surface_points(
trimesh_mesh_ncs,
num_near_surface_samples,
sdf_sampling_std_dev
)
sampled_points_list.append(near_points)
# 合并采样点
if not sampled_points_list:
logger.warning("没有采样到任何点。")
return None
return np.vstack(sampled_points_list).astype(np.float32)
# 在原始的sample_sdf_points_and_normals函数中使用新的采样函数
def sample_sdf_points_and_normals(
trimesh_mesh_ncs: trimesh.Trimesh,
surf_bbox_ncs: np.ndarray,
@ -169,12 +83,43 @@ def sample_sdf_points_and_normals(
logger.debug(f"固定 SDF 采样点数: {num_sdf_samples} (均匀: {num_uniform_samples}, 近表面: {num_near_surface_samples})")
# --- 执行采样 ---
sampled_points_ncs = sample_points(
trimesh_mesh_ncs,
num_uniform_samples,
num_near_surface_samples,
sdf_sampling_std_dev
)
sampled_points_list = []
# 均匀采样 (在 [-0.5, 0.5] 范围内)
if num_uniform_samples > 0:
uniform_points = np.random.uniform(-0.5, 0.5, (num_uniform_samples, 3))
sampled_points_list.append(uniform_points)
# 近表面采样
if num_near_surface_samples > 0:
if trimesh_mesh_ncs.faces.shape[0] > 0:
try:
near_points_on_surface = trimesh_mesh_ncs.sample(num_near_surface_samples)
proximity_query_near = ProximityQuery(trimesh_mesh_ncs)
closest_points_near, distances_near, face_indices_near = proximity_query_near.on_surface(near_points_on_surface)
if np.any(face_indices_near >= len(trimesh_mesh_ncs.face_normals)):
raise IndexError("Face index out of bounds during near-surface normal lookup")
normals_near = trimesh_mesh_ncs.face_normals[face_indices_near]
perturbations = np.random.randn(num_near_surface_samples, 1) * sdf_sampling_std_dev
near_points = near_points_on_surface + normals_near * perturbations
# 确保近表面点也在 [-0.5, 0.5] 范围内
near_points = np.clip(near_points, -0.5, 0.5)
sampled_points_list.append(near_points)
except Exception as e:
logger.warning(f"近表面采样失败: {e}。将替换为均匀采样点。")
fallback_uniform = np.random.uniform(-0.5, 0.5, (num_near_surface_samples, 3))
sampled_points_list.append(fallback_uniform)
else:
logger.warning("网格没有面,无法执行近表面采样。将替换为均匀采样点。")
fallback_uniform = np.random.uniform(-0.5, 0.5, (num_near_surface_samples, 3))
sampled_points_list.append(fallback_uniform)
# --- 合并采样点 ---
if not sampled_points_list:
logger.warning("没有为SDF采样到任何点。")
return None
sampled_points_ncs = np.vstack(sampled_points_list).astype(np.float32)
try:
proximity_query = ProximityQuery(trimesh_mesh_ncs)

1400
brep2sdf/data/utils.py

File diff suppressed because it is too large

31
brep2sdf/networks/octree.py

@ -6,6 +6,10 @@ import torch.nn.functional as F
import numpy as np
from brep2sdf.utils.logger import logger
from brep2sdf.networks.patch_graph import PatchGraph
def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor:
"""判断两个轴对齐包围盒(AABB)是否相交
@ -27,7 +31,7 @@ def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor:
return torch.all((max1 >= min2) & (max2 >= min1))
class OctreeNode(nn.Module):
def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None):
def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None):
super().__init__()
# 静态张量存储节点信息
self.register_buffer('bbox', bbox) # 当前节点的边界框
@ -38,6 +42,9 @@ class OctreeNode(nn.Module):
self.register_buffer('face_indices', torch.from_numpy(face_indices).to(bbox.device)) # 面片索引张量
self.register_buffer('surf_bbox', surf_bbox) # 面片边界框
# PatchGraph作为普通属性
self.patch_graph = patch_graph # 不再使用register_buffer
self.max_depth = max_depth
# 将param_key改为张量
self.register_buffer('param_key', torch.tensor(-1, dtype=torch.long))
@ -51,7 +58,7 @@ class OctreeNode(nn.Module):
k: 参数索引值
"""
self.param_key.fill_(k)
@torch.jit.export
def build_static_tree(self) -> None:
"""构建静态八叉树结构"""
@ -72,7 +79,8 @@ class OctreeNode(nn.Module):
node_idx, bbox, faces = queue.pop(0)
self.node_bboxes[node_idx] = bbox
if faces.shape[0] <= 2 or current_idx >= self.max_depth:
# 判断 要不要继续分裂
if not self._should_split_node(current_idx):
self.is_leaf_mask[node_idx] = True
continue
@ -103,7 +111,20 @@ class OctreeNode(nn.Module):
# 将子节点加入队列
if intersecting_faces:
queue.append((child_idx, child_bbox, torch.tensor(intersecting_faces, device=self.bbox.device)))
def _should_split_node(self, current_depth: int) -> bool:
"""判断节点是否需要分裂"""
# 检查是否达到最大深度
if current_depth >= self.max_depth:
return False
# 检查是否为完全图
is_clique = self.patch_graph.is_clique(self.face_indices)
if is_clique:
return False
return True
@torch.jit.export
def _generate_child_bboxes(self, min_coords: torch.Tensor, mid_coords: torch.Tensor, max_coords: torch.Tensor) -> torch.Tensor:
"""生成8个子节点的边界框"""
@ -208,6 +229,7 @@ class OctreeNode(nn.Module):
'is_leaf_mask': self.is_leaf_mask,
'face_indices': self.face_indices,
'surf_bbox': self.surf_bbox,
'patch_graph': self.patch_graph,
'max_depth': self.max_depth,
'param_key': self.param_key,
'_is_leaf': self._is_leaf
@ -223,6 +245,7 @@ class OctreeNode(nn.Module):
self.is_leaf_mask = state['is_leaf_mask']
self.face_indices = state['face_indices']
self.surf_bbox = state['surf_bbox']
self.patch_graph = state['patch_graph']
self.max_depth = state['max_depth']
self.param_key = state['param_key']
self._is_leaf = state['_is_leaf']

182
brep2sdf/networks/patch_graph.py

@ -0,0 +1,182 @@
from typing import Tuple, Optional
import torch
import torch.nn as nn
import numpy as np
from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE
from OCC.Core.TopExp import TopExp_Explorer
from OCC.Core.TopoDS import TopoDS_Edge, TopoDS_Face, topods_Edge, topods_Face
from OCC.Core.BRep import BRep_Tool
from OCC.Core.GeomLProp import GeomLProp_SLProps
from OCC.Core.BRepAdaptor import BRepAdaptor_Surface
class PatchGraph(nn.Module):
def __init__(self, num_patches: int, device: torch.device = None):
super().__init__()
self.num_patches = num_patches
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 注册缓冲区
self.register_buffer('edge_index', None) # 边的连接关系 (2, E)
self.register_buffer('edge_type', None) # 边的类型 (E,) 0:凹边 1:凸边
self.register_buffer('patch_features', None) # 面片特征 (N, F)
def set_edges(self, edge_index: torch.Tensor, edge_type: torch.Tensor) -> None:
"""设置边的信息
参数:
edge_index: 形状为 (2, E) 的张量表示边的连接关系
edge_type: 形状为 (E,) 的张量0表示凹边1表示凸边
"""
if edge_index.shape[0] != 2:
raise ValueError(f"edge_index 必须是形状为 (2, E) 的张量,但得到 {edge_index.shape}")
if edge_index.shape[1] != edge_type.shape[0]:
raise ValueError("edge_index 和 edge_type 的边数量不匹配")
self.edge_index = edge_index.to(self.device)
self.edge_type = edge_type.to(self.device)
def get_subgraph(self, node_faces: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""获取子图的边和类型"""
if self.edge_index is None:
return None, None
node_faces = node_faces.to(self.device)
mask = torch.isin(self.edge_index[0], node_faces) & torch.isin(self.edge_index[1], node_faces)
subgraph_edges = self.edge_index[:, mask]
subgraph_types = self.edge_type[mask]
return subgraph_edges, subgraph_types
@staticmethod
def from_preprocessed_data(surf_wcs: np.ndarray, edgeFace_adj: np.ndarray, edge_types: np.ndarray, device: torch.device = None) -> 'PatchGraph':
num_faces = len(surf_wcs)
graph = PatchGraph(num_faces, device)
edge_pairs = []
edge_types_list = []
for edge_idx in range(len(edgeFace_adj)):
connected_faces = np.where(edgeFace_adj[edge_idx] == 1)[0]
if len(connected_faces) == 2:
face1, face2 = connected_faces
edge_pairs.extend([[face1, face2], [face2, face1]])
edge_type = edge_types[edge_idx]
edge_types_list.extend([edge_type, edge_type])
if edge_pairs:
edge_index = torch.tensor(edge_pairs, dtype=torch.long, device=graph.device).t()
edge_type = torch.tensor(edge_types_list, dtype=torch.long, device=graph.device)
graph.set_edges(edge_index, edge_type)
return graph
def set_features(self, features: torch.Tensor) -> None:
"""设置面片特征
参数:
features: 形状为 (N, F) 的张量表示面片的特征向量
"""
if features.shape[0] != self.num_patches:
raise ValueError(f"特征数量 {features.shape[0]} 与面片数量 {self.num_patches} 不匹配")
self.patch_features = features
def is_clique(self, node_faces: torch.Tensor) -> bool:
"""检查给定面片集合是否构成完全图
参数:
node_faces: 要检查的面片索引集合
返回:
bool: 是否为完全图
"""
if self.edge_index is None:
return False
# 获取子图的边
mask = torch.isin(self.edge_index[0], node_faces) & torch.isin(self.edge_index[1], node_faces)
subgraph_edges = self.edge_index[:, mask]
# 计算完全图应有的边数
n = len(node_faces)
expected_edges = n * (n - 1) // 2
# 计算实际的边数(考虑无向图)
actual_edges = len(subgraph_edges[0]) // 2
return actual_edges == expected_edges
def combine_sdf(self, sdf_values: torch.Tensor) -> torch.Tensor:
"""根据邻接关系组合SDF值
参数:
sdf_values: 形状为 (N,) 的张量表示每个面片的SDF值
返回:
torch.Tensor: 组合后的SDF值
"""
if self.edge_index is None or self.edge_type is None:
raise RuntimeError("请先设置边的信息")
# 获取所有相连面片对的SDF值
sdf_i = sdf_values[self.edge_index[0]] # (E,)
sdf_j = sdf_values[self.edge_index[1]] # (E,)
# 根据边的类型选择组合方式
concave_mask = self.edge_type == 0
convex_mask = self.edge_type == 1
# 初始化结果为第一个SDF值
result = sdf_values[0].clone()
# 凹边取最大值,凸边取最小值
if torch.any(concave_mask):
result = torch.max(result, torch.max(torch.stack([sdf_i[concave_mask],
sdf_j[concave_mask]])))
if torch.any(convex_mask):
result = torch.min(result, torch.min(torch.stack([sdf_i[convex_mask],
sdf_j[convex_mask]])))
return result
@staticmethod
def from_preprocessed_data(
surf_wcs: np.ndarray, # 形状为(N,)的对象数组,每个元素是形状为(M, 3)的float32数组
edgeFace_adj: np.ndarray, # 形状为(num_edges, num_faces)的int32数组
edge_types: np.ndarray # 形状为(num_edges,)的int32数组
) -> 'PatchGraph':
"""从预处理的数据直接构建面片邻接图
参数:
surf_wcs: 世界坐标系下的曲面几何数据形状为(N,)的对象数组每个元素是形状为(M, 3)的float32数组
edgeFace_adj: -面邻接矩阵形状为(num_edges, num_faces)的int32数组1表示边与面相邻
edge_types: 边的类型数组形状为(num_edges,)的int32数组0表示凹边1表示凸边
返回:
PatchGraph: 初始化好的面片邻接图包含
- edge_index: 形状为(2, num_edges*2)的torch.long张量表示双向边的连接关系
- edge_type: 形状为(num_edges*2,)的torch.long张量表示每条边的类型
"""
num_faces = len(surf_wcs)
graph = PatchGraph(num_faces)
# 构建边的索引和类型
edge_pairs = []
edge_types_list = []
# 遍历边-面邻接矩阵
for edge_idx in range(len(edgeFace_adj)):
connected_faces = np.where(edgeFace_adj[edge_idx] == 1)[0]
if len(connected_faces) == 2:
face1, face2 = connected_faces
# 添加双向边
edge_pairs.extend([[face1, face2], [face2, face1]])
# 使用预计算的边类型
edge_type = edge_types[edge_idx]
edge_types_list.extend([edge_type, edge_type]) # 双向边使用相同的类型
if edge_pairs: # 确保有边存在
edge_index = torch.tensor(edge_pairs, dtype=torch.long).t()
edge_type = torch.tensor(edge_types_list, dtype=torch.long)
graph.set_edges(edge_index, edge_type)
return graph

161
brep2sdf/test.py

@ -1,161 +1,4 @@
import os
import torch
import numpy as np
from torch.utils.data import DataLoader
from brep2sdf.data.data import BRepSDFDataset
from brep2sdf.networks.network import BRepToSDF
from brep2sdf.utils.logger import logger
from brep2sdf.config.default_config import get_default_config
import matplotlib.pyplot as plt
from tqdm import tqdm
class Tester:
def __init__(self, config, checkpoint_path):
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 初始化测试数据集
self.test_dataset = BRepSDFDataset(
brep_dir=config.data.brep_dir,
sdf_dir=config.data.sdf_dir,
valid_data_dir=config.data.valid_data_dir,
split='test'
)
# 初始化数据加载器
self.test_loader = DataLoader(
self.test_dataset,
batch_size=1, # 测试时使用batch_size=1
shuffle=False,
num_workers=config.train.num_workers,
pin_memory=False
)
# 加载模型
self.model = BRepToSDF(config).to(self.device)
self.load_checkpoint(checkpoint_path)
# 创建结果保存目录
self.result_dir = os.path.join(config.data.result_save_dir, 'test_results')
os.makedirs(self.result_dir, exist_ok=True)
def load_checkpoint(self, checkpoint_path):
"""加载检查点"""
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
logger.info(f"Loaded checkpoint from {checkpoint_path}")
def compute_metrics(self, pred_sdf, gt_sdf):
"""计算评估指标"""
mse = torch.mean((pred_sdf - gt_sdf) ** 2).item()
mae = torch.mean(torch.abs(pred_sdf - gt_sdf)).item()
max_error = torch.max(torch.abs(pred_sdf - gt_sdf)).item()
return {
'mse': mse,
'mae': mae,
'max_error': max_error
}
def visualize_results(self, pred_sdf, gt_sdf, points, save_path):
"""可视化预测结果"""
fig = plt.figure(figsize=(15, 5))
# 绘制预测SDF
ax1 = fig.add_subplot(131, projection='3d')
scatter = ax1.scatter(points[:, 0], points[:, 1], points[:, 2],
c=pred_sdf.squeeze().cpu(), cmap='coolwarm')
ax1.set_title('Predicted SDF')
plt.colorbar(scatter)
# 绘制真实SDF
ax2 = fig.add_subplot(132, projection='3d')
scatter = ax2.scatter(points[:, 0], points[:, 1], points[:, 2],
c=gt_sdf.squeeze().cpu(), cmap='coolwarm')
ax2.set_title('Ground Truth SDF')
plt.colorbar(scatter)
# 绘制误差图
ax3 = fig.add_subplot(133, projection='3d')
error = torch.abs(pred_sdf - gt_sdf)
scatter = ax3.scatter(points[:, 0], points[:, 1], points[:, 2],
c=error.squeeze().cpu(), cmap='Reds')
ax3.set_title('Absolute Error')
plt.colorbar(scatter)
plt.tight_layout()
plt.savefig(save_path)
plt.close()
def test(self):
"""执行测试"""
self.model.eval()
total_metrics = {'mse': 0, 'mae': 0, 'max_error': 0}
logger.info("Starting testing...")
with torch.no_grad():
for idx, batch in enumerate(tqdm(self.test_loader)):
# 获取数据并移动到设备
surf_ncs = batch['surf_ncs'].to(self.device)
edge_ncs = batch['edge_ncs'].to(self.device)
surf_pos = batch['surf_pos'].to(self.device)
edge_pos = batch['edge_pos'].to(self.device)
vertex_pos = batch['vertex_pos'].to(self.device)
edge_mask = batch['edge_mask'].to(self.device)
points = batch['points'].to(self.device)
gt_sdf = batch['sdf'].to(self.device)
# 前向传播
pred_sdf = self.model(
surf_ncs=surf_ncs, edge_ncs=edge_ncs,
surf_pos=surf_pos, edge_pos=edge_pos,
vertex_pos=vertex_pos, edge_mask=edge_mask,
query_points=points
)
# 计算指标
metrics = self.compute_metrics(pred_sdf, gt_sdf)
for k, v in metrics.items():
total_metrics[k] += v
# 可视化结果
if idx % self.config.test.vis_freq == 0:
save_path = os.path.join(self.result_dir, f'result_{idx}.png')
self.visualize_results(pred_sdf, gt_sdf, points[0].cpu(), save_path)
# 计算平均指标
num_samples = len(self.test_loader)
avg_metrics = {k: v / num_samples for k, v in total_metrics.items()}
# 保存测试结果
logger.info("Test Results:")
for k, v in avg_metrics.items():
logger.info(f"{k}: {v:.6f}")
# 保存指标到文件
with open(os.path.join(self.result_dir, 'test_metrics.txt'), 'w') as f:
for k, v in avg_metrics.items():
f.write(f"{k}: {v:.6f}\n")
return avg_metrics
def main():
# 获取配置
config = get_default_config()
# 设置检查点路径
checkpoint_path = os.path.join(
config.data.model_save_dir,
config.data.best_model_name.format(model_name=config.data.model_name)
)
# 初始化测试器并执行测试
tester = Tester(config, checkpoint_path)
metrics = tester.test()
if __name__ == '__main__':
main()
model = torch.jit.load("/home/wch/brep2sdf/data/output_data/00000054.pt")
print(model)

59
brep2sdf/train.py

@ -13,6 +13,7 @@ from brep2sdf.data.pre_process_by_mesh import process_single_step
from brep2sdf.networks.network import Net
from brep2sdf.networks.octree import OctreeNode
from brep2sdf.networks.loss import LossManager
from brep2sdf.networks.patch_graph import PatchGraph
from brep2sdf.utils.logger import logger
@ -31,10 +32,17 @@ parser.add_argument(
help='只采样零表面点 SDF 训练'
)
parser.add_argument(
'--force-reprocess',
'--force-reprocess','-f',
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
help='强制重新进行数据预处理,忽略缓存或已有结果'
)
parser.add_argument(
'--resume-checkpoint-path',
type=str,
default=None,
help='从指定的checkpoint文件继续训练'
)
args = parser.parse_args()
@ -86,8 +94,13 @@ class Trainer:
#logger.info( self.brep_data )
#self.sdf_data = load_sdf_file(sdf_path=self.config.data.sdf_path, num_query_points=self.config.data.num_query_points).to(self.device)
# 构建面片邻接图
graph = PatchGraph.from_preprocessed_data(
surf_wcs=self.data['surf_wcs'],
edgeFace_adj=self.data['edgeFace_adj'],
edge_types=self.data['edge_types']
)
# 初始化网络
surf_bbox=torch.tensor(
self.data['surf_bbox_ncs'],
dtype=torch.float32,
@ -95,7 +108,7 @@ class Trainer:
)
self.build_tree(surf_bbox=surf_bbox, max_depth=4)
self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=4)
self.model = Net(
@ -115,12 +128,13 @@ class Trainer:
logger.info(f"初始化完成,正在处理模型 {self.model_name}")
def build_tree(self,surf_bbox, max_depth=6):
def build_tree(self,surf_bbox, graph, max_depth=6):
num_faces = surf_bbox.shape[0]
bbox = self._calculate_global_bbox(surf_bbox)
self.root = OctreeNode(
bbox=bbox,
face_indices=np.arange(num_faces), # 初始包含所有面
patch_graph=graph,
max_depth=max_depth,
surf_bbox=surf_bbox
)
@ -302,8 +316,13 @@ class Trainer:
best_val_loss = float('inf')
logger.info("Starting training...")
start_time = time.time()
start_epoch = 1
if args.resume_checkpoint_path:
start_epoch = self._load_checkpoint(args.resume_checkpoint_path)
logger.info(f"Loaded model from {args.resume_checkpoint_path}")
for epoch in range(1, self.config.train.num_epochs + 1):
for epoch in range(start_epoch, self.config.train.num_epochs + 1):
# 训练一个epoch
train_loss = self.train_epoch(epoch)
@ -329,7 +348,8 @@ class Trainer:
# 训练完成
total_time = time.time() - start_time
self._tracing_model()
self._tracing_model_by_script()
#self._tracing_model()
logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s')
logger.info(f'Best validation loss: {best_val_loss:.6f}')
#self.test_load()
@ -349,8 +369,8 @@ class Trainer:
self.model.eval()
# 确保模型中的所有逻辑都兼容 TorchScript
scripted_model = torch.jit.script(self.model)
optimized_model = optimize_for_mobile(scripted_model)
torch.jit.save(optimized_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt")
#optimized_model = optimize_for_mobile(scripted_model)
torch.jit.save(scripted_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt")
def _tracing_model(self):
"""保存模型"""
@ -375,11 +395,6 @@ class Trainer:
except Exception as e:
logger.error(f"模型验证失败:{e}")
def _load_checkpoint(self, checkpoint_path):
"""从检查点恢复训练状态"""
model = torch.load(checkpoint_path)
return model
def _save_checkpoint(self, epoch: int, train_loss: float):
"""保存训练检查点"""
checkpoint_dir = os.path.join(
@ -387,18 +402,26 @@ class Trainer:
self.model_name
)
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir,f"epoch_{epoch:03d}.pth")
'''
checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch:03d}.pth")
# 只保存状态字典
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'loss': train_loss,
'config': self.config
}, checkpoint_path)
'''
torch.save(self.model,checkpoint_path)
def _load_checkpoint(self, checkpoint_path):
"""从检查点恢复训练状态"""
try:
checkpoint = torch.load(checkpoint_path)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch'] + 1
except Exception as e:
logger.error(f"加载checkpoint失败: {str(e)}")
raise
def main():
# 这里需要初始化配置
config = get_default_config()

Loading…
Cancel
Save