Browse Source

fix:之前采样点集中在边缘,现在是表面均匀采样

final
mckay 1 month ago
parent
commit
81742489ca
  1. 19
      brep2sdf/data/pre_process_by_mesh.py
  2. 35
      brep2sdf/data/sampler.py

19
brep2sdf/data/pre_process_by_mesh.py

@ -25,7 +25,7 @@ from OCC.Core.StlAPI import StlAPI_Writer
from OCC.Core.BRepAdaptor import BRepAdaptor_Surface from OCC.Core.BRepAdaptor import BRepAdaptor_Surface
from OCC.Core.gp import gp_Pnt, gp_Vec from OCC.Core.gp import gp_Pnt, gp_Vec
from brep2sdf.data.sampler import sample_sdf_points_and_normals, sample_face_points_brep, sample_edge_points_brep from brep2sdf.data.sampler import sample_sdf_points_and_normals, sample_face_points_brep, sample_edge_points_brep,sample_zero_surface_points_and_normals
from brep2sdf.data.data import check_data_format from brep2sdf.data.data import check_data_format
from brep2sdf.data.utils import get_bbox, normalize, get_adjacency_info,batch_compute_normals from brep2sdf.data.utils import get_bbox, normalize, get_adjacency_info,batch_compute_normals
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
@ -44,7 +44,8 @@ def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False):
返回: 返回:
dict: 包含以下键值对的字典: dict: 包含以下键值对的字典:
# 几何数据 # 几何数据
'surf_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示面的点云坐标 'train_surf_ncs' np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示归一化后的面点云
'surf_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示面的点云坐标,很多是边缘点,不合适训练
'edge_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示边的采样点坐标 'edge_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示边的采样点坐标
'surf_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示归一化后的面点云 'surf_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示归一化后的面点云
'edge_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示归一化后的边采样点 'edge_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示归一化后的边采样点
@ -86,6 +87,7 @@ def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False):
edge_explorer = TopExp_Explorer(shape, TopAbs_EDGE) edge_explorer = TopExp_Explorer(shape, TopAbs_EDGE)
vertex_explorer = TopExp_Explorer(shape, TopAbs_VERTEX) vertex_explorer = TopExp_Explorer(shape, TopAbs_VERTEX)
#tarin_surf_pnts = []
face_pnts = [] face_pnts = []
edge_pnts = [] edge_pnts = []
corner_pnts = [] corner_pnts = []
@ -149,7 +151,7 @@ def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False):
if face_idx < len(original_points): if face_idx < len(original_points):
points = original_points[face_idx] points = original_points[face_idx]
target_points = points_per_face[face_idx] target_points = points_per_face[face_idx]
#tarin_surf_pnts.append(sample_face_points_brep(face, min_points=target_points))
# 如果需要补充采样 # 如果需要补充采样
if len(points) < target_points: if len(points) < target_points:
try: try:
@ -343,6 +345,7 @@ def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False):
# 创建结果字典并确保所有数组都有正确的类型 # 创建结果字典并确保所有数组都有正确的类型
data = { data = {
#'train_surf_ncs': np.array(train_surf_ncs, dtype=object), # 保持对象数组
'surf_wcs': np.array(surfs_wcs, dtype=object), # 保持对象数组 'surf_wcs': np.array(surfs_wcs, dtype=object), # 保持对象数组
'edge_wcs': np.array(edges_wcs, dtype=object), # 保持对象数组 'edge_wcs': np.array(edges_wcs, dtype=object), # 保持对象数组
'surf_ncs': np.array(surfs_ncs, dtype=object), # 保持对象数组 'surf_ncs': np.array(surfs_ncs, dtype=object), # 保持对象数组
@ -388,7 +391,10 @@ def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False):
# 创建归一化 Trimesh # 创建归一化 Trimesh
vertices_wcs = trimesh_mesh.vertices.astype(np.float32) vertices_wcs = trimesh_mesh.vertices.astype(np.float32)
vertices_ncs = (vertices_wcs - data['normalization_params']['center']) / data['normalization_params']['scale'] logger.debug(f"vertices_wcs:{vertices_wcs}")
logger.debug(f"center:{data['normalization_params']['center']},scale:{data['normalization_params']['scale']}")
vertices_ncs = (vertices_wcs - data['normalization_params']['center']) * data['normalization_params']['scale']
logger.debug(f"vertices_ncs:{vertices_ncs}")
trimesh_mesh_ncs = trimesh.Trimesh(vertices=vertices_ncs, faces=trimesh_mesh.faces, process=False) trimesh_mesh_ncs = trimesh.Trimesh(vertices=vertices_ncs, faces=trimesh_mesh.faces, process=False)
if not trimesh_mesh_ncs.is_watertight: if not trimesh_mesh_ncs.is_watertight:
@ -397,11 +403,13 @@ def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False):
if not trimesh_mesh_ncs.is_watertight: if not trimesh_mesh_ncs.is_watertight:
logger.warning(f"{step_path} 的归一化网格修复后仍不是 watertight。") logger.warning(f"{step_path} 的归一化网格修复后仍不是 watertight。")
data["train_surf_ncs"] = sample_zero_surface_points_and_normals(trimesh_mesh_ncs, config.data.num_surf_points) # 归一化网格的顶点
except Exception as e: except Exception as e:
logger.error(f"{step_path} 加载/处理 Trimesh 失败: {e}") logger.error(f"{step_path} 加载/处理 Trimesh 失败: {e}")
trimesh_mesh = None trimesh_mesh = None
trimesh_mesh_ncs = None trimesh_mesh_ncs = None
# 如果你需要归一化后的表面点
# --- 计算表面点法线 --- # --- 计算表面点法线 ---
if sample_normal_vector and trimesh_mesh_ncs is not None: if sample_normal_vector and trimesh_mesh_ncs is not None:
logger.debug("计算表面点法线...") logger.debug("计算表面点法线...")
@ -435,6 +443,7 @@ def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False):
def process_single_step(step_path:str, output_path:str=None, sample_normal_vector=False, sample_sdf_points=False, timeout:int=300) -> dict: def process_single_step(step_path:str, output_path:str=None, sample_normal_vector=False, sample_sdf_points=False, timeout:int=300) -> dict:
"""处理单个STEP文件, 从 brep 2 pkl """处理单个STEP文件, 从 brep 2 pkl
return data = { return data = {
'train_surf_ncs' np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示归一化后的面点云
'surf_wcs': np.array(surfs_wcs, dtype=object), # 世界坐标系下的曲面几何数据(对象数组) 'surf_wcs': np.array(surfs_wcs, dtype=object), # 世界坐标系下的曲面几何数据(对象数组)
'edge_wcs': np.array(edges_wcs, dtype=object), # 世界坐标系下的边几何数据(对象数组) 'edge_wcs': np.array(edges_wcs, dtype=object), # 世界坐标系下的边几何数据(对象数组)
'surf_ncs': np.array(surfs_ncs, dtype=object), # 归一化坐标系下的曲面几何数据(对象数组) 面归一化点云 [num_faces, num_surf_sample_points, 3] 'surf_ncs': np.array(surfs_ncs, dtype=object), # 归一化坐标系下的曲面几何数据(对象数组) 面归一化点云 [num_faces, num_surf_sample_points, 3]

35
brep2sdf/data/sampler.py

@ -5,12 +5,13 @@ CAD模型处理脚本
- 拓扑信息--顶点的邻接关系 - 拓扑信息--顶点的邻接关系
- 空间信息包围盒数据 - 空间信息包围盒数据
""" """
from typing import Optional
import numpy as np import numpy as np
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
import trimesh import trimesh
from trimesh.proximity import ProximityQuery from trimesh.proximity import ProximityQuery
# 导入OpenCASCADE相关库 # 导入OpenCASCADE相关库
from OCC.Core.BRepAdaptor import BRepAdaptor_Surface, BRepAdaptor_Curve from OCC.Core.BRepAdaptor import BRepAdaptor_Surface, BRepAdaptor_Curve
from OCC.Core.GeomLProp import GeomLProp_SLProps from OCC.Core.GeomLProp import GeomLProp_SLProps
@ -185,10 +186,42 @@ def sample_edge_points_brep(edge, num_samples=50):
return None return None
def sample_zero_surface_points_and_normals(
trimesh_mesh_ncs: trimesh.Trimesh,
num_samples: int = 4096,
) -> Optional[np.ndarray]:
"""
从归一化 Trimesh 网格中均匀采样表面点及其对应法线并附加 SDF=0
返回形状为 (N, 7) 的数组[x, y, z, nx, ny, nz, sdf=0.0]
参数:
trimesh_mesh_ncs: 归一化的 Trimesh 网格
num_samples: 要采样的表面点数量
返回:
points_with_normals_sdf: (N, 7) 的数组包含坐标法线SDF=0
"""
if not isinstance(trimesh_mesh_ncs, trimesh.Trimesh):
raise ValueError("输入必须是 trimesh.Trimesh 类型")
try:
# 表面采样点及对应的 face index
points, face_indices = trimesh_mesh_ncs.sample(num_samples, return_index=True)
# 获取每个点所在的三角面片的法线
normals = trimesh_mesh_ncs.face_normals[face_indices]
# 构造 sdf 标签为 0 的列
sdf_zeros = np.zeros((num_samples, 1), dtype=np.float32)
# 合并为 (N, 7) 的数组 [xyz, normal, sdf=0]
points_with_normals_sdf = np.hstack([points, normals, sdf_zeros], dtype=np.float32)
return points_with_normals_sdf
except Exception as e:
print(f"表面点采样失败: {e}")
return None
def sample_sdf_points_and_normals( def sample_sdf_points_and_normals(
trimesh_mesh_ncs: trimesh.Trimesh, trimesh_mesh_ncs: trimesh.Trimesh,

Loading…
Cancel
Save