diff --git a/.gitignore b/.gitignore index 53c5944..b1e4919 100644 --- a/.gitignore +++ b/.gitignore @@ -171,6 +171,8 @@ test_data/ logs/ wandb/ *.pth +*.pt +*.csv checkpoints/ data/gt_mesh @@ -178,4 +180,5 @@ data/gt_point data/step data/input_data data/output_data -data/name_list.txt \ No newline at end of file +data/name_list.txt +data/scripts/IsoSurfacing \ No newline at end of file diff --git a/brep2sdf/IsoSurfacing.py b/brep2sdf/IsoSurfacing.py index d4a90d4..5645f5c 100644 --- a/brep2sdf/IsoSurfacing.py +++ b/brep2sdf/IsoSurfacing.py @@ -1,66 +1,161 @@ -import os -import subprocess -from tqdm import tqdm +import numpy as np +import torch +import argparse +from skimage import measure +import time +import trimesh -# 使用一个 c++ 程序处理,这里只是调用,注意要在docker里面运行。宿主机编译失败 +def create_grid(depth, box_size): + """ + 创建三维网格点 + :param depth: 网格深度(决定分辨率) + :param box_size: 边界框大小(边长) + :return: 网格点数组和坐标网格 + """ + grid_size = 2**depth + 1 + start = -box_size / 2 + end = box_size / 2 + x = np.linspace(start, end, grid_size) + y = np.linspace(start, end, grid_size) + z = np.linspace(start, end, grid_size) + xx, yy, zz = np.meshgrid(x, y, z, indexing='ij') + points = np.stack([xx.ravel(), yy.ravel(), zz.ravel()], axis=1) + return points, xx, yy, zz + +def predict_sdf(model, points, device): + """ + 使用模型预测SDF值 + :param model: PyTorch模型 + :param points: 输入点坐标 (N, 3) + :param device: 设备(CPU/GPU) + :return: SDF值数组 (N,) + """ + points_t = torch.from_numpy(points).float().to(device) + + with torch.no_grad(): + sdf = model(points_t).cpu().numpy().flatten() + return sdf + +def extract_surface(sdf, xx, yy, zz, method='MC'): + """ + 提取零表面 + :param sdf: SDF值三维数组 + :param xx/yy/zz: 网格坐标 + :param method: 提取方法(MC: Marching Cubes) + :return: 顶点和面片 + """ + if method == 'MC': + verts, faces, _, _ = measure.marching_cubes(sdf, level=0) + else: + raise NotImplementedError("仅支持Marching Cubes方法") + return verts, faces + +def save_ply(vertices, faces, filename): + """ + 保存顶点和面片为PLY文件 + :param vertices: 顶点数组 (N, 3) + :param faces: 面片数组 (M, 3) + :param filename: 输出文件名 + """ + with open(filename, 'w') as f: + f.write("ply\n") + f.write("format ascii 1.0\n") + f.write(f"element vertex {len(vertices)}\n") + f.write("property float x\n") + f.write("property float y\n") + f.write("property float z\n") + f.write(f"element face {len(faces)}\n") + f.write("property list uchar int vertex_indices\n") + f.write("end_header\n") + for v in vertices: + f.write(f"{v[0]} {v[1]} {v[2]}\n") + for face in faces: + f.write(f"3 {face[0]} {face[1]} {face[2]}\n") + +def compute_sdf_error(model, gt_mesh, res, device): + """ + 计算预测SDF与GT网格的误差 + :param model: PyTorch模型 + :param gt_mesh: GT网格(Trimesh格式) + :param res: 误差计算分辨率 + :param device: 设备 + :return: 平均误差和最大误差 + """ + # 生成均匀采样点 + box_size = max(gt_mesh.extents) + start = -box_size / 2 + end = box_size / 2 + x = np.linspace(start, end, res) + y = np.linspace(start, end, res) + z = np.linspace(start, end, res) + points = np.array(np.meshgrid(x, y, z)).T.reshape(-1, 3) + + # 预测SDF + pred_sdf = predict_sdf(model, points, device) + + # 计算GT距离 + distances = gt_mesh.nearest.on_surface(points)[1] + gt_sdf = np.abs(distances) + + # 计算误差 + abs_error = np.abs(pred_sdf - gt_sdf) + rel_error = abs_error / (np.abs(gt_sdf) + 1e-9) + avg_abs = np.mean(abs_error) + avg_rel = np.mean(rel_error) + max_abs = np.max(abs_error) + max_rel = np.max(rel_error) + + return avg_abs, avg_rel, max_abs, max_rel def main(): - # 定义 STEP 文件目录和名称列表文件路径 - output_data_root_dir = "/workspace/home/wch/brep2sdf/data/output_data" - name_list_path = "/workspace/home/wch/brep2sdf/data/name_list.txt" + parser = argparse.ArgumentParser(description='IsoSurface Generator') + parser.add_argument('-i', '--input', type=str, required=True, help='Input model file (.pt)') + parser.add_argument('-o', '--output', type=str, required=True, help='Output mesh file (.ply)') + parser.add_argument('--depth', type=int, default=7, help='网格深度(分辨率)') + parser.add_argument('--box_size', type=float, default=2.0, help='边界框大小') + parser.add_argument('--method', type=str, default='MC', choices=['MC'], help='表面提取方法') + parser.add_argument('--use-gpu', action='store_true', help='使用GPU') + parser.add_argument('--compare', type=str, help='GT网格文件(.ply)') + parser.add_argument('--compres', type=int, default=32, help='误差计算分辨率') + args = parser.parse_args() + + # 设置设备 + device = torch.device("cuda" if args.use_gpu and torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") - # 读取名称列表 - try: - with open(name_list_path, 'r') as f: - names = [line.strip() for line in f if line.strip()] # 去除空行 - except FileNotFoundError: - print(f"Error: File '{name_list_path}' not found.") - return - except Exception as e: - print(f"Error reading file '{name_list_path}': {e}") - return + model = torch.jit.load(args.input).to(device) + #model = torch.load(args.input).to(device) + model.eval() - # 遍历名称列表并处理每个 STEP 文件 - for name in tqdm(names, desc="ISOsurfing pt files"): - pt_file = os.path.join(output_data_root_dir, f"{name}.pt") - - if not pt_file: - print(f"Warning: No pt files found in directory '{output_data_root_dir}'. Skipping...") - continue - + # 创建网格并预测SDF + points, xx, yy, zz = create_grid(args.depth, args.box_size) + sdf = predict_sdf(model, points, device) + print(points.shape) + print(sdf.shape) + print(sdf) + sdf_grid = sdf.reshape(xx.shape) - # ./ISG_console_pytorch -i ./test/teaser.pt -o outputmesh.ply -v -0.01 -d 8 - # 构造子进程命令 - command = [ - "python", "/workspace/home/wch/brep2sdf/data/scripts/IsoSurfacing/build/App/console_pytorch/ISG_console_pytorch", - "-i", pt_file, # 使用当前遍历的pt文件 - "-o", os.path.join(output_data_root_dir, f"{name}_outputmesh.ply"), # 动态生成输出文件路径 - "-v", "-0.01", "-d", "8" - ] + # 提取表面 + print("Extracting surface...") + start_time = time.time() + verts, faces = extract_surface(sdf_grid, xx, yy, zz, args.method) + print(f"Surface extraction took {time.time() - start_time:.2f} seconds") - # 调用子进程运行命令 - try: - result = subprocess.run( - command, - capture_output=True, - text=True, - check=True # 如果返回非零退出码,则抛出 CalledProcessError - ) - print(f"Successfully processed '{name}'") - print("STDOUT:", result.stdout) - print("STDERR:", result.stderr) - except subprocess.CalledProcessError as e: - print(f"Error processing '{name}': Command failed with return code {e.returncode}") - print(f"Command: {e.cmd}") - print(f"Error type: {type(e).__name__}") - print("STDOUT:", e.stdout) - print("STDERR:", e.stderr) - print("Traceback:", e.__traceback__) - except Exception as e: - print(f"Unexpected error processing '{name}': {str(e)}") - print(f"Command: {command}") - print("Traceback:", traceback.format_exc()) + # 保存网格 + save_ply(verts, faces, args.output) + print(f"Mesh saved to {args.output}") + # 误差评估(可选) + if args.compare: + print("Computing SDF error...") + gt_mesh = trimesh.load(args.compare) + avg_abs, avg_rel, max_abs, max_rel = compute_sdf_error( + model, gt_mesh, args.compres, device + ) + print(f"Average Absolute Error: {avg_abs:.4f}") + print(f"Average Relative Error: {avg_rel:.4f}") + print(f"Max Absolute Error: {max_abs:.4f}") + print(f"Max Relative Error: {max_rel:.4f}") -if __name__ == '__main__': +if __name__ == "__main__": main() \ No newline at end of file diff --git a/brep2sdf/batch_train.py b/brep2sdf/batch_train.py new file mode 100644 index 0000000..5fd3124 --- /dev/null +++ b/brep2sdf/batch_train.py @@ -0,0 +1,67 @@ +import os +import subprocess +from tqdm import tqdm + + +def main(): + # 定义 STEP 文件目录和名称列表文件路径 + step_root_dir = "/home/wch/brep2sdf/data/step" + name_list_path = "/home/wch/brep2sdf/data/name_list.txt" + + # 读取名称列表 + try: + with open(name_list_path, 'r') as f: + names = [line.strip() for line in f if line.strip()] # 去除空行 + except FileNotFoundError: + print(f"Error: File '{name_list_path}' not found.") + return + except Exception as e: + print(f"Error reading file '{name_list_path}': {e}") + return + + # 遍历名称列表并处理每个 STEP 文件 + for name in tqdm(names, desc="Processing STEP files"): + step_dir = os.path.join(step_root_dir, name) + + # 动态生成 STEP 文件路径(假设只有一个文件) + step_files = [ + os.path.join(step_dir, f) + for f in os.listdir(step_dir) + if f.endswith(".step") and f.startswith(name) + ] + + if not step_files: + print(f"Warning: No STEP files found in directory '{step_dir}'. Skipping...") + continue + + # 假设我们只处理第一个匹配的文件 + input_step = step_files[0] + + # 构造子进程命令 + command = [ + "python", "train.py", + "--use-normal", + "-i", input_step, # 输入文件路径 + ] + + # 调用子进程运行命令 + try: + result = subprocess.run( + command, + capture_output=True, + text=True, + check=True # 如果返回非零退出码,则抛出 CalledProcessError + ) + print(f"Processed '{input_step}' successfully.") + print("STDOUT:", result.stdout) + print("STDERR:", result.stderr) + except subprocess.CalledProcessError as e: + print(f"Error processing '{input_step}': {e}") + print("STDOUT:", e.stdout) + print("STDERR:", e.stderr) + except Exception as e: + print(f"Unexpected error processing '{input_step}': {e}") + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py index c41db00..aac078e 100644 --- a/brep2sdf/config/default_config.py +++ b/brep2sdf/config/default_config.py @@ -47,7 +47,7 @@ class TrainConfig: # 基本训练参数 batch_size: int = 8 num_workers: int = 4 - num_epochs: int = 20 + num_epochs: int = 200 learning_rate: float = 0.01 min_lr: float = 1e-5 weight_decay: float = 0.01 @@ -89,7 +89,7 @@ class LogConfig: # 本地日志 log_dir: str = '/home/wch/brep2sdf/logs' # 日志保存目录 log_level: str = 'INFO' # 日志级别 - console_level: str = 'INFO' # 控制台日志级别 + console_level: str = 'DEBUG' # 控制台日志级别 file_level: str = 'DEBUG' # 文件日志级别 @dataclass diff --git a/brep2sdf/evaluation.py b/brep2sdf/evaluation.py new file mode 100644 index 0000000..a82cc69 --- /dev/null +++ b/brep2sdf/evaluation.py @@ -0,0 +1,280 @@ +import os +import sys +from brep2sdf.utils.logger import logger + + +# 导入日志系统 +from brep2sdf.utils.logger import logger +import numpy as np +from scipy.spatial import cKDTree +from scipy.spatial.distance import directed_hausdorff +import trimesh +import pandas as pd +import csv +import math +import pickle + +import argparse + +project_dir = "/home/wch/brep2sdf" +# parse args first and set gpu id +parser = argparse.ArgumentParser() +parser.add_argument('--gt_path', type=str, + default=os.path.join(project_dir, 'data/gt_point'), + help='ground truth data path') +parser.add_argument('--pred_path', type=str, + default=os.path.join(project_dir, 'data/output_data'), + help='converted data path') +parser.add_argument('--name_list', type=str, default='name_list.txt', help='names of models to be evaluated, if you want to evaluate the whole dataset, please set it as all_names.txt') +parser.add_argument('--nsample', type=int, default=50000, help='point batch size') +parser.add_argument('--regen', default = False, action="store_true", help = 'regenerate feature curves') +parser.add_argument('--csv_name', type=str, default='eval_results.csv', help='csv file name') +args = parser.parse_args() + +def distance_p2p(points_src, normals_src, points_tgt, normals_tgt): + ''' Computes minimal distances of each point in points_src to points_tgt. + + Args: + points_src (numpy array [N, 3]): source points + normals_src (numpy array [N, 3]): source normals + points_tgt (numpy array [M, 3]): target points + normals_tgt (numpy array [M, 3]): target + Returns: + dist (numpy array [N]): minimal distances of each point in points_src to points_tgt + normals_dot_product (numpy array [N]): dot product of normals of points_src and points_tgt + ''' + kdtree = cKDTree(points_tgt) + dist, idx = kdtree.query(points_src) + + if normals_src is not None and normals_tgt is not None: + normals_src = \ + normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True) + normals_tgt = \ + normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True) + + normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1) + # Handle normals that point into wrong direction gracefully + # (mostly due to mehtod not caring about this in generation) + normals_dot_product = np.abs(normals_dot_product) + return dist, normals_dot_product + +def distance_feature2mesh(points, mesh): + prox = trimesh.proximity.ProximityQuery(mesh) + signed_distance = prox.signed_distance(points) + return np.abs(signed_distance) + +def distance_p2mesh(points_src, normals_src, mesh): + points_tgt, idx = mesh.sample(args.nsample, return_index=True) + points_tgt = points_tgt.astype(np.float32) + normals_tgt = mesh.face_normals[idx] + cd1, nc1 = distance_p2p(points_src, normals_src, points_tgt, normals_tgt) #pred2gt + hd1 = cd1.max() + cd1 = cd1.mean() + + nc1 = np.clip(nc1, -1.0, 1.0) + angles1 = np.arccos(nc1) / math.pi * 180.0 + angles1_mean = angles1.mean() + angles1_std = np.std(angles1) + + cd2, nc2 = distance_p2p(points_tgt, normals_tgt, points_src, normals_src) #gt2pred + hd2 = cd2.max() + cd2 = cd2.mean() + nc2 = np.clip(nc2, -1.0, 1.0) + + angles2 = np.arccos(nc2)/ math.pi * 180.0 + angles2_mean = angles2.mean() + angles2_std = np.std(angles2) + + + cd = 0.5 * (cd1 + cd2) + hd = max(hd1, hd2) + angles_mean = 0.5 * (angles1_mean + angles2_mean) + angles_std = 0.5 * (angles1_std + angles2_std) + return cd, hd, angles_mean, angles_std, hd1, hd2 + + +def distance_fea(gt_pa, pred_pa): + """计算特征点之间的距离和角度差异 + Args: + gt_pa: 真实特征点和角度 [N, 4] + pred_pa: 预测特征点和角度 [N, 4] + Returns: + dfg2p: 真实到预测的距离 + dfp2g: 预测到真实的距离 + fag2p: 真实到预测的角度差 + fap2g: 预测到真实的角度差 + """ + gt_points = gt_pa[:,:3] + pred_points = pred_pa[:,:3] + gt_angle = gt_pa[:,3] + pred_angle = pred_pa[:,3] + dfg2p = 0.0 + dfp2g = 0.0 + fag2p = 0.0 + fap2g = 0.0 + pred_kdtree = cKDTree(pred_points) + dist1, idx1 = pred_kdtree.query(gt_points) + dfg2p = dist1.mean() + assert(idx1.shape[0] == gt_points.shape[0]) + fag2p = np.abs(gt_angle - pred_angle[idx1]) + + gt_kdtree = cKDTree(gt_points) + dist2, idx2 = gt_kdtree.query(pred_points) + dfp2g = dist2.mean() + fap2g = np.abs(pred_angle - gt_angle[idx2]) + + fag2p = fag2p.mean() + fap2g = fap2g.mean() + + return dfg2p, dfp2g, fag2p, fap2g + +def load_and_process_single_model(line, gt_path, pred_mesh_path, args): + """处理单个模型的评估 + Args: + line (str): 模型名称 + gt_path (str): 真值路径 + pred_mesh_path (str): 预测网格路径 + args: 参数配置 + Returns: + dict: 包含该模型所有评估指标的字典 + """ + try: + #line = line.strip()[:-4] # 不用去 _50k + result = {'name': line} + + # 加载点云数据 + test_xyz = os.path.join(gt_path, line+'_50k.xyz') + try: + ptnormal = np.loadtxt(test_xyz) + except FileNotFoundError: + logger.error(f"XYZ file not found: {test_xyz}") + return None + except IOError as e: + logger.error(f"Error reading XYZ file {test_xyz}: {str(e)}") + return None + except ValueError as e: + logger.error(f"Invalid data format in XYZ file {test_xyz}: {str(e)}") + return None + except Exception as e: + logger.error(f"Unexpected error loading {test_xyz}: {str(e)}") + return None + logger.debug("successfully load gt points.") + + # 加载预测网格 + meshfile = os.path.join(pred_mesh_path, '{}.ply'.format(line)) + if not os.path.exists(meshfile): + logger.warning(f'File not exists: {meshfile}, try to generate it...') + pt_file = os.path.join(pred_mesh_path, '{}.pt'.format(line)) + try: + # 记录开始执行命令 + logger.debug(f"Executing IsoSurfacing: python ./IsoSurfacing.py -i {pt_file} -o {meshfile} --use-gpu") + + # 执行命令并检查返回值 + ret = os.system(f"python ./IsoSurfacing.py -i {pt_file} -o {meshfile} --use-gpu") + + if ret != 0: + raise RuntimeError(f"IsoSurfacing failed with return code {ret}") + + # 检查输出文件是否生成 + if not os.path.exists(meshfile): + raise FileNotFoundError(f"Output mesh file not created: {meshfile}") + + logger.debug("IsoSurfacing completed successfully") + + except FileNotFoundError as e: + logger.error(f"IsoSurfacing input file not found: {str(e)}") + return None + except RuntimeError as e: + logger.error(f"IsoSurfacing execution failed: {str(e)}") + return None + except Exception as e: + logger.error(f"Unexpected error in IsoSurfacing: {str(e)}") + return None + + # 检查缓存 + stat_file = meshfile + "_stat" + if not args.regen and os.path.exists(stat_file) and os.path.getsize(stat_file) > 0: + with open(stat_file, 'rb') as f: + return pickle.load(f) + + # 计算网格距离指标 + mesh = trimesh.load(meshfile) + logger.debug("successfully load pred mesh.") + cd, hd, adm, ads, hd_pred2gt, hd_gt2pred = distance_p2mesh( + ptnormal[:,:3], ptnormal[:,3:6], mesh) + + result.update({ + 'CD': cd, 'HD': hd, 'HDpred2gt': hd_pred2gt, + 'HDgt2pred': hd_gt2pred, 'AngleDiffMean': adm, + 'AngleDiffStd': ads + }) + + # 计算特征点指标 + gt_ptangle = np.loadtxt(os.path.join(gt_path, line + '.ptangle')) + pred_ptangle_path = meshfile[:-4]+'.ptangle' + + if not os.path.exists(pred_ptangle_path) or args.regen: + os.system('/home/wch/brep2sdf/data/scripts/MeshFeatureSample/build/SimpleSample -i {} -o {} -s 4e-3'.format(meshfile, pred_ptangle_path)) + + pred_ptangle = np.loadtxt(pred_ptangle_path).reshape(-1,4) + + # 处理特征点结果 + if len(gt_ptangle) == 0 or len(pred_ptangle) == 0: + result.update({ + 'FeaDfgt2pred': 0.0, 'FeaDfpred2gt': 0.0, + 'FeaAnglegt2pred': 0.0, 'FeaAnglepred2gt': 0.0, + 'FeaDf': 0.0, 'FeaAngle': 0.0 + }) + else: + dfg2p, dfp2g, fag2p, fap2g = distance_fea(gt_ptangle, pred_ptangle) + result.update({ + 'FeaDfgt2pred': dfg2p, 'FeaDfpred2gt': dfp2g, + 'FeaAnglegt2pred': fag2p, 'FeaAnglepred2gt': fap2g, + 'FeaDf': (dfg2p + dfp2g) / 2.0, + 'FeaAngle': (fag2p + fap2g) / 2.0 + }) + + # 保存缓存 + with open(stat_file, "wb") as f: + pickle.dump(result, f) + + return result + except Exception as e: + logger.error(f"Error processing {line}: {str(e)}") + return None + +def compute_all(): + """计算所有模型的评估指标""" + try: + # 初始化结果字典 + results = [] + + # 读取模型列表 + with open(os.path.join(project_dir, 'data', args.name_list), 'r') as f: + lines = f.readlines() + print(lines) + # 处理每个模型 + for line in lines: + result = load_and_process_single_model(line, args.gt_path, args.pred_path, args) + if result: + results.append(result) + logger.info(result) + # 计算平均值 + mean_result = {'name': 'mean'} + for key in results[0].keys(): + if key != 'name': + mean_result[key] = sum(r[key] for r in results) / len(results) + results.append(mean_result) + + # 保存结果 + df = pd.DataFrame(results) + df.to_csv(args.csv_name, index=False) + + logger.info(f"Evaluation completed. Results saved to {os.path.abspath(args.csv_name)}") + + except Exception as e: + logger.error(f"Error in compute_all: {str(e)}") + raise + +if __name__ == '__main__': + compute_all() \ No newline at end of file diff --git a/brep2sdf/train.py b/brep2sdf/train.py index a77a35c..d6e71b8 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -80,6 +80,8 @@ class Trainer: data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name) if os.path.exists(data_path) and not args.force_reprocess: self.data = load_brep_file(data_path) + if args.use_normal and self.data.get("surf_pnt_normals", None) is None: + self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal) else: self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal) @@ -276,7 +278,7 @@ class Trainer: # 3. 在no_grad上下文中执行追踪 with torch.no_grad(): traced_model = torch.jit.trace(self.model, example_input) - torch.jit.save(traced_model, f"{self.model_name}.pt") + torch.jit.save(traced_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt") def _load_checkpoint(self, checkpoint_path): """从检查点恢复训练状态"""