From 6263f9d35320d5046c8f6cd1aae2f323e167a7fb Mon Sep 17 00:00:00 2001 From: mckay Date: Fri, 13 Dec 2024 23:09:40 +0800 Subject: [PATCH] feat: diagnose script --- brep2sdf/scripts/diagnose.py | 162 +++++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 brep2sdf/scripts/diagnose.py diff --git a/brep2sdf/scripts/diagnose.py b/brep2sdf/scripts/diagnose.py new file mode 100644 index 0000000..01a0be0 --- /dev/null +++ b/brep2sdf/scripts/diagnose.py @@ -0,0 +1,162 @@ +import torch +import torch.nn as nn +from brep2sdf.networks.network import BRepToSDF +from brep2sdf.config.default_config import get_default_config +from brep2sdf.utils.logger import logger +from brep2sdf.networks.loss import Brep2SDFLoss + +class ModelDiagnostics: + def __init__(self, model, config): + self.model = model + self.config = config + self.device = next(model.parameters()).device + + def check_model_architecture(self): + """检查模型架构""" + logger.info("\n=== 模型架构检查 ===") + + # 检查每个子模块 + modules = { + 'query_encoder': self.model.query_encoder, + 'brep_embedder': self.model.brep_embedder, + 'transformer': self.model.transformer, + 'sdf_head': self.model.sdf_head + } + + for name, module in modules.items(): + logger.info(f"\n{name}:") + logger.info(f" 参数数量: {sum(p.numel() for p in module.parameters()):,}") + logger.info(f" 需要梯度的参数: {sum(p.numel() for p in module.parameters() if p.requires_grad):,}") + + # 检查模块中的关键层 + for layer_name, layer in module.named_modules(): + if isinstance(layer, (nn.Linear, nn.LayerNorm)): + logger.info(f" {layer_name}:") + logger.info(f" 输入维度: {layer.in_features if hasattr(layer, 'in_features') else 'N/A'}") + logger.info(f" 输出维度: {layer.out_features if hasattr(layer, 'out_features') else 'N/A'}") + + def check_forward_pass(self, batch): + """检查前向传播过程""" + logger.info("\n=== 前向传播检查 ===") + + self.model.eval() + with torch.no_grad(): + # 1. 查询点编码 + query_features = self.model.query_encoder(batch['query_points']) + logger.info("\n查询点编码:") + logger.info(f" 输入形状: {batch['query_points'].shape}") + logger.info(f" 输出形状: {query_features.shape}") + logger.info(f" 特征统计:") + logger.info(f" 均值: {query_features.mean():.4f}") + logger.info(f" 标准差: {query_features.std():.4f}") + logger.info(f" 最大值: {query_features.max():.4f}") + logger.info(f" 最小值: {query_features.min():.4f}") + + # 2. B-rep特征编码 + brep_features = self.model.brep_embedder( + edge_ncs=batch['edge_ncs'], + edge_pos=batch['edge_pos'], + edge_mask=batch['edge_mask'], + surf_ncs=batch['surf_ncs'], + surf_pos=batch['surf_pos'], + vertex_pos=batch['vertex_pos'] + ) + logger.info("\nB-rep特征编码:") + logger.info(f" 输出形状: {brep_features.shape}") + logger.info(f" 特征统计:") + logger.info(f" 均值: {brep_features.mean():.4f}") + logger.info(f" 标准差: {brep_features.std():.4f}") + + # 3. 全局特征 + global_features = brep_features.mean(dim=1) + logger.info("\n全局特征:") + logger.info(f" 形状: {global_features.shape}") + logger.info(f" 统计:") + logger.info(f" 均值: {global_features.mean():.4f}") + logger.info(f" 标准差: {global_features.std():.4f}") + + # 4. 最终预测 + sdf = self.model(**batch) + logger.info("\nSDF预测:") + logger.info(f" 形状: {sdf.shape}") + logger.info(f" 统计:") + logger.info(f" 均值: {sdf.mean():.4f}") + logger.info(f" 标准差: {sdf.std():.4f}") + logger.info(f" 最大值: {sdf.max():.4f}") + logger.info(f" 最小值: {sdf.min():.4f}") + + def check_gradients(self, batch, criterion): + """检查梯度流动""" + logger.info("\n=== 梯度检查 ===") + + self.model.train() + self.model.zero_grad() + + # 前向传播 + pred = self.model(**batch) + loss = criterion(pred, batch['gt_sdf']) + + # 反向传播 + loss.backward() + + # 检查每个模块的梯度 + modules = { + 'query_encoder': self.model.query_encoder, + 'brep_embedder': self.model.brep_embedder, + 'transformer': self.model.transformer, + 'sdf_head': self.model.sdf_head + } + + for name, module in modules.items(): + logger.info(f"\n{name} 梯度:") + total_grad_norm = 0 + for param_name, param in module.named_parameters(): + if param.grad is not None: + grad_norm = param.grad.norm().item() + total_grad_norm += grad_norm + logger.info(f" {param_name}:") + logger.info(f" 梯度范数: {grad_norm:.6f}") + logger.info(f" 梯度均值: {param.grad.mean().item():.6f}") + logger.info(f" 梯度标准差: {param.grad.std().item():.6f}") + else: + logger.warning(f" {param_name}: 没有梯度!") + logger.info(f" 总梯度范数: {total_grad_norm:.6f}") + +def main(): + # 获取配置 + config = get_default_config() + + # 初始化模型 + model = BRepToSDF(config=config) + + # 创建诊断器 + diagnostics = ModelDiagnostics(model, config) + + # 生成测试数据 + batch = { + 'edge_ncs': torch.randn(2, config.data.max_face, config.data.max_edge, + config.model.num_edge_points, 3), + 'edge_pos': torch.randn(2, config.data.max_face, config.data.max_edge, 6), + 'edge_mask': torch.ones(2, config.data.max_face, config.data.max_edge, + dtype=torch.bool), + 'surf_ncs': torch.randn(2, config.data.max_face, config.model.num_surf_points, 3), + 'surf_pos': torch.randn(2, config.data.max_face, 6), + 'vertex_pos': torch.randn(2, config.data.max_face, config.data.max_edge, 2, 3), + 'query_points': torch.randn(2, 1000, 3), + #'sdf': torch.randn(2, 1000, 1) + } + + # 初始化损失函数 + criterion = Brep2SDFLoss( + batch_size=config.train.batch_size, + enforce_minmax=(config.train.clamping_distance > 0), + clamping_distance=config.train.clamping_distance + ) + + # 运行诊断 + diagnostics.check_model_architecture() + diagnostics.check_forward_pass(batch) + diagnostics.check_gradients(batch, criterion) + +if __name__ == '__main__': + main() \ No newline at end of file