Browse Source

feat: diagnose script

main
mckay 3 months ago
parent
commit
6263f9d353
  1. 162
      brep2sdf/scripts/diagnose.py

162
brep2sdf/scripts/diagnose.py

@ -0,0 +1,162 @@
import torch
import torch.nn as nn
from brep2sdf.networks.network import BRepToSDF
from brep2sdf.config.default_config import get_default_config
from brep2sdf.utils.logger import logger
from brep2sdf.networks.loss import Brep2SDFLoss
class ModelDiagnostics:
def __init__(self, model, config):
self.model = model
self.config = config
self.device = next(model.parameters()).device
def check_model_architecture(self):
"""检查模型架构"""
logger.info("\n=== 模型架构检查 ===")
# 检查每个子模块
modules = {
'query_encoder': self.model.query_encoder,
'brep_embedder': self.model.brep_embedder,
'transformer': self.model.transformer,
'sdf_head': self.model.sdf_head
}
for name, module in modules.items():
logger.info(f"\n{name}:")
logger.info(f" 参数数量: {sum(p.numel() for p in module.parameters()):,}")
logger.info(f" 需要梯度的参数: {sum(p.numel() for p in module.parameters() if p.requires_grad):,}")
# 检查模块中的关键层
for layer_name, layer in module.named_modules():
if isinstance(layer, (nn.Linear, nn.LayerNorm)):
logger.info(f" {layer_name}:")
logger.info(f" 输入维度: {layer.in_features if hasattr(layer, 'in_features') else 'N/A'}")
logger.info(f" 输出维度: {layer.out_features if hasattr(layer, 'out_features') else 'N/A'}")
def check_forward_pass(self, batch):
"""检查前向传播过程"""
logger.info("\n=== 前向传播检查 ===")
self.model.eval()
with torch.no_grad():
# 1. 查询点编码
query_features = self.model.query_encoder(batch['query_points'])
logger.info("\n查询点编码:")
logger.info(f" 输入形状: {batch['query_points'].shape}")
logger.info(f" 输出形状: {query_features.shape}")
logger.info(f" 特征统计:")
logger.info(f" 均值: {query_features.mean():.4f}")
logger.info(f" 标准差: {query_features.std():.4f}")
logger.info(f" 最大值: {query_features.max():.4f}")
logger.info(f" 最小值: {query_features.min():.4f}")
# 2. B-rep特征编码
brep_features = self.model.brep_embedder(
edge_ncs=batch['edge_ncs'],
edge_pos=batch['edge_pos'],
edge_mask=batch['edge_mask'],
surf_ncs=batch['surf_ncs'],
surf_pos=batch['surf_pos'],
vertex_pos=batch['vertex_pos']
)
logger.info("\nB-rep特征编码:")
logger.info(f" 输出形状: {brep_features.shape}")
logger.info(f" 特征统计:")
logger.info(f" 均值: {brep_features.mean():.4f}")
logger.info(f" 标准差: {brep_features.std():.4f}")
# 3. 全局特征
global_features = brep_features.mean(dim=1)
logger.info("\n全局特征:")
logger.info(f" 形状: {global_features.shape}")
logger.info(f" 统计:")
logger.info(f" 均值: {global_features.mean():.4f}")
logger.info(f" 标准差: {global_features.std():.4f}")
# 4. 最终预测
sdf = self.model(**batch)
logger.info("\nSDF预测:")
logger.info(f" 形状: {sdf.shape}")
logger.info(f" 统计:")
logger.info(f" 均值: {sdf.mean():.4f}")
logger.info(f" 标准差: {sdf.std():.4f}")
logger.info(f" 最大值: {sdf.max():.4f}")
logger.info(f" 最小值: {sdf.min():.4f}")
def check_gradients(self, batch, criterion):
"""检查梯度流动"""
logger.info("\n=== 梯度检查 ===")
self.model.train()
self.model.zero_grad()
# 前向传播
pred = self.model(**batch)
loss = criterion(pred, batch['gt_sdf'])
# 反向传播
loss.backward()
# 检查每个模块的梯度
modules = {
'query_encoder': self.model.query_encoder,
'brep_embedder': self.model.brep_embedder,
'transformer': self.model.transformer,
'sdf_head': self.model.sdf_head
}
for name, module in modules.items():
logger.info(f"\n{name} 梯度:")
total_grad_norm = 0
for param_name, param in module.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm().item()
total_grad_norm += grad_norm
logger.info(f" {param_name}:")
logger.info(f" 梯度范数: {grad_norm:.6f}")
logger.info(f" 梯度均值: {param.grad.mean().item():.6f}")
logger.info(f" 梯度标准差: {param.grad.std().item():.6f}")
else:
logger.warning(f" {param_name}: 没有梯度!")
logger.info(f" 总梯度范数: {total_grad_norm:.6f}")
def main():
# 获取配置
config = get_default_config()
# 初始化模型
model = BRepToSDF(config=config)
# 创建诊断器
diagnostics = ModelDiagnostics(model, config)
# 生成测试数据
batch = {
'edge_ncs': torch.randn(2, config.data.max_face, config.data.max_edge,
config.model.num_edge_points, 3),
'edge_pos': torch.randn(2, config.data.max_face, config.data.max_edge, 6),
'edge_mask': torch.ones(2, config.data.max_face, config.data.max_edge,
dtype=torch.bool),
'surf_ncs': torch.randn(2, config.data.max_face, config.model.num_surf_points, 3),
'surf_pos': torch.randn(2, config.data.max_face, 6),
'vertex_pos': torch.randn(2, config.data.max_face, config.data.max_edge, 2, 3),
'query_points': torch.randn(2, 1000, 3),
#'sdf': torch.randn(2, 1000, 1)
}
# 初始化损失函数
criterion = Brep2SDFLoss(
batch_size=config.train.batch_size,
enforce_minmax=(config.train.clamping_distance > 0),
clamping_distance=config.train.clamping_distance
)
# 运行诊断
diagnostics.check_model_architecture()
diagnostics.check_forward_pass(batch)
diagnostics.check_gradients(batch, criterion)
if __name__ == '__main__':
main()
Loading…
Cancel
Save