1 changed files with 162 additions and 0 deletions
@ -0,0 +1,162 @@ |
|||
import torch |
|||
import torch.nn as nn |
|||
from brep2sdf.networks.network import BRepToSDF |
|||
from brep2sdf.config.default_config import get_default_config |
|||
from brep2sdf.utils.logger import logger |
|||
from brep2sdf.networks.loss import Brep2SDFLoss |
|||
|
|||
class ModelDiagnostics: |
|||
def __init__(self, model, config): |
|||
self.model = model |
|||
self.config = config |
|||
self.device = next(model.parameters()).device |
|||
|
|||
def check_model_architecture(self): |
|||
"""检查模型架构""" |
|||
logger.info("\n=== 模型架构检查 ===") |
|||
|
|||
# 检查每个子模块 |
|||
modules = { |
|||
'query_encoder': self.model.query_encoder, |
|||
'brep_embedder': self.model.brep_embedder, |
|||
'transformer': self.model.transformer, |
|||
'sdf_head': self.model.sdf_head |
|||
} |
|||
|
|||
for name, module in modules.items(): |
|||
logger.info(f"\n{name}:") |
|||
logger.info(f" 参数数量: {sum(p.numel() for p in module.parameters()):,}") |
|||
logger.info(f" 需要梯度的参数: {sum(p.numel() for p in module.parameters() if p.requires_grad):,}") |
|||
|
|||
# 检查模块中的关键层 |
|||
for layer_name, layer in module.named_modules(): |
|||
if isinstance(layer, (nn.Linear, nn.LayerNorm)): |
|||
logger.info(f" {layer_name}:") |
|||
logger.info(f" 输入维度: {layer.in_features if hasattr(layer, 'in_features') else 'N/A'}") |
|||
logger.info(f" 输出维度: {layer.out_features if hasattr(layer, 'out_features') else 'N/A'}") |
|||
|
|||
def check_forward_pass(self, batch): |
|||
"""检查前向传播过程""" |
|||
logger.info("\n=== 前向传播检查 ===") |
|||
|
|||
self.model.eval() |
|||
with torch.no_grad(): |
|||
# 1. 查询点编码 |
|||
query_features = self.model.query_encoder(batch['query_points']) |
|||
logger.info("\n查询点编码:") |
|||
logger.info(f" 输入形状: {batch['query_points'].shape}") |
|||
logger.info(f" 输出形状: {query_features.shape}") |
|||
logger.info(f" 特征统计:") |
|||
logger.info(f" 均值: {query_features.mean():.4f}") |
|||
logger.info(f" 标准差: {query_features.std():.4f}") |
|||
logger.info(f" 最大值: {query_features.max():.4f}") |
|||
logger.info(f" 最小值: {query_features.min():.4f}") |
|||
|
|||
# 2. B-rep特征编码 |
|||
brep_features = self.model.brep_embedder( |
|||
edge_ncs=batch['edge_ncs'], |
|||
edge_pos=batch['edge_pos'], |
|||
edge_mask=batch['edge_mask'], |
|||
surf_ncs=batch['surf_ncs'], |
|||
surf_pos=batch['surf_pos'], |
|||
vertex_pos=batch['vertex_pos'] |
|||
) |
|||
logger.info("\nB-rep特征编码:") |
|||
logger.info(f" 输出形状: {brep_features.shape}") |
|||
logger.info(f" 特征统计:") |
|||
logger.info(f" 均值: {brep_features.mean():.4f}") |
|||
logger.info(f" 标准差: {brep_features.std():.4f}") |
|||
|
|||
# 3. 全局特征 |
|||
global_features = brep_features.mean(dim=1) |
|||
logger.info("\n全局特征:") |
|||
logger.info(f" 形状: {global_features.shape}") |
|||
logger.info(f" 统计:") |
|||
logger.info(f" 均值: {global_features.mean():.4f}") |
|||
logger.info(f" 标准差: {global_features.std():.4f}") |
|||
|
|||
# 4. 最终预测 |
|||
sdf = self.model(**batch) |
|||
logger.info("\nSDF预测:") |
|||
logger.info(f" 形状: {sdf.shape}") |
|||
logger.info(f" 统计:") |
|||
logger.info(f" 均值: {sdf.mean():.4f}") |
|||
logger.info(f" 标准差: {sdf.std():.4f}") |
|||
logger.info(f" 最大值: {sdf.max():.4f}") |
|||
logger.info(f" 最小值: {sdf.min():.4f}") |
|||
|
|||
def check_gradients(self, batch, criterion): |
|||
"""检查梯度流动""" |
|||
logger.info("\n=== 梯度检查 ===") |
|||
|
|||
self.model.train() |
|||
self.model.zero_grad() |
|||
|
|||
# 前向传播 |
|||
pred = self.model(**batch) |
|||
loss = criterion(pred, batch['gt_sdf']) |
|||
|
|||
# 反向传播 |
|||
loss.backward() |
|||
|
|||
# 检查每个模块的梯度 |
|||
modules = { |
|||
'query_encoder': self.model.query_encoder, |
|||
'brep_embedder': self.model.brep_embedder, |
|||
'transformer': self.model.transformer, |
|||
'sdf_head': self.model.sdf_head |
|||
} |
|||
|
|||
for name, module in modules.items(): |
|||
logger.info(f"\n{name} 梯度:") |
|||
total_grad_norm = 0 |
|||
for param_name, param in module.named_parameters(): |
|||
if param.grad is not None: |
|||
grad_norm = param.grad.norm().item() |
|||
total_grad_norm += grad_norm |
|||
logger.info(f" {param_name}:") |
|||
logger.info(f" 梯度范数: {grad_norm:.6f}") |
|||
logger.info(f" 梯度均值: {param.grad.mean().item():.6f}") |
|||
logger.info(f" 梯度标准差: {param.grad.std().item():.6f}") |
|||
else: |
|||
logger.warning(f" {param_name}: 没有梯度!") |
|||
logger.info(f" 总梯度范数: {total_grad_norm:.6f}") |
|||
|
|||
def main(): |
|||
# 获取配置 |
|||
config = get_default_config() |
|||
|
|||
# 初始化模型 |
|||
model = BRepToSDF(config=config) |
|||
|
|||
# 创建诊断器 |
|||
diagnostics = ModelDiagnostics(model, config) |
|||
|
|||
# 生成测试数据 |
|||
batch = { |
|||
'edge_ncs': torch.randn(2, config.data.max_face, config.data.max_edge, |
|||
config.model.num_edge_points, 3), |
|||
'edge_pos': torch.randn(2, config.data.max_face, config.data.max_edge, 6), |
|||
'edge_mask': torch.ones(2, config.data.max_face, config.data.max_edge, |
|||
dtype=torch.bool), |
|||
'surf_ncs': torch.randn(2, config.data.max_face, config.model.num_surf_points, 3), |
|||
'surf_pos': torch.randn(2, config.data.max_face, 6), |
|||
'vertex_pos': torch.randn(2, config.data.max_face, config.data.max_edge, 2, 3), |
|||
'query_points': torch.randn(2, 1000, 3), |
|||
#'sdf': torch.randn(2, 1000, 1) |
|||
} |
|||
|
|||
# 初始化损失函数 |
|||
criterion = Brep2SDFLoss( |
|||
batch_size=config.train.batch_size, |
|||
enforce_minmax=(config.train.clamping_distance > 0), |
|||
clamping_distance=config.train.clamping_distance |
|||
) |
|||
|
|||
# 运行诊断 |
|||
diagnostics.check_model_architecture() |
|||
diagnostics.check_forward_pass(batch) |
|||
diagnostics.check_gradients(batch, criterion) |
|||
|
|||
if __name__ == '__main__': |
|||
main() |
Loading…
Reference in new issue