Browse Source

增加非零表面采样点,优化batch train

final
mckay 3 months ago
parent
commit
a3a2a2c7a1
  1. 4
      brep2sdf/IsoSurfacing.py
  2. 177
      brep2sdf/batch_train.py
  3. 6
      brep2sdf/config/default_config.py
  4. 143
      brep2sdf/data/data.py
  5. 300
      brep2sdf/data/pre_process_by_mesh.py
  6. 1
      brep2sdf/networks/encoder.py
  7. 242
      brep2sdf/train.py

4
brep2sdf/IsoSurfacing.py

@ -111,8 +111,8 @@ def main():
parser = argparse.ArgumentParser(description='IsoSurface Generator')
parser.add_argument('-i', '--input', type=str, required=True, help='Input model file (.pt)')
parser.add_argument('-o', '--output', type=str, required=True, help='Output mesh file (.ply)')
parser.add_argument('--depth', type=int, default=7, help='网格深度(分辨率)')
parser.add_argument('--box_size', type=float, default=2.0, help='边界框大小')
parser.add_argument('--depth', type=int, default=3, help='网格深度(分辨率)')
parser.add_argument('--box_size', type=float, default=1.0, help='边界框大小')
parser.add_argument('--method', type=str, default='MC', choices=['MC'], help='表面提取方法')
parser.add_argument('--use-gpu', action='store_true', help='使用GPU')
parser.add_argument('--compare', type=str, help='GT网格文件(.ply)')

177
brep2sdf/batch_train.py

@ -1,67 +1,152 @@
import os
import subprocess
import argparse
import logging
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
# 配置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def main():
# 定义 STEP 文件目录和名称列表文件路径
step_root_dir = "/home/wch/brep2sdf/data/step"
name_list_path = "/home/wch/brep2sdf/data/name_list.txt"
def run_training_process(input_step: str, train_script: str, common_args: list) -> tuple[str, bool, str, str]:
"""
为单个 STEP 文件运行 train.py 子进程
Args:
input_step: 输入 STEP 文件的路径
train_script: train.py 脚本的路径
common_args: 传递给 train.py 的通用参数列表
Returns:
Tuple: (输入文件路径, 是否成功, stdout, stderr)
"""
command = [
"python", train_script,
*common_args,
"-i", input_step,
]
try:
result = subprocess.run(
command,
capture_output=True,
text=True,
check=True, # 如果返回非零退出码,则抛出 CalledProcessError
timeout=600 # 添加超时设置(例如10分钟)
)
return input_step, True, result.stdout, result.stderr
except subprocess.CalledProcessError as e:
logger.error(f"命令 '{' '.join(command)}' 执行失败。")
return input_step, False, e.stdout, e.stderr
except subprocess.TimeoutExpired as e:
logger.error(f"处理 '{input_step}' 超时。")
return input_step, False, e.stdout or "", e.stderr or ""
except Exception as e:
logger.error(f"处理 '{input_step}' 时发生意外错误: {e}")
return input_step, False, "", str(e)
def main(args):
# 读取名称列表
try:
with open(name_list_path, 'r') as f:
names = [line.strip() for line in f if line.strip()] # 去除空行
with open(args.name_list_path, 'r') as f:
names = [line.strip() for line in f if line.strip()]
logger.info(f"'{args.name_list_path}' 读取了 {len(names)} 个名称。")
except FileNotFoundError:
print(f"Error: File '{name_list_path}' not found.")
logger.error(f"错误: 文件 '{args.name_list_path}' 未找到。")
return
except Exception as e:
print(f"Error reading file '{name_list_path}': {e}")
logger.error(f"读取文件 '{args.name_list_path}' 时出错: {e}")
return
# 遍历名称列表并处理每个 STEP 文件
for name in tqdm(names, desc="Processing STEP files"):
step_dir = os.path.join(step_root_dir, name)
# 准备 train.py 的通用参数
# 注意:从命令行参数或其他配置中获取这些参数通常更好
common_train_args = [
"--use-normal",
"--only-zero-surface",
#"--force-reprocess",
# 可以添加更多通用参数
]
if args.train_args:
common_train_args.extend(args.train_args)
# 动态生成 STEP 文件路径(假设只有一个文件)
step_files = [
os.path.join(step_dir, f)
for f in os.listdir(step_dir)
if f.endswith(".step") and f.startswith(name)
]
tasks = []
skipped_count = 0
# 准备所有任务
for name in names:
step_dir = os.path.join(args.step_root_dir, name)
if not os.path.isdir(step_dir):
logger.warning(f"目录 '{step_dir}' 不存在。跳过 '{name}'")
skipped_count += 1
continue
if not step_files:
print(f"Warning: No STEP files found in directory '{step_dir}'. Skipping...")
continue
step_files = []
try:
step_files = [
os.path.join(step_dir, f)
for f in os.listdir(step_dir)
if f.lower().endswith((".step", ".stp")) and f.startswith(name) # 支持 .stp 并忽略大小写
]
except OSError as e:
logger.warning(f"无法访问目录 '{step_dir}': {e}。跳过 '{name}'")
skipped_count += 1
continue
# 假设我们只处理第一个匹配的文件
input_step = step_files[0]
if len(step_files) == 0:
logger.warning(f"在目录 '{step_dir}' 中未找到匹配的 STEP 文件。跳过 '{name}'")
skipped_count += 1
elif len(step_files) > 1:
logger.warning(f"在目录 '{step_dir}' 中找到多个匹配的 STEP 文件,将使用第一个: {step_files[0]}")
tasks.append(step_files[0])
else:
tasks.append(step_files[0])
# 构造子进程命令
command = [
"python", "train.py",
"--use-normal",
"-i", input_step, # 输入文件路径
]
if not tasks:
logger.info("没有找到需要处理的有效 STEP 文件。")
return
# 调用子进程运行命令
try:
result = subprocess.run(
command,
capture_output=True,
text=True,
check=True # 如果返回非零退出码,则抛出 CalledProcessError
)
print(f"Processed '{input_step}' successfully.")
print("STDOUT:", result.stdout)
print("STDERR:", result.stderr)
except subprocess.CalledProcessError as e:
print(f"Error processing '{input_step}': {e}")
print("STDOUT:", e.stdout)
print("STDERR:", e.stderr)
except Exception as e:
print(f"Unexpected error processing '{input_step}': {e}")
logger.info(f"准备处理 {len(tasks)} 个 STEP 文件,跳过了 {skipped_count} 个名称。")
success_count = 0
failure_count = 0
# 使用 ProcessPoolExecutor 进行并行处理
with ProcessPoolExecutor(max_workers=args.workers) as executor:
# 提交所有任务
futures = {
executor.submit(run_training_process, task_path, args.train_script, common_train_args): task_path
for task_path in tasks
}
# 使用 tqdm 显示进度并处理结果
for future in tqdm(as_completed(futures), total=len(tasks), desc="运行训练"):
input_path = futures[future]
try:
input_step, success, stdout, stderr = future.result()
if success:
success_count += 1
# 可以选择记录成功的 stdout/stderr,但通常只记录失败的更有用
# logger.debug(f"成功处理 '{input_step}'. STDOUT:\n{stdout}\nSTDERR:\n{stderr}")
else:
failure_count += 1
logger.error(f"处理 '{input_step}' 失败。STDOUT:\n{stdout}\nSTDERR:\n{stderr}")
except Exception as e:
failure_count += 1
logger.error(f"处理任务 '{input_path}' 时获取结果失败: {e}")
logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}, 跳过: {skipped_count}")
if __name__ == '__main__':
main()
parser = argparse.ArgumentParser(description="批量运行 train.py 处理 STEP 文件。")
parser.add_argument('--step-root-dir', type=str, default="/home/wch/brep2sdf/data/step",
help="包含 STEP 子目录的根目录。")
parser.add_argument('--name-list-path', type=str, default="/home/wch/brep2sdf/data/name_list.txt",
help="包含要处理的名称列表的文件路径。")
parser.add_argument('--train-script', type=str, default="train.py",
help="要执行的训练脚本路径。")
parser.add_argument('--workers', type=int, default=os.cpu_count(),
help="用于并行处理的工作进程数。")
parser.add_argument('--train-args', nargs='*',
help="传递给 train.py 的额外参数 (例如 --epochs 10 --batch-size 32)。")
args = parser.parse_args()
main(args)

