1 changed files with 161 additions and 0 deletions
@ -0,0 +1,161 @@ |
|||||
|
import os |
||||
|
import torch |
||||
|
import numpy as np |
||||
|
from torch.utils.data import DataLoader |
||||
|
from brep2sdf.data.data import BRepSDFDataset |
||||
|
from brep2sdf.networks.network import BRepToSDF |
||||
|
from brep2sdf.utils.logger import logger |
||||
|
from brep2sdf.config.default_config import get_default_config |
||||
|
import matplotlib.pyplot as plt |
||||
|
from tqdm import tqdm |
||||
|
|
||||
|
class Tester: |
||||
|
def __init__(self, config, checkpoint_path): |
||||
|
self.config = config |
||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
||||
|
|
||||
|
# 初始化测试数据集 |
||||
|
self.test_dataset = BRepSDFDataset( |
||||
|
brep_dir=config.data.brep_dir, |
||||
|
sdf_dir=config.data.sdf_dir, |
||||
|
valid_data_dir=config.data.valid_data_dir, |
||||
|
split='test' |
||||
|
) |
||||
|
|
||||
|
# 初始化数据加载器 |
||||
|
self.test_loader = DataLoader( |
||||
|
self.test_dataset, |
||||
|
batch_size=1, # 测试时使用batch_size=1 |
||||
|
shuffle=False, |
||||
|
num_workers=config.train.num_workers, |
||||
|
pin_memory=False |
||||
|
) |
||||
|
|
||||
|
# 加载模型 |
||||
|
self.model = BRepToSDF(config).to(self.device) |
||||
|
self.load_checkpoint(checkpoint_path) |
||||
|
|
||||
|
# 创建结果保存目录 |
||||
|
self.result_dir = os.path.join(config.data.result_save_dir, 'test_results') |
||||
|
os.makedirs(self.result_dir, exist_ok=True) |
||||
|
|
||||
|
def load_checkpoint(self, checkpoint_path): |
||||
|
"""加载检查点""" |
||||
|
if not os.path.exists(checkpoint_path): |
||||
|
raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}") |
||||
|
|
||||
|
checkpoint = torch.load(checkpoint_path, map_location=self.device) |
||||
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
||||
|
logger.info(f"Loaded checkpoint from {checkpoint_path}") |
||||
|
|
||||
|
def compute_metrics(self, pred_sdf, gt_sdf): |
||||
|
"""计算评估指标""" |
||||
|
mse = torch.mean((pred_sdf - gt_sdf) ** 2).item() |
||||
|
mae = torch.mean(torch.abs(pred_sdf - gt_sdf)).item() |
||||
|
max_error = torch.max(torch.abs(pred_sdf - gt_sdf)).item() |
||||
|
|
||||
|
return { |
||||
|
'mse': mse, |
||||
|
'mae': mae, |
||||
|
'max_error': max_error |
||||
|
} |
||||
|
|
||||
|
def visualize_results(self, pred_sdf, gt_sdf, points, save_path): |
||||
|
"""可视化预测结果""" |
||||
|
fig = plt.figure(figsize=(15, 5)) |
||||
|
|
||||
|
# 绘制预测SDF |
||||
|
ax1 = fig.add_subplot(131, projection='3d') |
||||
|
scatter = ax1.scatter(points[:, 0], points[:, 1], points[:, 2], |
||||
|
c=pred_sdf.squeeze().cpu(), cmap='coolwarm') |
||||
|
ax1.set_title('Predicted SDF') |
||||
|
plt.colorbar(scatter) |
||||
|
|
||||
|
# 绘制真实SDF |
||||
|
ax2 = fig.add_subplot(132, projection='3d') |
||||
|
scatter = ax2.scatter(points[:, 0], points[:, 1], points[:, 2], |
||||
|
c=gt_sdf.squeeze().cpu(), cmap='coolwarm') |
||||
|
ax2.set_title('Ground Truth SDF') |
||||
|
plt.colorbar(scatter) |
||||
|
|
||||
|
# 绘制误差图 |
||||
|
ax3 = fig.add_subplot(133, projection='3d') |
||||
|
error = torch.abs(pred_sdf - gt_sdf) |
||||
|
scatter = ax3.scatter(points[:, 0], points[:, 1], points[:, 2], |
||||
|
c=error.squeeze().cpu(), cmap='Reds') |
||||
|
ax3.set_title('Absolute Error') |
||||
|
plt.colorbar(scatter) |
||||
|
|
||||
|
plt.tight_layout() |
||||
|
plt.savefig(save_path) |
||||
|
plt.close() |
||||
|
|
||||
|
def test(self): |
||||
|
"""执行测试""" |
||||
|
self.model.eval() |
||||
|
total_metrics = {'mse': 0, 'mae': 0, 'max_error': 0} |
||||
|
|
||||
|
logger.info("Starting testing...") |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
for idx, batch in enumerate(tqdm(self.test_loader)): |
||||
|
# 获取数据并移动到设备 |
||||
|
surf_ncs = batch['surf_ncs'].to(self.device) |
||||
|
edge_ncs = batch['edge_ncs'].to(self.device) |
||||
|
surf_pos = batch['surf_pos'].to(self.device) |
||||
|
edge_pos = batch['edge_pos'].to(self.device) |
||||
|
vertex_pos = batch['vertex_pos'].to(self.device) |
||||
|
edge_mask = batch['edge_mask'].to(self.device) |
||||
|
points = batch['points'].to(self.device) |
||||
|
gt_sdf = batch['sdf'].to(self.device) |
||||
|
|
||||
|
# 前向传播 |
||||
|
pred_sdf = self.model( |
||||
|
surf_ncs=surf_ncs, edge_ncs=edge_ncs, |
||||
|
surf_pos=surf_pos, edge_pos=edge_pos, |
||||
|
vertex_pos=vertex_pos, edge_mask=edge_mask, |
||||
|
query_points=points |
||||
|
) |
||||
|
|
||||
|
# 计算指标 |
||||
|
metrics = self.compute_metrics(pred_sdf, gt_sdf) |
||||
|
for k, v in metrics.items(): |
||||
|
total_metrics[k] += v |
||||
|
|
||||
|
# 可视化结果 |
||||
|
if idx % self.config.test.vis_freq == 0: |
||||
|
save_path = os.path.join(self.result_dir, f'result_{idx}.png') |
||||
|
self.visualize_results(pred_sdf, gt_sdf, points[0].cpu(), save_path) |
||||
|
|
||||
|
# 计算平均指标 |
||||
|
num_samples = len(self.test_loader) |
||||
|
avg_metrics = {k: v / num_samples for k, v in total_metrics.items()} |
||||
|
|
||||
|
# 保存测试结果 |
||||
|
logger.info("Test Results:") |
||||
|
for k, v in avg_metrics.items(): |
||||
|
logger.info(f"{k}: {v:.6f}") |
||||
|
|
||||
|
# 保存指标到文件 |
||||
|
with open(os.path.join(self.result_dir, 'test_metrics.txt'), 'w') as f: |
||||
|
for k, v in avg_metrics.items(): |
||||
|
f.write(f"{k}: {v:.6f}\n") |
||||
|
|
||||
|
return avg_metrics |
||||
|
|
||||
|
def main(): |
||||
|
# 获取配置 |
||||
|
config = get_default_config() |
||||
|
|
||||
|
# 设置检查点路径 |
||||
|
checkpoint_path = os.path.join( |
||||
|
config.data.model_save_dir, |
||||
|
config.data.best_model_name.format(model_name=config.data.model_name) |
||||
|
) |
||||
|
|
||||
|
# 初始化测试器并执行测试 |
||||
|
tester = Tester(config, checkpoint_path) |
||||
|
metrics = tester.test() |
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
main() |
Loading…
Reference in new issue