diff --git a/brep2sdf/IsoSurfacing.py b/brep2sdf/IsoSurfacing.py index 5645f5c..c1e6d9e 100644 --- a/brep2sdf/IsoSurfacing.py +++ b/brep2sdf/IsoSurfacing.py @@ -111,8 +111,8 @@ def main(): parser = argparse.ArgumentParser(description='IsoSurface Generator') parser.add_argument('-i', '--input', type=str, required=True, help='Input model file (.pt)') parser.add_argument('-o', '--output', type=str, required=True, help='Output mesh file (.ply)') - parser.add_argument('--depth', type=int, default=7, help='网格深度(分辨率)') - parser.add_argument('--box_size', type=float, default=2.0, help='边界框大小') + parser.add_argument('--depth', type=int, default=3, help='网格深度(分辨率)') + parser.add_argument('--box_size', type=float, default=1.0, help='边界框大小') parser.add_argument('--method', type=str, default='MC', choices=['MC'], help='表面提取方法') parser.add_argument('--use-gpu', action='store_true', help='使用GPU') parser.add_argument('--compare', type=str, help='GT网格文件(.ply)') diff --git a/brep2sdf/batch_train.py b/brep2sdf/batch_train.py index 5fd3124..645986f 100644 --- a/brep2sdf/batch_train.py +++ b/brep2sdf/batch_train.py @@ -1,67 +1,152 @@ import os import subprocess +import argparse +import logging +from concurrent.futures import ProcessPoolExecutor, as_completed from tqdm import tqdm +# 配置日志记录 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) -def main(): - # 定义 STEP 文件目录和名称列表文件路径 - step_root_dir = "/home/wch/brep2sdf/data/step" - name_list_path = "/home/wch/brep2sdf/data/name_list.txt" +def run_training_process(input_step: str, train_script: str, common_args: list) -> tuple[str, bool, str, str]: + """ + 为单个 STEP 文件运行 train.py 子进程。 + Args: + input_step: 输入 STEP 文件的路径。 + train_script: train.py 脚本的路径。 + common_args: 传递给 train.py 的通用参数列表。 + + Returns: + Tuple: (输入文件路径, 是否成功, stdout, stderr) + """ + command = [ + "python", train_script, + *common_args, + "-i", input_step, + ] + try: + result = subprocess.run( + command, + capture_output=True, + text=True, + check=True, # 如果返回非零退出码,则抛出 CalledProcessError + timeout=600 # 添加超时设置(例如10分钟) + ) + return input_step, True, result.stdout, result.stderr + except subprocess.CalledProcessError as e: + logger.error(f"命令 '{' '.join(command)}' 执行失败。") + return input_step, False, e.stdout, e.stderr + except subprocess.TimeoutExpired as e: + logger.error(f"处理 '{input_step}' 超时。") + return input_step, False, e.stdout or "", e.stderr or "" + except Exception as e: + logger.error(f"处理 '{input_step}' 时发生意外错误: {e}") + return input_step, False, "", str(e) + +def main(args): # 读取名称列表 try: - with open(name_list_path, 'r') as f: - names = [line.strip() for line in f if line.strip()] # 去除空行 + with open(args.name_list_path, 'r') as f: + names = [line.strip() for line in f if line.strip()] + logger.info(f"从 '{args.name_list_path}' 读取了 {len(names)} 个名称。") except FileNotFoundError: - print(f"Error: File '{name_list_path}' not found.") + logger.error(f"错误: 文件 '{args.name_list_path}' 未找到。") return except Exception as e: - print(f"Error reading file '{name_list_path}': {e}") + logger.error(f"读取文件 '{args.name_list_path}' 时出错: {e}") return - # 遍历名称列表并处理每个 STEP 文件 - for name in tqdm(names, desc="Processing STEP files"): - step_dir = os.path.join(step_root_dir, name) - - # 动态生成 STEP 文件路径(假设只有一个文件) - step_files = [ - os.path.join(step_dir, f) - for f in os.listdir(step_dir) - if f.endswith(".step") and f.startswith(name) - ] - - if not step_files: - print(f"Warning: No STEP files found in directory '{step_dir}'. Skipping...") - continue - - # 假设我们只处理第一个匹配的文件 - input_step = step_files[0] - - # 构造子进程命令 - command = [ - "python", "train.py", - "--use-normal", - "-i", input_step, # 输入文件路径 - ] - - # 调用子进程运行命令 + # 准备 train.py 的通用参数 + # 注意:从命令行参数或其他配置中获取这些参数通常更好 + common_train_args = [ + "--use-normal", + "--only-zero-surface", + #"--force-reprocess", + # 可以添加更多通用参数 + ] + if args.train_args: + common_train_args.extend(args.train_args) + + tasks = [] + skipped_count = 0 + # 准备所有任务 + for name in names: + step_dir = os.path.join(args.step_root_dir, name) + if not os.path.isdir(step_dir): + logger.warning(f"目录 '{step_dir}' 不存在。跳过 '{name}'。") + skipped_count += 1 + continue + + step_files = [] try: - result = subprocess.run( - command, - capture_output=True, - text=True, - check=True # 如果返回非零退出码,则抛出 CalledProcessError - ) - print(f"Processed '{input_step}' successfully.") - print("STDOUT:", result.stdout) - print("STDERR:", result.stderr) - except subprocess.CalledProcessError as e: - print(f"Error processing '{input_step}': {e}") - print("STDOUT:", e.stdout) - print("STDERR:", e.stderr) - except Exception as e: - print(f"Unexpected error processing '{input_step}': {e}") + step_files = [ + os.path.join(step_dir, f) + for f in os.listdir(step_dir) + if f.lower().endswith((".step", ".stp")) and f.startswith(name) # 支持 .stp 并忽略大小写 + ] + except OSError as e: + logger.warning(f"无法访问目录 '{step_dir}': {e}。跳过 '{name}'。") + skipped_count += 1 + continue + + if len(step_files) == 0: + logger.warning(f"在目录 '{step_dir}' 中未找到匹配的 STEP 文件。跳过 '{name}'。") + skipped_count += 1 + elif len(step_files) > 1: + logger.warning(f"在目录 '{step_dir}' 中找到多个匹配的 STEP 文件,将使用第一个: {step_files[0]}。") + tasks.append(step_files[0]) + else: + tasks.append(step_files[0]) + if not tasks: + logger.info("没有找到需要处理的有效 STEP 文件。") + return + + logger.info(f"准备处理 {len(tasks)} 个 STEP 文件,跳过了 {skipped_count} 个名称。") + + success_count = 0 + failure_count = 0 + + # 使用 ProcessPoolExecutor 进行并行处理 + with ProcessPoolExecutor(max_workers=args.workers) as executor: + # 提交所有任务 + futures = { + executor.submit(run_training_process, task_path, args.train_script, common_train_args): task_path + for task_path in tasks + } + + # 使用 tqdm 显示进度并处理结果 + for future in tqdm(as_completed(futures), total=len(tasks), desc="运行训练"): + input_path = futures[future] + try: + input_step, success, stdout, stderr = future.result() + if success: + success_count += 1 + # 可以选择记录成功的 stdout/stderr,但通常只记录失败的更有用 + # logger.debug(f"成功处理 '{input_step}'. STDOUT:\n{stdout}\nSTDERR:\n{stderr}") + else: + failure_count += 1 + logger.error(f"处理 '{input_step}' 失败。STDOUT:\n{stdout}\nSTDERR:\n{stderr}") + except Exception as e: + failure_count += 1 + logger.error(f"处理任务 '{input_path}' 时获取结果失败: {e}") + + logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}, 跳过: {skipped_count}") if __name__ == '__main__': - main() \ No newline at end of file + parser = argparse.ArgumentParser(description="批量运行 train.py 处理 STEP 文件。") + parser.add_argument('--step-root-dir', type=str, default="/home/wch/brep2sdf/data/step", + help="包含 STEP 子目录的根目录。") + parser.add_argument('--name-list-path', type=str, default="/home/wch/brep2sdf/data/name_list.txt", + help="包含要处理的名称列表的文件路径。") + parser.add_argument('--train-script', type=str, default="train.py", + help="要执行的训练脚本路径。") + parser.add_argument('--workers', type=int, default=os.cpu_count(), + help="用于并行处理的工作进程数。") + parser.add_argument('--train-args', nargs='*', + help="传递给 train.py 的额外参数 (例如 --epochs 10 --batch-size 32)。") + + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py index aac078e..a4b9d1e 100644 --- a/brep2sdf/config/default_config.py +++ b/brep2sdf/config/default_config.py @@ -27,7 +27,7 @@ class ModelConfig: @dataclass class DataConfig: """数据相关配置""" - max_face: int = 80 + max_face: int = 400 max_edge: int = 16 num_query_points: int = 32*32*32 # 限制查询点数量,sdf 采样点数 本来是 128*128*128 ,在data load时随机采样 bbox_scaled: float = 1.0 @@ -47,8 +47,8 @@ class TrainConfig: # 基本训练参数 batch_size: int = 8 num_workers: int = 4 - num_epochs: int = 200 - learning_rate: float = 0.01 + num_epochs: int = 1 + learning_rate: float = 0.0001 min_lr: float = 1e-5 weight_decay: float = 0.01 diff --git a/brep2sdf/data/data.py b/brep2sdf/data/data.py index 117fc69..8761992 100644 --- a/brep2sdf/data/data.py +++ b/brep2sdf/data/data.py @@ -318,6 +318,149 @@ def load_sdf_file(sdf_path: str, num_query_points: int = 4096) -> torch.Tensor: logger.error(f"Error message: {str(e)}") raise +def prepare_sdf_data(surf_data, normals=None, max_points=100000, device='cuda'): + """ + 准备SDF数据,合并表面点云,可选地包含法线,并进行降采样。 + + Args: + surf_data: list[np.ndarray] 每个元素是一个形状为(M, 3)的表面点云数组。 + normals: list[np.ndarray] | None 每个元素是形状为(M, 3)的法线数组,与surf_data对应。 + max_points: int 降采样后的最大点数。 + device: str | torch.device Pytorch设备。 + + Returns: + torch.Tensor: 形状为 (N, 4) 或 (N, 7) 的张量,N <= max_points。 + 列为 [x, y, z, sdf=0] 或 [x, y, z, nx, ny, nz, sdf=0]。 + """ + total_points = sum(len(s) for s in surf_data) + has_normals = normals is not None + + # 确定输出数组的形状 + num_features = 7 if has_normals else 4 + output_size = min(total_points, max_points) + sdf_array = np.zeros((output_size, num_features), dtype=np.float32) + + if total_points > max_points: + # --- 执行降采样 --- + logger.debug(f"总点数 {total_points} 超过 {max_points},执行降采样...") + indices = [] + for i, points in enumerate(surf_data): + indices.extend([(i, j) for j in range(len(points))]) + + np.random.shuffle(indices) + selected_indices = indices[:max_points] # 选择前max_points个 + + # 根据索引填充sdf_array + for idx, (surf_idx, pnt_idx) in enumerate(selected_indices): + sdf_array[idx, :3] = surf_data[surf_idx][pnt_idx] + if has_normals: + # 检查法线数据是否存在且有效 + if surf_idx < len(normals) and pnt_idx < len(normals[surf_idx]): + sdf_array[idx, 3:6] = normals[surf_idx][pnt_idx] + else: + logger.warning(f"降采样时发现无效的法线索引: surf_idx={surf_idx}, pnt_idx={pnt_idx}") + # 可以选择填充默认值,例如 [0, 0, 1] + sdf_array[idx, 3:6] = np.array([0.0, 0.0, 1.0], dtype=np.float32) + + else: + # --- 不执行降采样,使用所有点 --- + logger.debug(f"总点数 {total_points} 未超过 {max_points},使用所有点。") + # 直接拼接所有点 + all_points = np.concatenate(surf_data, axis=0) + sdf_array[:, :3] = all_points + + if has_normals: + # 检查法线数据是否与点数据匹配 + total_normal_points = sum(len(n) for n in normals) + if total_normal_points == total_points: + all_normals = np.concatenate(normals, axis=0) + sdf_array[:, 3:6] = all_normals + else: + logger.error(f"点数 ({total_points}) 与法线点数 ({total_normal_points}) 不匹配!") + # 处理不匹配的情况,例如只填充坐标,或者抛出错误 + # 这里选择只填充坐标,并将法线部分保留为0 + sdf_array = sdf_array[:, :4] # 退化为只有坐标和SDF + sdf_array[:, -1] = 0.0 # 确保SDF为0 + # 或者可以填充默认法线 + # sdf_array[:, 3:6] = np.tile(np.array([0.0, 0.0, 1.0]), (total_points, 1)) + + # 注意:表面点的SDF值通常设为0 + sdf_array[:, -1] = 0.0 + + return torch.tensor(sdf_array, dtype=torch.float32, device=device) + + +def print_data_distribution(data: torch.Tensor) -> None: + """打印数据分布统计信息 + + Args: + data: 形状为 (N, 7) 的张量 [x, y, z, nx, ny, nz, sdf] 或 (N, 4) 的张量 [x, y, z, sdf] + """ + # 检查数据维度 + n_features = data.shape[1] + has_normals = n_features == 7 + + # 统计坐标信息 + coords = data[:, :3] + logger.debug("坐标分布统计:") + logger.debug(f" 范围: min={coords.min(dim=0)[0]}, max={coords.max(dim=0)[0]}") + logger.debug(f" 均值: mean={coords.mean(dim=0)}") + logger.debug(f" 标准差: std={coords.std(dim=0)}") + + # 如果有法向量,统计法向量信息 + if has_normals: + normals = data[:, 3:6] + normal_lengths = torch.norm(normals, dim=1) + logger.debug("\n法向量分布统计:") + logger.debug(f" 范围: min={normals.min(dim=0)[0]}, max={normals.max(dim=0)[0]}") + logger.debug(f" 均值: mean={normals.mean(dim=0)}") + logger.debug(f" 标准差: std={normals.std(dim=0)}") + logger.debug(f" 法向量长度: mean={normal_lengths.mean():.4f}, std={normal_lengths.std():.4f}") + + # 统计SDF值信息 + sdf = data[:, -1] + logger.debug("\nSDF值分布统计:") + logger.debug(f" 范围: min={sdf.min():.4f}, max={sdf.max():.4f}") + logger.debug(f" 均值: mean={sdf.mean():.4f}") + logger.debug(f" 标准差: std={sdf.std():.4f}") + logger.debug(f" 零值附近(|sdf|<1e-4)的点数量: {torch.sum(torch.abs(sdf) < 1e-4)}") + +# --- 添加一个辅助函数用于检查 --- +def check_tensor(tensor: torch.Tensor | None, name: str, epoch: int, step: int = -1) -> bool: + """检查张量是否包含 inf 或 nan""" + prefix = f"Epoch {epoch}" + (f" Step {step}" if step >= 0 else "") + if tensor is None: + # 对于可选的张量(如 normals),None 是有效的,但对于其他张量可能是问题 + # logger.warning(f"{prefix}: Tensor '{name}' is None.") + return False # 返回 False 表示没有检测到 inf/nan (但要注意 None 本身) + if not isinstance(tensor, torch.Tensor): + logger.info(f"{prefix}: '{name}' is not a Tensor, but {type(tensor)}.") + return True # 类型错误,视为问题 + + has_inf = torch.isinf(tensor).any() + has_nan = torch.isnan(tensor).any() + + if has_inf: + logger.info(f"{prefix}: !!! Infinity detected in '{name}' !!!") + # 可以选择性地打印更多信息 + # inf_indices = torch.where(torch.isinf(tensor)) + # logger.error(f"Inf indices: {inf_indices}") + # logger.error(f"Inf values sample: {tensor[inf_indices][:5]}") + if has_nan: + logger.info(f"{prefix}: !!! NaN detected in '{name}' !!!") + # nan_indices = torch.where(torch.isnan(tensor)) + # logger.error(f"NaN indices: {nan_indices}") + return has_inf or has_nan +# --- 辅助函数结束 --- + + + + + + + + + def test_dataset(): """测试数据集功能""" try: diff --git a/brep2sdf/data/pre_process_by_mesh.py b/brep2sdf/data/pre_process_by_mesh.py index fcae98b..9629e4c 100644 --- a/brep2sdf/data/pre_process_by_mesh.py +++ b/brep2sdf/data/pre_process_by_mesh.py @@ -18,6 +18,8 @@ from scipy.spatial import cKDTree from brep2sdf.utils.logger import logger import tempfile import trimesh +from trimesh.proximity import ProximityQuery + # 导入OpenCASCADE相关库 from OCC.Core.STEPControl import STEPControl_Reader # STEP文件读取器 from OCC.Core.TopExp import TopExp_Explorer, topexp # 拓扑结构遍历 @@ -167,7 +169,7 @@ def get_bbox(shape, subshape): -def parse_solid(step_path,sample_normal_vector=False): +def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False): """ 解析STEP文件中的CAD模型数据 @@ -260,7 +262,7 @@ def parse_solid(step_path,sample_normal_vector=False): while edge_explorer.More(): edge = topods.Edge(edge_explorer.Current()) edges.append(edge) - logger.debug(len(edges)) + #logger.debug(len(edges)) curve_info = BRep_Tool.Curve(edge) if curve_info is None: continue # 跳过无效边 @@ -378,19 +380,76 @@ def parse_solid(step_path,sample_normal_vector=False): } } - if sample_normal_vector: - # 从 mesh 读 法向量 - mesh.Perform() - # 导出为STL临时文件 + trimesh_mesh = None + trimesh_mesh_ncs = None + + # --- Trimesh 加载和处理 (如果需要) --- + if sample_normal_vector or sample_sdf_points: + logger.debug("加载 Trimesh 用于法线/SDF 采样...") + # 注意:这里的 mesh (BRepMesh_IncrementalMesh) 与 trimesh 不同 + # 需要从原始 shape 导出 STL stl_writer = StlAPI_Writer() stl_writer.SetASCIIMode(False) - with tempfile.NamedTemporaryFile(suffix='.stl') as tmp: - stl_writer.Write(shape, tmp.name) - trimesh_mesh = trimesh.load(tmp.name) - data['surf_pnt_normals']= batch_compute_normals(trimesh_mesh,surfs_wcs) - + tmp_stl_path = "" + try: + with tempfile.NamedTemporaryFile(suffix='.stl', delete=True) as tmp: + tmp_stl_path = tmp.name + # 检查 shape 是否有效 + if shape.IsNull(): + raise ValueError("OCC Shape is Null, cannot write STL.") + success = stl_writer.Write(shape, tmp_stl_path) + if not success: + raise RuntimeError(f"StlAPI_Writer failed to write {tmp_stl_path}") + + trimesh_mesh = trimesh.load(tmp_stl_path) + + # 创建归一化 Trimesh + vertices_wcs = trimesh_mesh.vertices.astype(np.float32) + vertices_ncs = (vertices_wcs - data['normalization_params']['center']) / data['normalization_params']['scale'] + trimesh_mesh_ncs = trimesh.Trimesh(vertices=vertices_ncs, faces=trimesh_mesh.faces, process=False) + + if not trimesh_mesh_ncs.is_watertight: + logger.debug(f"{step_path} 的归一化网格不是 watertight,尝试修复。") + trimesh.repair.fill_holes(trimesh_mesh_ncs) + if not trimesh_mesh_ncs.is_watertight: + logger.warning(f"{step_path} 的归一化网格修复后仍不是 watertight。") + + except Exception as e: + logger.error(f"为 {step_path} 加载/处理 Trimesh 失败: {e}") + trimesh_mesh = None + trimesh_mesh_ncs = None + + # --- 计算表面点法线 --- + if sample_normal_vector and trimesh_mesh_ncs is not None: + logger.debug("计算表面点法线...") + # 使用 data['surf_ncs'] 因为它们已经是归一化后的点云 + if data['surf_ncs'].shape[0] > 0: + # 确保 batch_compute_normals 使用归一化的 mesh + data['surf_pnt_normals'] = batch_compute_normals(trimesh_mesh_ncs, data['surf_ncs']) + else: + logger.warning("没有有效的归一化表面点云用于法线计算。") + data['surf_pnt_normals'] = np.array([], dtype=object) + elif sample_normal_vector: + logger.warning("请求了表面法线计算,但 Trimesh 加载失败。") + data['surf_pnt_normals'] = np.array([], dtype=object) # 添加空键 + + # --- SDF 点采样 --- + data['sampled_points_normals_sdf'] = None # 初始化键 + if sample_sdf_points: + if trimesh_mesh_ncs is not None: + # 调用封装的函数,传递固定数量参数 + data['sampled_points_normals_sdf'] = sample_sdf_points_and_normals( + trimesh_mesh_ncs=trimesh_mesh_ncs, + surf_bbox_ncs=data['surf_bbox_ncs'], + num_sdf_samples=4096, # <-- 传递固定数量 + sdf_sampling_std_dev=0.0001 + ) + else: + logger.warning("请求了 SDF 点采样,但 Trimesh 加载失败。") return data + + def load_step(step_path): """Load STEP file and return solids""" reader = STEPControl_Reader() @@ -474,6 +533,206 @@ def batch_compute_normals(mesh, surf_wcs, normal_type='vertex', k_neighbors=3): return normals_output +def sample_sdf_points_and_normals( + trimesh_mesh_ncs: trimesh.Trimesh, + surf_bbox_ncs: np.ndarray, + num_sdf_samples: int = 4096, + sdf_sampling_std_dev: float = 0.01 +) -> np.ndarray | None: + """ + 在归一化坐标系(NCS)下采样固定数量的点,并计算它们的SDF值和最近表面法线。 + 采用均匀采样和近表面采样的混合策略。 + + 参数: + trimesh_mesh_ncs: 归一化的 Trimesh 对象。 + surf_bbox_ncs: 归一化坐标系下各面的包围盒 [num_faces, 6]。 + num_sdf_samples: 要采样的总点数。 + sdf_sampling_std_dev: 近表面采样时沿法线扰动的标准差。 + + 返回: + np.ndarray | None: 形状为 (N, 7) 的数组 [x, y, z, nx, ny, nz, sdf], + 如果采样或计算失败则返回 None。 + """ + logger.debug("为 SDF 计算采样点 (固定数量策略)...") + if trimesh_mesh_ncs is None or not isinstance(trimesh_mesh_ncs, trimesh.Trimesh): + logger.error("无效的 Trimesh 对象提供给 SDF 采样。") + return None + if num_sdf_samples <= 0: + logger.warning("请求的 SDF 采样点数为零或负数,不进行采样。") + return None + + # 使用固定的采样边界 [-0.5, 0.5],因为我们已经做过归一化 + min_bound_ncs = np.array([-0.5, -0.5, -0.5], dtype=np.float32) + max_bound_ncs = np.array([0.5, 0.5, 0.5], dtype=np.float32) + bbox_size_ncs = max_bound_ncs - min_bound_ncs + + # --- 使用固定的总样本数分配点数 --- + num_uniform_samples = num_sdf_samples // 2 + num_near_surface_samples = num_sdf_samples - num_uniform_samples + logger.debug(f"固定 SDF 采样点数: {num_sdf_samples} (均匀: {num_uniform_samples}, 近表面: {num_near_surface_samples})") + + # --- 执行采样 --- + sampled_points_list = [] + + # 均匀采样 (在 [-0.5, 0.5] 范围内) + if num_uniform_samples > 0: + uniform_points = np.random.uniform(-0.5, 0.5, (num_uniform_samples, 3)) + sampled_points_list.append(uniform_points) + + # 近表面采样 + if num_near_surface_samples > 0: + if trimesh_mesh_ncs.faces.shape[0] > 0: + try: + near_points_on_surface = trimesh_mesh_ncs.sample(num_near_surface_samples) + proximity_query_near = ProximityQuery(trimesh_mesh_ncs) + closest_points_near, distances_near, face_indices_near = proximity_query_near.on_surface(near_points_on_surface) + if np.any(face_indices_near >= len(trimesh_mesh_ncs.face_normals)): + raise IndexError("Face index out of bounds during near-surface normal lookup") + normals_near = trimesh_mesh_ncs.face_normals[face_indices_near] + perturbations = np.random.randn(num_near_surface_samples, 1) * sdf_sampling_std_dev + near_points = near_points_on_surface + normals_near * perturbations + # 确保近表面点也在 [-0.5, 0.5] 范围内 + near_points = np.clip(near_points, -0.5, 0.5) + sampled_points_list.append(near_points) + except Exception as e: + logger.warning(f"近表面采样失败: {e}。将替换为均匀采样点。") + fallback_uniform = np.random.uniform(-0.5, 0.5, (num_near_surface_samples, 3)) + sampled_points_list.append(fallback_uniform) + else: + logger.warning("网格没有面,无法执行近表面采样。将替换为均匀采样点。") + fallback_uniform = np.random.uniform(-0.5, 0.5, (num_near_surface_samples, 3)) + sampled_points_list.append(fallback_uniform) + + # --- 合并采样点 --- + if not sampled_points_list: + logger.warning("没有为SDF采样到任何点。") + return None + + sampled_points_ncs = np.vstack(sampled_points_list).astype(np.float32) + + try: + proximity_query = ProximityQuery(trimesh_mesh_ncs) + + # 分批计算SDF以避免内存问题 + batch_size = 1000 + sdf_values = [] + closest_points = [] + face_indices = [] + + for i in range(0, len(sampled_points_ncs), batch_size): + batch_points = sampled_points_ncs[i:i + batch_size] + + # 计算当前批次的最近点和面 + batch_closest, batch_distances, batch_faces = proximity_query.on_surface(batch_points) + + # 计算点到最近面的向量 + direction_vectors = batch_points - batch_closest + + # 使用batch_compute_normals计算最近点的法向量 + # 将batch_closest重新组织为所需的格式 (N,) 数组,每个元素是 (M, 3) 数组 + closest_points_reshaped = np.array([batch_closest], dtype=object) + closest_points_reshaped[0] = batch_closest + + # 计算法向量 + normals_batch = batch_compute_normals( + trimesh_mesh_ncs, + closest_points_reshaped, + normal_type='vertex', # 使用顶点法向量 + k_neighbors=3 + )[0] # 取第一个元素因为我们只传入了一个批次 + + # 计算方向向量与法向量的点积 + dot_products = np.sum(direction_vectors * normals_batch, axis=1) + signs = np.sign(dot_products) + + # 确保零点处的符号处理 + zero_mask = np.abs(batch_distances) < 1e-6 + signs[zero_mask] = 0.0 + + # 计算带符号距离 + batch_sdf = batch_distances * signs + + # 限制SDF值的范围 + batch_sdf = np.clip(batch_sdf, -1.414, 1.414) # sqrt(2) + + # 添加调试信息 + if i == 0: # 只打印第一个批次的统计信息 + logger.debug(f"批次统计 (首批次):") + logger.debug(f" 法向量范围: [{normals_batch.min():.4f}, {normals_batch.max():.4f}]") + logger.debug(f" 法向量长度: {np.linalg.norm(normals_batch, axis=1).mean():.4f}") + logger.debug(f" 距离范围: [{batch_distances.min():.4f}, {batch_distances.max():.4f}]") + logger.debug(f" 点积范围: [{dot_products.min():.4f}, {dot_products.max():.4f}]") + logger.debug(f" 符号分布: 正={np.sum(signs > 0)}, 负={np.sum(signs < 0)}, 零={np.sum(signs == 0)}") + logger.debug(f" SDF范围: [{batch_sdf.min():.4f}, {batch_sdf.max():.4f}]") + + sdf_values.append(batch_sdf) + closest_points.append(batch_closest) + face_indices.append(batch_faces) + + # 合并批次结果 + sdf_values = np.concatenate(sdf_values) + closest_points = np.concatenate(closest_points) + + # 为所有点计算法向量 + all_points_reshaped = np.array([closest_points], dtype=object) + all_points_reshaped[0] = closest_points + sampled_normals = batch_compute_normals( + trimesh_mesh_ncs, + all_points_reshaped, + normal_type='vertex', + k_neighbors=3 + )[0] + + # 验证法向量 + normal_lengths = np.linalg.norm(sampled_normals, axis=1) + logger.debug(f"最终法向量统计:") + logger.debug(f" 形状: {sampled_normals.shape}") + logger.debug(f" 长度: min={normal_lengths.min():.4f}, max={normal_lengths.max():.4f}, mean={normal_lengths.mean():.4f}") + logger.debug(f" 分量范围: x=[{sampled_normals[:,0].min():.4f}, {sampled_normals[:,0].max():.4f}]") + logger.debug(f" 分量范围: y=[{sampled_normals[:,1].min():.4f}, {sampled_normals[:,1].max():.4f}]") + logger.debug(f" 分量范围: z=[{sampled_normals[:,2].min():.4f}, {sampled_normals[:,2].max():.4f}]") + + # 添加验证 + valid_mask = ( + ~np.isnan(sdf_values) & ~np.isinf(sdf_values) & + ~np.isnan(sampled_normals).any(axis=1) & ~np.isinf(sampled_normals).any(axis=1) & + ~np.isnan(sampled_points_ncs).any(axis=1) & ~np.isinf(sampled_points_ncs).any(axis=1) + ) + + if not np.all(valid_mask): + num_invalid = sampled_points_ncs.shape[0] - np.sum(valid_mask) + logger.warning(f"在 SDF/法线计算中发现 {num_invalid} 个无效条目。将它们过滤掉。") + sampled_points_ncs = sampled_points_ncs[valid_mask] + sampled_normals = sampled_normals[valid_mask] + sdf_values = sdf_values[valid_mask] + + if sampled_points_ncs.shape[0] > 0: + combined_data = np.hstack(( + sampled_points_ncs, + sampled_normals, + sdf_values[:, np.newaxis] + )).astype(np.float32) + + # 添加SDF分布验证 + final_sdf = combined_data[:, -1] + logger.debug(f"最终SDF分布验证:") + logger.debug(f" 正值点数: {np.sum(final_sdf > 0)}") + logger.debug(f" 负值点数: {np.sum(final_sdf < 0)}") + logger.debug(f" 零值点数: {np.sum(np.abs(final_sdf) < 1e-6)}") + logger.debug(f" SDF范围: [{final_sdf.min():.4f}, {final_sdf.max():.4f}]") + + # 验证分布是否合理 + if np.sum(final_sdf > 0) == 0 or np.sum(final_sdf < 0) == 0: + logger.warning("警告:SDF值分布异常,没有正值或负值!") + + return combined_data + else: + logger.warning("过滤 SDF/法线结果后没有剩余有效点。") + return None + except Exception as e: + logger.error(f"计算 SDF 或法线时失败: {str(e)}") + return None + def check_data_format(data, step_file): """检查数据格式是否正确""" required_keys = [ @@ -513,7 +772,7 @@ def check_data_format(data, step_file): return True, "" -def process_single_step(step_path:str, output_path:str=None, sample_normal_vector=False, timeout:int=300) -> dict: +def process_single_step(step_path:str, output_path:str=None, sample_normal_vector=False, sample_sdf_points=False, timeout:int=300) -> dict: """处理单个STEP文件, 从 brep 2 pkl return data = { 'surf_wcs': np.array(surfs_wcs, dtype=object), # 世界坐标系下的曲面几何数据(对象数组) @@ -537,7 +796,7 @@ def process_single_step(step_path:str, output_path:str=None, sample_normal_vecto logger.error(f"文件格式不支持,必须是.step或.stp文件: {step_path}") return None # 解析STEP文件 - data = parse_solid(step_path, sample_normal_vector) + data = parse_solid(step_path, sample_normal_vector,sample_sdf_points) if data is None: logger.error(f"Failed to parse STEP file: {step_path}") return None @@ -551,12 +810,11 @@ def process_single_step(step_path:str, output_path:str=None, sample_normal_vecto # 保存结果 if output_path: try: - logger.info(f"Saving results to: {output_path}") + logger.debug(f"Saving results to: {output_path}") os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, 'wb') as f: pickle.dump(data, f) - logger.info("数据预处理完成") - logger.info(f"Results saved successfully: {output_path}") + logger.debug(f"Results saved successfully: {output_path}") return data except Exception as e: logger.error(f'Failed to save {output_path}: {str(e)}') @@ -588,19 +846,19 @@ def test(step_file_path, output_path=None): return None # 打印统计信息 - logger.info("\nStatistics:") - logger.info(f"Number of surfaces: {len(data['surf_wcs'])}") - logger.info(f"Number of edges: {len(data['edge_wcs'])}") - logger.info(f"Number of corners: {len(data['corner_wcs'])}") # 修正为corner_wcs + logger.debug("\nStatistics:") + logger.debug(f"Number of surfaces: {len(data['surf_wcs'])}") + logger.debug(f"Number of edges: {len(data['edge_wcs'])}") + logger.debug(f"Number of corners: {len(data['corner_wcs'])}") # 修正为corner_wcs # 保存结果 if output_path: try: - logger.info(f"Saving results to: {output_path}") + logger.debug(f"Saving results to: {output_path}") os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, 'wb') as f: pickle.dump(data, f) - logger.info(f"Results saved successfully: {output_path}") + logger.debug(f"Results saved successfully: {output_path}") except Exception as e: logger.error(f"Failed to save {output_path}: {str(e)}") return None diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index 9dfb0e8..81e5337 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -81,7 +81,6 @@ class Encoder(nn.Module): def _init_parameters(self): """为所有叶子节点初始化特征参数""" - # 使用栈模拟递归遍历(避免递归) stack = [(self.octree, "root")] # (当前节点, 当前路径) param_index = 0 # 参数索引计数器 diff --git a/brep2sdf/train.py b/brep2sdf/train.py index e589753..20ef0ea 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -7,7 +7,7 @@ import numpy as np import argparse from brep2sdf.config.default_config import get_default_config -from brep2sdf.data.data import load_brep_file,load_sdf_file +from brep2sdf.data.data import load_brep_file,load_sdf_file, prepare_sdf_data, print_data_distribution, check_tensor from brep2sdf.data.pre_process_by_mesh import process_single_step from brep2sdf.networks.network import Net from brep2sdf.networks.octree import OctreeNode @@ -24,6 +24,11 @@ parser.add_argument( action='store_true', # 默认为 False,如果用户指定该参数,则为 True help='强制采样点有法向量' ) +parser.add_argument( + '--only-zero-surface', + action='store_true', # 默认为 False,如果用户指定该参数,则为 True + help='只采样零表面点 SDF 训练' +) parser.add_argument( '--force-reprocess', action='store_true', # 默认为 False,如果用户指定该参数,则为 True @@ -32,43 +37,6 @@ parser.add_argument( args = parser.parse_args() -def prepare_sdf_data(surf_data, normals=None, max_points=100000, device='cuda'): - total_points = sum(len(s) for s in surf_data) - - # 降采样逻辑(修复版) - if total_points > max_points: - # 生成索引 - indices = [] - for i, points in enumerate(surf_data): - indices.extend([(i, j) for j in range(len(points))]) - - # 随机打乱索引 - np.random.shuffle(indices) - - # 选择前max_points个索引 - selected_indices = indices[:max_points] - if not normals is None: - # 根据索引构建sdf_array - sdf_array = np.zeros((max_points, 4), dtype=np.float32) - for idx, (i, j) in enumerate(selected_indices): - sdf_array[idx, :3] = surf_data[i][j] - else: - sdf_array = np.zeros((max_points, 7), dtype=np.float32) - for idx, (i, j) in enumerate(selected_indices): - sdf_array[idx, :3] = surf_data[i][j] - sdf_array[idx, 3:6] = normals[i][j] - else: - if not normals is None: - sdf_array = np.zeros((total_points, 4), dtype=np.float32) - sdf_array[:, :3] = np.concatenate(surf_data) - sdf_array = np.zeros((max_points, 7), dtype=np.float32) - else: - for idx, (i, j) in enumerate(selected_indices): - sdf_array[idx, :3] = surf_data[i][j] - sdf_array[idx, 3:6] = normals[i][j] - - return torch.tensor(sdf_array, dtype=torch.float32, device=device) - class Trainer: def __init__(self, config, input_step): @@ -85,18 +53,33 @@ class Trainer: logger.error(f"fail to load {data_path}, {str(e)}") raise e if args.use_normal and self.data.get("surf_pnt_normals", None) is None: - self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal) + self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal,sample_sdf_points=not args.only_zero_surface) else: - self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal) + self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal,sample_sdf_points=not args.only_zero_surface) # 将曲面点云列表转换为 (N*M, 4) 数组 surfs = self.data["surf_ncs"] - self.sdf_data = prepare_sdf_data( + + # 准备表面点的SDF数据 + surface_sdf_data = prepare_sdf_data( surfs, - normals = self.data["surf_pnt_normals"], + normals=self.data["surf_pnt_normals"], max_points=4096, device=self.device ) + # 如果不是仅使用零表面,则合并采样点数据 + if not args.only_zero_surface: + # 加载采样点数据 + sampled_sdf_data = torch.tensor( + self.data['sampled_points_normals_sdf'], + dtype=torch.float32, + device=self.device + ) + # 合并表面点数据和采样点数据 + self.sdf_data = torch.cat([surface_sdf_data, sampled_sdf_data], dim=0) + else: + self.sdf_data = surface_sdf_data + print_data_distribution(self.sdf_data) # 初始化数据集 #self.brep_data = load_brep_file(self.config.data.pkl_path) #logger.info( self.brep_data ) @@ -146,71 +129,154 @@ class Trainer: def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor: """ - 计算整个数据集的全局边界框,综合考虑表面包围盒和采样点 - + 返回一个固定的全局边界框(单位立方体)。 + 参数: - surf_bbox: 形状为 (num_edges, 6) 的Tensor,表示每条边的包围盒 - [xmin, ymin, zmin, xmax, ymax, zmax] - + surf_bbox: (此参数在此实现中未使用) + 返回: - 形状为 (6,) 的Tensor,格式为 [x_min, y_min, z_min, x_max, y_max, z_max] + 形状为 (6,) 的Tensor,表示固定的边界框 [-0.5, -0.5, -0.5, 0.5, 0.5, 0.5] """ - # 验证输入 - if not isinstance(surf_bbox, torch.Tensor): - raise TypeError(f"surf_bbox 必须是 torch.Tensor,但得到 {type(surf_bbox)}") - if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6: - raise ValueError(f"surf_bbox 形状应为 (num_edges, 6),但得到 {surf_bbox.shape}") - - # 计算表面包围盒的全局范围 - global_min = surf_bbox[:, :3].min(dim=0).values - global_max = surf_bbox[:, 3:].max(dim=0).values + # 直接定义固定的单位立方体边界框 + # 注意:确保张量在正确的设备上创建 + fixed_bbox = torch.tensor([-0.5, -0.5, -0.5, 0.5, 0.5, 0.5], + dtype=torch.float32, + device=self.device) # 假设 self.device 存储了目标设备 + logger.debug(f"使用固定的全局边界框: {fixed_bbox.cpu().numpy()}") + return fixed_bbox - # 返回合并后的边界框 - return torch.cat([global_min, global_max]) + # --- 旧的计算逻辑 (注释掉或删除) --- + # # 验证输入 + # if not isinstance(surf_bbox, torch.Tensor): + # raise TypeError(f"surf_bbox 必须是 torch.Tensor,但得到 {type(surf_bbox)}") + # if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6: + # raise ValueError(f"surf_bbox 形状应为 (num_edges, 6),但得到 {surf_bbox.shape}") + # + # # 计算表面包围盒的全局范围 + # global_min = surf_bbox[:, :3].min(dim=0).values + # global_max = surf_bbox[:, 3:].max(dim=0).values + # + # # 返回合并后的边界框 + # return torch.cat([global_min, global_max]) + # return [-0.5,] # 这个是错误的 def train_epoch(self, epoch: int) -> float: self.model.train() total_loss = 0.0 - + step = 0 # 如果你的训练是分批次的,这里应该用批次索引 + + # --- 1. 检查输入数据 --- + # 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) + # 并且 SDF 值总是在最后一列 + if self.sdf_data is None: + logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.") + return float('inf') - # 获取数据并移动到设备 - points = self.sdf_data[:,0:3] - points.requires_grad_(True) + points = self.sdf_data[:, 0:3].clone().detach() # 取前3列作为点 + gt_sdf = self.sdf_data[:, -1].clone().detach() # 取最后一列作为SDF真值 + normals = None if args.use_normal: - normals = self.sdf_data[:,3:6] - gt_sdf = self.sdf_data[:,6] - - else: - gt_sdf = self.sdf_data[:,3] - - # 前向传播 + if self.sdf_data.shape[1] < 7: # 检查是否有足够的列给法线 + logger.error(f"Epoch {epoch}: --use-normal is specified, but sdf_data has only {self.sdf_data.shape[1]} columns.") + return float('inf') + normals = self.sdf_data[:, 3:6].clone().detach() # 取中间3列作为法线 + + # 执行检查 + if check_tensor(points, "Input Points", epoch, step): return float('inf') + if check_tensor(gt_sdf, "Input GT SDF", epoch, step): return float('inf') + if args.use_normal: + # 只有在请求法线时才检查 normals + if check_tensor(normals, "Input Normals", epoch, step): return float('inf') + + + # --- 准备模型输入,启用梯度 --- + points.requires_grad_(True) # 在检查之后启用梯度 + + # --- 前向传播 --- self.optimizer.zero_grad() pred_sdf = self.model(points) - - # 计算损失 - if args.use_normal: - loss,loss_details = self.loss_manager.compute_loss( - points, - normals, - gt_sdf, - pred_sdf - ) # 计算损失 - else: - loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) - - # 反向传播和优化 - loss.backward() - self.optimizer.step() - - total_loss += loss.item() - - # 记录训练进度 + + # --- 2. 检查模型输出 --- + if check_tensor(pred_sdf, "Predicted SDF (Model Output)", epoch, step): return float('inf') + + # --- 计算损失 --- + loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败 + loss_details = {} + try: + # --- 3. 检查损失计算前的输入 --- + # (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf) + if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss") + if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss") + if args.use_normal: + # 检查法线和带梯度的点 + if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") + if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") + + loss, loss_details = self.loss_manager.compute_loss( + points, + normals, # 传递检查过的 normals + gt_sdf, + pred_sdf + ) + else: + loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) + + # --- 4. 检查损失计算结果 --- + if check_tensor(loss, "Calculated Loss", epoch, step): + logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") + if loss_details: logger.error(f"Loss Details: {loss_details}") + return float('inf') # 如果损失无效,停止这个epoch + + except Exception as loss_e: + logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True) + return float('inf') # 如果计算出错,停止这个epoch + + + # --- 反向传播和优化 --- + try: + loss.backward() + + # --- 5. (可选) 检查梯度 --- + # for name, param in self.model.named_parameters(): + # if param.grad is not None: + # if check_tensor(param.grad, f"Gradient/{name}", epoch, step): + # logger.warning(f"Epoch {epoch} Step {step}: Bad gradient for {name}. Consider clipping or zeroing.") + # # 例如:param.grad.data.clamp_(-1, 1) # 值裁剪 + # # 或在 optimizer.step() 前进行范数裁剪: + # # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + + # --- (推荐) 添加梯度裁剪 --- + # 防止梯度爆炸,这可能是导致 inf/nan 的原因之一 + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 范数裁剪 + + self.optimizer.step() + + except Exception as backward_e: + logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) + # 如果你想看是哪个操作导致的,可以启用 anomaly detection + # torch.autograd.set_detect_anomaly(True) # 放在训练开始前 + return float('inf') # 如果反向传播或优化出错,停止这个epoch + + + # --- 记录和累加损失 --- + current_loss = loss.item() + if not np.isfinite(current_loss): # 再次确认损失是有效的数值 + logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).") + return float('inf') + + total_loss += current_loss + + # 记录训练进度 (只记录有效的损失) logger.info(f'Train Epoch: {epoch:4d}]\t' - f'Loss: {loss.item():.6f}') - - return total_loss + f'Loss: {current_loss:.6f}') + if loss_details: logger.info(f"Loss Details: {loss_details}") + + # (如果你的训练分批次,这里应该继续循环下一批次) + # step += 1 + + return total_loss # 对于单批次训练,直接返回当前损失 def validate(self, epoch: int) -> float: self.model.eval() @@ -259,10 +325,10 @@ class Trainer: # 训练完成 total_time = time.time() - start_time + + self._tracing_model() logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s') logger.info(f'Best validation loss: {best_val_loss:.6f}') - self._tracing_model() - #self.test_load() def test_load(self):