6
brep2sdf/config/default_config.py

@ -27,7 +27,7 @@ class ModelConfig:
@dataclass
class DataConfig:
"""数据相关配置"""
max_face: int = 80
max_face: int = 400
max_edge: int = 16
num_query_points: int = 32*32*32 # 限制查询点数量,sdf 采样点数 本来是 128*128*128 ,在data load时随机采样
bbox_scaled: float = 1.0
@ -47,8 +47,8 @@ class TrainConfig:
# 基本训练参数
batch_size: int = 8
num_workers: int = 4
num_epochs: int = 200
learning_rate: float = 0.01
num_epochs: int = 1
learning_rate: float = 0.0001
min_lr: float = 1e-5
weight_decay: float = 0.01

143
brep2sdf/data/data.py

@ -318,6 +318,149 @@ def load_sdf_file(sdf_path: str, num_query_points: int = 4096) -> torch.Tensor:
logger.error(f"Error message: {str(e)}")
raise
def prepare_sdf_data(surf_data, normals=None, max_points=100000, device='cuda'):
"""
准备SDF数据合并表面点云可选地包含法线并进行降采样
Args:
surf_data: list[np.ndarray] 每个元素是一个形状为(M, 3)的表面点云数组
normals: list[np.ndarray] | None 每个元素是形状为(M, 3)的法线数组与surf_data对应
max_points: int 降采样后的最大点数
device: str | torch.device Pytorch设备
Returns:
torch.Tensor: 形状为 (N, 4) (N, 7) 的张量N <= max_points
列为 [x, y, z, sdf=0] [x, y, z, nx, ny, nz, sdf=0]
"""
total_points = sum(len(s) for s in surf_data)
has_normals = normals is not None
# 确定输出数组的形状
num_features = 7 if has_normals else 4
output_size = min(total_points, max_points)
sdf_array = np.zeros((output_size, num_features), dtype=np.float32)
if total_points > max_points:
# --- 执行降采样 ---
logger.debug(f"总点数 {total_points} 超过 {max_points},执行降采样...")
indices = []
for i, points in enumerate(surf_data):
indices.extend([(i, j) for j in range(len(points))])
np.random.shuffle(indices)
selected_indices = indices[:max_points] # 选择前max_points个
# 根据索引填充sdf_array
for idx, (surf_idx, pnt_idx) in enumerate(selected_indices):
sdf_array[idx, :3] = surf_data[surf_idx][pnt_idx]
if has_normals:
# 检查法线数据是否存在且有效
if surf_idx < len(normals) and pnt_idx < len(normals[surf_idx]):
sdf_array[idx, 3:6] = normals[surf_idx][pnt_idx]
else:
logger.warning(f"降采样时发现无效的法线索引: surf_idx={surf_idx}, pnt_idx={pnt_idx}")
# 可以选择填充默认值,例如 [0, 0, 1]
sdf_array[idx, 3:6] = np.array([0.0, 0.0, 1.0], dtype=np.float32)
else:
# --- 不执行降采样,使用所有点 ---
logger.debug(f"总点数 {total_points} 未超过 {max_points},使用所有点。")
# 直接拼接所有点
all_points = np.concatenate(surf_data, axis=0)
sdf_array[:, :3] = all_points
if has_normals:
# 检查法线数据是否与点数据匹配
total_normal_points = sum(len(n) for n in normals)
if total_normal_points == total_points:
all_normals = np.concatenate(normals, axis=0)
sdf_array[:, 3:6] = all_normals
else:
logger.error(f"点数 ({total_points}) 与法线点数 ({total_normal_points}) 不匹配!")
# 处理不匹配的情况,例如只填充坐标,或者抛出错误
# 这里选择只填充坐标,并将法线部分保留为0
sdf_array = sdf_array[:, :4] # 退化为只有坐标和SDF
sdf_array[:, -1] = 0.0 # 确保SDF为0
# 或者可以填充默认法线
# sdf_array[:, 3:6] = np.tile(np.array([0.0, 0.0, 1.0]), (total_points, 1))
# 注意:表面点的SDF值通常设为0
sdf_array[:, -1] = 0.0
return torch.tensor(sdf_array, dtype=torch.float32, device=device)
def print_data_distribution(data: torch.Tensor) -> None:
"""打印数据分布统计信息
Args:
data: 形状为 (N, 7) 的张量 [x, y, z, nx, ny, nz, sdf] (N, 4) 的张量 [x, y, z, sdf]
"""
# 检查数据维度
n_features = data.shape[1]
has_normals = n_features == 7
# 统计坐标信息
coords = data[:, :3]
logger.debug("坐标分布统计:")
logger.debug(f" 范围: min={coords.min(dim=0)[0]}, max={coords.max(dim=0)[0]}")
logger.debug(f" 均值: mean={coords.mean(dim=0)}")
logger.debug(f" 标准差: std={coords.std(dim=0)}")
# 如果有法向量,统计法向量信息
if has_normals:
normals = data[:, 3:6]
normal_lengths = torch.norm(normals, dim=1)
logger.debug("\n法向量分布统计:")
logger.debug(f" 范围: min={normals.min(dim=0)[0]}, max={normals.max(dim=0)[0]}")
logger.debug(f" 均值: mean={normals.mean(dim=0)}")
logger.debug(f" 标准差: std={normals.std(dim=0)}")
logger.debug(f" 法向量长度: mean={normal_lengths.mean():.4f}, std={normal_lengths.std():.4f}")
# 统计SDF值信息
sdf = data[:, -1]
logger.debug("\nSDF值分布统计:")
logger.debug(f" 范围: min={sdf.min():.4f}, max={sdf.max():.4f}")
logger.debug(f" 均值: mean={sdf.mean():.4f}")
logger.debug(f" 标准差: std={sdf.std():.4f}")
logger.debug(f" 零值附近(|sdf|<1e-4)的点数量: {torch.sum(torch.abs(sdf) < 1e-4)}")
# --- 添加一个辅助函数用于检查 ---
def check_tensor(tensor: torch.Tensor | None, name: str, epoch: int, step: int = -1) -> bool:
"""检查张量是否包含 inf 或 nan"""
prefix = f"Epoch {epoch}" + (f" Step {step}" if step >= 0 else "")
if tensor is None:
# 对于可选的张量(如 normals),None 是有效的,但对于其他张量可能是问题
# logger.warning(f"{prefix}: Tensor '{name}' is None.")
return False # 返回 False 表示没有检测到 inf/nan (但要注意 None 本身)
if not isinstance(tensor, torch.Tensor):
logger.info(f"{prefix}: '{name}' is not a Tensor, but {type(tensor)}.")
return True # 类型错误,视为问题
has_inf = torch.isinf(tensor).any()
has_nan = torch.isnan(tensor).any()
if has_inf:
logger.info(f"{prefix}: !!! Infinity detected in '{name}' !!!")
# 可以选择性地打印更多信息
# inf_indices = torch.where(torch.isinf(tensor))
# logger.error(f"Inf indices: {inf_indices}")
# logger.error(f"Inf values sample: {tensor[inf_indices][:5]}")
if has_nan:
logger.info(f"{prefix}: !!! NaN detected in '{name}' !!!")
# nan_indices = torch.where(torch.isnan(tensor))
# logger.error(f"NaN indices: {nan_indices}")
return has_inf or has_nan
# --- 辅助函数结束 ---
def test_dataset():
"""测试数据集功能"""
try:

