Browse Source

拆分数据预处理

final
mckay 2 months ago
parent
commit
18b0ed99c8
  1. 538
      brep2sdf/data/data.py
  2. 238
      brep2sdf/data/pre_process_by_mesh.py
  3. 300
      brep2sdf/data/sampler.py
  4. 31
      brep2sdf/train.py

538
brep2sdf/data/data.py

@ -9,250 +9,6 @@ from brep2sdf.config.default_config import get_default_config
class BRepSDFDataset(Dataset):
def __init__(self, brep_dir:str, sdf_dir:str, valid_data_dir:str, use_filter: bool=True, split:str='train'):
"""
初始化数据集
参数:
brep_dir: pkl文件目录
sdf_dir: npz文件目录
split: 数据集分割('train', 'val', 'test')
"""
super().__init__()
# 使用配置文件
self.config = get_default_config()
self.brep_dir = os.path.join(brep_dir, split)
self.sdf_dir = os.path.join(sdf_dir, split)
self.split = split
# 使用配置文件中的参数替换固定参数
self.max_face = self.config.data.max_face
self.max_edge = self.config.data.max_edge
self.bbox_scaled = self.config.data.bbox_scaled
# 检查目录是否存在
if not os.path.exists(self.brep_dir):
raise ValueError(f"B-rep directory not found: {self.brep_dir}")
if not os.path.exists(self.sdf_dir):
raise ValueError(f"SDF directory not found: {self.sdf_dir}")
# 加载数据列表
# 如果存在valid_data_file,则加载valid_list
valid_data_file = os.path.join(valid_data_dir, f'{split}_success.txt')
if valid_data_file:
valid_data_file = os.path.join(self.brep_dir, valid_data_file)
self.valid_data_list = self._load_valid_list(valid_data_file)
else:
raise ValueError(f"Valid data file not found: {valid_data_file}")
self.brep_data_list = self._load_data_list(self.brep_dir)
self.sdf_data_list = self._load_data_list(self.sdf_dir)
if use_filter:
self._filter_num_faces_and_num_edges()
# 检查数据集是否为空
if len(self.brep_data_list) == 0 :
raise ValueError(f"No valid brep data found in {split} set")
if len(self.sdf_data_list) == 0:
raise ValueError(f"No valid sdf data found in {split} set")
logger.info(f"Loaded {split} dataset with {len(self.brep_data_list)} samples")
def _load_valid_list(self,valid_data_file:str):
with open(valid_data_file, 'r') as f:
valid_list = [line.strip() for line in f.readlines()]
return valid_list
# data_dir 为 self.brep_dir or sdf_dir
def _load_data_list(self, data_dir):
data_list = []
for sample_file in os.listdir(data_dir):
if sample_file.split('.')[0] in self.valid_data_list:
path = os.path.join(data_dir, sample_file)
data_list.append(path)
#logger.info(data_list)
return data_list
def _filter_num_faces_and_num_edges(self):
'''
Filter the data if their face_num or edge_num > max_face or max_edge.
'''
# Collect indices of elements that satisfy the condition
filtered_indices = [
idx for idx in range(len(self.brep_data_list))
if (self._get_brep_face_and_edge(self.brep_data_list[idx]) <= (self.max_face, self.max_edge))
]
#filtered_indices = filtered_indices[0:8] # TODO rm
# Use filtered_indices to update brep_data_list and sdf_data_list
self.brep_data_list = [self.brep_data_list[idx] for idx in filtered_indices]
self.sdf_data_list = [self.sdf_data_list[idx] for idx in filtered_indices]
def __len__(self):
return len(self.brep_data_list)
def __getitem__(self, idx):
"""获取单个数据样本"""
try:
brep_path = self.brep_data_list[idx]
sdf_path = self.sdf_data_list[idx]
name = os.path.splitext(os.path.basename(brep_path))[0]
# 加载B-rep和SDF数据
brep_raw = self._load_brep_file(brep_path)
sdf_data = self._load_sdf_file(sdf_path)
try:
# 处理B-rep数据
brep_features = process_brep_data(
data=brep_raw,
max_face=self.max_face,
max_edge=self.max_edge,
bbox_scaled=self.bbox_scaled
)
'''
# 打印数据形状
logger.debug(f"Processed data shapes for {os.path.basename(brep_path)}:")
for value in brep_features:
if isinstance(value, torch.Tensor):
logger.debug(f" {value.shape}")
# 检查返回值的类型和数量
if not isinstance(brep_features, tuple):
logger.error(f"process_brep_data returned {type(brep_features)}, expected tuple")
raise ValueError("Invalid return type from process_brep_data")
if len(brep_features) != 6:
logger.error(f"Expected 6 features, got {len(brep_features)}")
logger.error("Features returned:")
for i, feat in enumerate(brep_features):
if isinstance(feat, torch.Tensor):
logger.error(f" {i}: Tensor of shape {feat.shape}")
else:
logger.error(f" {i}: {type(feat)}")
raise ValueError(f"Incorrect number of features: {len(brep_features)}")
'''
# 解包处理后的特征
edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos = brep_features
sdf_points = sdf_data[:, :3]
sdf_values = sdf_data[:, 3:]
# 构建返回字典
return {
'name': name,
'edge_ncs': edge_ncs, # [max_face, max_edge, 10, 3]
'edge_pos': edge_pos, # [max_face, max_edge, 6]
'edge_mask': edge_mask, # [max_face, max_edge]
'surf_ncs': surf_ncs, # [max_face, 100, 3]
'surf_pos': surf_pos, # [max_face, 6]
'vertex_pos': vertex_pos, # [max_face, max_edge, 6]
'points': sdf_points, # [num_queries, 3] 所有点的xyz坐标
'sdf': sdf_values # [num_queries, 1] 所有点的sdf值
}
except Exception as e:
logger.error(f"\nError processing B-rep data for file: {brep_path}")
logger.error(f"Error type: {type(e).__name__}")
logger.error(f"Error message: {str(e)}")
# 打印原始数据的结构
logger.error("\nRaw data structure:")
for key, value in brep_raw.items():
if isinstance(value, list):
logger.error(f" {key}: list of length {len(value)}")
if value:
logger.error(f" First element type: {type(value[0])}")
if hasattr(value[0], 'shape'):
logger.error(f" First element shape: {value[0].shape}")
elif hasattr(value, 'shape'):
logger.error(f" {key}: shape {value.shape}")
else:
logger.error(f" {key}: {type(value)}")
raise
except Exception as e:
logger.error(f"Error loading sample from {brep_path}: {str(e)}")
logger.error("Data structure:")
raise
def _load_brep_file(self, brep_path):
with open(brep_path, 'rb') as f:
brep_raw = pickle.load(f)
return brep_raw
def _load_sdf_file(self, sdf_path):
"""加载和处理SDF数据,并进行随机采样"""
try:
# 加载SDF值
sdf_data = np.load(sdf_path)
if 'pos' not in sdf_data or 'neg' not in sdf_data:
raise ValueError("Missing pos/neg data in SDF file")
sdf_pos = sdf_data['pos'] # (N1, 4)
sdf_neg = sdf_data['neg'] # (N2, 4)
# 添加数据验证
if sdf_pos.shape[1] != 4 or sdf_neg.shape[1] != 4:
raise ValueError(f"Invalid SDF data shape: pos={sdf_pos.shape}, neg={sdf_neg.shape}")
# 随机采样
max_points = self.config.data.num_query_points # 例如4096
# 确保正负样本均衡
if max_points // 2 > sdf_pos.shape[0]:
logger.warning(f"正样本过少,期望>{max_points // 2},实际:{sdf_pos.shape[0]}")
if max_points // 2 > sdf_neg.shape[0]:
num_neg = sdf_neg.shape[0]
else:
num_neg = max_points // 2
num_pos = max_points - num_neg
# 随机采样正样本
if sdf_pos.shape[0] > num_pos:
pos_indices = np.random.choice(sdf_pos.shape[0], num_pos, replace=False)
sdf_pos = sdf_pos[pos_indices]
# 随机采样负样本
if sdf_neg.shape[0] > num_neg:
neg_indices = np.random.choice(sdf_neg.shape[0], num_neg, replace=False)
sdf_neg = sdf_neg[neg_indices]
# 合并数据
sdf_np = np.concatenate([sdf_pos, sdf_neg], axis=0)
# 再次随机打乱
np.random.shuffle(sdf_np)
# 如果总点数仍然超过最大限制,再次采样
if sdf_np.shape[0] > max_points:
indices = np.random.choice(sdf_np.shape[0], max_points, replace=False)
sdf_np = sdf_np[indices]
#logger.debug(f"Sampled SDF points: {sdf_np.shape[0]} (max: {max_points})")
return torch.from_numpy(sdf_np.astype(np.float32))
except Exception as e:
logger.error(f"Error loading SDF from {sdf_path}")
logger.error(f"Error type: {type(e).__name__}")
logger.error(f"Error message: {str(e)}")
raise
def _get_brep_face_and_edge(self, brep_path: str) -> tuple[int,int]:
brep: dict = self._load_brep_file(brep_path)
face_edge_adj = brep["faceEdge_adj"]
num_faces, num_edges = face_edge_adj.shape
return num_faces, num_edges
def load_brep_file(brep_path):
with open(brep_path, 'rb') as f:
brep_raw = pickle.load(f)
@ -457,270 +213,42 @@ def check_tensor(tensor: torch.Tensor | None, name: str, epoch: int, step: int =
def check_data_format(data, step_file):
"""检查数据格式是否正确"""
required_keys = [
'surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs', 'corner_wcs',
'edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj',
'surf_bbox_wcs', 'edge_bbox_wcs', 'corner_unique'
]
# 检查所有必需的键是否存在
for key in required_keys:
if key not in data:
return False, f"Missing key: {key}"
# 检查几何数据
geometry_arrays = ['surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs']
for key in geometry_arrays:
if not isinstance(data[key], np.ndarray):
return False, f"{key} should be a numpy array"
# 允许对象数组
if data[key].dtype != object:
return False, f"{key} should be a numpy array with dtype=object"
# 检查其他数组
float32_arrays = ['corner_wcs', 'surf_bbox_wcs', 'edge_bbox_wcs', 'corner_unique']
for key in float32_arrays:
if not isinstance(data[key], np.ndarray):
return False, f"{key} should be a numpy array"
if data[key].dtype != np.float32:
return False, f"{key} should be a numpy array with dtype=float32"
def test_dataset():
"""测试数据集功能"""
try:
# 获取配置
config = get_default_config()
# 定义预期的数据维度
expected_shapes = {
'edge_ncs': (config.data.max_face, config.data.max_edge, config.model.num_edge_points, 3),
'edge_pos': (config.data.max_face, config.data.max_edge, 6),
'edge_mask': (config.data.max_face, config.data.max_edge),
'surf_ncs': (config.data.max_face, config.model.num_surf_points, 3),
'surf_pos': (config.data.max_face, 6),
'vertex_pos': (config.data.max_face, config.data.max_edge, 2, 3),
'points': (config.data.num_query_points, 3),
'sdf': (config.data.num_query_points, 1)
}
logger.info("="*50)
logger.info("测试数据集")
logger.info(f"预期形状:")
for key, shape in expected_shapes.items():
logger.info(f" {key}: {shape}")
# 初始化数据集
dataset = BRepSDFDataset(
brep_dir=config.data.brep_dir,
sdf_dir=config.data.sdf_dir,
valid_data_dir=config.data.valid_data_dir,
split='train'
)
# 测试数据加载
logger.info("\n测试数据加载...")
sample = dataset[0]
# 检查数据类型和形状
logger.info("\n数据类型和形状检查:")
for key, value in sample.items():
if isinstance(value, torch.Tensor):
actual_shape = tuple(value.shape)
expected_shape = expected_shapes.get(key)
shape_match = "" if actual_shape == expected_shape else ""
logger.info(f"\n{key}:")
logger.info(f" 实际形状: {actual_shape}")
logger.info(f" 预期形状: {expected_shape}")
logger.info(f" 匹配状态: {shape_match}")
logger.info(f" 数据类型: {value.dtype}")
# 仅对浮点类型计算数值范围、均值和标准差
if value.dtype.is_floating_point:
logger.info(f" 数值范围: [{value.min():.3f}, {value.max():.3f}]")
logger.info(f" 均值: {value.mean():.3f}")
logger.info(f" 标准差: {value.std():.3f}")
if shape_match == "":
logger.warning(f" 形状不匹配: {key}")
if key in ['points', 'sdf']:
logger.warning(f" 查询点数量不一致,预期 {expected_shape[0]},实际 {actual_shape[0]}")
elif key in ['edge_ncs', 'edge_pos', 'edge_mask']:
logger.warning(f" 边数量不一致,预期 {expected_shape[1]},实际 {actual_shape[1]}")
elif key in ['surf_ncs', 'surf_pos']:
logger.warning(f" 面数量不一致,预期 {expected_shape[0]},实际 {actual_shape[0]}")
# 测试批处理
logger.info("\n测试批处理...")
batch_size = 4
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0
)
batch = next(iter(dataloader))
logger.info("\n批处理形状检查:")
for key, value in batch.items():
if isinstance(value, torch.Tensor):
batch_shape = tuple(value.shape)
expected_batch_shape = (batch_size,) + expected_shapes[key]
shape_match = "" if batch_shape == expected_batch_shape else ""
logger.info(f"\n{key}:")
logger.info(f" 实际形状: {batch_shape}")
logger.info(f" 预期形状: {expected_batch_shape}")
logger.info(f" 匹配状态: {shape_match}")
logger.info(f" 数据类型: {value.dtype}")
# 仅对浮点类型计算数值范围、均值和标准差
if value.dtype.is_floating_point:
logger.info(f" 数值范围: [{value.min():.3f}, {value.max():.3f}]")
logger.info(f" 均值: {value.mean():.3f}")
logger.info(f" 标准差: {value.std():.3f}")
if shape_match == "":
logger.warning(f" 批处理形状不匹配: {key}")
logger.info("\n测试完成!")
logger.info("="*50)
except Exception as e:
logger.error(f"测试过程中出错: {str(e)}")
raise
from collections import defaultdict
from tqdm import tqdm
def validate_dataset(split: str = 'train', num_samples: int = None):
"""全面验证数据集
int32_arrays = ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj']
for key in int32_arrays:
if not isinstance(data[key], np.ndarray):
return False, f"{key} should be a numpy array"
if data[key].dtype != np.int32:
return False, f"{key} should be a numpy array with dtype=int32"
Args:
split: 数据集分割 ('train', 'val', 'test')
num_samples: 要检查的样本数量None表示检查所有样本
"""
try:
config = get_default_config()
logger.info(f"开始验证{split}数据集...")
# 初始化数据集
dataset = BRepSDFDataset(
brep_dir=config.data.brep_dir,
sdf_dir=config.data.sdf_dir,
valid_data_dir=config.data.valid_data_dir,
split='train'
)
total_samples = len(dataset) if num_samples is None else min(num_samples, len(dataset))
logger.info(f"总样本数: {total_samples}")
# 初始化统计信息
stats = {
'face_counts': [],
'edge_counts': [],
'vertex_counts': [],
'sdf_point_counts': [],
'invalid_samples': [],
'shape_mismatches': defaultdict(int),
'value_ranges': defaultdict(lambda: {'min': float('inf'), 'max': float('-inf')}),
'nan_counts': defaultdict(int),
'inf_counts': defaultdict(int)
}
# 遍历数据集
for idx in tqdm(range(total_samples), desc="验证数据"):
try:
sample = dataset[idx]
# 1. 检查数据完整性
required_keys = ['surf_ncs', 'surf_pos', 'edge_ncs', 'edge_pos',
'vertex_pos', 'points', 'sdf', 'edge_mask']
missing_keys = [key for key in required_keys if key not in sample]
if missing_keys:
stats['invalid_samples'].append((idx, f"缺少键: {missing_keys}"))
continue
# 2. 检查形状
expected_shapes = {
'surf_ncs': (config.data.max_face, config.model.num_surf_points, 3),
'surf_pos': (config.data.max_face, 6),
'edge_ncs': (config.data.max_face, config.data.max_edge, config.model.num_edge_points, 3),
'edge_pos': (config.data.max_face, config.data.max_edge, 6),
'edge_mask': (config.data.max_face, config.data.max_edge),
'vertex_pos': (config.data.max_face, config.data.max_edge, 2, 3),
'points': (config.data.num_query_points, 3),
'sdf': (config.data.num_query_points, 1)
}
for key, expected_shape in expected_shapes.items():
if key in sample:
actual_shape = tuple(sample[key].shape)
if actual_shape != expected_shape:
stats['shape_mismatches'][key] += 1
stats['invalid_samples'].append(
(idx, f"{key} 形状不匹配: 预期 {expected_shape}, 实际 {actual_shape}")
)
# 3. 检查数值范围和无效值
for key, tensor in sample.items():
if isinstance(tensor, torch.Tensor) and tensor.dtype.is_floating_point:
# 更新值范围
stats['value_ranges'][key]['min'] = min(stats['value_ranges'][key]['min'],
tensor.min().item())
stats['value_ranges'][key]['max'] = max(stats['value_ranges'][key]['max'],
tensor.max().item())
# 检查NaN和Inf
nan_count = torch.isnan(tensor).sum().item()
inf_count = torch.isinf(tensor).sum().item()
if nan_count > 0:
stats['nan_counts'][key] += nan_count
if inf_count > 0:
stats['inf_counts'][key] += inf_count
# 4. 收集统计信息
stats['face_counts'].append(sample['surf_ncs'].shape[0])
stats['edge_counts'].append(sample['edge_ncs'].shape[1])
stats['vertex_counts'].append(len(torch.unique(sample['vertex_pos'].reshape(-1, 3), dim=0)))
stats['sdf_point_counts'].append(sample['points'].shape[0])
except Exception as e:
stats['invalid_samples'].append((idx, str(e)))
# 输出统计结果
logger.info("\n=== 数据集验证结果 ===")
# 1. 基本统计信息
logger.info("\n基本统计信息:")
logger.info(f"总样本数: {total_samples}")
logger.info(f"有效样本数: {total_samples - len(stats['invalid_samples'])}")
logger.info(f"无效样本数: {len(stats['invalid_samples'])}")
# 2. 形状不匹配统计
if stats['shape_mismatches']:
logger.info("\n形状不匹配统计:")
for key, count in stats['shape_mismatches'].items():
logger.info(f" {key}: {count}个样本不匹配")
# 3. 数值范围统计
logger.info("\n数值范围统计:")
for key, ranges in stats['value_ranges'].items():
logger.info(f" {key}:")
logger.info(f" 最小值: {ranges['min']:.3f}")
logger.info(f" 最大值: {ranges['max']:.3f}")
# 4. 无效值统计
if sum(stats['nan_counts'].values()) > 0 or sum(stats['inf_counts'].values()) > 0:
logger.info("\n无效值统计:")
for key in stats['nan_counts'].keys():
if stats['nan_counts'][key] > 0:
logger.info(f" {key} 包含 {stats['nan_counts'][key]} 个 NaN 值")
for key in stats['inf_counts'].keys():
if stats['inf_counts'][key] > 0:
logger.info(f" {key} 包含 {stats['inf_counts'][key]} 个 Inf 值")
# 5. 几何特征统计
logger.info("\n几何特征统计:")
for name, values in [
('面数', stats['face_counts']),
('边数', stats['edge_counts']),
('顶点数', stats['vertex_counts']),
('SDF采样点数', stats['sdf_point_counts'])
]:
values = np.array(values)
logger.info(f" {name}:")
logger.info(f" 最小值: {np.min(values)}")
logger.info(f" 最大值: {np.max(values)}")
logger.info(f" 平均值: {np.mean(values):.2f}")
logger.info(f" 中位数: {np.median(values):.2f}")
logger.info(f" 标准差: {np.std(values):.2f}")
# 6. 输出无效样本详情
if stats['invalid_samples']:
logger.info("\n无效样本详情:")
for idx, error in stats['invalid_samples']:
logger.info(f" 样本 {idx}: {error}")
return stats
except Exception as e:
logger.error(f"验证过程出错: {str(e)}")
raise
return True, ""
if __name__ == '__main__':
validate_dataset(split='train', num_samples=None) # 先测试100个样本

