Browse Source

feat: log for param in debug mode

main
mckay 4 months ago
parent
commit
424ae34067
  1. 161
      brep2sdf/networks/encoder.py

161
brep2sdf/networks/encoder.py

@ -221,17 +221,9 @@ class BRepFeatureEmbedder(nn.Module):
self.vertex_proj = nn.Linear(self.embed_dim, self.embed_dim)
# 添加 transformer 初始化
self.transformer = nn.TransformerEncoder(
encoder_layer=nn.TransformerEncoderLayer(
d_model=self.embed_dim,
nhead=8, # 注意力头数,通常是embed_dim的因子
dim_feedforward=4*self.embed_dim, # 前馈网络维度,通常是embed_dim的4倍
dropout=0.1,
activation='gelu',
batch_first=False # 因为我们用了transpose(0,1)
),
num_layers=6 # transformer层数
)
layer = nn.TransformerEncoderLayer(d_model=self.embed_dim, nhead=8, norm_first=True,
dim_feedforward=1024, dropout=0.1)
self.net = nn.TransformerEncoder(layer, 6, nn.LayerNorm(self.embed_dim))
def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, **kwargs):
"""
@ -245,27 +237,48 @@ class BRepFeatureEmbedder(nn.Module):
"""
B, F, E = edge_pos.shape[:3]
# 检查输入张量
input_tensors = {
'edge_ncs': edge_ncs,
'edge_pos': edge_pos,
'surf_ncs': surf_ncs,
'surf_pos': surf_pos,
'vertex_pos': vertex_pos
}
logger.info("\n=== 输入张量检查 ===")
for name, tensor in input_tensors.items():
print_tensor_stats(name, tensor)
# 1. 处理顶点特征
vertex_embed = self.vertp_embed(vertex_pos[..., :3]) # [B, F, E, 2, embed_dim]
print_tensor_stats('vertex_embed', vertex_embed)
vertex_embed = self.vertex_proj(vertex_embed) # [B, F, E, 2, embed_dim]
print_tensor_stats('vertex_embed(after proj)', vertex_embed)
vertex_embed = vertex_embed.mean(dim=3) # [B, F, E, embed_dim]
# 2. 处理边特征
edge_embeds = self.edgez_embed(edge_ncs) # [B, F, E, embed_dim]
print_tensor_stats('edge_embeds', edge_embeds)
edge_p_embeds = self.edgep_embed(edge_pos) # [B, F, E, embed_dim]
print_tensor_stats('edge_p_embeds', edge_p_embeds)
# 3. 处理面特征
surf_embeds = self.surfz_embed(surf_ncs) # [B, F, embed_dim]
print_tensor_stats('surf_embeds', surf_embeds)
surf_p_embeds = self.surfp_embed(surf_pos) # [B, F, embed_dim]
print_tensor_stats('surf_p_embeds', surf_p_embeds)
# 4. 组合特征
if self.use_cf:
# 组合边特征
edge_features = edge_embeds + edge_p_embeds + vertex_embed # [B, F, E, embed_dim]
edge_features = edge_features.reshape(B, F*E, -1) # [B, F*E, embed_dim]
print_tensor_stats('edge_features', edge_features)
# 组合面特征
surf_features = surf_embeds + surf_p_embeds # [B, F, embed_dim]
print_tensor_stats('surf_features', surf_features)
# 拼接所有特征
embeds = torch.cat([
@ -287,9 +300,11 @@ class BRepFeatureEmbedder(nn.Module):
mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, F*E+F]
else:
mask = None
logger.debug(f"embeds shape: {embeds.shape}")
# 6. Transformer处理
output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask)
output = self.net(embeds.transpose(0, 1), src_key_padding_mask=mask)
print_tensor_stats('output', output)
logger.debug(f"output shape: {output.shape}")
return output.transpose(0, 1) # [B, F*E+F, embed_dim]
class SDFTransformer(nn.Module):
@ -501,6 +516,22 @@ class BRepToSDF(nn.Module):
logger.error(f" query_points: {query_points.shape}")
raise
def print_tensor_stats(name: str, tensor: torch.Tensor):
"""打印张量的统计信息"""
logger.info(f"\n=== {name} 统计信息 ===")
logger.info(f" shape: {tensor.shape}")
logger.info(f" norm: {tensor.norm().item():.6f}")
logger.info(f" mean: {tensor.mean().item():.6f}")
logger.info(f" std: {tensor.std().item():.6f}")
logger.info(f" min: {tensor.min().item():.6f}")
logger.info(f" max: {tensor.max().item():.6f}")
logger.info(f" requires_grad: {tensor.requires_grad}")
if tensor.requires_grad:
if not tensor.grad_fn:
logger.warning(f"⚠️ {name} requires_grad=True 但没有梯度函数!")
else:
logger.warning(f"⚠️ {name} requires_grad=False!")
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1):
"""SDF损失函数"""
# 确保points需要梯度
@ -580,5 +611,107 @@ def main():
logger.error(f"Error during forward pass: {str(e)}")
raise
def test_brep_embedder():
"""测试BRepFeatureEmbedder的参数初始化和梯度流动"""
# 1. 初始化配置和模型
config = get_default_config()
embedder = BRepFeatureEmbedder(config)
# 2. 生成测试数据
B, F, E = 2, 8, 16 # batch_size, max_face, max_edge
test_data = {
'edge_ncs': torch.randn(B, F, E, config.model.num_edge_points, 3, requires_grad=True),
'edge_pos': torch.randn(B, F, E, 6, requires_grad=True),
'edge_mask': torch.ones(B, F, E, dtype=torch.bool),
'surf_ncs': torch.randn(B, F, config.model.num_surf_points, 3, requires_grad=True),
'surf_pos': torch.randn(B, F, 6, requires_grad=True),
'vertex_pos': torch.randn(B, F, E, 2, 3, requires_grad=True)
}
# 3. 检查初始参数
logger.info("\n=== 初始参数检查 ===")
for name, param in embedder.named_parameters():
logger.info(f"\n{name}:")
logger.info(f" shape: {param.shape}")
logger.info(f" requires_grad: {param.requires_grad}")
logger.info(f" norm: {param.norm().item():.6f}")
logger.info(f" mean: {param.mean().item():.6f}")
logger.info(f" std: {param.std().item():.6f}")
# 4. 前向传播
logger.info("\n=== 前向传播 ===")
outputs = embedder(**test_data)
logger.info(f"Output shape: {outputs.shape}")
# 5. 检查中间特征
def check_tensor(tensor, name):
logger.info(f"\n{name}:")
logger.info(f" shape: {tensor.shape}")
logger.info(f" requires_grad: {tensor.requires_grad}")
logger.info(f" has_grad_fn: {tensor.grad_fn is not None}")
if tensor.grad_fn:
logger.info(f" grad_fn: {type(tensor.grad_fn).__name__}")
logger.info(f" norm: {tensor.norm().item():.6f}")
logger.info(f" mean: {tensor.mean().item():.6f}")
logger.info(f" std: {tensor.std().item():.6f}")
# 6. 反向传播
logger.info("\n=== 反向传播 ===")
loss = outputs.mean() # 简单的损失函数
loss.backward()
# 7. 检查梯度
logger.info("\n=== 梯度检查 ===")
# 7.1 检查输入梯度
for name, tensor in test_data.items():
if tensor.requires_grad:
logger.info(f"\n{name} gradient:")
if tensor.grad is not None:
logger.info(f" grad norm: {tensor.grad.norm().item():.6f}")
logger.info(f" grad mean: {tensor.grad.mean().item():.6f}")
logger.info(f" grad std: {tensor.grad.std().item():.6f}")
else:
logger.info(" No gradient!")
# 7.2 检查模型参数梯度
logger.info("\n=== 模型参数梯度 ===")
for name, param in embedder.named_parameters():
logger.info(f"\n{name}:")
if param.grad is not None:
logger.info(f" grad norm: {param.grad.norm().item():.6f}")
logger.info(f" grad mean: {param.grad.mean().item():.6f}")
logger.info(f" grad std: {param.grad.std().item():.6f}")
# 检查是否有任何梯度为NaN或inf
if torch.isnan(param.grad).any():
logger.warning(" Contains NaN gradients!")
if torch.isinf(param.grad).any():
logger.warning(" Contains Inf gradients!")
else:
logger.warning(" No gradient!")
# 8. 特别检查transformer层
logger.info("\n=== Transformer层检查 ===")
for i, layer in enumerate(embedder.net.layers):
logger.info(f"\nTransformer Layer {i}:")
# 检查自注意力层
logger.info(" Self Attention:")
logger.info(f" in_proj_weight norm: {layer.self_attn.in_proj_weight.norm().item():.6f}")
logger.info(f" in_proj_bias norm: {layer.self_attn.in_proj_bias.norm().item():.6f}")
logger.info(f" out_proj.weight norm: {layer.self_attn.out_proj.weight.norm().item():.6f}")
logger.info(f" out_proj.bias norm: {layer.self_attn.out_proj.bias.norm().item():.6f}")
if layer.self_attn.in_proj_weight.grad is not None:
logger.info(f" in_proj_weight grad norm: {layer.self_attn.in_proj_weight.grad.norm().item():.6f}")
else:
logger.warning(" in_proj_weight has no gradient!")
# 检查LayerNorm层
logger.info(" LayerNorm:")
logger.info(f" norm1.weight norm: {layer.norm1.weight.norm().item():.6f}")
logger.info(f" norm1.bias norm: {layer.norm1.bias.norm().item():.6f}")
logger.info(f" norm2.weight norm: {layer.norm2.weight.norm().item():.6f}")
logger.info(f" norm2.bias norm: {layer.norm2.bias.norm().item():.6f}")
if __name__ == "__main__":
main()
test_brep_embedder()
Loading…
Cancel
Save