300
brep2sdf/data/pre_process_by_mesh.py

@ -18,6 +18,8 @@ from scipy.spatial import cKDTree
from brep2sdf.utils.logger import logger
import tempfile
import trimesh
from trimesh.proximity import ProximityQuery
# 导入OpenCASCADE相关库
from OCC.Core.STEPControl import STEPControl_Reader # STEP文件读取器
from OCC.Core.TopExp import TopExp_Explorer, topexp # 拓扑结构遍历
@ -167,7 +169,7 @@ def get_bbox(shape, subshape):
def parse_solid(step_path,sample_normal_vector=False):
def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False):
"""
解析STEP文件中的CAD模型数据
@ -260,7 +262,7 @@ def parse_solid(step_path,sample_normal_vector=False):
while edge_explorer.More():
edge = topods.Edge(edge_explorer.Current())
edges.append(edge)
logger.debug(len(edges))
#logger.debug(len(edges))
curve_info = BRep_Tool.Curve(edge)
if curve_info is None:
continue # 跳过无效边
@ -378,19 +380,76 @@ def parse_solid(step_path,sample_normal_vector=False):
}
}
if sample_normal_vector:
# 从 mesh 读 法向量
mesh.Perform()
# 导出为STL临时文件
trimesh_mesh = None
trimesh_mesh_ncs = None
# --- Trimesh 加载和处理 (如果需要) ---
if sample_normal_vector or sample_sdf_points:
logger.debug("加载 Trimesh 用于法线/SDF 采样...")
# 注意:这里的 mesh (BRepMesh_IncrementalMesh) 与 trimesh 不同
# 需要从原始 shape 导出 STL
stl_writer = StlAPI_Writer()
stl_writer.SetASCIIMode(False)
with tempfile.NamedTemporaryFile(suffix='.stl') as tmp:
stl_writer.Write(shape, tmp.name)
trimesh_mesh = trimesh.load(tmp.name)
data['surf_pnt_normals']= batch_compute_normals(trimesh_mesh,surfs_wcs)
tmp_stl_path = ""
try:
with tempfile.NamedTemporaryFile(suffix='.stl', delete=True) as tmp:
tmp_stl_path = tmp.name
# 检查 shape 是否有效
if shape.IsNull():
raise ValueError("OCC Shape is Null, cannot write STL.")
success = stl_writer.Write(shape, tmp_stl_path)
if not success:
raise RuntimeError(f"StlAPI_Writer failed to write {tmp_stl_path}")
trimesh_mesh = trimesh.load(tmp_stl_path)
# 创建归一化 Trimesh
vertices_wcs = trimesh_mesh.vertices.astype(np.float32)
vertices_ncs = (vertices_wcs - data['normalization_params']['center']) / data['normalization_params']['scale']
trimesh_mesh_ncs = trimesh.Trimesh(vertices=vertices_ncs, faces=trimesh_mesh.faces, process=False)
if not trimesh_mesh_ncs.is_watertight:
logger.debug(f"{step_path} 的归一化网格不是 watertight,尝试修复。")
trimesh.repair.fill_holes(trimesh_mesh_ncs)
if not trimesh_mesh_ncs.is_watertight:
logger.warning(f"{step_path} 的归一化网格修复后仍不是 watertight。")
except Exception as e:
logger.error(f"{step_path} 加载/处理 Trimesh 失败: {e}")
trimesh_mesh = None
trimesh_mesh_ncs = None
# --- 计算表面点法线 ---
if sample_normal_vector and trimesh_mesh_ncs is not None:
logger.debug("计算表面点法线...")
# 使用 data['surf_ncs'] 因为它们已经是归一化后的点云
if data['surf_ncs'].shape[0] > 0:
# 确保 batch_compute_normals 使用归一化的 mesh
data['surf_pnt_normals'] = batch_compute_normals(trimesh_mesh_ncs, data['surf_ncs'])
else:
logger.warning("没有有效的归一化表面点云用于法线计算。")
data['surf_pnt_normals'] = np.array([], dtype=object)
elif sample_normal_vector:
logger.warning("请求了表面法线计算,但 Trimesh 加载失败。")
data['surf_pnt_normals'] = np.array([], dtype=object) # 添加空键
# --- SDF 点采样 ---
data['sampled_points_normals_sdf'] = None # 初始化键
if sample_sdf_points:
if trimesh_mesh_ncs is not None:
# 调用封装的函数,传递固定数量参数
data['sampled_points_normals_sdf'] = sample_sdf_points_and_normals(
trimesh_mesh_ncs=trimesh_mesh_ncs,
surf_bbox_ncs=data['surf_bbox_ncs'],
num_sdf_samples=4096, # <-- 传递固定数量
sdf_sampling_std_dev=0.0001
)
else:
logger.warning("请求了 SDF 点采样,但 Trimesh 加载失败。")
return data
def load_step(step_path):
"""Load STEP file and return solids"""
reader = STEPControl_Reader()
@ -474,6 +533,206 @@ def batch_compute_normals(mesh, surf_wcs, normal_type='vertex', k_neighbors=3):
return normals_output
def sample_sdf_points_and_normals(
trimesh_mesh_ncs: trimesh.Trimesh,
surf_bbox_ncs: np.ndarray,
num_sdf_samples: int = 4096,
sdf_sampling_std_dev: float = 0.01
) -> np.ndarray | None:
"""
在归一化坐标系(NCS)下采样固定数量的点并计算它们的SDF值和最近表面法线
采用均匀采样和近表面采样的混合策略
参数:
trimesh_mesh_ncs: 归一化的 Trimesh 对象
surf_bbox_ncs: 归一化坐标系下各面的包围盒 [num_faces, 6]
num_sdf_samples: 要采样的总点数
sdf_sampling_std_dev: 近表面采样时沿法线扰动的标准差
返回:
np.ndarray | None: 形状为 (N, 7) 的数组 [x, y, z, nx, ny, nz, sdf]
如果采样或计算失败则返回 None
"""
logger.debug("为 SDF 计算采样点 (固定数量策略)...")
if trimesh_mesh_ncs is None or not isinstance(trimesh_mesh_ncs, trimesh.Trimesh):
logger.error("无效的 Trimesh 对象提供给 SDF 采样。")
return None
if num_sdf_samples <= 0:
logger.warning("请求的 SDF 采样点数为零或负数,不进行采样。")
return None
# 使用固定的采样边界 [-0.5, 0.5],因为我们已经做过归一化
min_bound_ncs = np.array([-0.5, -0.5, -0.5], dtype=np.float32)
max_bound_ncs = np.array([0.5, 0.5, 0.5], dtype=np.float32)
bbox_size_ncs = max_bound_ncs - min_bound_ncs
# --- 使用固定的总样本数分配点数 ---
num_uniform_samples = num_sdf_samples // 2
num_near_surface_samples = num_sdf_samples - num_uniform_samples
logger.debug(f"固定 SDF 采样点数: {num_sdf_samples} (均匀: {num_uniform_samples}, 近表面: {num_near_surface_samples})")
# --- 执行采样 ---
sampled_points_list = []
# 均匀采样 (在 [-0.5, 0.5] 范围内)
if num_uniform_samples > 0:
uniform_points = np.random.uniform(-0.5, 0.5, (num_uniform_samples, 3))
sampled_points_list.append(uniform_points)
# 近表面采样
if num_near_surface_samples > 0:
if trimesh_mesh_ncs.faces.shape[0] > 0:
try:
near_points_on_surface = trimesh_mesh_ncs.sample(num_near_surface_samples)
proximity_query_near = ProximityQuery(trimesh_mesh_ncs)
closest_points_near, distances_near, face_indices_near = proximity_query_near.on_surface(near_points_on_surface)
if np.any(face_indices_near >= len(trimesh_mesh_ncs.face_normals)):
raise IndexError("Face index out of bounds during near-surface normal lookup")
normals_near = trimesh_mesh_ncs.face_normals[face_indices_near]
perturbations = np.random.randn(num_near_surface_samples, 1) * sdf_sampling_std_dev
near_points = near_points_on_surface + normals_near * perturbations
# 确保近表面点也在 [-0.5, 0.5] 范围内
near_points = np.clip(near_points, -0.5, 0.5)
sampled_points_list.append(near_points)
except Exception as e:
logger.warning(f"近表面采样失败: {e}。将替换为均匀采样点。")
fallback_uniform = np.random.uniform(-0.5, 0.5, (num_near_surface_samples, 3))
sampled_points_list.append(fallback_uniform)
else:
logger.warning("网格没有面,无法执行近表面采样。将替换为均匀采样点。")
fallback_uniform = np.random.uniform(-0.5, 0.5, (num_near_surface_samples, 3))
sampled_points_list.append(fallback_uniform)
# --- 合并采样点 ---
if not sampled_points_list:
logger.warning("没有为SDF采样到任何点。")
return None
sampled_points_ncs = np.vstack(sampled_points_list).astype(np.float32)
try:
proximity_query = ProximityQuery(trimesh_mesh_ncs)
# 分批计算SDF以避免内存问题
batch_size = 1000
sdf_values = []
closest_points = []
face_indices = []
for i in range(0, len(sampled_points_ncs), batch_size):
batch_points = sampled_points_ncs[i:i + batch_size]
# 计算当前批次的最近点和面
batch_closest, batch_distances, batch_faces = proximity_query.on_surface(batch_points)
# 计算点到最近面的向量
direction_vectors = batch_points - batch_closest
# 使用batch_compute_normals计算最近点的法向量
# 将batch_closest重新组织为所需的格式 (N,) 数组,每个元素是 (M, 3) 数组
closest_points_reshaped = np.array([batch_closest], dtype=object)
closest_points_reshaped[0] = batch_closest
# 计算法向量
normals_batch = batch_compute_normals(
trimesh_mesh_ncs,
closest_points_reshaped,
normal_type='vertex', # 使用顶点法向量
k_neighbors=3
)[0] # 取第一个元素因为我们只传入了一个批次
# 计算方向向量与法向量的点积
dot_products = np.sum(direction_vectors * normals_batch, axis=1)
signs = np.sign(dot_products)
# 确保零点处的符号处理
zero_mask = np.abs(batch_distances) < 1e-6
signs[zero_mask] = 0.0
# 计算带符号距离
batch_sdf = batch_distances * signs
# 限制SDF值的范围
batch_sdf = np.clip(batch_sdf, -1.414, 1.414) # sqrt(2)
# 添加调试信息
if i == 0: # 只打印第一个批次的统计信息
logger.debug(f"批次统计 (首批次):")
logger.debug(f" 法向量范围: [{normals_batch.min():.4f}, {normals_batch.max():.4f}]")
logger.debug(f" 法向量长度: {np.linalg.norm(normals_batch, axis=1).mean():.4f}")
logger.debug(f" 距离范围: [{batch_distances.min():.4f}, {batch_distances.max():.4f}]")
logger.debug(f" 点积范围: [{dot_products.min():.4f}, {dot_products.max():.4f}]")
logger.debug(f" 符号分布: 正={np.sum(signs > 0)}, 负={np.sum(signs < 0)}, 零={np.sum(signs == 0)}")
logger.debug(f" SDF范围: [{batch_sdf.min():.4f}, {batch_sdf.max():.4f}]")
sdf_values.append(batch_sdf)
closest_points.append(batch_closest)
face_indices.append(batch_faces)
# 合并批次结果
sdf_values = np.concatenate(sdf_values)
closest_points = np.concatenate(closest_points)
# 为所有点计算法向量
all_points_reshaped = np.array([closest_points], dtype=object)
all_points_reshaped[0] = closest_points
sampled_normals = batch_compute_normals(
trimesh_mesh_ncs,
all_points_reshaped,
normal_type='vertex',
k_neighbors=3
)[0]
# 验证法向量
normal_lengths = np.linalg.norm(sampled_normals, axis=1)
logger.debug(f"最终法向量统计:")
logger.debug(f" 形状: {sampled_normals.shape}")
logger.debug(f" 长度: min={normal_lengths.min():.4f}, max={normal_lengths.max():.4f}, mean={normal_lengths.mean():.4f}")
logger.debug(f" 分量范围: x=[{sampled_normals[:,0].min():.4f}, {sampled_normals[:,0].max():.4f}]")
logger.debug(f" 分量范围: y=[{sampled_normals[:,1].min():.4f}, {sampled_normals[:,1].max():.4f}]")
logger.debug(f" 分量范围: z=[{sampled_normals[:,2].min():.4f}, {sampled_normals[:,2].max():.4f}]")
# 添加验证
valid_mask = (
~np.isnan(sdf_values) & ~np.isinf(sdf_values) &
~np.isnan(sampled_normals).any(axis=1) & ~np.isinf(sampled_normals).any(axis=1) &
~np.isnan(sampled_points_ncs).any(axis=1) & ~np.isinf(sampled_points_ncs).any(axis=1)
)
if not np.all(valid_mask):
num_invalid = sampled_points_ncs.shape[0] - np.sum(valid_mask)
logger.warning(f"在 SDF/法线计算中发现 {num_invalid} 个无效条目。将它们过滤掉。")
sampled_points_ncs = sampled_points_ncs[valid_mask]
sampled_normals = sampled_normals[valid_mask]
sdf_values = sdf_values[valid_mask]
if sampled_points_ncs.shape[0] > 0:
combined_data = np.hstack((
sampled_points_ncs,
sampled_normals,
sdf_values[:, np.newaxis]
)).astype(np.float32)
# 添加SDF分布验证
final_sdf = combined_data[:, -1]
logger.debug(f"最终SDF分布验证:")
logger.debug(f" 正值点数: {np.sum(final_sdf > 0)}")
logger.debug(f" 负值点数: {np.sum(final_sdf < 0)}")
logger.debug(f" 零值点数: {np.sum(np.abs(final_sdf) < 1e-6)}")
logger.debug(f" SDF范围: [{final_sdf.min():.4f}, {final_sdf.max():.4f}]")
# 验证分布是否合理
if np.sum(final_sdf > 0) == 0 or np.sum(final_sdf < 0) == 0:
logger.warning("警告:SDF值分布异常,没有正值或负值!")
return combined_data
else:
logger.warning("过滤 SDF/法线结果后没有剩余有效点。")
return None
except Exception as e:
logger.error(f"计算 SDF 或法线时失败: {str(e)}")
return None
def check_data_format(data, step_file):
"""检查数据格式是否正确"""
required_keys = [
@ -513,7 +772,7 @@ def check_data_format(data, step_file):
return True, ""
def process_single_step(step_path:str, output_path:str=None, sample_normal_vector=False, timeout:int=300) -> dict:
def process_single_step(step_path:str, output_path:str=None, sample_normal_vector=False, sample_sdf_points=False, timeout:int=300) -> dict:
"""处理单个STEP文件, 从 brep 2 pkl
return data = {
'surf_wcs': np.array(surfs_wcs, dtype=object), # 世界坐标系下的曲面几何数据(对象数组)
@ -537,7 +796,7 @@ def process_single_step(step_path:str, output_path:str=None, sample_normal_vecto
logger.error(f"文件格式不支持,必须是.step或.stp文件: {step_path}")
return None
# 解析STEP文件
data = parse_solid(step_path, sample_normal_vector)
data = parse_solid(step_path, sample_normal_vector,sample_sdf_points)
if data is None:
logger.error(f"Failed to parse STEP file: {step_path}")
return None
@ -551,12 +810,11 @@ def process_single_step(step_path:str, output_path:str=None, sample_normal_vecto
# 保存结果
if output_path:
try:
logger.info(f"Saving results to: {output_path}")
logger.debug(f"Saving results to: {output_path}")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'wb') as f:
pickle.dump(data, f)
logger.info("数据预处理完成")
logger.info(f"Results saved successfully: {output_path}")
logger.debug(f"Results saved successfully: {output_path}")
return data
except Exception as e:
logger.error(f'Failed to save {output_path}: {str(e)}')
@ -588,19 +846,19 @@ def test(step_file_path, output_path=None):
return None
# 打印统计信息
logger.info("\nStatistics:")
logger.info(f"Number of surfaces: {len(data['surf_wcs'])}")
logger.info(f"Number of edges: {len(data['edge_wcs'])}")
logger.info(f"Number of corners: {len(data['corner_wcs'])}") # 修正为corner_wcs
logger.debug("\nStatistics:")
logger.debug(f"Number of surfaces: {len(data['surf_wcs'])}")
logger.debug(f"Number of edges: {len(data['edge_wcs'])}")
logger.debug(f"Number of corners: {len(data['corner_wcs'])}") # 修正为corner_wcs
# 保存结果
if output_path:
try:
logger.info(f"Saving results to: {output_path}")
logger.debug(f"Saving results to: {output_path}")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'wb') as f:
pickle.dump(data, f)
logger.info(f"Results saved successfully: {output_path}")
logger.debug(f"Results saved successfully: {output_path}")
except Exception as e:
logger.error(f"Failed to save {output_path}: {str(e)}")
return None

1
brep2sdf/networks/encoder.py

@ -81,7 +81,6 @@ class Encoder(nn.Module):
def _init_parameters(self):
"""为所有叶子节点初始化特征参数"""
# 使用栈模拟递归遍历(避免递归)
stack = [(self.octree, "root")] # (当前节点, 当前路径)
param_index = 0 # 参数索引计数器

242
brep2sdf/train.py

@ -7,7 +7,7 @@ import numpy as np
import argparse
from brep2sdf.config.default_config import get_default_config
from brep2sdf.data.data import load_brep_file,load_sdf_file
from brep2sdf.data.data import load_brep_file,load_sdf_file, prepare_sdf_data, print_data_distribution, check_tensor
from brep2sdf.data.pre_process_by_mesh import process_single_step
from brep2sdf.networks.network import Net
from brep2sdf.networks.octree import OctreeNode
@ -24,6 +24,11 @@ parser.add_argument(
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
help='强制采样点有法向量'
)
parser.add_argument(
'--only-zero-surface',
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
help='只采样零表面点 SDF 训练'
)
parser.add_argument(
'--force-reprocess',
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
@ -32,43 +37,6 @@ parser.add_argument(
args = parser.parse_args()
def prepare_sdf_data(surf_data, normals=None, max_points=100000, device='cuda'):
total_points = sum(len(s) for s in surf_data)
# 降采样逻辑(修复版)
if total_points > max_points:
# 生成索引
indices = []
for i, points in enumerate(surf_data):
indices.extend([(i, j) for j in range(len(points))])
# 随机打乱索引
np.random.shuffle(indices)
# 选择前max_points个索引
selected_indices = indices[:max_points]
if not normals is None:
# 根据索引构建sdf_array
sdf_array = np.zeros((max_points, 4), dtype=np.float32)
for idx, (i, j) in enumerate(selected_indices):
sdf_array[idx, :3] = surf_data[i][j]
else:
sdf_array = np.zeros((max_points, 7), dtype=np.float32)
for idx, (i, j) in enumerate(selected_indices):
sdf_array[idx, :3] = surf_data[i][j]
sdf_array[idx, 3:6] = normals[i][j]
else:
if not normals is None:
sdf_array = np.zeros((total_points, 4), dtype=np.float32)
sdf_array[:, :3] = np.concatenate(surf_data)
sdf_array = np.zeros((max_points, 7), dtype=np.float32)
else:
for idx, (i, j) in enumerate(selected_indices):
sdf_array[idx, :3] = surf_data[i][j]
sdf_array[idx, 3:6] = normals[i][j]
return torch.tensor(sdf_array, dtype=torch.float32, device=device)
class Trainer:
def __init__(self, config, input_step):
@ -85,18 +53,33 @@ class Trainer:
logger.error(f"fail to load {data_path}, {str(e)}")
raise e
if args.use_normal and self.data.get("surf_pnt_normals", None) is None:
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal)
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal,sample_sdf_points=not args.only_zero_surface)
else:
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal)
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal,sample_sdf_points=not args.only_zero_surface)
# 将曲面点云列表转换为 (N*M, 4) 数组
surfs = self.data["surf_ncs"]
self.sdf_data = prepare_sdf_data(
# 准备表面点的SDF数据
surface_sdf_data = prepare_sdf_data(
surfs,
normals = self.data["surf_pnt_normals"],
normals=self.data["surf_pnt_normals"],
max_points=4096,
device=self.device
)
# 如果不是仅使用零表面,则合并采样点数据
if not args.only_zero_surface:
# 加载采样点数据
sampled_sdf_data = torch.tensor(
self.data['sampled_points_normals_sdf'],
dtype=torch.float32,
device=self.device
)
# 合并表面点数据和采样点数据
self.sdf_data = torch.cat([surface_sdf_data, sampled_sdf_data], dim=0)
else:
self.sdf_data = surface_sdf_data
print_data_distribution(self.sdf_data)
# 初始化数据集
#self.brep_data = load_brep_file(self.config.data.pkl_path)
#logger.info( self.brep_data )
@ -146,71 +129,154 @@ class Trainer:
def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor:
"""
计算整个数据集的全局边界框综合考虑表面包围盒和采样点
返回一个固定的全局边界框单位立方体
参数:
surf_bbox: 形状为 (num_edges, 6) 的Tensor表示每条边的包围盒
[xmin, ymin, zmin, xmax, ymax, zmax]
surf_bbox: (此参数在此实现中未使用)
返回:
形状为 (6,) 的Tensor格式为 [x_min, y_min, z_min, x_max, y_max, z_max]
形状为 (6,) 的Tensor表示固定的边界框 [-0.5, -0.5, -0.5, 0.5, 0.5, 0.5]
"""
# 验证输入
if not isinstance(surf_bbox, torch.Tensor):
raise TypeError(f"surf_bbox 必须是 torch.Tensor,但得到 {type(surf_bbox)}")
if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6:
raise ValueError(f"surf_bbox 形状应为 (num_edges, 6),但得到 {surf_bbox.shape}")
# 计算表面包围盒的全局范围
global_min = surf_bbox[:, :3].min(dim=0).values
global_max = surf_bbox[:, 3:].max(dim=0).values
# 返回合并后的边界框
return torch.cat([global_min, global_max])
# 直接定义固定的单位立方体边界框
# 注意:确保张量在正确的设备上创建
fixed_bbox = torch.tensor([-0.5, -0.5, -0.5, 0.5, 0.5, 0.5],
dtype=torch.float32,
device=self.device) # 假设 self.device 存储了目标设备
logger.debug(f"使用固定的全局边界框: {fixed_bbox.cpu().numpy()}")
return fixed_bbox
# --- 旧的计算逻辑 (注释掉或删除) ---
# # 验证输入
# if not isinstance(surf_bbox, torch.Tensor):
# raise TypeError(f"surf_bbox 必须是 torch.Tensor,但得到 {type(surf_bbox)}")
# if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6:
# raise ValueError(f"surf_bbox 形状应为 (num_edges, 6),但得到 {surf_bbox.shape}")
#
# # 计算表面包围盒的全局范围
# global_min = surf_bbox[:, :3].min(dim=0).values
# global_max = surf_bbox[:, 3:].max(dim=0).values
#
# # 返回合并后的边界框
# return torch.cat([global_min, global_max])
# return [-0.5,] # 这个是错误的
def train_epoch(self, epoch: int) -> float:
self.model.train()
total_loss = 0.0
# 获取数据并移动到设备
points = self.sdf_data[:,0:3]
points.requires_grad_(True)
step = 0 # 如果你的训练是分批次的,这里应该用批次索引
# --- 1. 检查输入数据 ---
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列)
# 并且 SDF 值总是在最后一列
if self.sdf_data is None:
logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.")
return float('inf')
points = self.sdf_data[:, 0:3].clone().detach() # 取前3列作为点
gt_sdf = self.sdf_data[:, -1].clone().detach() # 取最后一列作为SDF真值
normals = None
if args.use_normal:
normals = self.sdf_data[:,3:6]
gt_sdf = self.sdf_data[:,6]
if self.sdf_data.shape[1] < 7: # 检查是否有足够的列给法线
logger.error(f"Epoch {epoch}: --use-normal is specified, but sdf_data has only {self.sdf_data.shape[1]} columns.")
return float('inf')
normals = self.sdf_data[:, 3:6].clone().detach() # 取中间3列作为法线
# 执行检查
if check_tensor(points, "Input Points", epoch, step): return float('inf')
if check_tensor(gt_sdf, "Input GT SDF", epoch, step): return float('inf')
if args.use_normal:
# 只有在请求法线时才检查 normals
if check_tensor(normals, "Input Normals", epoch, step): return float('inf')
else:
gt_sdf = self.sdf_data[:,3]
# 前向传播
# --- 准备模型输入,启用梯度 ---
points.requires_grad_(True) # 在检查之后启用梯度
# --- 前向传播 ---
self.optimizer.zero_grad()
pred_sdf = self.model(points)
# 计算损失
if args.use_normal:
loss,loss_details = self.loss_manager.compute_loss(
points,
normals,
gt_sdf,
pred_sdf
) # 计算损失
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
# --- 2. 检查模型输出 ---
if check_tensor(pred_sdf, "Predicted SDF (Model Output)", epoch, step): return float('inf')
# --- 计算损失 ---
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败
loss_details = {}
try:
# --- 3. 检查损失计算前的输入 ---
# (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf)
if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss")
if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss")
if args.use_normal:
# 检查法线和带梯度的点
if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss")
if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss")
loss, loss_details = self.loss_manager.compute_loss(
points,
normals, # 传递检查过的 normals
gt_sdf,
pred_sdf
)
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
# --- 4. 检查损失计算结果 ---
if check_tensor(loss, "Calculated Loss", epoch, step):
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.")
if loss_details: logger.error(f"Loss Details: {loss_details}")
return float('inf') # 如果损失无效,停止这个epoch
except Exception as loss_e:
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True)
return float('inf') # 如果计算出错,停止这个epoch
# --- 反向传播和优化 ---
try:
loss.backward()
# --- 5. (可选) 检查梯度 ---
# for name, param in self.model.named_parameters():
# if param.grad is not None:
# if check_tensor(param.grad, f"Gradient/{name}", epoch, step):
# logger.warning(f"Epoch {epoch} Step {step}: Bad gradient for {name}. Consider clipping or zeroing.")
# # 例如:param.grad.data.clamp_(-1, 1) # 值裁剪
# # 或在 optimizer.step() 前进行范数裁剪:
# # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# 反向传播和优化
loss.backward()
self.optimizer.step()
# --- (推荐) 添加梯度裁剪 ---
# 防止梯度爆炸,这可能是导致 inf/nan 的原因之一
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 范数裁剪
total_loss += loss.item()
self.optimizer.step()
# 记录训练进度
except Exception as backward_e:
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True)
# 如果你想看是哪个操作导致的,可以启用 anomaly detection
# torch.autograd.set_detect_anomaly(True) # 放在训练开始前
return float('inf') # 如果反向传播或优化出错,停止这个epoch
# --- 记录和累加损失 ---
current_loss = loss.item()
if not np.isfinite(current_loss): # 再次确认损失是有效的数值
logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).")
return float('inf')
total_loss += current_loss
# 记录训练进度 (只记录有效的损失)
logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {loss.item():.6f}')
f'Loss: {current_loss:.6f}')
if loss_details: logger.info(f"Loss Details: {loss_details}")
# (如果你的训练分批次,这里应该继续循环下一批次)
# step += 1
return total_loss
return total_loss # 对于单批次训练,直接返回当前损失
def validate(self, epoch: int) -> float:
self.model.eval()
@ -259,10 +325,10 @@ class Trainer:
# 训练完成
total_time = time.time() - start_time
self._tracing_model()
logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s')
logger.info(f'Best validation loss: {best_val_loss:.6f}')
self._tracing_model()
#self.test_load()
def test_load(self):

Loading…
Cancel
Save