Browse Source

feat: test脚本

main
mckay 4 months ago
parent
commit
5d0d76f5eb
  1. 161
      brep2sdf/test.py

161
brep2sdf/test.py

@ -0,0 +1,161 @@
import os
import torch
import numpy as np
from torch.utils.data import DataLoader
from brep2sdf.data.data import BRepSDFDataset
from brep2sdf.networks.network import BRepToSDF
from brep2sdf.utils.logger import logger
from brep2sdf.config.default_config import get_default_config
import matplotlib.pyplot as plt
from tqdm import tqdm
class Tester:
def __init__(self, config, checkpoint_path):
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 初始化测试数据集
self.test_dataset = BRepSDFDataset(
brep_dir=config.data.brep_dir,
sdf_dir=config.data.sdf_dir,
valid_data_dir=config.data.valid_data_dir,
split='test'
)
# 初始化数据加载器
self.test_loader = DataLoader(
self.test_dataset,
batch_size=1, # 测试时使用batch_size=1
shuffle=False,
num_workers=config.train.num_workers,
pin_memory=False
)
# 加载模型
self.model = BRepToSDF(config).to(self.device)
self.load_checkpoint(checkpoint_path)
# 创建结果保存目录
self.result_dir = os.path.join(config.data.result_save_dir, 'test_results')
os.makedirs(self.result_dir, exist_ok=True)
def load_checkpoint(self, checkpoint_path):
"""加载检查点"""
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
logger.info(f"Loaded checkpoint from {checkpoint_path}")
def compute_metrics(self, pred_sdf, gt_sdf):
"""计算评估指标"""
mse = torch.mean((pred_sdf - gt_sdf) ** 2).item()
mae = torch.mean(torch.abs(pred_sdf - gt_sdf)).item()
max_error = torch.max(torch.abs(pred_sdf - gt_sdf)).item()
return {
'mse': mse,
'mae': mae,
'max_error': max_error
}
def visualize_results(self, pred_sdf, gt_sdf, points, save_path):
"""可视化预测结果"""
fig = plt.figure(figsize=(15, 5))
# 绘制预测SDF
ax1 = fig.add_subplot(131, projection='3d')
scatter = ax1.scatter(points[:, 0], points[:, 1], points[:, 2],
c=pred_sdf.squeeze().cpu(), cmap='coolwarm')
ax1.set_title('Predicted SDF')
plt.colorbar(scatter)
# 绘制真实SDF
ax2 = fig.add_subplot(132, projection='3d')
scatter = ax2.scatter(points[:, 0], points[:, 1], points[:, 2],
c=gt_sdf.squeeze().cpu(), cmap='coolwarm')
ax2.set_title('Ground Truth SDF')
plt.colorbar(scatter)
# 绘制误差图
ax3 = fig.add_subplot(133, projection='3d')
error = torch.abs(pred_sdf - gt_sdf)
scatter = ax3.scatter(points[:, 0], points[:, 1], points[:, 2],
c=error.squeeze().cpu(), cmap='Reds')
ax3.set_title('Absolute Error')
plt.colorbar(scatter)
plt.tight_layout()
plt.savefig(save_path)
plt.close()
def test(self):
"""执行测试"""
self.model.eval()
total_metrics = {'mse': 0, 'mae': 0, 'max_error': 0}
logger.info("Starting testing...")
with torch.no_grad():
for idx, batch in enumerate(tqdm(self.test_loader)):
# 获取数据并移动到设备
surf_ncs = batch['surf_ncs'].to(self.device)
edge_ncs = batch['edge_ncs'].to(self.device)
surf_pos = batch['surf_pos'].to(self.device)
edge_pos = batch['edge_pos'].to(self.device)
vertex_pos = batch['vertex_pos'].to(self.device)
edge_mask = batch['edge_mask'].to(self.device)
points = batch['points'].to(self.device)
gt_sdf = batch['sdf'].to(self.device)
# 前向传播
pred_sdf = self.model(
surf_ncs=surf_ncs, edge_ncs=edge_ncs,
surf_pos=surf_pos, edge_pos=edge_pos,
vertex_pos=vertex_pos, edge_mask=edge_mask,
query_points=points
)
# 计算指标
metrics = self.compute_metrics(pred_sdf, gt_sdf)
for k, v in metrics.items():
total_metrics[k] += v
# 可视化结果
if idx % self.config.test.vis_freq == 0:
save_path = os.path.join(self.result_dir, f'result_{idx}.png')
self.visualize_results(pred_sdf, gt_sdf, points[0].cpu(), save_path)
# 计算平均指标
num_samples = len(self.test_loader)
avg_metrics = {k: v / num_samples for k, v in total_metrics.items()}
# 保存测试结果
logger.info("Test Results:")
for k, v in avg_metrics.items():
logger.info(f"{k}: {v:.6f}")
# 保存指标到文件
with open(os.path.join(self.result_dir, 'test_metrics.txt'), 'w') as f:
for k, v in avg_metrics.items():
f.write(f"{k}: {v:.6f}\n")
return avg_metrics
def main():
# 获取配置
config = get_default_config()
# 设置检查点路径
checkpoint_path = os.path.join(
config.data.model_save_dir,
config.data.best_model_name.format(model_name=config.data.model_name)
)
# 初始化测试器并执行测试
tester = Tester(config, checkpoint_path)
metrics = tester.test()
if __name__ == '__main__':
main()
Loading…
Cancel
Save