1 changed files with 162 additions and 0 deletions
@ -0,0 +1,162 @@ |
|||||
|
import torch |
||||
|
import torch.nn as nn |
||||
|
from brep2sdf.networks.network import BRepToSDF |
||||
|
from brep2sdf.config.default_config import get_default_config |
||||
|
from brep2sdf.utils.logger import logger |
||||
|
from brep2sdf.networks.loss import Brep2SDFLoss |
||||
|
|
||||
|
class ModelDiagnostics: |
||||
|
def __init__(self, model, config): |
||||
|
self.model = model |
||||
|
self.config = config |
||||
|
self.device = next(model.parameters()).device |
||||
|
|
||||
|
def check_model_architecture(self): |
||||
|
"""检查模型架构""" |
||||
|
logger.info("\n=== 模型架构检查 ===") |
||||
|
|
||||
|
# 检查每个子模块 |
||||
|
modules = { |
||||
|
'query_encoder': self.model.query_encoder, |
||||
|
'brep_embedder': self.model.brep_embedder, |
||||
|
'transformer': self.model.transformer, |
||||
|
'sdf_head': self.model.sdf_head |
||||
|
} |
||||
|
|
||||
|
for name, module in modules.items(): |
||||
|
logger.info(f"\n{name}:") |
||||
|
logger.info(f" 参数数量: {sum(p.numel() for p in module.parameters()):,}") |
||||
|
logger.info(f" 需要梯度的参数: {sum(p.numel() for p in module.parameters() if p.requires_grad):,}") |
||||
|
|
||||
|
# 检查模块中的关键层 |
||||
|
for layer_name, layer in module.named_modules(): |
||||
|
if isinstance(layer, (nn.Linear, nn.LayerNorm)): |
||||
|
logger.info(f" {layer_name}:") |
||||
|
logger.info(f" 输入维度: {layer.in_features if hasattr(layer, 'in_features') else 'N/A'}") |
||||
|
logger.info(f" 输出维度: {layer.out_features if hasattr(layer, 'out_features') else 'N/A'}") |
||||
|
|
||||
|
def check_forward_pass(self, batch): |
||||
|
"""检查前向传播过程""" |
||||
|
logger.info("\n=== 前向传播检查 ===") |
||||
|
|
||||
|
self.model.eval() |
||||
|
with torch.no_grad(): |
||||
|
# 1. 查询点编码 |
||||
|
query_features = self.model.query_encoder(batch['query_points']) |
||||
|
logger.info("\n查询点编码:") |
||||
|
logger.info(f" 输入形状: {batch['query_points'].shape}") |
||||
|
logger.info(f" 输出形状: {query_features.shape}") |
||||
|
logger.info(f" 特征统计:") |
||||
|
logger.info(f" 均值: {query_features.mean():.4f}") |
||||
|
logger.info(f" 标准差: {query_features.std():.4f}") |
||||
|
logger.info(f" 最大值: {query_features.max():.4f}") |
||||
|
logger.info(f" 最小值: {query_features.min():.4f}") |
||||
|
|
||||
|
# 2. B-rep特征编码 |
||||
|
brep_features = self.model.brep_embedder( |
||||
|
edge_ncs=batch['edge_ncs'], |
||||
|
edge_pos=batch['edge_pos'], |
||||
|
edge_mask=batch['edge_mask'], |
||||
|
surf_ncs=batch['surf_ncs'], |
||||
|
surf_pos=batch['surf_pos'], |
||||
|
vertex_pos=batch['vertex_pos'] |
||||
|
) |
||||
|
logger.info("\nB-rep特征编码:") |
||||
|
logger.info(f" 输出形状: {brep_features.shape}") |
||||
|
logger.info(f" 特征统计:") |
||||
|
logger.info(f" 均值: {brep_features.mean():.4f}") |
||||
|
logger.info(f" 标准差: {brep_features.std():.4f}") |
||||
|
|
||||
|
# 3. 全局特征 |
||||
|
global_features = brep_features.mean(dim=1) |
||||
|
logger.info("\n全局特征:") |
||||
|
logger.info(f" 形状: {global_features.shape}") |
||||
|
logger.info(f" 统计:") |
||||
|
logger.info(f" 均值: {global_features.mean():.4f}") |
||||
|
logger.info(f" 标准差: {global_features.std():.4f}") |
||||
|
|
||||
|
# 4. 最终预测 |
||||
|
sdf = self.model(**batch) |
||||
|
logger.info("\nSDF预测:") |
||||
|
logger.info(f" 形状: {sdf.shape}") |
||||
|
logger.info(f" 统计:") |
||||
|
logger.info(f" 均值: {sdf.mean():.4f}") |
||||
|
logger.info(f" 标准差: {sdf.std():.4f}") |
||||
|
logger.info(f" 最大值: {sdf.max():.4f}") |
||||
|
logger.info(f" 最小值: {sdf.min():.4f}") |
||||
|
|
||||
|
def check_gradients(self, batch, criterion): |
||||
|
"""检查梯度流动""" |
||||
|
logger.info("\n=== 梯度检查 ===") |
||||
|
|
||||
|
self.model.train() |
||||
|
self.model.zero_grad() |
||||
|
|
||||
|
# 前向传播 |
||||
|
pred = self.model(**batch) |
||||
|
loss = criterion(pred, batch['gt_sdf']) |
||||
|
|
||||
|
# 反向传播 |
||||
|
loss.backward() |
||||
|
|
||||
|
# 检查每个模块的梯度 |
||||
|
modules = { |
||||
|
'query_encoder': self.model.query_encoder, |
||||
|
'brep_embedder': self.model.brep_embedder, |
||||
|
'transformer': self.model.transformer, |
||||
|
'sdf_head': self.model.sdf_head |
||||
|
} |
||||
|
|
||||
|
for name, module in modules.items(): |
||||
|
logger.info(f"\n{name} 梯度:") |
||||
|
total_grad_norm = 0 |
||||
|
for param_name, param in module.named_parameters(): |
||||
|
if param.grad is not None: |
||||
|
grad_norm = param.grad.norm().item() |
||||
|
total_grad_norm += grad_norm |
||||
|
logger.info(f" {param_name}:") |
||||
|
logger.info(f" 梯度范数: {grad_norm:.6f}") |
||||
|
logger.info(f" 梯度均值: {param.grad.mean().item():.6f}") |
||||
|
logger.info(f" 梯度标准差: {param.grad.std().item():.6f}") |
||||
|
else: |
||||
|
logger.warning(f" {param_name}: 没有梯度!") |
||||
|
logger.info(f" 总梯度范数: {total_grad_norm:.6f}") |
||||
|
|
||||
|
def main(): |
||||
|
# 获取配置 |
||||
|
config = get_default_config() |
||||
|
|
||||
|
# 初始化模型 |
||||
|
model = BRepToSDF(config=config) |
||||
|
|
||||
|
# 创建诊断器 |
||||
|
diagnostics = ModelDiagnostics(model, config) |
||||
|
|
||||
|
# 生成测试数据 |
||||
|
batch = { |
||||
|
'edge_ncs': torch.randn(2, config.data.max_face, config.data.max_edge, |
||||
|
config.model.num_edge_points, 3), |
||||
|
'edge_pos': torch.randn(2, config.data.max_face, config.data.max_edge, 6), |
||||
|
'edge_mask': torch.ones(2, config.data.max_face, config.data.max_edge, |
||||
|
dtype=torch.bool), |
||||
|
'surf_ncs': torch.randn(2, config.data.max_face, config.model.num_surf_points, 3), |
||||
|
'surf_pos': torch.randn(2, config.data.max_face, 6), |
||||
|
'vertex_pos': torch.randn(2, config.data.max_face, config.data.max_edge, 2, 3), |
||||
|
'query_points': torch.randn(2, 1000, 3), |
||||
|
#'sdf': torch.randn(2, 1000, 1) |
||||
|
} |
||||
|
|
||||
|
# 初始化损失函数 |
||||
|
criterion = Brep2SDFLoss( |
||||
|
batch_size=config.train.batch_size, |
||||
|
enforce_minmax=(config.train.clamping_distance > 0), |
||||
|
clamping_distance=config.train.clamping_distance |
||||
|
) |
||||
|
|
||||
|
# 运行诊断 |
||||
|
diagnostics.check_model_architecture() |
||||
|
diagnostics.check_forward_pass(batch) |
||||
|
diagnostics.check_gradients(batch, criterion) |
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
main() |
Loading…
Reference in new issue