From 794b2bb220bf58127f87c5984cd160ef1ce59b47 Mon Sep 17 00:00:00 2001 From: mckay Date: Sat, 3 May 2025 20:27:59 +0800 Subject: [PATCH] =?UTF-8?q?=E8=84=9A=E6=9C=AC=E6=96=B0=E5=A2=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/IsoSurfacing.py | 7 +- brep2sdf/IsoSurfacing_sp.py | 204 ++++++++++++++++++++++++ brep2sdf/eval_pos.py | 143 +++++++++++++++++ brep2sdf/scripts/abc_vis.py | 124 ++++++++++++++ brep2sdf/scripts/clean_csv_filenames.py | 0 brep2sdf/scripts/npz2points.py | 25 +++ brep2sdf/scripts/paint.py | 44 +++++ brep2sdf/sdf_vis.py | 69 ++++++++ brep2sdf/test.py | 175 ++++++++++++-------- 9 files changed, 723 insertions(+), 68 deletions(-) create mode 100644 brep2sdf/IsoSurfacing_sp.py create mode 100644 brep2sdf/eval_pos.py create mode 100644 brep2sdf/scripts/abc_vis.py create mode 100644 brep2sdf/scripts/clean_csv_filenames.py create mode 100644 brep2sdf/scripts/npz2points.py create mode 100644 brep2sdf/scripts/paint.py create mode 100644 brep2sdf/sdf_vis.py diff --git a/brep2sdf/IsoSurfacing.py b/brep2sdf/IsoSurfacing.py index c9a7650..30d3499 100644 --- a/brep2sdf/IsoSurfacing.py +++ b/brep2sdf/IsoSurfacing.py @@ -4,6 +4,7 @@ import argparse from skimage import measure import time import trimesh +from brep2sdf.utils.logger import logger def create_grid(depth, box_size): """ @@ -37,7 +38,7 @@ def predict_sdf(model, points, device): points_t = torch.from_numpy(points).float().to(device) with torch.no_grad(): - sdf = model(points_t).cpu().numpy().flatten() + sdf = model.forward_background(points_t).cpu().numpy().flatten() return sdf def extract_surface(sdf, xx, yy, zz, method='MC', feature_angle=30.0, voxel_size=0.01): @@ -129,7 +130,7 @@ def main(): parser.add_argument('-i', '--input', type=str, required=True, help='Input model file (.pt)') parser.add_argument('-o', '--output', type=str, required=True, help='Output mesh file (.ply)') parser.add_argument('--depth', type=int, default=7, help='网格深度(分辨率)') - parser.add_argument('--box_size', type=float, default=2.0, # 从1.0改为2.0 + parser.add_argument('--box_size', type=float, default=1.0, # 从1.0改为2.0 help='边界框大小(建议设为2.0以得到[-1,1]范围)') parser.add_argument('--method', type=str, default='MC', choices=['MC', 'EMC', 'DC'], # 新增算法选项 @@ -153,10 +154,12 @@ def main(): # 创建网格并预测SDF points, xx, yy, zz = create_grid(args.depth, args.box_size) + print(points.shape) sdf = predict_sdf(model, points, device) print(points.shape) print(sdf.shape) print(sdf) + logger.print_tensor_stats("sdf",torch.tensor(sdf)) sdf_grid = sdf.reshape(xx.shape) # 提取表面 diff --git a/brep2sdf/IsoSurfacing_sp.py b/brep2sdf/IsoSurfacing_sp.py new file mode 100644 index 0000000..a394a3a --- /dev/null +++ b/brep2sdf/IsoSurfacing_sp.py @@ -0,0 +1,204 @@ +import numpy as np +import torch +import argparse +from skimage import measure +import time +import trimesh +from brep2sdf.utils.logger import logger +from brep2sdf.networks.octree import OctreeNode + +def create_grid_with_octree(octree, model, device): + """ + 使用八叉树创建三维网格点 + :param octree: 八叉树对象 + :param model: PyTorch模型 + :param device: 设备(CPU/GPU) + :return: 网格点数组和SDF值数组 + """ + leaf_indices = (octree.is_leaf_mask & octree.is_valid_leaf_mask).nonzero().flatten() + print(leaf_indices.shape) + points = [] + for idx in leaf_indices: + bbox = octree.node_bboxes[idx] + min_coords = bbox[:3].cpu().numpy() + max_coords = bbox[3:].cpu().numpy() + # 在叶子节点的边界框内采样 + num_samples = 1 # 可根据需要调整采样点数 + x = np.linspace(min_coords[0], max_coords[0], num_samples) + y = np.linspace(min_coords[1], max_coords[1], num_samples) + z = np.linspace(min_coords[2], max_coords[2], num_samples) + xx, yy, zz = np.meshgrid(x, y, z, indexing='ij') + node_points = np.stack([xx.ravel(), yy.ravel(), zz.ravel()], axis=1) + points.append(node_points) + + points = np.vstack(points) + sdf = predict_sdf(model, points, device) + return points, sdf + +def predict_sdf(model, points, device): + """ + 使用模型预测SDF值 + :param model: PyTorch模型 + :param points: 输入点坐标 (N, 3) + :param device: 设备(CPU/GPU) + :return: SDF值数组 (N,) + """ + points_t = torch.from_numpy(points).float().to(device) + + with torch.no_grad(): + sdf = model.forward_background(points_t).cpu().numpy().flatten() + return sdf + +def extract_surface(sdf, xx, yy, zz, method='MC', feature_angle=30.0, voxel_size=0.01): + """ + 提取零表面 + :param sdf: SDF值三维数组 + :param xx/yy/zz: 网格坐标 + :param method: 提取方法(MC: Marching Cubes) + :return: 顶点和面片 + """ + if method == 'MC': + verts, faces, _, _ = measure.marching_cubes(sdf, level=0) + elif method == 'EMC': + from iso_algorithms import enhanced_marching_cubes + verts, faces = enhanced_marching_cubes( + sdf, + feature_angle=feature_angle, + gradient_direction='descent' + ) + elif method == 'DC': + from iso_algorithms import dual_contouring + verts, faces = dual_contouring(sdf, voxel_size=voxel_size) + else: + raise ValueError(f"不支持的算法: {method}") + + # 新增顶点后处理 + verts = (verts - sdf.shape[0]//2) / (sdf.shape[0]//2) # 归一化到[-1,1] + return verts, faces + +def save_ply(vertices, faces, filename): + """ + 保存顶点和面片为PLY文件 + :param vertices: 顶点数组 (N, 3) + :param faces: 面片数组 (M, 3) + :param filename: 输出文件名 + """ + with open(filename, 'w') as f: + f.write("ply\n") + f.write("format ascii 1.0\n") + f.write(f"element vertex {len(vertices)}\n") + f.write("property float x\n") + f.write("property float y\n") + f.write("property float z\n") + f.write(f"element face {len(faces)}\n") + f.write("property list uchar int vertex_indices\n") + f.write("end_header\n") + for v in vertices: + f.write(f"{v[0]} {v[1]} {v[2]}\n") + for face in faces: + f.write(f"3 {face[0]} {face[1]} {face[2]}\n") + +def compute_sdf_error(model, gt_mesh, res, device): + """ + 计算预测SDF与GT网格的误差 + :param model: PyTorch模型 + :param gt_mesh: GT网格(Trimesh格式) + :param res: 误差计算分辨率 + :param device: 设备 + :return: 平均误差和最大误差 + """ + # 生成均匀采样点 + box_size = max(gt_mesh.extents) + start = -box_size / 2 + end = box_size / 2 + x = np.linspace(start, end, res) + y = np.linspace(start, end, res) + z = np.linspace(start, end, res) + points = np.array(np.meshgrid(x, y, z)).T.reshape(-1, 3) + + # 预测SDF + pred_sdf = predict_sdf(model, points, device) + + # 计算GT距离 + distances = gt_mesh.nearest.on_surface(points)[1] + gt_sdf = np.abs(distances) + + # 计算误差 + abs_error = np.abs(pred_sdf - gt_sdf) + rel_error = abs_error / (np.abs(gt_sdf) + 1e-9) + avg_abs = np.mean(abs_error) + avg_rel = np.mean(rel_error) + max_abs = np.max(abs_error) + max_rel = np.max(rel_error) + + return avg_abs, avg_rel, max_abs, max_rel + +def main(): + parser = argparse.ArgumentParser(description='IsoSurface Generator') + parser.add_argument('-i', '--input', type=str, required=True, help='Input model file (.pt)') + parser.add_argument('-o', '--output', type=str, required=True, help='Output mesh file (.ply)') + parser.add_argument('--depth', type=int, default=7, help='网格深度(分辨率)') + parser.add_argument('--box_size', type=float, default=1.0, # 从1.0改为2.0 + help='边界框大小(建议设为2.0以得到[-1,1]范围)') + parser.add_argument('--method', type=str, default='MC', + choices=['MC', 'EMC', 'DC'], # 新增算法选项 + help='表面提取方法: MC-MarchingCubes, EMC-EnhancedMC, DC-DualContouring') + parser.add_argument('--feature_angle', type=float, default=30.0, + help='特征角度阈值(EMC算法专用)') + parser.add_argument('--voxel_size', type=float, default=0.01, + help='体素尺寸(DC算法专用)') + parser.add_argument('--use-gpu', action='store_true', help='使用GPU') + parser.add_argument('--compare', type=str, help='GT网格文件(.ply)') + parser.add_argument('--compres', type=int, default=32, help='误差计算分辨率') + args = parser.parse_args() + + # 设置设备 + device = torch.device("cuda" if args.use_gpu and torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + model = torch.jit.load(args.input).to(device) + #model = torch.load(args.input).to(device) + model.eval() + + octree = model.octree_module + + # 使用八叉树创建网格并预测SDF + points, sdf = create_grid_with_octree(octree, model, device) + print(1) + # 这里需要根据实际情况将points转换为网格坐标xx, yy, zz + # 简单示例:假设points是均匀采样的 + grid_size = int(np.ceil(len(points) ** (1/3))) + xx = points[:, 0].reshape(grid_size, grid_size, grid_size) + yy = points[:, 1].reshape(grid_size, grid_size, grid_size) + zz = points[:, 2].reshape(grid_size, grid_size, grid_size) + sdf_grid = sdf.reshape(xx.shape) + + # 提取表面 + print("Extracting surface...") + start_time = time.time() + verts, faces = extract_surface(sdf_grid, xx, yy, zz, args.method) + + # 新增顶点归一化校验 + max_val = np.max(np.abs(verts)) + if max_val > 1.0 + 1e-6: # 允许微小误差 + verts = verts / max_val + print(f"Surface extraction took {time.time() - start_time:.2f} seconds") + + # 保存网格 + save_ply(verts, faces, args.output) + print(f"Mesh saved to {args.output}") + + # 误差评估(可选) + if args.compare: + print("Computing SDF error...") + gt_mesh = trimesh.load(args.compare) + avg_abs, avg_rel, max_abs, max_rel = compute_sdf_error( + model, gt_mesh, args.compres, device + ) + print(f"Average Absolute Error: {avg_abs:.4f}") + print(f"Average Relative Error: {avg_rel:.4f}") + print(f"Max Absolute Error: {max_abs:.4f}") + print(f"Max Relative Error: {max_rel:.4f}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/brep2sdf/eval_pos.py b/brep2sdf/eval_pos.py new file mode 100644 index 0000000..9c4bf0a --- /dev/null +++ b/brep2sdf/eval_pos.py @@ -0,0 +1,143 @@ +import trimesh +import numpy as np +from brep2sdf.data.sampler import sample_zero_surface_points_and_normals +from brep2sdf.utils.load import get_namelist, get_step_paths +from brep2sdf.networks.network import gradient +import torch +import os +from brep2sdf.utils.logger import logger + +def position_loss(pred_sdfs: torch.Tensor) -> torch.Tensor: + """位置损失函数""" + # 保持梯度流 + squared_diff = torch.pow(pred_sdfs, 2) + return torch.mean(squared_diff) + +def average_normal_error(normals1: torch.Tensor, normals2: torch.Tensor) -> torch.Tensor: + """ + 计算平均法向量误差 (NAE) + :param normals1: 形状为 (B, 3) 的法向量张量 + :param normals2: 形状为 (B, 3) 的法向量张量 + :return: NAE 值 + """ + dot_products = torch.sum(normals1 * normals2, dim=-1) + absolute_dot_products = torch.abs(dot_products) + angle_errors = 1 - absolute_dot_products + return torch.mean(angle_errors) + + + +def load_model(model_path): + """加载模型的通用函数""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + try: + model = torch.jit.load(model_path).to(device) + logger.info(f"成功加载模型: {model_path}") + return model + except Exception as e: + logger.error(f"加载模型 {model_path} 时出错: {e}") + return None + +def nh(model_path, points): + model = load_model(model_path) + if model is None: + return None + try: + return model(points) + except Exception as e: + logger.error(f"调用 NH 模型时出错: {e}") + return None + +def mine(model_path, points): + model = load_model(model_path) + if model is None: + return None + try: + return model.forward_background(points) + except Exception as e: + logger.error(f"调用 mine 模型时出错: {e}") + return None + +def run(name): + # 替换为实际的 obj 文件路径 + obj_file_path = f"/home/wch/brep2sdf/data/gt_mesh/{name}.obj" + model_path = f"/home/wch/brep2sdf/data/output_data/{name}.pt" + nh_model = f"/home/wch/NH-Rep/data_backup/output_data/extracted/output_data/{name}_0_50k_model_h.pt" + + # 检查文件是否存在 + if not os.path.isfile(obj_file_path): + logger.error(f"OBJ 文件 {obj_file_path} 不存在。") + return + + try: + # 读取 obj 文件 + mesh = trimesh.load_mesh(obj_file_path) + logger.info(f"成功读取 OBJ 文件: {obj_file_path}") + except Exception as e: + logger.error(f"读取 OBJ 文件 {obj_file_path} 时出错: {e}") + return + + try: + # 调用采样函数 + result1 = sample_zero_surface_points_and_normals(mesh, num_samples=4096) + if result1 is None: + logger.error("采样失败,返回 None") + return + # 提取前 3 列作为坐标点 + coordinates = result1[:, :3] + # 将 ndarray 转换为 Tensor 并移动到设备上 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + coordinates_tensor = torch.from_numpy(coordinates).float().to(device).requires_grad_(True) + + sdf1 = nh(nh_model, coordinates_tensor) / 2 + sdf2 = mine(model_path, coordinates_tensor) + + loss1, loss2 = {}, {} + if sdf1 is not None and sdf2 is not None: + loss1["de"] = position_loss(sdf1).item() + loss2["de"] = position_loss(sdf2).item() + logger.info(f"NH 模型位置损失: {loss1}") + logger.info(f"Mine 模型位置损失: {loss2}") + + # 将 gt_normal 转换为 torch.Tensor 并移动到设备上 + gt_normal = torch.from_numpy(result1[:, 3:6]).float().to(device) + # 假设 gradient 函数已正确导入 + normal1 = gradient(coordinates_tensor, sdf1) + normal2 = gradient(coordinates_tensor, sdf2) + + loss1["nae"] = average_normal_error(gt_normal, normal1).item() + loss2["nae"] = average_normal_error(gt_normal, normal2).item() + + print("NH 模型的平均法向量误差 (NAE):", loss1["nae"]) + print("Mine 模型的平均法向量误差 (NAE):", loss2["nae"]) + + return loss1, loss2 + else: + logger.error("无法计算损失,SDF 结果为 None") + + except Exception as e: + logger.error(f"处理过程中出现错误: {e}") + +def main(): + names = get_namelist("/home/wch/brep2sdf/data/name_list.txt") + tl1_de, tl1_nae, tl2_de, tl2_nae = 0.0, 0.0, 0.0, 0.0 + valid_count = 0 + for name in names: + result = run(name) + if result is not None: + l1, l2 = result + tl1_de += l1["de"] + tl1_nae += l1["nae"] + tl2_de += l2["de"] + tl2_nae += l2["nae"] + valid_count += 1 + if valid_count > 0: + print(f"NH 模型平均位置损失 (de): {tl1_de/valid_count}") + print(f"NH 模型平均法向量误差 (nae): {tl1_nae/valid_count}") + print(f"Mine 模型平均位置损失 (de): {tl2_de/valid_count}") + print(f"Mine 模型平均法向量误差 (nae): {tl2_nae/valid_count}") + else: + print("没有有效的结果,无法计算平均值。") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/brep2sdf/scripts/abc_vis.py b/brep2sdf/scripts/abc_vis.py new file mode 100644 index 0000000..be51850 --- /dev/null +++ b/brep2sdf/scripts/abc_vis.py @@ -0,0 +1,124 @@ +import os +import glob +import argparse +from OCC.Core.STEPControl import STEPControl_Reader +from OCC.Core.TopExp import TopExp_Explorer +from OCC.Core.TopAbs import TopAbs_FACE +import matplotlib.pyplot as plt +import csv +from tqdm import tqdm # 导入 tqdm 库 +import concurrent.futures +from multiprocessing import Process, Queue + +# 定义一个新的函数,用于在子进程中执行计数操作 +def count_faces_task(file_path, result_queue): + try: + # 创建 STEP 读取器 + reader = STEPControl_Reader() + # 读取 STEP 文件 + status = reader.ReadFile(file_path) + if status == 1: + reader.TransferRoots() + shape = reader.OneShape() + # 遍历所有面 + explorer = TopExp_Explorer(shape, TopAbs_FACE) + face_count = 0 + while explorer.More(): + face_count += 1 + explorer.Next() + #print(face_count) + result_queue.put(face_count) + else: + print(f"无法读取文件 {file_path}") + result_queue.put(None) + except Exception as e: + print(f"处理文件 {file_path} 时出错: {e}") + result_queue.put(None) + +def count_faces_in_step_file(file_path, timeout=30): + result_queue = Queue() + p = Process(target=count_faces_task, args=(file_path, result_queue)) + p.start() + p.join(timeout) + + if p.is_alive(): + print(f"处理文件 {file_path} 超时,已终止") + p.terminate() + p.join() + return None + + result = result_queue.get() + return result + +def main(): + parser = argparse.ArgumentParser(description='统计 ABC 数据集模型面的数量并可视化') + parser.add_argument('-i','--input_dir', type=str, required=True, help='包含 STEP 文件的输入目录') + parser.add_argument('-o', '--output_file', type=str, default='face_counts.csv', help='保存面数量数据的 CSV 文件路径') + # 新增参数,用于指定进程数 + parser.add_argument('--processes', type=int, default=os.cpu_count()-1, help='并行处理的进程数,默认为 CPU 核心数') + args = parser.parse_args() + + # 读取已处理的文件名 + processed_files = set() + if os.path.exists(args.output_file): + with open(args.output_file, 'r', newline='') as csvfile: + reader = csv.reader(csvfile) + try: + next(reader) # 尝试跳过表头 + except StopIteration: + # 如果文件为空,直接跳过 + pass + for row in reader: + processed_files.add(row[0]) + + # 获取所有 STEP 文件并过滤掉已处理的文件 + step_files = glob.glob(os.path.join(args.input_dir, "**/*.step"), recursive=True) + step_files = [file for file in step_files if os.path.basename(file) not in processed_files] + + # 划分批次 + num_processes = args.processes + batch_size = len(step_files) // num_processes + 1 + batches = [step_files[i:i + batch_size] for i in range(0, len(step_files), batch_size)] + + face_counts = [] + # 打开 CSV 文件,准备逐批次写入 + with open(args.output_file, 'a+', newline='') as csvfile: + writer = csv.writer(csvfile) + # 写入表头 + writer.writerow(['文件名', '面的数量']) + + for batch in tqdm(batches, desc="处理批次进度"): + batch_results = [] + with concurrent.futures.ProcessPoolExecutor(max_workers=num_processes) as executor: + future_to_file = {executor.submit(count_faces_in_step_file, file_path): file_path for file_path in batch} + for future in concurrent.futures.as_completed(future_to_file): + file_path = future_to_file[future] + try: + # 指定超时时间 + result = future.result(timeout=30) + except concurrent.futures.TimeoutError: + print(f'{file_path} 处理超时,已终止') + continue + except Exception as exc: + print(f'{file_path} 产生了异常: {exc}') + else: + if result is not None: + batch_results.append((file_path, result)) + # 逐批次写入 CSV + writer.writerow([os.path.basename(file_path), result]) + + face_counts.extend(batch_results) + + if face_counts: + # 绘制直方图,需要提取面数 + face_counts_only = [count for _, count in face_counts] + plt.hist(face_counts_only, bins=50, edgecolor='black') + plt.title('ABC 数据集模型面数量直方图') + plt.xlabel('面的数量') + plt.ylabel('模型数量') + plt.show() + else: + print("未找到有效的 STEP 文件或处理过程中出现错误。") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/brep2sdf/scripts/clean_csv_filenames.py b/brep2sdf/scripts/clean_csv_filenames.py new file mode 100644 index 0000000..e69de29 diff --git a/brep2sdf/scripts/npz2points.py b/brep2sdf/scripts/npz2points.py new file mode 100644 index 0000000..cd2d260 --- /dev/null +++ b/brep2sdf/scripts/npz2points.py @@ -0,0 +1,25 @@ +import torch +import numpy as np +import pickle + + + +def load_brep_file(brep_path): + with open(brep_path, 'rb') as f: + brep_raw = pickle.load(f) + return brep_raw + + +if __name__ == "__main__": + data=load_brep_file("/home/wch/brep2sdf/data/output_data/00000031.xyz") + surfs =data["train_surf_ncs"] + print(surfs) + with open("0031_t.xyz","w") as f: + for point in surfs: + #f.write(f"{point[0]} {point[1]} {point[2]}\n") + f.write(f"{point[0]} {point[1]} {point[2]} {point[3]} {point[4]} {point[5]}\n") + ''' + for surf in surfs: + for point in surf: + f.write(f"{point[0]} {point[1]} {point[2]}\n") + ''' \ No newline at end of file diff --git a/brep2sdf/scripts/paint.py b/brep2sdf/scripts/paint.py new file mode 100644 index 0000000..e073683 --- /dev/null +++ b/brep2sdf/scripts/paint.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn +from torchviz import make_dot + +class SimpleEncoder(nn.Module): + def __init__(self, feature_dim): + super(SimpleEncoder, self).__init__() + self.simple_encoder = nn.Sequential( + nn.Linear(3, 256), + nn.BatchNorm1d(256), + nn.ReLU(), + nn.Linear(256, 512), + nn.BatchNorm1d(512), + nn.ReLU(), + nn.Linear(512, 256), + nn.BatchNorm1d(256), + nn.ReLU(), + nn.Linear(256, feature_dim) + ) + + def forward(self, x): + return self.simple_encoder(x) + +# 创建模型实例 +feature_dim = 8 # 根据你的需求设定 +model = SimpleEncoder(feature_dim) + +# 方法一:将模型设置为评估模式 +model.eval() + +# 方法二:增加输入数据的批次大小 +# x = torch.randn(2, 3) # 将批次大小从 1 改为 2 + +# 创建随机输入张量(根据实际情况调整大小) +x = torch.randn(1, 3) + +# 获取模型输出 +output = model(x) + +# 使用torchviz生成模型图 +dot = make_dot(output, params=dict(list(model.named_parameters()))) + +# 保存图像文件 +dot.render("simple_encoder", format="png") \ No newline at end of file diff --git a/brep2sdf/sdf_vis.py b/brep2sdf/sdf_vis.py new file mode 100644 index 0000000..d95e5c2 --- /dev/null +++ b/brep2sdf/sdf_vis.py @@ -0,0 +1,69 @@ +import numpy as np +import torch +import argparse +import time +from brep2sdf.utils.logger import logger + +def create_grid(depth, box_size): + """ + 创建三维网格点 + :param depth: 网格深度(决定分辨率) + :param box_size: 边界框大小(边长) + :return: 网格点数组和坐标网格 + """ + grid_size = 2**depth + 1 + start = -box_size / 2 + end = box_size / 2 + x = np.linspace(start, end, grid_size) + y = np.linspace(start, end, grid_size) + z = np.linspace(start, end, grid_size) + xx, yy, zz = np.meshgrid(x, y, z, indexing='ij') + points = np.stack([xx.ravel(), yy.ravel(), zz.ravel()], axis=1) + + # 新增归一化处理 + max_coord = np.max(np.abs(points)) + points = points / max_coord # 归一化到[-1,1] + return points, xx, yy, zz + +def predict_sdf(model, points, device): + """ + 使用模型预测SDF值 + :param model: PyTorch模型 + :param points: 输入点坐标 (N, 3) + :param device: 设备(CPU/GPU) + :return: SDF值数组 (N,) + """ + points_t = torch.from_numpy(points).float().to(device) + + with torch.no_grad(): + sdf = model.forward_background(points_t).cpu().numpy().flatten() + return sdf + +def main(): + parser = argparse.ArgumentParser(description='SDF Visualization') + parser.add_argument('-i', '--input', type=str, required=True, help='Input model file (.pt)') + parser.add_argument('--depth', type=int, default=7, help='网格深度(分辨率)') + parser.add_argument('--box_size', type=float, default=2.0, + help='边界框大小(建议设为2.0以得到[-1,1]范围)') + parser.add_argument('--use-gpu', action='store_true', help='使用GPU') + parser.add_argument('-o', '--output', type=str, default='sdf_data.npz', help='输出SDF数据文件(.npz格式)') + args = parser.parse_args() + + # 设置设备 + device = torch.device("cuda" if args.use_gpu and torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + model = torch.jit.load(args.input).to(device) + model.eval() + + # 创建网格并预测SDF + points, xx, yy, zz = create_grid(args.depth, args.box_size) + sdf = predict_sdf(model, points, device) + sdf_grid = sdf.reshape(xx.shape) + + # 保存SDF数据到文件 + np.savez(args.output, xx=xx, yy=yy, zz=zz, sdf_grid=sdf_grid) + print(f"SDF数据已保存到 {args.output}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/brep2sdf/test.py b/brep2sdf/test.py index 4266e2c..35784c9 100644 --- a/brep2sdf/test.py +++ b/brep2sdf/test.py @@ -1,69 +1,112 @@ +import trimesh +import numpy as np +from brep2sdf.data.sampler import sample_zero_surface_points_and_normals +from brep2sdf.networks.network import gradient import torch +import logging + +# 配置日志记录 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +def position_loss(pred_sdfs: torch.Tensor) -> torch.Tensor: + """位置损失函数""" + # 保持梯度流 + squared_diff = torch.pow(pred_sdfs, 2) + return torch.mean(squared_diff) + +def average_normal_error(normals1: torch.Tensor, normals2: torch.Tensor) -> torch.Tensor: + """ + 计算平均法向量误差 (NAE) + :param normals1: 形状为 (B, 3) 的法向量张量 + :param normals2: 形状为 (B, 3) 的法向量张量 + :return: NAE 值 + """ + dot_products = torch.sum(normals1 * normals2, dim=-1) + absolute_dot_products = torch.abs(dot_products) + angle_errors = 1 - absolute_dot_products + return torch.mean(angle_errors) + +def + +# ========== +def load_model(model_path): + """加载模型的通用函数""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + try: + model = torch.jit.load(model_path).to(device) + logging.info(f"成功加载模型: {model_path}") + return model + except Exception as e: + logging.error(f"加载模型 {model_path} 时出错: {e}") + return None + + +#========== +def nh(model_path, points): + model = load_model(model_path) + if model is None: + return None + try: + return model(points) + except Exception as e: + logging.error(f"调用 NH 模型时出错: {e}") + return None + +def mine(model_path, points): + model = load_model(model_path) + if model is None: + return None + try: + return model.forward_background(points) + except Exception as e: + logging.error(f"调用 mine 模型时出错: {e}") + return None + +def main(): + # 替换为实际的 obj 文件路径 + obj_file_path = "/home/wch/brep2sdf/data/gt_mesh/00000031.obj" + model_path = "/home/wch/brep2sdf/data/output_data/00000031.pt" + nh_model = "/home/wch/NH-Rep/data/output_data/00000031_0_50k_model_h.pt" + + try: + # 读取 obj 文件 + mesh = trimesh.load_mesh(obj_file_path) + logging.info(f"成功读取 OBJ 文件: {obj_file_path}") + except Exception as e: + logging.error(f"读取 OBJ 文件 {obj_file_path} 时出错: {e}") + return + + try: + # 调用采样函数 + result1 = sample_zero_surface_points_and_normals(mesh, num_samples=4096) + if result1 is None: + logging.error("采样失败,返回 None") + return + # 提取前 3 列作为坐标点 + coordinates = result1[:, :3] + # 将 ndarray 转换为 Tensor 并移动到设备上 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + coordinates_tensor = torch.from_numpy(coordinates).float().to(device) + + sdf1 = nh(nh_model, coordinates_tensor) / 2 + sdf2 = mine(model_path, coordinates_tensor) + + if sdf1 is not None and sdf2 is not None: + loss1_ = position_loss(sdf1) + loss2 = position_loss(sdf2) + logging.info(f"NH 模型位置损失: {loss1.item()}") + logging.info(f"Mine 模型位置损失: {loss2.item()}") + + gt_normal = result1[:, 3:6] + normal1 = gradient(coordinates, sdf1) + normal2 = gradient(coordinates, sdf2) + nae1=average_normal_error(gt_normal, normal1) + nae2=average_normal_error(gt_normal, normal2) + else: + logging.error("无法计算损失,SDF 结果为 None") + + except Exception as e: + logging.error(f"处理过程中出现错误: {e}") -from typing import List, Tuple - -def bbox_intersect(surf_bboxes: torch.Tensor, indices: torch.Tensor, child_bboxes: torch.Tensor) -> torch.Tensor: - ''' - args: - surf_bboxes: [B, 6] - 表示多个包围盒的张量,每个包围盒由其最小和最大坐标定义。 - indices: 选择bbox, [N], N <= B - 用于选择特定包围盒的索引张量。 - child_bboxes: [8, 6] - 一个包围盒被分成八个子包围盒后的结果。 - return: - intersect_mask: [8, N] - 布尔掩码,表示每个子包围盒与选择的包围盒是否相交。 - ''' - # 提取选中的边界框 - selected_bboxes = surf_bboxes[indices] # 形状为 [N, 6] - min1, max1 = selected_bboxes[:, :3], selected_bboxes[:, 3:] # 形状为 [N, 3] - min2, max2 = child_bboxes[:, :3], child_bboxes[:, 3:] # 形状为 [8, 3] - - # 确保广播机制正常工作 - intersect_mask = torch.all( - (max1.unsqueeze(0) >= min2.unsqueeze(1)) & # 形状为 [8, N, 3] - (max2.unsqueeze(1) >= min1.unsqueeze(0)), # 形状为 [8, N, 3] - dim=-1 - ) # 最终形状为 [8, N] - - return intersect_mask - -# 测试程序 if __name__ == "__main__": - # 构造输入数据 - surf_bboxes = torch.tensor([ - [0, 0, 0, 1, 1, 1], # 立方体 1 - [0.5, 0.5, 0.5, 1.5, 1.5, 1.5], # 立方体 2 - [2, 2, 2, 3, 3, 3] # 立方体 3 - ]) # [B=3, 6] - - indices = torch.tensor([0, 1]) # 选择前两个立方体 - - # 假设父边界框为 [0, 0, 0, 2, 2, 2],生成其八个子边界框 - parent_bbox = torch.tensor([0, 0, 0, 2, 2, 2]) - center = (parent_bbox[:3] + parent_bbox[3:]) / 2 - child_bboxes = torch.tensor([ - [parent_bbox[0], parent_bbox[1], parent_bbox[2], center[0], center[1], center[2]], # 左下前 - [center[0], parent_bbox[1], parent_bbox[2], parent_bbox[3], center[1], center[2]], # 右下前 - [parent_bbox[0], center[1], parent_bbox[2], center[0], parent_bbox[4], center[2]], # 左上前 - [center[0], center[1], parent_bbox[2], parent_bbox[3], parent_bbox[4], center[2]], # 右上前 - [parent_bbox[0], parent_bbox[1], center[2], center[0], center[1], parent_bbox[5]], # 左下后 - [center[0], parent_bbox[1], center[2], parent_bbox[3], center[1], parent_bbox[5]], # 右下后 - [parent_bbox[0], center[1], center[2], center[0], parent_bbox[4], parent_bbox[5]], # 左上后 - [center[0], center[1], center[2], parent_bbox[3], parent_bbox[4], parent_bbox[5]] # 右上后 - ]) # [8, 6] - - # 调用函数 - intersect_mask = bbox_intersect(surf_bboxes, indices, child_bboxes) - - # 输出结果 - print("Intersect Mask:") - print(intersect_mask) - - # 将布尔掩码转换为索引列表 - child_indices = [] - for i in range(8): # 遍历每个子节点 - intersecting_faces = indices[intersect_mask[i]] # 获取当前子节点的相交面片索引 - child_indices.append(intersecting_faces) - - # 打印每个子节点对应的相交索引 - print("\nChild Indices:") - for i, indices in enumerate(child_indices): - print(f"Child {i}: {indices}") \ No newline at end of file + main() \ No newline at end of file