1 changed files with 162 additions and 0 deletions
			
			
		@ -0,0 +1,162 @@ | 
				
			|||||
 | 
					import torch | 
				
			||||
 | 
					import torch.nn as nn | 
				
			||||
 | 
					from brep2sdf.networks.network import BRepToSDF | 
				
			||||
 | 
					from brep2sdf.config.default_config import get_default_config | 
				
			||||
 | 
					from brep2sdf.utils.logger import logger | 
				
			||||
 | 
					from brep2sdf.networks.loss import Brep2SDFLoss | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					class ModelDiagnostics: | 
				
			||||
 | 
					    def __init__(self, model, config): | 
				
			||||
 | 
					        self.model = model | 
				
			||||
 | 
					        self.config = config | 
				
			||||
 | 
					        self.device = next(model.parameters()).device | 
				
			||||
 | 
					         | 
				
			||||
 | 
					    def check_model_architecture(self): | 
				
			||||
 | 
					        """检查模型架构""" | 
				
			||||
 | 
					        logger.info("\n=== 模型架构检查 ===") | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        # 检查每个子模块 | 
				
			||||
 | 
					        modules = { | 
				
			||||
 | 
					            'query_encoder': self.model.query_encoder, | 
				
			||||
 | 
					            'brep_embedder': self.model.brep_embedder, | 
				
			||||
 | 
					            'transformer': self.model.transformer, | 
				
			||||
 | 
					            'sdf_head': self.model.sdf_head | 
				
			||||
 | 
					        } | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        for name, module in modules.items(): | 
				
			||||
 | 
					            logger.info(f"\n{name}:") | 
				
			||||
 | 
					            logger.info(f"  参数数量: {sum(p.numel() for p in module.parameters()):,}") | 
				
			||||
 | 
					            logger.info(f"  需要梯度的参数: {sum(p.numel() for p in module.parameters() if p.requires_grad):,}") | 
				
			||||
 | 
					             | 
				
			||||
 | 
					            # 检查模块中的关键层 | 
				
			||||
 | 
					            for layer_name, layer in module.named_modules(): | 
				
			||||
 | 
					                if isinstance(layer, (nn.Linear, nn.LayerNorm)): | 
				
			||||
 | 
					                    logger.info(f"  {layer_name}:") | 
				
			||||
 | 
					                    logger.info(f"    输入维度: {layer.in_features if hasattr(layer, 'in_features') else 'N/A'}") | 
				
			||||
 | 
					                    logger.info(f"    输出维度: {layer.out_features if hasattr(layer, 'out_features') else 'N/A'}") | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    def check_forward_pass(self, batch): | 
				
			||||
 | 
					        """检查前向传播过程""" | 
				
			||||
 | 
					        logger.info("\n=== 前向传播检查 ===") | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        self.model.eval() | 
				
			||||
 | 
					        with torch.no_grad(): | 
				
			||||
 | 
					            # 1. 查询点编码 | 
				
			||||
 | 
					            query_features = self.model.query_encoder(batch['query_points']) | 
				
			||||
 | 
					            logger.info("\n查询点编码:") | 
				
			||||
 | 
					            logger.info(f"  输入形状: {batch['query_points'].shape}") | 
				
			||||
 | 
					            logger.info(f"  输出形状: {query_features.shape}") | 
				
			||||
 | 
					            logger.info(f"  特征统计:") | 
				
			||||
 | 
					            logger.info(f"    均值: {query_features.mean():.4f}") | 
				
			||||
 | 
					            logger.info(f"    标准差: {query_features.std():.4f}") | 
				
			||||
 | 
					            logger.info(f"    最大值: {query_features.max():.4f}") | 
				
			||||
 | 
					            logger.info(f"    最小值: {query_features.min():.4f}") | 
				
			||||
 | 
					             | 
				
			||||
 | 
					            # 2. B-rep特征编码 | 
				
			||||
 | 
					            brep_features = self.model.brep_embedder( | 
				
			||||
 | 
					                edge_ncs=batch['edge_ncs'], | 
				
			||||
 | 
					                edge_pos=batch['edge_pos'], | 
				
			||||
 | 
					                edge_mask=batch['edge_mask'], | 
				
			||||
 | 
					                surf_ncs=batch['surf_ncs'], | 
				
			||||
 | 
					                surf_pos=batch['surf_pos'], | 
				
			||||
 | 
					                vertex_pos=batch['vertex_pos'] | 
				
			||||
 | 
					            ) | 
				
			||||
 | 
					            logger.info("\nB-rep特征编码:") | 
				
			||||
 | 
					            logger.info(f"  输出形状: {brep_features.shape}") | 
				
			||||
 | 
					            logger.info(f"  特征统计:") | 
				
			||||
 | 
					            logger.info(f"    均值: {brep_features.mean():.4f}") | 
				
			||||
 | 
					            logger.info(f"    标准差: {brep_features.std():.4f}") | 
				
			||||
 | 
					             | 
				
			||||
 | 
					            # 3. 全局特征 | 
				
			||||
 | 
					            global_features = brep_features.mean(dim=1) | 
				
			||||
 | 
					            logger.info("\n全局特征:") | 
				
			||||
 | 
					            logger.info(f"  形状: {global_features.shape}") | 
				
			||||
 | 
					            logger.info(f"  统计:") | 
				
			||||
 | 
					            logger.info(f"    均值: {global_features.mean():.4f}") | 
				
			||||
 | 
					            logger.info(f"    标准差: {global_features.std():.4f}") | 
				
			||||
 | 
					             | 
				
			||||
 | 
					            # 4. 最终预测 | 
				
			||||
 | 
					            sdf = self.model(**batch) | 
				
			||||
 | 
					            logger.info("\nSDF预测:") | 
				
			||||
 | 
					            logger.info(f"  形状: {sdf.shape}") | 
				
			||||
 | 
					            logger.info(f"  统计:") | 
				
			||||
 | 
					            logger.info(f"    均值: {sdf.mean():.4f}") | 
				
			||||
 | 
					            logger.info(f"    标准差: {sdf.std():.4f}") | 
				
			||||
 | 
					            logger.info(f"    最大值: {sdf.max():.4f}") | 
				
			||||
 | 
					            logger.info(f"    最小值: {sdf.min():.4f}") | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    def check_gradients(self, batch, criterion): | 
				
			||||
 | 
					        """检查梯度流动""" | 
				
			||||
 | 
					        logger.info("\n=== 梯度检查 ===") | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        self.model.train() | 
				
			||||
 | 
					        self.model.zero_grad() | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        # 前向传播 | 
				
			||||
 | 
					        pred = self.model(**batch) | 
				
			||||
 | 
					        loss = criterion(pred, batch['gt_sdf']) | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        # 反向传播 | 
				
			||||
 | 
					        loss.backward() | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        # 检查每个模块的梯度 | 
				
			||||
 | 
					        modules = { | 
				
			||||
 | 
					            'query_encoder': self.model.query_encoder, | 
				
			||||
 | 
					            'brep_embedder': self.model.brep_embedder, | 
				
			||||
 | 
					            'transformer': self.model.transformer, | 
				
			||||
 | 
					            'sdf_head': self.model.sdf_head | 
				
			||||
 | 
					        } | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        for name, module in modules.items(): | 
				
			||||
 | 
					            logger.info(f"\n{name} 梯度:") | 
				
			||||
 | 
					            total_grad_norm = 0 | 
				
			||||
 | 
					            for param_name, param in module.named_parameters(): | 
				
			||||
 | 
					                if param.grad is not None: | 
				
			||||
 | 
					                    grad_norm = param.grad.norm().item() | 
				
			||||
 | 
					                    total_grad_norm += grad_norm | 
				
			||||
 | 
					                    logger.info(f"  {param_name}:") | 
				
			||||
 | 
					                    logger.info(f"    梯度范数: {grad_norm:.6f}") | 
				
			||||
 | 
					                    logger.info(f"    梯度均值: {param.grad.mean().item():.6f}") | 
				
			||||
 | 
					                    logger.info(f"    梯度标准差: {param.grad.std().item():.6f}") | 
				
			||||
 | 
					                else: | 
				
			||||
 | 
					                    logger.warning(f"  {param_name}: 没有梯度!") | 
				
			||||
 | 
					            logger.info(f"  总梯度范数: {total_grad_norm:.6f}") | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					def main(): | 
				
			||||
 | 
					    # 获取配置 | 
				
			||||
 | 
					    config = get_default_config() | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    # 初始化模型 | 
				
			||||
 | 
					    model = BRepToSDF(config=config) | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    # 创建诊断器 | 
				
			||||
 | 
					    diagnostics = ModelDiagnostics(model, config) | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    # 生成测试数据 | 
				
			||||
 | 
					    batch = { | 
				
			||||
 | 
					        'edge_ncs': torch.randn(2, config.data.max_face, config.data.max_edge,  | 
				
			||||
 | 
					                               config.model.num_edge_points, 3), | 
				
			||||
 | 
					        'edge_pos': torch.randn(2, config.data.max_face, config.data.max_edge, 6), | 
				
			||||
 | 
					        'edge_mask': torch.ones(2, config.data.max_face, config.data.max_edge,  | 
				
			||||
 | 
					                               dtype=torch.bool), | 
				
			||||
 | 
					        'surf_ncs': torch.randn(2, config.data.max_face, config.model.num_surf_points, 3), | 
				
			||||
 | 
					        'surf_pos': torch.randn(2, config.data.max_face, 6), | 
				
			||||
 | 
					        'vertex_pos': torch.randn(2, config.data.max_face, config.data.max_edge, 2, 3), | 
				
			||||
 | 
					        'query_points': torch.randn(2, 1000, 3), | 
				
			||||
 | 
					        #'sdf': torch.randn(2, 1000, 1) | 
				
			||||
 | 
					    } | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    # 初始化损失函数 | 
				
			||||
 | 
					    criterion = Brep2SDFLoss( | 
				
			||||
 | 
					        batch_size=config.train.batch_size, | 
				
			||||
 | 
					        enforce_minmax=(config.train.clamping_distance > 0), | 
				
			||||
 | 
					        clamping_distance=config.train.clamping_distance | 
				
			||||
 | 
					    ) | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    # 运行诊断 | 
				
			||||
 | 
					    diagnostics.check_model_architecture() | 
				
			||||
 | 
					    diagnostics.check_forward_pass(batch) | 
				
			||||
 | 
					    diagnostics.check_gradients(batch, criterion) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					if __name__ == '__main__': | 
				
			||||
 | 
					    main()  | 
				
			||||
					Loading…
					
					
				
		Reference in new issue