1 changed files with 161 additions and 0 deletions
@ -0,0 +1,161 @@ |
|||
import os |
|||
import torch |
|||
import numpy as np |
|||
from torch.utils.data import DataLoader |
|||
from brep2sdf.data.data import BRepSDFDataset |
|||
from brep2sdf.networks.network import BRepToSDF |
|||
from brep2sdf.utils.logger import logger |
|||
from brep2sdf.config.default_config import get_default_config |
|||
import matplotlib.pyplot as plt |
|||
from tqdm import tqdm |
|||
|
|||
class Tester: |
|||
def __init__(self, config, checkpoint_path): |
|||
self.config = config |
|||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|||
|
|||
# 初始化测试数据集 |
|||
self.test_dataset = BRepSDFDataset( |
|||
brep_dir=config.data.brep_dir, |
|||
sdf_dir=config.data.sdf_dir, |
|||
valid_data_dir=config.data.valid_data_dir, |
|||
split='test' |
|||
) |
|||
|
|||
# 初始化数据加载器 |
|||
self.test_loader = DataLoader( |
|||
self.test_dataset, |
|||
batch_size=1, # 测试时使用batch_size=1 |
|||
shuffle=False, |
|||
num_workers=config.train.num_workers, |
|||
pin_memory=False |
|||
) |
|||
|
|||
# 加载模型 |
|||
self.model = BRepToSDF(config).to(self.device) |
|||
self.load_checkpoint(checkpoint_path) |
|||
|
|||
# 创建结果保存目录 |
|||
self.result_dir = os.path.join(config.data.result_save_dir, 'test_results') |
|||
os.makedirs(self.result_dir, exist_ok=True) |
|||
|
|||
def load_checkpoint(self, checkpoint_path): |
|||
"""加载检查点""" |
|||
if not os.path.exists(checkpoint_path): |
|||
raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}") |
|||
|
|||
checkpoint = torch.load(checkpoint_path, map_location=self.device) |
|||
self.model.load_state_dict(checkpoint['model_state_dict']) |
|||
logger.info(f"Loaded checkpoint from {checkpoint_path}") |
|||
|
|||
def compute_metrics(self, pred_sdf, gt_sdf): |
|||
"""计算评估指标""" |
|||
mse = torch.mean((pred_sdf - gt_sdf) ** 2).item() |
|||
mae = torch.mean(torch.abs(pred_sdf - gt_sdf)).item() |
|||
max_error = torch.max(torch.abs(pred_sdf - gt_sdf)).item() |
|||
|
|||
return { |
|||
'mse': mse, |
|||
'mae': mae, |
|||
'max_error': max_error |
|||
} |
|||
|
|||
def visualize_results(self, pred_sdf, gt_sdf, points, save_path): |
|||
"""可视化预测结果""" |
|||
fig = plt.figure(figsize=(15, 5)) |
|||
|
|||
# 绘制预测SDF |
|||
ax1 = fig.add_subplot(131, projection='3d') |
|||
scatter = ax1.scatter(points[:, 0], points[:, 1], points[:, 2], |
|||
c=pred_sdf.squeeze().cpu(), cmap='coolwarm') |
|||
ax1.set_title('Predicted SDF') |
|||
plt.colorbar(scatter) |
|||
|
|||
# 绘制真实SDF |
|||
ax2 = fig.add_subplot(132, projection='3d') |
|||
scatter = ax2.scatter(points[:, 0], points[:, 1], points[:, 2], |
|||
c=gt_sdf.squeeze().cpu(), cmap='coolwarm') |
|||
ax2.set_title('Ground Truth SDF') |
|||
plt.colorbar(scatter) |
|||
|
|||
# 绘制误差图 |
|||
ax3 = fig.add_subplot(133, projection='3d') |
|||
error = torch.abs(pred_sdf - gt_sdf) |
|||
scatter = ax3.scatter(points[:, 0], points[:, 1], points[:, 2], |
|||
c=error.squeeze().cpu(), cmap='Reds') |
|||
ax3.set_title('Absolute Error') |
|||
plt.colorbar(scatter) |
|||
|
|||
plt.tight_layout() |
|||
plt.savefig(save_path) |
|||
plt.close() |
|||
|
|||
def test(self): |
|||
"""执行测试""" |
|||
self.model.eval() |
|||
total_metrics = {'mse': 0, 'mae': 0, 'max_error': 0} |
|||
|
|||
logger.info("Starting testing...") |
|||
|
|||
with torch.no_grad(): |
|||
for idx, batch in enumerate(tqdm(self.test_loader)): |
|||
# 获取数据并移动到设备 |
|||
surf_ncs = batch['surf_ncs'].to(self.device) |
|||
edge_ncs = batch['edge_ncs'].to(self.device) |
|||
surf_pos = batch['surf_pos'].to(self.device) |
|||
edge_pos = batch['edge_pos'].to(self.device) |
|||
vertex_pos = batch['vertex_pos'].to(self.device) |
|||
edge_mask = batch['edge_mask'].to(self.device) |
|||
points = batch['points'].to(self.device) |
|||
gt_sdf = batch['sdf'].to(self.device) |
|||
|
|||
# 前向传播 |
|||
pred_sdf = self.model( |
|||
surf_ncs=surf_ncs, edge_ncs=edge_ncs, |
|||
surf_pos=surf_pos, edge_pos=edge_pos, |
|||
vertex_pos=vertex_pos, edge_mask=edge_mask, |
|||
query_points=points |
|||
) |
|||
|
|||
# 计算指标 |
|||
metrics = self.compute_metrics(pred_sdf, gt_sdf) |
|||
for k, v in metrics.items(): |
|||
total_metrics[k] += v |
|||
|
|||
# 可视化结果 |
|||
if idx % self.config.test.vis_freq == 0: |
|||
save_path = os.path.join(self.result_dir, f'result_{idx}.png') |
|||
self.visualize_results(pred_sdf, gt_sdf, points[0].cpu(), save_path) |
|||
|
|||
# 计算平均指标 |
|||
num_samples = len(self.test_loader) |
|||
avg_metrics = {k: v / num_samples for k, v in total_metrics.items()} |
|||
|
|||
# 保存测试结果 |
|||
logger.info("Test Results:") |
|||
for k, v in avg_metrics.items(): |
|||
logger.info(f"{k}: {v:.6f}") |
|||
|
|||
# 保存指标到文件 |
|||
with open(os.path.join(self.result_dir, 'test_metrics.txt'), 'w') as f: |
|||
for k, v in avg_metrics.items(): |
|||
f.write(f"{k}: {v:.6f}\n") |
|||
|
|||
return avg_metrics |
|||
|
|||
def main(): |
|||
# 获取配置 |
|||
config = get_default_config() |
|||
|
|||
# 设置检查点路径 |
|||
checkpoint_path = os.path.join( |
|||
config.data.model_save_dir, |
|||
config.data.best_model_name.format(model_name=config.data.model_name) |
|||
) |
|||
|
|||
# 初始化测试器并执行测试 |
|||
tester = Tester(config, checkpoint_path) |
|||
metrics = tester.test() |
|||
|
|||
if __name__ == '__main__': |
|||
main() |
Loading…
Reference in new issue