8 changed files with 576 additions and 1686 deletions
File diff suppressed because it is too large
@ -0,0 +1,182 @@ |
|||
from typing import Tuple, Optional |
|||
import torch |
|||
import torch.nn as nn |
|||
import numpy as np |
|||
from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE |
|||
from OCC.Core.TopExp import TopExp_Explorer |
|||
from OCC.Core.TopoDS import TopoDS_Edge, TopoDS_Face, topods_Edge, topods_Face |
|||
from OCC.Core.BRep import BRep_Tool |
|||
from OCC.Core.GeomLProp import GeomLProp_SLProps |
|||
from OCC.Core.BRepAdaptor import BRepAdaptor_Surface |
|||
|
|||
class PatchGraph(nn.Module): |
|||
def __init__(self, num_patches: int, device: torch.device = None): |
|||
super().__init__() |
|||
self.num_patches = num_patches |
|||
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|||
|
|||
# 注册缓冲区 |
|||
self.register_buffer('edge_index', None) # 边的连接关系 (2, E) |
|||
self.register_buffer('edge_type', None) # 边的类型 (E,) 0:凹边 1:凸边 |
|||
self.register_buffer('patch_features', None) # 面片特征 (N, F) |
|||
|
|||
def set_edges(self, edge_index: torch.Tensor, edge_type: torch.Tensor) -> None: |
|||
"""设置边的信息 |
|||
|
|||
参数: |
|||
edge_index: 形状为 (2, E) 的张量,表示边的连接关系 |
|||
edge_type: 形状为 (E,) 的张量,0表示凹边,1表示凸边 |
|||
""" |
|||
if edge_index.shape[0] != 2: |
|||
raise ValueError(f"edge_index 必须是形状为 (2, E) 的张量,但得到 {edge_index.shape}") |
|||
if edge_index.shape[1] != edge_type.shape[0]: |
|||
raise ValueError("edge_index 和 edge_type 的边数量不匹配") |
|||
|
|||
self.edge_index = edge_index.to(self.device) |
|||
self.edge_type = edge_type.to(self.device) |
|||
|
|||
def get_subgraph(self, node_faces: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|||
"""获取子图的边和类型""" |
|||
if self.edge_index is None: |
|||
return None, None |
|||
|
|||
node_faces = node_faces.to(self.device) |
|||
mask = torch.isin(self.edge_index[0], node_faces) & torch.isin(self.edge_index[1], node_faces) |
|||
subgraph_edges = self.edge_index[:, mask] |
|||
subgraph_types = self.edge_type[mask] |
|||
|
|||
return subgraph_edges, subgraph_types |
|||
|
|||
@staticmethod |
|||
def from_preprocessed_data(surf_wcs: np.ndarray, edgeFace_adj: np.ndarray, edge_types: np.ndarray, device: torch.device = None) -> 'PatchGraph': |
|||
num_faces = len(surf_wcs) |
|||
graph = PatchGraph(num_faces, device) |
|||
|
|||
edge_pairs = [] |
|||
edge_types_list = [] |
|||
|
|||
for edge_idx in range(len(edgeFace_adj)): |
|||
connected_faces = np.where(edgeFace_adj[edge_idx] == 1)[0] |
|||
if len(connected_faces) == 2: |
|||
face1, face2 = connected_faces |
|||
edge_pairs.extend([[face1, face2], [face2, face1]]) |
|||
edge_type = edge_types[edge_idx] |
|||
edge_types_list.extend([edge_type, edge_type]) |
|||
|
|||
if edge_pairs: |
|||
edge_index = torch.tensor(edge_pairs, dtype=torch.long, device=graph.device).t() |
|||
edge_type = torch.tensor(edge_types_list, dtype=torch.long, device=graph.device) |
|||
graph.set_edges(edge_index, edge_type) |
|||
|
|||
return graph |
|||
|
|||
def set_features(self, features: torch.Tensor) -> None: |
|||
"""设置面片特征 |
|||
|
|||
参数: |
|||
features: 形状为 (N, F) 的张量,表示面片的特征向量 |
|||
""" |
|||
if features.shape[0] != self.num_patches: |
|||
raise ValueError(f"特征数量 {features.shape[0]} 与面片数量 {self.num_patches} 不匹配") |
|||
self.patch_features = features |
|||
|
|||
def is_clique(self, node_faces: torch.Tensor) -> bool: |
|||
"""检查给定面片集合是否构成完全图 |
|||
|
|||
参数: |
|||
node_faces: 要检查的面片索引集合 |
|||
|
|||
返回: |
|||
bool: 是否为完全图 |
|||
""" |
|||
if self.edge_index is None: |
|||
return False |
|||
|
|||
# 获取子图的边 |
|||
mask = torch.isin(self.edge_index[0], node_faces) & torch.isin(self.edge_index[1], node_faces) |
|||
subgraph_edges = self.edge_index[:, mask] |
|||
|
|||
# 计算完全图应有的边数 |
|||
n = len(node_faces) |
|||
expected_edges = n * (n - 1) // 2 |
|||
|
|||
# 计算实际的边数(考虑无向图) |
|||
actual_edges = len(subgraph_edges[0]) // 2 |
|||
|
|||
return actual_edges == expected_edges |
|||
|
|||
def combine_sdf(self, sdf_values: torch.Tensor) -> torch.Tensor: |
|||
"""根据邻接关系组合SDF值 |
|||
|
|||
参数: |
|||
sdf_values: 形状为 (N,) 的张量,表示每个面片的SDF值 |
|||
|
|||
返回: |
|||
torch.Tensor: 组合后的SDF值 |
|||
""" |
|||
if self.edge_index is None or self.edge_type is None: |
|||
raise RuntimeError("请先设置边的信息") |
|||
|
|||
# 获取所有相连面片对的SDF值 |
|||
sdf_i = sdf_values[self.edge_index[0]] # (E,) |
|||
sdf_j = sdf_values[self.edge_index[1]] # (E,) |
|||
|
|||
# 根据边的类型选择组合方式 |
|||
concave_mask = self.edge_type == 0 |
|||
convex_mask = self.edge_type == 1 |
|||
|
|||
# 初始化结果为第一个SDF值 |
|||
result = sdf_values[0].clone() |
|||
|
|||
# 凹边取最大值,凸边取最小值 |
|||
if torch.any(concave_mask): |
|||
result = torch.max(result, torch.max(torch.stack([sdf_i[concave_mask], |
|||
sdf_j[concave_mask]]))) |
|||
if torch.any(convex_mask): |
|||
result = torch.min(result, torch.min(torch.stack([sdf_i[convex_mask], |
|||
sdf_j[convex_mask]]))) |
|||
|
|||
return result |
|||
|
|||
@staticmethod |
|||
def from_preprocessed_data( |
|||
surf_wcs: np.ndarray, # 形状为(N,)的对象数组,每个元素是形状为(M, 3)的float32数组 |
|||
edgeFace_adj: np.ndarray, # 形状为(num_edges, num_faces)的int32数组 |
|||
edge_types: np.ndarray # 形状为(num_edges,)的int32数组 |
|||
) -> 'PatchGraph': |
|||
"""从预处理的数据直接构建面片邻接图 |
|||
|
|||
参数: |
|||
surf_wcs: 世界坐标系下的曲面几何数据,形状为(N,)的对象数组,每个元素是形状为(M, 3)的float32数组 |
|||
edgeFace_adj: 边-面邻接矩阵,形状为(num_edges, num_faces)的int32数组,1表示边与面相邻 |
|||
edge_types: 边的类型数组,形状为(num_edges,)的int32数组,0表示凹边,1表示凸边 |
|||
|
|||
返回: |
|||
PatchGraph: 初始化好的面片邻接图,包含: |
|||
- edge_index: 形状为(2, num_edges*2)的torch.long张量,表示双向边的连接关系 |
|||
- edge_type: 形状为(num_edges*2,)的torch.long张量,表示每条边的类型 |
|||
""" |
|||
num_faces = len(surf_wcs) |
|||
graph = PatchGraph(num_faces) |
|||
|
|||
# 构建边的索引和类型 |
|||
edge_pairs = [] |
|||
edge_types_list = [] |
|||
|
|||
# 遍历边-面邻接矩阵 |
|||
for edge_idx in range(len(edgeFace_adj)): |
|||
connected_faces = np.where(edgeFace_adj[edge_idx] == 1)[0] |
|||
if len(connected_faces) == 2: |
|||
face1, face2 = connected_faces |
|||
# 添加双向边 |
|||
edge_pairs.extend([[face1, face2], [face2, face1]]) |
|||
# 使用预计算的边类型 |
|||
edge_type = edge_types[edge_idx] |
|||
edge_types_list.extend([edge_type, edge_type]) # 双向边使用相同的类型 |
|||
|
|||
if edge_pairs: # 确保有边存在 |
|||
edge_index = torch.tensor(edge_pairs, dtype=torch.long).t() |
|||
edge_type = torch.tensor(edge_types_list, dtype=torch.long) |
|||
graph.set_edges(edge_index, edge_type) |
|||
|
|||
return graph |
|||
@ -1,161 +1,4 @@ |
|||
import os |
|||
import torch |
|||
import numpy as np |
|||
from torch.utils.data import DataLoader |
|||
from brep2sdf.data.data import BRepSDFDataset |
|||
from brep2sdf.networks.network import BRepToSDF |
|||
from brep2sdf.utils.logger import logger |
|||
from brep2sdf.config.default_config import get_default_config |
|||
import matplotlib.pyplot as plt |
|||
from tqdm import tqdm |
|||
|
|||
class Tester: |
|||
def __init__(self, config, checkpoint_path): |
|||
self.config = config |
|||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|||
|
|||
# 初始化测试数据集 |
|||
self.test_dataset = BRepSDFDataset( |
|||
brep_dir=config.data.brep_dir, |
|||
sdf_dir=config.data.sdf_dir, |
|||
valid_data_dir=config.data.valid_data_dir, |
|||
split='test' |
|||
) |
|||
|
|||
# 初始化数据加载器 |
|||
self.test_loader = DataLoader( |
|||
self.test_dataset, |
|||
batch_size=1, # 测试时使用batch_size=1 |
|||
shuffle=False, |
|||
num_workers=config.train.num_workers, |
|||
pin_memory=False |
|||
) |
|||
|
|||
# 加载模型 |
|||
self.model = BRepToSDF(config).to(self.device) |
|||
self.load_checkpoint(checkpoint_path) |
|||
|
|||
# 创建结果保存目录 |
|||
self.result_dir = os.path.join(config.data.result_save_dir, 'test_results') |
|||
os.makedirs(self.result_dir, exist_ok=True) |
|||
|
|||
def load_checkpoint(self, checkpoint_path): |
|||
"""加载检查点""" |
|||
if not os.path.exists(checkpoint_path): |
|||
raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}") |
|||
|
|||
checkpoint = torch.load(checkpoint_path, map_location=self.device) |
|||
self.model.load_state_dict(checkpoint['model_state_dict']) |
|||
logger.info(f"Loaded checkpoint from {checkpoint_path}") |
|||
|
|||
def compute_metrics(self, pred_sdf, gt_sdf): |
|||
"""计算评估指标""" |
|||
mse = torch.mean((pred_sdf - gt_sdf) ** 2).item() |
|||
mae = torch.mean(torch.abs(pred_sdf - gt_sdf)).item() |
|||
max_error = torch.max(torch.abs(pred_sdf - gt_sdf)).item() |
|||
|
|||
return { |
|||
'mse': mse, |
|||
'mae': mae, |
|||
'max_error': max_error |
|||
} |
|||
|
|||
def visualize_results(self, pred_sdf, gt_sdf, points, save_path): |
|||
"""可视化预测结果""" |
|||
fig = plt.figure(figsize=(15, 5)) |
|||
|
|||
# 绘制预测SDF |
|||
ax1 = fig.add_subplot(131, projection='3d') |
|||
scatter = ax1.scatter(points[:, 0], points[:, 1], points[:, 2], |
|||
c=pred_sdf.squeeze().cpu(), cmap='coolwarm') |
|||
ax1.set_title('Predicted SDF') |
|||
plt.colorbar(scatter) |
|||
|
|||
# 绘制真实SDF |
|||
ax2 = fig.add_subplot(132, projection='3d') |
|||
scatter = ax2.scatter(points[:, 0], points[:, 1], points[:, 2], |
|||
c=gt_sdf.squeeze().cpu(), cmap='coolwarm') |
|||
ax2.set_title('Ground Truth SDF') |
|||
plt.colorbar(scatter) |
|||
|
|||
# 绘制误差图 |
|||
ax3 = fig.add_subplot(133, projection='3d') |
|||
error = torch.abs(pred_sdf - gt_sdf) |
|||
scatter = ax3.scatter(points[:, 0], points[:, 1], points[:, 2], |
|||
c=error.squeeze().cpu(), cmap='Reds') |
|||
ax3.set_title('Absolute Error') |
|||
plt.colorbar(scatter) |
|||
|
|||
plt.tight_layout() |
|||
plt.savefig(save_path) |
|||
plt.close() |
|||
|
|||
def test(self): |
|||
"""执行测试""" |
|||
self.model.eval() |
|||
total_metrics = {'mse': 0, 'mae': 0, 'max_error': 0} |
|||
|
|||
logger.info("Starting testing...") |
|||
|
|||
with torch.no_grad(): |
|||
for idx, batch in enumerate(tqdm(self.test_loader)): |
|||
# 获取数据并移动到设备 |
|||
surf_ncs = batch['surf_ncs'].to(self.device) |
|||
edge_ncs = batch['edge_ncs'].to(self.device) |
|||
surf_pos = batch['surf_pos'].to(self.device) |
|||
edge_pos = batch['edge_pos'].to(self.device) |
|||
vertex_pos = batch['vertex_pos'].to(self.device) |
|||
edge_mask = batch['edge_mask'].to(self.device) |
|||
points = batch['points'].to(self.device) |
|||
gt_sdf = batch['sdf'].to(self.device) |
|||
|
|||
# 前向传播 |
|||
pred_sdf = self.model( |
|||
surf_ncs=surf_ncs, edge_ncs=edge_ncs, |
|||
surf_pos=surf_pos, edge_pos=edge_pos, |
|||
vertex_pos=vertex_pos, edge_mask=edge_mask, |
|||
query_points=points |
|||
) |
|||
|
|||
# 计算指标 |
|||
metrics = self.compute_metrics(pred_sdf, gt_sdf) |
|||
for k, v in metrics.items(): |
|||
total_metrics[k] += v |
|||
|
|||
# 可视化结果 |
|||
if idx % self.config.test.vis_freq == 0: |
|||
save_path = os.path.join(self.result_dir, f'result_{idx}.png') |
|||
self.visualize_results(pred_sdf, gt_sdf, points[0].cpu(), save_path) |
|||
|
|||
# 计算平均指标 |
|||
num_samples = len(self.test_loader) |
|||
avg_metrics = {k: v / num_samples for k, v in total_metrics.items()} |
|||
|
|||
# 保存测试结果 |
|||
logger.info("Test Results:") |
|||
for k, v in avg_metrics.items(): |
|||
logger.info(f"{k}: {v:.6f}") |
|||
|
|||
# 保存指标到文件 |
|||
with open(os.path.join(self.result_dir, 'test_metrics.txt'), 'w') as f: |
|||
for k, v in avg_metrics.items(): |
|||
f.write(f"{k}: {v:.6f}\n") |
|||
|
|||
return avg_metrics |
|||
|
|||
def main(): |
|||
# 获取配置 |
|||
config = get_default_config() |
|||
|
|||
# 设置检查点路径 |
|||
checkpoint_path = os.path.join( |
|||
config.data.model_save_dir, |
|||
config.data.best_model_name.format(model_name=config.data.model_name) |
|||
) |
|||
|
|||
# 初始化测试器并执行测试 |
|||
tester = Tester(config, checkpoint_path) |
|||
metrics = tester.test() |
|||
|
|||
if __name__ == '__main__': |
|||
main() |
|||
model = torch.jit.load("/home/wch/brep2sdf/data/output_data/00000054.pt") |
|||
print(model) |
|||
Loading…
Reference in new issue