1 changed files with 161 additions and 0 deletions
			
			
		@ -0,0 +1,161 @@ | 
				
			|||
import os | 
				
			|||
import torch | 
				
			|||
import numpy as np | 
				
			|||
from torch.utils.data import DataLoader | 
				
			|||
from brep2sdf.data.data import BRepSDFDataset | 
				
			|||
from brep2sdf.networks.network import BRepToSDF | 
				
			|||
from brep2sdf.utils.logger import logger | 
				
			|||
from brep2sdf.config.default_config import get_default_config | 
				
			|||
import matplotlib.pyplot as plt | 
				
			|||
from tqdm import tqdm | 
				
			|||
 | 
				
			|||
class Tester: | 
				
			|||
    def __init__(self, config, checkpoint_path): | 
				
			|||
        self.config = config | 
				
			|||
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | 
				
			|||
         | 
				
			|||
        # 初始化测试数据集 | 
				
			|||
        self.test_dataset = BRepSDFDataset( | 
				
			|||
            brep_dir=config.data.brep_dir, | 
				
			|||
            sdf_dir=config.data.sdf_dir, | 
				
			|||
            valid_data_dir=config.data.valid_data_dir, | 
				
			|||
            split='test' | 
				
			|||
        ) | 
				
			|||
         | 
				
			|||
        # 初始化数据加载器 | 
				
			|||
        self.test_loader = DataLoader( | 
				
			|||
            self.test_dataset, | 
				
			|||
            batch_size=1,  # 测试时使用batch_size=1 | 
				
			|||
            shuffle=False, | 
				
			|||
            num_workers=config.train.num_workers, | 
				
			|||
            pin_memory=False | 
				
			|||
        ) | 
				
			|||
         | 
				
			|||
        # 加载模型 | 
				
			|||
        self.model = BRepToSDF(config).to(self.device) | 
				
			|||
        self.load_checkpoint(checkpoint_path) | 
				
			|||
         | 
				
			|||
        # 创建结果保存目录 | 
				
			|||
        self.result_dir = os.path.join(config.data.result_save_dir, 'test_results') | 
				
			|||
        os.makedirs(self.result_dir, exist_ok=True) | 
				
			|||
         | 
				
			|||
    def load_checkpoint(self, checkpoint_path): | 
				
			|||
        """加载检查点""" | 
				
			|||
        if not os.path.exists(checkpoint_path): | 
				
			|||
            raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}") | 
				
			|||
             | 
				
			|||
        checkpoint = torch.load(checkpoint_path, map_location=self.device) | 
				
			|||
        self.model.load_state_dict(checkpoint['model_state_dict']) | 
				
			|||
        logger.info(f"Loaded checkpoint from {checkpoint_path}") | 
				
			|||
         | 
				
			|||
    def compute_metrics(self, pred_sdf, gt_sdf): | 
				
			|||
        """计算评估指标""" | 
				
			|||
        mse = torch.mean((pred_sdf - gt_sdf) ** 2).item() | 
				
			|||
        mae = torch.mean(torch.abs(pred_sdf - gt_sdf)).item() | 
				
			|||
        max_error = torch.max(torch.abs(pred_sdf - gt_sdf)).item() | 
				
			|||
         | 
				
			|||
        return { | 
				
			|||
            'mse': mse, | 
				
			|||
            'mae': mae, | 
				
			|||
            'max_error': max_error | 
				
			|||
        } | 
				
			|||
     | 
				
			|||
    def visualize_results(self, pred_sdf, gt_sdf, points, save_path): | 
				
			|||
        """可视化预测结果""" | 
				
			|||
        fig = plt.figure(figsize=(15, 5)) | 
				
			|||
         | 
				
			|||
        # 绘制预测SDF | 
				
			|||
        ax1 = fig.add_subplot(131, projection='3d') | 
				
			|||
        scatter = ax1.scatter(points[:, 0], points[:, 1], points[:, 2],  | 
				
			|||
                            c=pred_sdf.squeeze().cpu(), cmap='coolwarm') | 
				
			|||
        ax1.set_title('Predicted SDF') | 
				
			|||
        plt.colorbar(scatter) | 
				
			|||
         | 
				
			|||
        # 绘制真实SDF | 
				
			|||
        ax2 = fig.add_subplot(132, projection='3d') | 
				
			|||
        scatter = ax2.scatter(points[:, 0], points[:, 1], points[:, 2],  | 
				
			|||
                            c=gt_sdf.squeeze().cpu(), cmap='coolwarm') | 
				
			|||
        ax2.set_title('Ground Truth SDF') | 
				
			|||
        plt.colorbar(scatter) | 
				
			|||
         | 
				
			|||
        # 绘制误差图 | 
				
			|||
        ax3 = fig.add_subplot(133, projection='3d') | 
				
			|||
        error = torch.abs(pred_sdf - gt_sdf) | 
				
			|||
        scatter = ax3.scatter(points[:, 0], points[:, 1], points[:, 2],  | 
				
			|||
                            c=error.squeeze().cpu(), cmap='Reds') | 
				
			|||
        ax3.set_title('Absolute Error') | 
				
			|||
        plt.colorbar(scatter) | 
				
			|||
         | 
				
			|||
        plt.tight_layout() | 
				
			|||
        plt.savefig(save_path) | 
				
			|||
        plt.close() | 
				
			|||
         | 
				
			|||
    def test(self): | 
				
			|||
        """执行测试""" | 
				
			|||
        self.model.eval() | 
				
			|||
        total_metrics = {'mse': 0, 'mae': 0, 'max_error': 0} | 
				
			|||
         | 
				
			|||
        logger.info("Starting testing...") | 
				
			|||
         | 
				
			|||
        with torch.no_grad(): | 
				
			|||
            for idx, batch in enumerate(tqdm(self.test_loader)): | 
				
			|||
                # 获取数据并移动到设备 | 
				
			|||
                surf_ncs = batch['surf_ncs'].to(self.device) | 
				
			|||
                edge_ncs = batch['edge_ncs'].to(self.device) | 
				
			|||
                surf_pos = batch['surf_pos'].to(self.device) | 
				
			|||
                edge_pos = batch['edge_pos'].to(self.device) | 
				
			|||
                vertex_pos = batch['vertex_pos'].to(self.device) | 
				
			|||
                edge_mask = batch['edge_mask'].to(self.device) | 
				
			|||
                points = batch['points'].to(self.device) | 
				
			|||
                gt_sdf = batch['sdf'].to(self.device) | 
				
			|||
                 | 
				
			|||
                # 前向传播 | 
				
			|||
                pred_sdf = self.model( | 
				
			|||
                    surf_ncs=surf_ncs, edge_ncs=edge_ncs, | 
				
			|||
                    surf_pos=surf_pos, edge_pos=edge_pos, | 
				
			|||
                    vertex_pos=vertex_pos, edge_mask=edge_mask, | 
				
			|||
                    query_points=points | 
				
			|||
                ) | 
				
			|||
                 | 
				
			|||
                # 计算指标 | 
				
			|||
                metrics = self.compute_metrics(pred_sdf, gt_sdf) | 
				
			|||
                for k, v in metrics.items(): | 
				
			|||
                    total_metrics[k] += v | 
				
			|||
                 | 
				
			|||
                # 可视化结果 | 
				
			|||
                if idx % self.config.test.vis_freq == 0: | 
				
			|||
                    save_path = os.path.join(self.result_dir, f'result_{idx}.png') | 
				
			|||
                    self.visualize_results(pred_sdf, gt_sdf, points[0].cpu(), save_path) | 
				
			|||
         | 
				
			|||
        # 计算平均指标 | 
				
			|||
        num_samples = len(self.test_loader) | 
				
			|||
        avg_metrics = {k: v / num_samples for k, v in total_metrics.items()} | 
				
			|||
         | 
				
			|||
        # 保存测试结果 | 
				
			|||
        logger.info("Test Results:") | 
				
			|||
        for k, v in avg_metrics.items(): | 
				
			|||
            logger.info(f"{k}: {v:.6f}") | 
				
			|||
             | 
				
			|||
        # 保存指标到文件 | 
				
			|||
        with open(os.path.join(self.result_dir, 'test_metrics.txt'), 'w') as f: | 
				
			|||
            for k, v in avg_metrics.items(): | 
				
			|||
                f.write(f"{k}: {v:.6f}\n") | 
				
			|||
         | 
				
			|||
        return avg_metrics | 
				
			|||
 | 
				
			|||
def main(): | 
				
			|||
    # 获取配置 | 
				
			|||
    config = get_default_config() | 
				
			|||
     | 
				
			|||
    # 设置检查点路径 | 
				
			|||
    checkpoint_path = os.path.join( | 
				
			|||
        config.data.model_save_dir, | 
				
			|||
        config.data.best_model_name.format(model_name=config.data.model_name) | 
				
			|||
    ) | 
				
			|||
     | 
				
			|||
    # 初始化测试器并执行测试 | 
				
			|||
    tester = Tester(config, checkpoint_path) | 
				
			|||
    metrics = tester.test() | 
				
			|||
 | 
				
			|||
if __name__ == '__main__': | 
				
			|||
    main() | 
				
			|||
					Loading…
					
					
				
		Reference in new issue