238
brep2sdf/data/pre_process_by_mesh.py

@ -34,6 +34,8 @@ from OCC.Core.Bnd import Bnd_Box # 包围盒
from OCC.Core.TopoDS import TopoDS_Shape, topods, TopoDS_Vertex # 拓扑数据结构
from OCC.Core.StlAPI import StlAPI_Writer
from brep2sdf.data.sampler import sample_sdf_points_and_normals
from brep2sdf.data.data import check_data_format
# 导入配置
from brep2sdf.config.default_config import get_default_config
config = get_default_config()
@ -533,244 +535,8 @@ def batch_compute_normals(mesh, surf_wcs, normal_type='vertex', k_neighbors=3):
return normals_output
def sample_sdf_points_and_normals(
trimesh_mesh_ncs: trimesh.Trimesh,
surf_bbox_ncs: np.ndarray,
num_sdf_samples: int = 4096,
sdf_sampling_std_dev: float = 0.01
) -> np.ndarray | None:
"""
在归一化坐标系(NCS)下采样固定数量的点并计算它们的SDF值和最近表面法线
采用均匀采样和近表面采样的混合策略
参数:
trimesh_mesh_ncs: 归一化的 Trimesh 对象
surf_bbox_ncs: 归一化坐标系下各面的包围盒 [num_faces, 6]
num_sdf_samples: 要采样的总点数
sdf_sampling_std_dev: 近表面采样时沿法线扰动的标准差
返回:
np.ndarray | None: 形状为 (N, 7) 的数组 [x, y, z, nx, ny, nz, sdf]
如果采样或计算失败则返回 None
"""
logger.debug("为 SDF 计算采样点 (固定数量策略)...")
if trimesh_mesh_ncs is None or not isinstance(trimesh_mesh_ncs, trimesh.Trimesh):
logger.error("无效的 Trimesh 对象提供给 SDF 采样。")
return None
if num_sdf_samples <= 0:
logger.warning("请求的 SDF 采样点数为零或负数,不进行采样。")
return None
# 使用固定的采样边界 [-0.5, 0.5],因为我们已经做过归一化
min_bound_ncs = np.array([-0.5, -0.5, -0.5], dtype=np.float32)
max_bound_ncs = np.array([0.5, 0.5, 0.5], dtype=np.float32)
bbox_size_ncs = max_bound_ncs - min_bound_ncs
# --- 使用固定的总样本数分配点数 ---
num_uniform_samples = num_sdf_samples // 2
num_near_surface_samples = num_sdf_samples - num_uniform_samples
logger.debug(f"固定 SDF 采样点数: {num_sdf_samples} (均匀: {num_uniform_samples}, 近表面: {num_near_surface_samples})")
# --- 执行采样 ---
sampled_points_list = []
# 均匀采样 (在 [-0.5, 0.5] 范围内)
if num_uniform_samples > 0:
uniform_points = np.random.uniform(-0.5, 0.5, (num_uniform_samples, 3))
sampled_points_list.append(uniform_points)
# 近表面采样
if num_near_surface_samples > 0:
if trimesh_mesh_ncs.faces.shape[0] > 0:
try:
near_points_on_surface = trimesh_mesh_ncs.sample(num_near_surface_samples)
proximity_query_near = ProximityQuery(trimesh_mesh_ncs)
closest_points_near, distances_near, face_indices_near = proximity_query_near.on_surface(near_points_on_surface)
if np.any(face_indices_near >= len(trimesh_mesh_ncs.face_normals)):
raise IndexError("Face index out of bounds during near-surface normal lookup")
normals_near = trimesh_mesh_ncs.face_normals[face_indices_near]
perturbations = np.random.randn(num_near_surface_samples, 1) * sdf_sampling_std_dev
near_points = near_points_on_surface + normals_near * perturbations
# 确保近表面点也在 [-0.5, 0.5] 范围内
near_points = np.clip(near_points, -0.5, 0.5)
sampled_points_list.append(near_points)
except Exception as e:
logger.warning(f"近表面采样失败: {e}。将替换为均匀采样点。")
fallback_uniform = np.random.uniform(-0.5, 0.5, (num_near_surface_samples, 3))
sampled_points_list.append(fallback_uniform)
else:
logger.warning("网格没有面,无法执行近表面采样。将替换为均匀采样点。")
fallback_uniform = np.random.uniform(-0.5, 0.5, (num_near_surface_samples, 3))
sampled_points_list.append(fallback_uniform)
# --- 合并采样点 ---
if not sampled_points_list:
logger.warning("没有为SDF采样到任何点。")
return None
sampled_points_ncs = np.vstack(sampled_points_list).astype(np.float32)
try:
proximity_query = ProximityQuery(trimesh_mesh_ncs)
# 分批计算SDF以避免内存问题
batch_size = 1000
sdf_values = []
closest_points = []
face_indices = []
for i in range(0, len(sampled_points_ncs), batch_size):
batch_points = sampled_points_ncs[i:i + batch_size]
# 计算当前批次的最近点和面
batch_closest, batch_distances, batch_faces = proximity_query.on_surface(batch_points)
# 计算点到最近面的向量
direction_vectors = batch_points - batch_closest
# 使用batch_compute_normals计算最近点的法向量
# 将batch_closest重新组织为所需的格式 (N,) 数组,每个元素是 (M, 3) 数组
closest_points_reshaped = np.array([batch_closest], dtype=object)
closest_points_reshaped[0] = batch_closest
# 计算法向量
normals_batch = batch_compute_normals(
trimesh_mesh_ncs,
closest_points_reshaped,
normal_type='vertex', # 使用顶点法向量
k_neighbors=3
)[0] # 取第一个元素因为我们只传入了一个批次
# 计算方向向量与法向量的点积
dot_products = np.sum(direction_vectors * normals_batch, axis=1)
signs = np.sign(dot_products)
# 确保零点处的符号处理
zero_mask = np.abs(batch_distances) < 1e-6
signs[zero_mask] = 0.0
# 计算带符号距离
batch_sdf = batch_distances * signs
# 限制SDF值的范围
batch_sdf = np.clip(batch_sdf, -1.414, 1.414) # sqrt(2)
# 添加调试信息
if i == 0: # 只打印第一个批次的统计信息
logger.debug(f"批次统计 (首批次):")
logger.debug(f" 法向量范围: [{normals_batch.min():.4f}, {normals_batch.max():.4f}]")
logger.debug(f" 法向量长度: {np.linalg.norm(normals_batch, axis=1).mean():.4f}")
logger.debug(f" 距离范围: [{batch_distances.min():.4f}, {batch_distances.max():.4f}]")
logger.debug(f" 点积范围: [{dot_products.min():.4f}, {dot_products.max():.4f}]")
logger.debug(f" 符号分布: 正={np.sum(signs > 0)}, 负={np.sum(signs < 0)}, 零={np.sum(signs == 0)}")
logger.debug(f" SDF范围: [{batch_sdf.min():.4f}, {batch_sdf.max():.4f}]")
sdf_values.append(batch_sdf)
closest_points.append(batch_closest)
face_indices.append(batch_faces)
# 合并批次结果
sdf_values = np.concatenate(sdf_values)
closest_points = np.concatenate(closest_points)
# 为所有点计算法向量
all_points_reshaped = np.array([closest_points], dtype=object)
all_points_reshaped[0] = closest_points
sampled_normals = batch_compute_normals(
trimesh_mesh_ncs,
all_points_reshaped,
normal_type='vertex',
k_neighbors=3
)[0]
# 验证法向量
normal_lengths = np.linalg.norm(sampled_normals, axis=1)
logger.debug(f"最终法向量统计:")
logger.debug(f" 形状: {sampled_normals.shape}")
logger.debug(f" 长度: min={normal_lengths.min():.4f}, max={normal_lengths.max():.4f}, mean={normal_lengths.mean():.4f}")
logger.debug(f" 分量范围: x=[{sampled_normals[:,0].min():.4f}, {sampled_normals[:,0].max():.4f}]")
logger.debug(f" 分量范围: y=[{sampled_normals[:,1].min():.4f}, {sampled_normals[:,1].max():.4f}]")
logger.debug(f" 分量范围: z=[{sampled_normals[:,2].min():.4f}, {sampled_normals[:,2].max():.4f}]")
# 添加验证
valid_mask = (
~np.isnan(sdf_values) & ~np.isinf(sdf_values) &
~np.isnan(sampled_normals).any(axis=1) & ~np.isinf(sampled_normals).any(axis=1) &
~np.isnan(sampled_points_ncs).any(axis=1) & ~np.isinf(sampled_points_ncs).any(axis=1)
)
if not np.all(valid_mask):
num_invalid = sampled_points_ncs.shape[0] - np.sum(valid_mask)
logger.warning(f"在 SDF/法线计算中发现 {num_invalid} 个无效条目。将它们过滤掉。")
sampled_points_ncs = sampled_points_ncs[valid_mask]
sampled_normals = sampled_normals[valid_mask]
sdf_values = sdf_values[valid_mask]
if sampled_points_ncs.shape[0] > 0:
combined_data = np.hstack((
sampled_points_ncs,
sampled_normals,
sdf_values[:, np.newaxis]
)).astype(np.float32)
# 添加SDF分布验证
final_sdf = combined_data[:, -1]
logger.debug(f"最终SDF分布验证:")
logger.debug(f" 正值点数: {np.sum(final_sdf > 0)}")
logger.debug(f" 负值点数: {np.sum(final_sdf < 0)}")
logger.debug(f" 零值点数: {np.sum(np.abs(final_sdf) < 1e-6)}")
logger.debug(f" SDF范围: [{final_sdf.min():.4f}, {final_sdf.max():.4f}]")
# 验证分布是否合理
if np.sum(final_sdf > 0) == 0 or np.sum(final_sdf < 0) == 0:
logger.warning("警告:SDF值分布异常,没有正值或负值!")
return combined_data
else:
logger.warning("过滤 SDF/法线结果后没有剩余有效点。")
return None
except Exception as e:
logger.error(f"计算 SDF 或法线时失败: {str(e)}")
return None
def check_data_format(data, step_file):
"""检查数据格式是否正确"""
required_keys = [
'surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs', 'corner_wcs',
'edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj',
'surf_bbox_wcs', 'edge_bbox_wcs', 'corner_unique'
]
# 检查所有必需的键是否存在
for key in required_keys:
if key not in data:
return False, f"Missing key: {key}"
# 检查几何数据
geometry_arrays = ['surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs']
for key in geometry_arrays:
if not isinstance(data[key], np.ndarray):
return False, f"{key} should be a numpy array"
# 允许对象数组
if data[key].dtype != object:
return False, f"{key} should be a numpy array with dtype=object"
# 检查其他数组
float32_arrays = ['corner_wcs', 'surf_bbox_wcs', 'edge_bbox_wcs', 'corner_unique']
for key in float32_arrays:
if not isinstance(data[key], np.ndarray):
return False, f"{key} should be a numpy array"
if data[key].dtype != np.float32:
return False, f"{key} should be a numpy array with dtype=float32"
int32_arrays = ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj']
for key in int32_arrays:
if not isinstance(data[key], np.ndarray):
return False, f"{key} should be a numpy array"
if data[key].dtype != np.int32:
return False, f"{key} should be a numpy array with dtype=int32"
return True, ""
def process_single_step(step_path:str, output_path:str=None, sample_normal_vector=False, sample_sdf_points=False, timeout:int=300) -> dict:
"""处理单个STEP文件, 从 brep 2 pkl

300
brep2sdf/data/sampler.py

@ -0,0 +1,300 @@
"""
CAD模型处理脚本
功能将STEP格式的CAD模型转换为结构化数据包括
- 几何信息顶点的坐标数据
- 拓扑信息--顶点的邻接关系
- 空间信息包围盒数据
"""
import os
import pickle # 用于数据序列化
import argparse # 命令行参数解析
import numpy as np
from tqdm import tqdm # 进度条显示
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError # 并行处理
import logging
from datetime import datetime
from scipy.spatial import cKDTree
from brep2sdf.utils.logger import logger
import tempfile
import trimesh
from trimesh.proximity import ProximityQuery
# 导入OpenCASCADE相关库
from OCC.Core.STEPControl import STEPControl_Reader # STEP文件读取器
from OCC.Core.TopExp import TopExp_Explorer, topexp # 拓扑结构遍历
from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX # 拓扑类型定义
from OCC.Core.BRep import BRep_Tool # B-rep工具
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh # 网格剖分
from OCC.Core.TopLoc import TopLoc_Location # 位置变换
from OCC.Core.IFSelect import IFSelect_RetDone,IFSelect_RetError, IFSelect_RetFail, IFSelect_RetVoid # 操作状态码
from OCC.Core.TopTools import TopTools_IndexedDataMapOfShapeListOfShape # 形状映射
from OCC.Core.BRepBndLib import brepbndlib # 包围盒计算
from OCC.Core.Bnd import Bnd_Box # 包围盒
from OCC.Core.TopoDS import TopoDS_Shape, topods, TopoDS_Vertex # 拓扑数据结构
from OCC.Core.StlAPI import StlAPI_Writer
# 导入配置
from brep2sdf.config.default_config import get_default_config
config = get_default_config()
# 设置最大面数阈值,用于加速处理
MAX_FACE = config.data.max_face
def _sample_uniform_points(num_points: int) -> np.ndarray:
"""在 [-0.5, 0.5] 范围内均匀采样点
参数:
num_points: 要采样的点数
返回:
np.ndarray: 形状为 (num_points, 3) 的采样点数组
"""
return np.random.uniform(-0.5, 0.5, (num_points, 3))
def _sample_near_surface_points(
mesh: trimesh.Trimesh,
num_points: int,
std_dev: float
) -> np.ndarray:
"""在网格表面附近采样点
参数:
mesh: 输入的trimesh网格
num_points: 要采样的点数
std_dev: 沿法线方向的扰动标准差
返回:
np.ndarray: 形状为 (num_points, 3) 的采样点数组
"""
if mesh.faces.shape[0] == 0:
logger.warning("网格没有面,无法执行近表面采样。将替换为均匀采样点。")
return _sample_uniform_points(num_points)
try:
near_points_on_surface = mesh.sample(num_points)
proximity_query_near = ProximityQuery(mesh)
closest_points_near, _, face_indices_near = proximity_query_near.on_surface(near_points_on_surface)
if np.any(face_indices_near >= len(mesh.face_normals)):
raise IndexError("Face index out of bounds during near-surface normal lookup")
normals_near = mesh.face_normals[face_indices_near]
perturbations = np.random.randn(num_points, 1) * std_dev
near_points = near_points_on_surface + normals_near * perturbations
return np.clip(near_points, -0.5, 0.5)
except Exception as e:
logger.warning(f"近表面采样失败: {e}。将替换为均匀采样点。")
return _sample_uniform_points(num_points)
def sample_points(
trimesh_mesh_ncs: trimesh.Trimesh,
num_uniform_samples: int,
num_near_surface_samples: int,
sdf_sampling_std_dev: float
) -> np.ndarray | None:
"""组合均匀采样和近表面采样的点
参数:
trimesh_mesh_ncs: 归一化的trimesh网格
num_uniform_samples: 均匀采样点数
num_near_surface_samples: 近表面采样点数
sdf_sampling_std_dev: 近表面采样的标准差
返回:
np.ndarray | None: 合并后的采样点数组失败时返回None
"""
sampled_points_list = []
# 均匀采样
if num_uniform_samples > 0:
uniform_points = _sample_uniform_points(num_uniform_samples)
sampled_points_list.append(uniform_points)
# 近表面采样
if num_near_surface_samples > 0:
near_points = _sample_near_surface_points(
trimesh_mesh_ncs,
num_near_surface_samples,
sdf_sampling_std_dev
)
sampled_points_list.append(near_points)
# 合并采样点
if not sampled_points_list:
logger.warning("没有采样到任何点。")
return None
return np.vstack(sampled_points_list).astype(np.float32)
# 在原始的sample_sdf_points_and_normals函数中使用新的采样函数
def sample_sdf_points_and_normals(
trimesh_mesh_ncs: trimesh.Trimesh,
surf_bbox_ncs: np.ndarray,
num_sdf_samples: int = 4096,
sdf_sampling_std_dev: float = 0.01
) -> np.ndarray | None:
"""
在归一化坐标系(NCS)下采样固定数量的点并计算它们的SDF值和最近表面法线
采用均匀采样和近表面采样的混合策略
参数:
trimesh_mesh_ncs: 归一化的 Trimesh 对象
surf_bbox_ncs: 归一化坐标系下各面的包围盒 [num_faces, 6]
num_sdf_samples: 要采样的总点数
sdf_sampling_std_dev: 近表面采样时沿法线扰动的标准差
返回:
np.ndarray | None: 形状为 (N, 7) 的数组 [x, y, z, nx, ny, nz, sdf]
如果采样或计算失败则返回 None
"""
logger.debug("为 SDF 计算采样点 (固定数量策略)...")
if trimesh_mesh_ncs is None or not isinstance(trimesh_mesh_ncs, trimesh.Trimesh):
logger.error("无效的 Trimesh 对象提供给 SDF 采样。")
return None
if num_sdf_samples <= 0:
logger.warning("请求的 SDF 采样点数为零或负数,不进行采样。")
return None
# 使用固定的采样边界 [-0.5, 0.5],因为我们已经做过归一化
min_bound_ncs = np.array([-0.5, -0.5, -0.5], dtype=np.float32)
max_bound_ncs = np.array([0.5, 0.5, 0.5], dtype=np.float32)
bbox_size_ncs = max_bound_ncs - min_bound_ncs
# --- 使用固定的总样本数分配点数 ---
num_uniform_samples = num_sdf_samples // 2
num_near_surface_samples = num_sdf_samples - num_uniform_samples
logger.debug(f"固定 SDF 采样点数: {num_sdf_samples} (均匀: {num_uniform_samples}, 近表面: {num_near_surface_samples})")
# --- 执行采样 ---
sampled_points_ncs = sample_points(
trimesh_mesh_ncs,
num_uniform_samples,
num_near_surface_samples,
sdf_sampling_std_dev
)
try:
proximity_query = ProximityQuery(trimesh_mesh_ncs)
# 分批计算SDF以避免内存问题
batch_size = 1000
sdf_values = []
closest_points = []
face_indices = []
for i in range(0, len(sampled_points_ncs), batch_size):
batch_points = sampled_points_ncs[i:i + batch_size]
# 计算当前批次的最近点和面
batch_closest, batch_distances, batch_faces = proximity_query.on_surface(batch_points)
# 计算点到最近面的向量
direction_vectors = batch_points - batch_closest
# 使用batch_compute_normals计算最近点的法向量
# 将batch_closest重新组织为所需的格式 (N,) 数组,每个元素是 (M, 3) 数组
closest_points_reshaped = np.array([batch_closest], dtype=object)
closest_points_reshaped[0] = batch_closest
# 计算法向量
normals_batch = batch_compute_normals(
trimesh_mesh_ncs,
closest_points_reshaped,
normal_type='vertex', # 使用顶点法向量
k_neighbors=3
)[0] # 取第一个元素因为我们只传入了一个批次
# 计算方向向量与法向量的点积
dot_products = np.sum(direction_vectors * normals_batch, axis=1)
signs = np.sign(dot_products)
# 确保零点处的符号处理
zero_mask = np.abs(batch_distances) < 1e-6
signs[zero_mask] = 0.0
# 计算带符号距离
batch_sdf = batch_distances * signs
# 限制SDF值的范围
batch_sdf = np.clip(batch_sdf, -1.414, 1.414) # sqrt(2)
# 添加调试信息
if i == 0: # 只打印第一个批次的统计信息
logger.debug(f"批次统计 (首批次):")
logger.debug(f" 法向量范围: [{normals_batch.min():.4f}, {normals_batch.max():.4f}]")
logger.debug(f" 法向量长度: {np.linalg.norm(normals_batch, axis=1).mean():.4f}")
logger.debug(f" 距离范围: [{batch_distances.min():.4f}, {batch_distances.max():.4f}]")
logger.debug(f" 点积范围: [{dot_products.min():.4f}, {dot_products.max():.4f}]")
logger.debug(f" 符号分布: 正={np.sum(signs > 0)}, 负={np.sum(signs < 0)}, 零={np.sum(signs == 0)}")
logger.debug(f" SDF范围: [{batch_sdf.min():.4f}, {batch_sdf.max():.4f}]")
sdf_values.append(batch_sdf)
closest_points.append(batch_closest)
face_indices.append(batch_faces)
# 合并批次结果
sdf_values = np.concatenate(sdf_values)
closest_points = np.concatenate(closest_points)
# 为所有点计算法向量
all_points_reshaped = np.array([closest_points], dtype=object)
all_points_reshaped[0] = closest_points
sampled_normals = batch_compute_normals(
trimesh_mesh_ncs,
all_points_reshaped,
normal_type='vertex',
k_neighbors=3
)[0]
# 验证法向量
normal_lengths = np.linalg.norm(sampled_normals, axis=1)
logger.debug(f"最终法向量统计:")
logger.debug(f" 形状: {sampled_normals.shape}")
logger.debug(f" 长度: min={normal_lengths.min():.4f}, max={normal_lengths.max():.4f}, mean={normal_lengths.mean():.4f}")
logger.debug(f" 分量范围: x=[{sampled_normals[:,0].min():.4f}, {sampled_normals[:,0].max():.4f}]")
logger.debug(f" 分量范围: y=[{sampled_normals[:,1].min():.4f}, {sampled_normals[:,1].max():.4f}]")
logger.debug(f" 分量范围: z=[{sampled_normals[:,2].min():.4f}, {sampled_normals[:,2].max():.4f}]")
# 添加验证
valid_mask = (
~np.isnan(sdf_values) & ~np.isinf(sdf_values) &
~np.isnan(sampled_normals).any(axis=1) & ~np.isinf(sampled_normals).any(axis=1) &
~np.isnan(sampled_points_ncs).any(axis=1) & ~np.isinf(sampled_points_ncs).any(axis=1)
)
if not np.all(valid_mask):
num_invalid = sampled_points_ncs.shape[0] - np.sum(valid_mask)
logger.warning(f"在 SDF/法线计算中发现 {num_invalid} 个无效条目。将它们过滤掉。")
sampled_points_ncs = sampled_points_ncs[valid_mask]
sampled_normals = sampled_normals[valid_mask]
sdf_values = sdf_values[valid_mask]
if sampled_points_ncs.shape[0] > 0:
combined_data = np.hstack((
sampled_points_ncs,
sampled_normals,
sdf_values[:, np.newaxis]
)).astype(np.float32)
# 添加SDF分布验证
final_sdf = combined_data[:, -1]
logger.debug(f"最终SDF分布验证:")
logger.debug(f" 正值点数: {np.sum(final_sdf > 0)}")
logger.debug(f" 负值点数: {np.sum(final_sdf < 0)}")
logger.debug(f" 零值点数: {np.sum(np.abs(final_sdf) < 1e-6)}")
logger.debug(f" SDF范围: [{final_sdf.min():.4f}, {final_sdf.max():.4f}]")
# 验证分布是否合理
if np.sum(final_sdf > 0) == 0 or np.sum(final_sdf < 0) == 0:
logger.warning("警告:SDF值分布异常,没有正值或负值!")
return combined_data
else:
logger.warning("过滤 SDF/法线结果后没有剩余有效点。")
return None
except Exception as e:
logger.error(f"计算 SDF 或法线时失败: {str(e)}")
return None

31
brep2sdf/train.py

@ -1,5 +1,6 @@
import torch
from torch.serialization import add_safe_globals
from torch.utils.mobile_optimizer import optimize_for_mobile
import torch.optim as optim
import time
import os
@ -64,7 +65,7 @@ class Trainer:
surface_sdf_data = prepare_sdf_data(
surfs,
normals=self.data["surf_pnt_normals"],
max_points=4096,
max_points=50000,
device=self.device
)
# 如果不是仅使用零表面,则合并采样点数据
@ -343,12 +344,36 @@ class Trainer:
sdfs= model(example_input)
logger.debug(f"sdfs:{sdfs}")
def _tracing_model(self):
def _tracing_model_by_script(self):
"""保存模型"""
self.model.eval()
# 确保模型中的所有逻辑都兼容 TorchScript
scripted_model = torch.jit.script(self.model)
torch.jit.save(scripted_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt")
optimized_model = optimize_for_mobile(scripted_model)
torch.jit.save(optimized_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt")
def _tracing_model(self):
"""保存模型"""
self.model.eval()
# 创建示例输入
example_input = torch.rand(1, 3, device=self.device)
# 使用 trace 方式导出模型
traced_model = torch.jit.trace(self.model, example_input)
# 保存模型
save_path = f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt"
torch.jit.save(traced_model, save_path)
# 验证保存的模型
try:
loaded_model = torch.jit.load(save_path)
test_input = torch.rand(1, 3, device=self.device)
_ = loaded_model(test_input)
logger.info(f"模型已保存并验证成功:{save_path}")
except Exception as e:
logger.error(f"模型验证失败:{e}")
def _load_checkpoint(self, checkpoint_path):
"""从检查点恢复训练状态"""

Loading…
Cancel
Save