diff --git a/.gitignore b/.gitignore index fc7678f..a3ec676 100644 --- a/.gitignore +++ b/.gitignore @@ -169,6 +169,7 @@ cython_debug/ *.step test_data/ logs/ +nohup.out wandb/ *.pth *.pt diff --git a/brep2sdf/IsoSurfacing.py b/brep2sdf/IsoSurfacing.py index 30d3499..61fc1a7 100644 --- a/brep2sdf/IsoSurfacing.py +++ b/brep2sdf/IsoSurfacing.py @@ -38,7 +38,9 @@ def predict_sdf(model, points, device): points_t = torch.from_numpy(points).float().to(device) with torch.no_grad(): - sdf = model.forward_background(points_t).cpu().numpy().flatten() + sdf = model(points_t).cpu().numpy().flatten() + # 替换 inf 值为 2 + #sdf[np.isinf(sdf)] = 2 return sdf def extract_surface(sdf, xx, yy, zz, method='MC', feature_angle=30.0, voxel_size=0.01): @@ -129,7 +131,7 @@ def main(): parser = argparse.ArgumentParser(description='IsoSurface Generator') parser.add_argument('-i', '--input', type=str, required=True, help='Input model file (.pt)') parser.add_argument('-o', '--output', type=str, required=True, help='Output mesh file (.ply)') - parser.add_argument('--depth', type=int, default=7, help='网格深度(分辨率)') + parser.add_argument('--depth', type=int, default=5, help='网格深度(分辨率)') parser.add_argument('--box_size', type=float, default=1.0, # 从1.0改为2.0 help='边界框大小(建议设为2.0以得到[-1,1]范围)') parser.add_argument('--method', type=str, default='MC', diff --git a/brep2sdf/batch_train.py b/brep2sdf/batch_train.py index 94cb2e1..14d4a0c 100644 --- a/brep2sdf/batch_train.py +++ b/brep2sdf/batch_train.py @@ -37,18 +37,21 @@ def run_training_process(input_step: str, train_script: str, common_args: list) Returns: Tuple: (输入文件路径, 是否成功, stdout, stderr) """ + name_id = input_step.split("/")[-2] command = [ "python", train_script, *common_args, "-i", input_step, + "--resume-checkpoint-path", f"/home/wch/brep2sdf/checkpoints/{name_id}/epoch_11000.pth" ] try: + logger.info(f"即将执行的命令: {' '.join(command)}") result = subprocess.run( command, capture_output=True, text=True, check=True, # 如果返回非零退出码,则抛出 CalledProcessError - timeout=600 # 添加超时设置(例如10分钟) + timeout=14400 ) return input_step, True, result.stdout, result.stderr except subprocess.CalledProcessError as e: @@ -61,6 +64,78 @@ def run_training_process(input_step: str, train_script: str, common_args: list) logger.error(f"处理 '{input_step}' 时发生意外错误: {e}") return input_step, False, "", str(e) + +def batch_train_max_workers_1(args): + # 读取名称列表 + names = get_namelist(args.name_list_path) + + # 准备 train.py 的通用参数 + # 注意:从命令行参数或其他配置中获取这些参数通常更好 + common_train_args = [ + "--use-normal", + "--only-zero-surface", + "--octree-cuda", + + #"--force-reprocess", + # 可以添加更多通用参数 + ] + if args.train_args: + common_train_args.extend(args.train_args) + + tasks = [] + skipped_count = 0 + # 准备所有任务 + for name in names: + step_dir = os.path.join(args.step_root_dir, name) + if not os.path.isdir(step_dir): + logger.warning(f"目录 '{step_dir}' 不存在。跳过 '{name}'。") + skipped_count += 1 + continue + + step_files = [] + try: + step_files = [ + os.path.join(step_dir, f) + for f in os.listdir(step_dir) + if f.lower().endswith((".step", ".stp")) and f.startswith(name) # 支持 .stp 并忽略大小写 + ] + except OSError as e: + logger.warning(f"无法访问目录 '{step_dir}': {e}。跳过 '{name}'。") + skipped_count += 1 + continue + + if len(step_files) == 0: + logger.warning(f"在目录 '{step_dir}' 中未找到匹配的 STEP 文件。跳过 '{name}'。") + skipped_count += 1 + elif len(step_files) > 1: + logger.warning(f"在目录 '{step_dir}' 中找到多个匹配的 STEP 文件,将使用第一个: {step_files[0]}。") + tasks.append(step_files[0]) + else: + tasks.append(step_files[0]) + + if not tasks: + logger.info("没有找到需要处理的有效 STEP 文件。") + return + + logger.info(f"准备处理 {len(tasks)} 个 STEP 文件,跳过了 {skipped_count} 个名称。") + + success_count = 0 + failure_count = 0 + + # 使用 for 循环顺序执行任务 + for task_path in tqdm(tasks, desc="运行训练"): + input_step, success, stdout, stderr = run_training_process(task_path, args.train_script, common_train_args) + if success: + success_count += 1 + # 可以选择记录成功的 stdout/stderr,但通常只记录失败的更有用 + # logger.debug(f"成功处理 '{input_step}'. STDOUT:\n{stdout}\nSTDERR:\n{stderr}") + else: + failure_count += 1 + logger.error(f"处理 '{input_step}' 失败。STDOUT:\n{stdout}\nSTDERR:\n{stderr}") + + logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}, 跳过: {skipped_count}") + + def batch_train(args): # 读取名称列表 names = get_namelist(args.name_list_path) @@ -71,7 +146,7 @@ def batch_train(args): "--use-normal", "--only-zero-surface", "--octree-cuda", - "--force-reprocess", + #"--force-reprocess", # 可以添加更多通用参数 ] if args.train_args: @@ -145,7 +220,7 @@ def batch_train(args): #=========================== # Iso -def run_isosurfacing_process(input_path, output_dir, use_gpu=True): +def run_isosurfacing_process(input_path, output_dir, use_gpu=True,if_nh=False): """ 运行 IsoSurfacing.py 脚本生成等值面。 @@ -163,8 +238,8 @@ def run_isosurfacing_process(input_path, output_dir, use_gpu=True): try: # 构造输出文件名 base_name = os.path.splitext(os.path.basename(input_path))[0] - output_path = os.path.join(output_dir, f"{base_name}.ply") - + output_path = os.path.join(output_dir, f"{base_name}_nh.ply") if if_nh else os.path.join(output_dir, f"{base_name}.ply") + #print(output_path) # 构造命令 command = [ "python", "IsoSurfacing.py", @@ -249,9 +324,68 @@ def batch_Iso(args): logger.error(f"处理任务 '{input_path}' 时获取结果失败: {e}") logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}") + +def batch_nh_Iso(args): + # python IsoSurfacing.py -i /home/wch/NH-Rep/data/output_data/00000003_0_50k_model_h.pt -o /home/wch/NH-Rep/data/output_data/00000003.ply --use-gpu + """ + 批量处理 .pt 文件,生成等值面并保存为 .ply 文件。 + """ + names = get_namelist(args.name_list_path) + # 检查输入目录是否存在 + if not os.path.exists(args.pt_dir): + logger.error(f"错误: 输入目录 '{args.pt_dir}' 不存在。") + return + # 获取所有 .pt 文件 + try: + pt_files = [ + os.path.join(args.pt_dir, f"{name}_0_50k_model_h.pt") + for name in names + ] + pt_files = [f for f in pt_files if os.path.exists(f)] + except OSError as e: + logger.error(f"无法访问输入目录 '{args.pt_dir}': {e}") + return + + if not pt_files: + logger.info(f"输入目录 '{args.pt_dir}' 中未找到任何 .pt 文件。") + return + + logger.info(f"在输入目录中找到 {len(pt_files)} 个 .pt 文件。") + + success_count = 0 + failure_count = 0 + + # 使用 ProcessPoolExecutor 进行并行处理 + with ProcessPoolExecutor(max_workers=args.workers) as executor: + # 提交所有任务 + futures = { + executor.submit(run_isosurfacing_process, pt_path, args.pt_dir, True, True): pt_path + for pt_path in pt_files + } + + # 使用 tqdm 显示进度并处理结果 + for future in tqdm(as_completed(futures), total=len(pt_files), desc="生成等值面"): + input_path = futures[future] + try: + input_pt, success, stdout, stderr = future.result() + if success: + success_count += 1 + # 可以选择记录成功的 stdout/stderr,但通常只记录失败的更有用 + # logger.debug(f"成功处理 '{input_pt}'. STDOUT:\n{stdout}\nSTDERR:\n{stderr}") + else: + failure_count += 1 + logger.error(f"处理 '{input_pt}' 失败。STDOUT:\n{stdout}\nSTDERR:\n{stderr}") + except Exception as e: + failure_count += 1 + logger.error(f"处理任务 '{input_path}' 时获取结果失败: {e}") + + logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}") + def main(args): - batch_train(args) + batch_train_max_workers_1(args) + #batch_train(args) #batch_Iso(args) + #batch_nh_Iso(args) if __name__ == '__main__': @@ -264,10 +398,12 @@ if __name__ == '__main__': help="包含要处理的名称列表的文件路径。") parser.add_argument('--train-script', type=str, default="train.py", help="要执行的训练脚本路径。") - parser.add_argument('--workers', type=int, default=os.cpu_count(), + parser.add_argument('--workers', type=int, default=1, help="用于并行处理的工作进程数。") parser.add_argument('--train-args', nargs='*', help="传递给 train.py 的额外参数 (例如 --epochs 10 --batch-size 32)。") args = parser.parse_args() - main(args) \ No newline at end of file + main(args) + +# python batch_train.py --pt-dir /home/wch/NH-Rep/data_backup/output_data/extracted/output_data \ No newline at end of file diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py index e7cfb4e..e0fa122 100644 --- a/brep2sdf/config/default_config.py +++ b/brep2sdf/config/default_config.py @@ -49,16 +49,18 @@ class TrainConfig: # 基本训练参数 batch_size: int = 8 num_workers: int = 4 - num_epochs: int = 50 - learning_rate: float = 0.005 + num_epochs1: int = 10000 + num_epochs2: int = 1000 + num_epochs3: int = 1000 + learning_rate: float = 0.1 learning_rate_schedule: List = field(default_factory=lambda: [{ "Type": "Step", # 学习率调度类型。"Step"表示在指定迭代次数后将学习率乘以因子 - "Initial": 0.005, + "Initial": 0.01, "Interval": 2000, - "Factor": 0.5 + "Factor": 0.3 }]) min_lr: float = 1e-5 - weight_decay: float = 0.01 + weight_decay: float = 0.0001 # 梯度和损失相关 max_grad_norm: float = 1.0 @@ -71,7 +73,7 @@ class TrainConfig: warmup_epochs: int = 5 # 保存和验证 - save_freq: int = 10 # 每多少个epoch保存一次 + save_freq: int = 1000 # 每多少个epoch保存一次 val_freq: int = 1 # 每多少个epoch验证一次 # 保存路径 diff --git a/brep2sdf/data/utils.py b/brep2sdf/data/utils.py index d609f98..7a85f08 100644 --- a/brep2sdf/data/utils.py +++ b/brep2sdf/data/utils.py @@ -246,4 +246,33 @@ def batch_compute_normals(mesh, surf_wcs, normal_type='vertex', k_neighbors=3): normals_output[i] = nearest_normals[start:end] start = end - return normals_output \ No newline at end of file + return normals_output + + +def points_in_box(points: torch.Tensor, bbox: torch.Tensor) -> torch.Tensor: + """ + 返回落在AABB包围盒内的点(保留所有7个通道) + + 参数: + points: 形状为 (N, 7) 的张量,其中前3维是坐标(x,y,z),其余为属性(如法线、颜色等) + bbox: 形状为 (6,) 的张量,表示AABB包围盒的坐标 [x_min, y_min, z_min, x_max, y_max, z_max] + + 返回: + torch.Tensor: 形状为 (K, 7),其中 K 是落在包围盒内的点的数量 + """ + assert points.shape[1] == 7, f"points 必须有7个通道,但得到 {points.shape[1]}" + assert bbox.shape == (6,), f"bbox 必须是长度为6的一维张量,但得到 {bbox.shape}" + + min_coords = bbox[:3] + max_coords = bbox[3:] + + # 检查每个点的 xyz 是否在包围盒内 (N, 3) + within_box = (points[:, :3] >= min_coords) & (points[:, :3] <= max_coords) + + # 所有轴都满足条件的点 (N,) + inside_mask = within_box.all(dim=1) + + # 提取符合条件的完整点(包括所有7个维度) + points_inside = points[inside_mask] + + return points_inside.detach().clone() diff --git a/brep2sdf/eval_pos.py b/brep2sdf/eval_pos.py index 9c4bf0a..53330d8 100644 --- a/brep2sdf/eval_pos.py +++ b/brep2sdf/eval_pos.py @@ -1,5 +1,6 @@ import trimesh import numpy as np +from brep2sdf.data.data import prepare_sdf_data,load_brep_file from brep2sdf.data.sampler import sample_zero_surface_points_and_normals from brep2sdf.utils.load import get_namelist, get_step_paths from brep2sdf.networks.network import gradient @@ -7,6 +8,21 @@ import torch import os from brep2sdf.utils.logger import logger +# 全局变量用于保存采样点 +GLOBAL_SAMPLED_POINTS = None + +def sample_grid_points(mesh, nh_mesh, our_mesh, num_samples_iou): + global GLOBAL_SAMPLED_POINTS + if GLOBAL_SAMPLED_POINTS is not None: + return GLOBAL_SAMPLED_POINTS + # 从一个较大的空间范围采样点 + bounds = np.vstack([mesh.bounds, nh_mesh.bounds, our_mesh.bounds]) + min_bound = np.min(bounds, axis=0) + max_bound = np.max(bounds, axis=0) + points = np.random.uniform(min_bound, max_bound, (num_samples_iou, 3)) + GLOBAL_SAMPLED_POINTS = points + return points + def position_loss(pred_sdfs: torch.Tensor) -> torch.Tensor: """位置损失函数""" # 保持梯度流 @@ -25,8 +41,54 @@ def average_normal_error(normals1: torch.Tensor, normals2: torch.Tensor) -> torc angle_errors = 1 - absolute_dot_products return torch.mean(angle_errors) +def chamfer_distance(points_A: torch.Tensor, points_B: torch.Tensor) -> torch.Tensor: + """ + 计算两个点集之间的 Chamfer Distance (CD) + :param points_A: 形状为 (N, 3) 的点集 A + :param points_B: 形状为 (M, 3) 的点集 B + :return: CD 值 + """ + N = points_A.shape[0] + M = points_B.shape[0] + points_A_expanded = points_A.unsqueeze(1).expand(N, M, 3) + points_B_expanded = points_B.unsqueeze(0).expand(N, M, 3) + distances = torch.sum((points_A_expanded - points_B_expanded) ** 2, dim=-1) # (N, M) + dist_A_to_B = torch.min(distances, dim=1)[0] # (N,) + dist_B_to_A = torch.min(distances, dim=0)[0] # (M,) + return (torch.mean(dist_A_to_B) + torch.mean(dist_B_to_A)) / 2 + +def hausdorff_distance(points_A: torch.Tensor, points_B: torch.Tensor) -> torch.Tensor: + """ + 计算两个点集之间的 Two-Side Hausdorff Distance (HD) + :param points_A: 形状为 (N, 3) 的点集 A + :param points_B: 形状为 (M, 3) 的点集 B + :return: HD 值 + """ + N = points_A.shape[0] + M = points_B.shape[0] + points_A_expanded = points_A.unsqueeze(1).expand(N, M, 3) + points_B_expanded = points_B.unsqueeze(0).expand(N, M, 3) + distances = torch.sum((points_A_expanded - points_B_expanded) ** 2, dim=-1) # (N, M) + dist_A_to_B = torch.min(distances, dim=1)[0] # (N,) + dist_B_to_A = torch.min(distances, dim=0)[0] # (M,) + return torch.max(torch.max(dist_A_to_B), torch.max(dist_B_to_A)) +def compute_iou(sdf1, sdf2, threshold=0.0): + """ + 计算两个 SDF 之间的 IoU + :param sdf1: 第一个 SDF 数组 + :param sdf2: 第二个 SDF 数组 + :param threshold: 阈值,用于判断点是否在表面内 + :return: IoU 值 + """ + inside1 = sdf1 <= threshold + inside2 = sdf2 <= threshold + intersection = np.logical_and(inside1, inside2).sum() + union = np.logical_or(inside1, inside2).sum() + iou = intersection / union if union > 0 else 0.0 + return iou +# load def load_model(model_path): """加载模型的通用函数""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -38,6 +100,9 @@ def load_model(model_path): logger.error(f"加载模型 {model_path} 时出错: {e}") return None + + +# model def nh(model_path, points): model = load_model(model_path) if model is None: @@ -53,16 +118,22 @@ def mine(model_path, points): if model is None: return None try: - return model.forward_background(points) + return model(points) except Exception as e: logger.error(f"调用 mine 模型时出错: {e}") return None + + def run(name): # 替换为实际的 obj 文件路径 obj_file_path = f"/home/wch/brep2sdf/data/gt_mesh/{name}.obj" model_path = f"/home/wch/brep2sdf/data/output_data/{name}.pt" nh_model = f"/home/wch/NH-Rep/data_backup/output_data/extracted/output_data/{name}_0_50k_model_h.pt" + ply_nh = f"/home/wch/NH-Rep/data_backup/output_data/extracted/output_data/{name}_0_50k_model_h_nh.ply" + ply_our = f"/home/wch/brep2sdf/data/output_data/{name}.ply" + npz_path = f"/home/wch/brep2sdf/data/output_data/{name}.xyz" + num_samples=4096 # 检查文件是否存在 if not os.path.isfile(obj_file_path): @@ -79,7 +150,7 @@ def run(name): try: # 调用采样函数 - result1 = sample_zero_surface_points_and_normals(mesh, num_samples=4096) + result1 = sample_zero_surface_points_and_normals(mesh, num_samples) if result1 is None: logger.error("采样失败,返回 None") return @@ -98,29 +169,115 @@ def run(name): loss2["de"] = position_loss(sdf2).item() logger.info(f"NH 模型位置损失: {loss1}") logger.info(f"Mine 模型位置损失: {loss2}") - + # 将 gt_normal 转换为 torch.Tensor 并移动到设备上 gt_normal = torch.from_numpy(result1[:, 3:6]).float().to(device) # 假设 gradient 函数已正确导入 normal1 = gradient(coordinates_tensor, sdf1) normal2 = gradient(coordinates_tensor, sdf2) - + loss1["nae"] = average_normal_error(gt_normal, normal1).item() loss2["nae"] = average_normal_error(gt_normal, normal2).item() - + print("NH 模型的平均法向量误差 (NAE):", loss1["nae"]) print("Mine 模型的平均法向量误差 (NAE):", loss2["nae"]) - - return loss1, loss2 + + else: logger.error("无法计算损失,SDF 结果为 None") + # 读取 ply 文件 + try: + nh_mesh = trimesh.load_mesh(ply_nh) + our_mesh = trimesh.load_mesh(ply_our) + logger.info(f"成功读取 PLY 文件: {ply_nh} 和 {ply_our}") + except Exception as e: + logger.error(f"读取 PLY 文件时出错: {e}") + return + + # 从网格中采样点 + nh_points = torch.from_numpy(nh_mesh.sample(num_samples)).float().to(device) + our_points = torch.from_numpy(our_mesh.sample(num_samples)).float().to(device) + + # 确保 coordinates 是 torch.Tensor 类型 + loss1["cd"] = chamfer_distance(coordinates_tensor, nh_points).item() + loss2["cd"] = chamfer_distance(coordinates_tensor, our_points).item() + loss1["hd"] = hausdorff_distance(coordinates_tensor, nh_points).item() + loss2["hd"] = hausdorff_distance(coordinates_tensor, our_points).item() + + # fea + data = load_brep_file(npz_path) + sampled_pnts=prepare_sdf_data(data["surf_ncs"],normals=data["surf_pnt_normals"],max_points=num_samples) + + # 展平处理 + flattened_pnts = sampled_pnts.flatten() + + + + # 修改此处,使用 clone().detach() + if isinstance(flattened_pnts[0:3], torch.Tensor): + f_pnts = flattened_pnts[0:3].clone().detach().to(device).view(-1, 3) + else: + f_pnts = torch.from_numpy(flattened_pnts[0:3]).clone().detach().to(device).view(-1, 3) + + if isinstance(flattened_pnts[3:6], torch.Tensor): + f_normals = flattened_pnts[3:6].clone().detach().to(device).view(-1, 3) + else: + f_normals = torch.from_numpy(flattened_pnts[3:6]).clone().detach().to(device).view(-1, 3) + + # 检查 f_pnts 和 f_normals 的形状 + if f_pnts.shape[-1] != 3 or f_normals.shape[-1] != 3: + logger.error(f"f_pnts 形状: {f_pnts.shape}, f_normals 形状: {f_normals.shape},期望最后一维尺寸为 3") + return + loss1["fcd"] = chamfer_distance(f_pnts, nh_points).item() + loss2["fcd"] = chamfer_distance(f_pnts, our_points).item() + loss1["fae"] = hausdorff_distance(f_normals, nh_points).item() + loss2["fae"] = hausdorff_distance(f_normals, our_points).item() + + # 计算 IoU,从obj文件计算 + # ... existing code ... + + # 计算 IoU,使用采样点方法 + try: + num_samples_iou = 10000 # 采样点数量,可以根据需要调整 + # 调用封装的采样函数 + points = sample_grid_points(mesh, nh_mesh, our_mesh, num_samples_iou) + + # 判断点是否在各个网格内部 + inside_mesh = mesh.contains(points) + inside_nh = nh_mesh.contains(points) + inside_our = our_mesh.contains(points) + + # 计算 nh_mesh 与 mesh 的交集和并集 + intersection_nh = np.logical_and(inside_mesh, inside_nh).sum() + union_nh = np.logical_or(inside_mesh, inside_nh).sum() + + # 计算 our_mesh 与 mesh 的交集和并集 + intersection_our = np.logical_and(inside_mesh, inside_our).sum() + union_our = np.logical_or(inside_mesh, inside_our).sum() + + # 计算 IoU + iou_nh = intersection_nh / union_nh if union_nh > 0 else 0.0 + iou_our = intersection_our / union_our if union_our > 0 else 0.0 + + loss1["iou"] = iou_nh + loss2["iou"] = iou_our + except Exception as e: + print(f"使用采样点计算 IoU 时出错: {e}") + + + return loss1, loss2 + except Exception as e: logger.error(f"处理过程中出现错误: {e}") def main(): names = get_namelist("/home/wch/brep2sdf/data/name_list.txt") tl1_de, tl1_nae, tl2_de, tl2_nae = 0.0, 0.0, 0.0, 0.0 + tl1_cd, tl1_hd, tl2_cd, tl2_hd = 0.0, 0.0, 0.0, 0.0 + tl1_fcd, tl1_fae, tl2_fcd, tl2_fae = 0.0, 0.0, 0.0, 0.0 + # 新增累加 IoU 的变量 + tl1_iou, tl2_iou = 0.0, 0.0 valid_count = 0 for name in names: result = run(name) @@ -130,14 +287,61 @@ def main(): tl1_nae += l1["nae"] tl2_de += l2["de"] tl2_nae += l2["nae"] + tl1_cd += l1["cd"] + tl1_hd += l1["hd"] + tl2_cd += l2["cd"] + tl2_hd += l2["hd"] + tl1_fcd += l1["fcd"] + tl1_fae += l1["fae"] + tl2_fcd += l2["fcd"] + tl2_fae += l2["fae"] + # 累加 IoU 的值 + tl1_iou += l1["iou"] + tl2_iou += l2["iou"] valid_count += 1 if valid_count > 0: - print(f"NH 模型平均位置损失 (de): {tl1_de/valid_count}") - print(f"NH 模型平均法向量误差 (nae): {tl1_nae/valid_count}") - print(f"Mine 模型平均位置损失 (de): {tl2_de/valid_count}") - print(f"Mine 模型平均法向量误差 (nae): {tl2_nae/valid_count}") + avg_l1_de = tl1_de / valid_count + avg_l1_nae = tl1_nae / valid_count + avg_l2_de = tl2_de / valid_count + avg_l2_nae = tl2_nae / valid_count + avg_l1_cd = tl1_cd / valid_count + avg_l1_hd = tl1_hd / valid_count + avg_l2_cd = tl2_cd / valid_count + avg_l2_hd = tl2_hd / valid_count + avg_l1_fcd = tl1_fcd / valid_count + avg_l1_fae = tl1_fae / valid_count + avg_l2_fcd = tl2_fcd / valid_count + avg_l2_fae = tl2_fae / valid_count + # 计算 IoU 的平均值 + avg_l1_iou = tl1_iou / valid_count + avg_l2_iou = tl2_iou / valid_count + # 打印表格表头 + print("| 模型 | Chamfer Distance (CD) | Hausdorff Distance (HD) | 平均法向量误差 (NAE) | Feature Chamfer Distance (FCD) | Feature Angle Error (FAE) | 位置损失 (DE) | IoU |") + print("|------|-----------------------|-------------------------|----------------------|--------------------------------|---------------------------|--------------|-----|") + # 打印 NH 模型数据 + print(f"| NH 模型 | {avg_l1_cd} | {avg_l1_hd} | {avg_l1_nae} | {avg_l1_fcd} | {avg_l1_fae} | {avg_l1_de} | {avg_l1_iou} |") + # 打印 Mine 模型数据 + print(f"| Mine 模型 | {avg_l2_cd} | {avg_l2_hd} | {avg_l2_nae} | {avg_l2_fcd} | {avg_l2_fae} | {avg_l2_de} | {avg_l2_iou} |") else: print("没有有效的结果,无法计算平均值。") +def test(name_id): + result = run(name_id) # 修正参数使用错误 + if result is not None: + l1, l2 = result + # 打印表格表头 + print("| 模型 | Chamfer Distance (CD) | Hausdorff Distance (HD) | 平均法向量误差 (NAE) | Feature Chamfer Distance (FCD) | Feature Angle Error (FAE) | 位置损失 (DE) | IoU |") + print("|------|-----------------------|-------------------------|----------------------|--------------------------------|---------------------------|--------------|-----|") + # 假设 IoU 数据存在于 l1 和 l2 中,如果不存在可以先忽略或者设置默认值 + # 打印 NH 模型数据 + print(f"| NH 模型 | {l1['cd']} | {l1['hd']} | {l1['nae']} | {l1['fcd']} | {l1['fae']} | {l1['de']} | {l1['iou']} |") + # 打印 Mine 模型数据 + print(f"| Mine 模型 | {l2['cd']} | {l2['hd']} | {l2['nae']} | {l2['fcd']} | {l2['fae']} | {l2['de']} | {l2['iou']} |") + else: + print("没有有效的结果。") + + + if __name__ == "__main__": - main() \ No newline at end of file + #main() + test("00000031") \ No newline at end of file diff --git a/brep2sdf/scripts/npz2points.py b/brep2sdf/scripts/npz2points.py index cd2d260..ef27e58 100644 --- a/brep2sdf/scripts/npz2points.py +++ b/brep2sdf/scripts/npz2points.py @@ -11,10 +11,10 @@ def load_brep_file(brep_path): if __name__ == "__main__": - data=load_brep_file("/home/wch/brep2sdf/data/output_data/00000031.xyz") + data=load_brep_file("/home/wch/brep2sdf/data/output_data/00000003.xyz") surfs =data["train_surf_ncs"] print(surfs) - with open("0031_t.xyz","w") as f: + with open("0003_t.xyz","w") as f: for point in surfs: #f.write(f"{point[0]} {point[1]} {point[2]}\n") f.write(f"{point[0]} {point[1]} {point[2]} {point[3]} {point[4]} {point[5]}\n") diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 9a9c267..14ac897 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -610,7 +610,7 @@ class Trainer: self.model.train() total_loss = 0.0 step = 0 # 如果你的训练是分批次的,这里应该用批次索引 - batch_size = 8192 * 2 # 设置合适的batch大小 + batch_size = 4096 # 设置合适的batch大小 # 数据处理 # manfld diff --git a/data/name_list copy.txt b/data/name_list copy.txt new file mode 100644 index 0000000..2126251 --- /dev/null +++ b/data/name_list copy.txt @@ -0,0 +1,24 @@ +00000003 +00000008 +00000009 +00000029 +00000031 +00000032 +00000047 +00000049 +00000057 +00000058 +00000060 +00000061 +00000065 +00000066 +00000067 +00000068 +00000070 +00000072 +00000076 +00000077 +00000078 +00000079 +00000088 +00000093 \ No newline at end of file