From 5d0d76f5ebdfcbc33b86a77d40d477a9659bee18 Mon Sep 17 00:00:00 2001 From: mckay Date: Sat, 23 Nov 2024 20:19:21 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20test=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/test.py | 161 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) create mode 100644 brep2sdf/test.py diff --git a/brep2sdf/test.py b/brep2sdf/test.py new file mode 100644 index 0000000..00f7695 --- /dev/null +++ b/brep2sdf/test.py @@ -0,0 +1,161 @@ +import os +import torch +import numpy as np +from torch.utils.data import DataLoader +from brep2sdf.data.data import BRepSDFDataset +from brep2sdf.networks.network import BRepToSDF +from brep2sdf.utils.logger import logger +from brep2sdf.config.default_config import get_default_config +import matplotlib.pyplot as plt +from tqdm import tqdm + +class Tester: + def __init__(self, config, checkpoint_path): + self.config = config + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # 初始化测试数据集 + self.test_dataset = BRepSDFDataset( + brep_dir=config.data.brep_dir, + sdf_dir=config.data.sdf_dir, + valid_data_dir=config.data.valid_data_dir, + split='test' + ) + + # 初始化数据加载器 + self.test_loader = DataLoader( + self.test_dataset, + batch_size=1, # 测试时使用batch_size=1 + shuffle=False, + num_workers=config.train.num_workers, + pin_memory=False + ) + + # 加载模型 + self.model = BRepToSDF(config).to(self.device) + self.load_checkpoint(checkpoint_path) + + # 创建结果保存目录 + self.result_dir = os.path.join(config.data.result_save_dir, 'test_results') + os.makedirs(self.result_dir, exist_ok=True) + + def load_checkpoint(self, checkpoint_path): + """加载检查点""" + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}") + + checkpoint = torch.load(checkpoint_path, map_location=self.device) + self.model.load_state_dict(checkpoint['model_state_dict']) + logger.info(f"Loaded checkpoint from {checkpoint_path}") + + def compute_metrics(self, pred_sdf, gt_sdf): + """计算评估指标""" + mse = torch.mean((pred_sdf - gt_sdf) ** 2).item() + mae = torch.mean(torch.abs(pred_sdf - gt_sdf)).item() + max_error = torch.max(torch.abs(pred_sdf - gt_sdf)).item() + + return { + 'mse': mse, + 'mae': mae, + 'max_error': max_error + } + + def visualize_results(self, pred_sdf, gt_sdf, points, save_path): + """可视化预测结果""" + fig = plt.figure(figsize=(15, 5)) + + # 绘制预测SDF + ax1 = fig.add_subplot(131, projection='3d') + scatter = ax1.scatter(points[:, 0], points[:, 1], points[:, 2], + c=pred_sdf.squeeze().cpu(), cmap='coolwarm') + ax1.set_title('Predicted SDF') + plt.colorbar(scatter) + + # 绘制真实SDF + ax2 = fig.add_subplot(132, projection='3d') + scatter = ax2.scatter(points[:, 0], points[:, 1], points[:, 2], + c=gt_sdf.squeeze().cpu(), cmap='coolwarm') + ax2.set_title('Ground Truth SDF') + plt.colorbar(scatter) + + # 绘制误差图 + ax3 = fig.add_subplot(133, projection='3d') + error = torch.abs(pred_sdf - gt_sdf) + scatter = ax3.scatter(points[:, 0], points[:, 1], points[:, 2], + c=error.squeeze().cpu(), cmap='Reds') + ax3.set_title('Absolute Error') + plt.colorbar(scatter) + + plt.tight_layout() + plt.savefig(save_path) + plt.close() + + def test(self): + """执行测试""" + self.model.eval() + total_metrics = {'mse': 0, 'mae': 0, 'max_error': 0} + + logger.info("Starting testing...") + + with torch.no_grad(): + for idx, batch in enumerate(tqdm(self.test_loader)): + # 获取数据并移动到设备 + surf_ncs = batch['surf_ncs'].to(self.device) + edge_ncs = batch['edge_ncs'].to(self.device) + surf_pos = batch['surf_pos'].to(self.device) + edge_pos = batch['edge_pos'].to(self.device) + vertex_pos = batch['vertex_pos'].to(self.device) + edge_mask = batch['edge_mask'].to(self.device) + points = batch['points'].to(self.device) + gt_sdf = batch['sdf'].to(self.device) + + # 前向传播 + pred_sdf = self.model( + surf_ncs=surf_ncs, edge_ncs=edge_ncs, + surf_pos=surf_pos, edge_pos=edge_pos, + vertex_pos=vertex_pos, edge_mask=edge_mask, + query_points=points + ) + + # 计算指标 + metrics = self.compute_metrics(pred_sdf, gt_sdf) + for k, v in metrics.items(): + total_metrics[k] += v + + # 可视化结果 + if idx % self.config.test.vis_freq == 0: + save_path = os.path.join(self.result_dir, f'result_{idx}.png') + self.visualize_results(pred_sdf, gt_sdf, points[0].cpu(), save_path) + + # 计算平均指标 + num_samples = len(self.test_loader) + avg_metrics = {k: v / num_samples for k, v in total_metrics.items()} + + # 保存测试结果 + logger.info("Test Results:") + for k, v in avg_metrics.items(): + logger.info(f"{k}: {v:.6f}") + + # 保存指标到文件 + with open(os.path.join(self.result_dir, 'test_metrics.txt'), 'w') as f: + for k, v in avg_metrics.items(): + f.write(f"{k}: {v:.6f}\n") + + return avg_metrics + +def main(): + # 获取配置 + config = get_default_config() + + # 设置检查点路径 + checkpoint_path = os.path.join( + config.data.model_save_dir, + config.data.best_model_name.format(model_name=config.data.model_name) + ) + + # 初始化测试器并执行测试 + tester = Tester(config, checkpoint_path) + metrics = tester.test() + +if __name__ == '__main__': + main() \ No newline at end of file