From 424ae340679d1c9053f0ab7516d981a6dc2cb46c Mon Sep 17 00:00:00 2001 From: mckay Date: Sat, 30 Nov 2024 16:50:40 +0800 Subject: [PATCH] feat: log for param in debug mode --- brep2sdf/networks/encoder.py | 161 ++++++++++++++++++++++++++++++++--- 1 file changed, 147 insertions(+), 14 deletions(-) diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index 50becc4..c7645e0 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -221,17 +221,9 @@ class BRepFeatureEmbedder(nn.Module): self.vertex_proj = nn.Linear(self.embed_dim, self.embed_dim) # 添加 transformer 初始化 - self.transformer = nn.TransformerEncoder( - encoder_layer=nn.TransformerEncoderLayer( - d_model=self.embed_dim, - nhead=8, # 注意力头数,通常是embed_dim的因子 - dim_feedforward=4*self.embed_dim, # 前馈网络维度,通常是embed_dim的4倍 - dropout=0.1, - activation='gelu', - batch_first=False # 因为我们用了transpose(0,1) - ), - num_layers=6 # transformer层数 - ) + layer = nn.TransformerEncoderLayer(d_model=self.embed_dim, nhead=8, norm_first=True, + dim_feedforward=1024, dropout=0.1) + self.net = nn.TransformerEncoder(layer, 6, nn.LayerNorm(self.embed_dim)) def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, **kwargs): """ @@ -245,27 +237,48 @@ class BRepFeatureEmbedder(nn.Module): """ B, F, E = edge_pos.shape[:3] + # 检查输入张量 + input_tensors = { + 'edge_ncs': edge_ncs, + 'edge_pos': edge_pos, + 'surf_ncs': surf_ncs, + 'surf_pos': surf_pos, + 'vertex_pos': vertex_pos + } + + logger.info("\n=== 输入张量检查 ===") + for name, tensor in input_tensors.items(): + print_tensor_stats(name, tensor) + # 1. 处理顶点特征 vertex_embed = self.vertp_embed(vertex_pos[..., :3]) # [B, F, E, 2, embed_dim] + print_tensor_stats('vertex_embed', vertex_embed) vertex_embed = self.vertex_proj(vertex_embed) # [B, F, E, 2, embed_dim] + print_tensor_stats('vertex_embed(after proj)', vertex_embed) vertex_embed = vertex_embed.mean(dim=3) # [B, F, E, embed_dim] # 2. 处理边特征 edge_embeds = self.edgez_embed(edge_ncs) # [B, F, E, embed_dim] + print_tensor_stats('edge_embeds', edge_embeds) edge_p_embeds = self.edgep_embed(edge_pos) # [B, F, E, embed_dim] + print_tensor_stats('edge_p_embeds', edge_p_embeds) # 3. 处理面特征 surf_embeds = self.surfz_embed(surf_ncs) # [B, F, embed_dim] + print_tensor_stats('surf_embeds', surf_embeds) surf_p_embeds = self.surfp_embed(surf_pos) # [B, F, embed_dim] + print_tensor_stats('surf_p_embeds', surf_p_embeds) # 4. 组合特征 if self.use_cf: # 组合边特征 edge_features = edge_embeds + edge_p_embeds + vertex_embed # [B, F, E, embed_dim] edge_features = edge_features.reshape(B, F*E, -1) # [B, F*E, embed_dim] + print_tensor_stats('edge_features', edge_features) # 组合面特征 surf_features = surf_embeds + surf_p_embeds # [B, F, embed_dim] + print_tensor_stats('surf_features', surf_features) # 拼接所有特征 embeds = torch.cat([ @@ -287,9 +300,11 @@ class BRepFeatureEmbedder(nn.Module): mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, F*E+F] else: mask = None - + logger.debug(f"embeds shape: {embeds.shape}") # 6. Transformer处理 - output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask) + output = self.net(embeds.transpose(0, 1), src_key_padding_mask=mask) + print_tensor_stats('output', output) + logger.debug(f"output shape: {output.shape}") return output.transpose(0, 1) # [B, F*E+F, embed_dim] class SDFTransformer(nn.Module): @@ -501,6 +516,22 @@ class BRepToSDF(nn.Module): logger.error(f" query_points: {query_points.shape}") raise +def print_tensor_stats(name: str, tensor: torch.Tensor): + """打印张量的统计信息""" + logger.info(f"\n=== {name} 统计信息 ===") + logger.info(f" shape: {tensor.shape}") + logger.info(f" norm: {tensor.norm().item():.6f}") + logger.info(f" mean: {tensor.mean().item():.6f}") + logger.info(f" std: {tensor.std().item():.6f}") + logger.info(f" min: {tensor.min().item():.6f}") + logger.info(f" max: {tensor.max().item():.6f}") + logger.info(f" requires_grad: {tensor.requires_grad}") + if tensor.requires_grad: + if not tensor.grad_fn: + logger.warning(f"⚠️ {name} requires_grad=True 但没有梯度函数!") + else: + logger.warning(f"⚠️ {name} requires_grad=False!") + def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): """SDF损失函数""" # 确保points需要梯度 @@ -580,5 +611,107 @@ def main(): logger.error(f"Error during forward pass: {str(e)}") raise +def test_brep_embedder(): + """测试BRepFeatureEmbedder的参数初始化和梯度流动""" + # 1. 初始化配置和模型 + config = get_default_config() + embedder = BRepFeatureEmbedder(config) + + # 2. 生成测试数据 + B, F, E = 2, 8, 16 # batch_size, max_face, max_edge + test_data = { + 'edge_ncs': torch.randn(B, F, E, config.model.num_edge_points, 3, requires_grad=True), + 'edge_pos': torch.randn(B, F, E, 6, requires_grad=True), + 'edge_mask': torch.ones(B, F, E, dtype=torch.bool), + 'surf_ncs': torch.randn(B, F, config.model.num_surf_points, 3, requires_grad=True), + 'surf_pos': torch.randn(B, F, 6, requires_grad=True), + 'vertex_pos': torch.randn(B, F, E, 2, 3, requires_grad=True) + } + + # 3. 检查初始参数 + logger.info("\n=== 初始参数检查 ===") + for name, param in embedder.named_parameters(): + logger.info(f"\n{name}:") + logger.info(f" shape: {param.shape}") + logger.info(f" requires_grad: {param.requires_grad}") + logger.info(f" norm: {param.norm().item():.6f}") + logger.info(f" mean: {param.mean().item():.6f}") + logger.info(f" std: {param.std().item():.6f}") + + # 4. 前向传播 + logger.info("\n=== 前向传播 ===") + outputs = embedder(**test_data) + logger.info(f"Output shape: {outputs.shape}") + + # 5. 检查中间特征 + def check_tensor(tensor, name): + logger.info(f"\n{name}:") + logger.info(f" shape: {tensor.shape}") + logger.info(f" requires_grad: {tensor.requires_grad}") + logger.info(f" has_grad_fn: {tensor.grad_fn is not None}") + if tensor.grad_fn: + logger.info(f" grad_fn: {type(tensor.grad_fn).__name__}") + logger.info(f" norm: {tensor.norm().item():.6f}") + logger.info(f" mean: {tensor.mean().item():.6f}") + logger.info(f" std: {tensor.std().item():.6f}") + + # 6. 反向传播 + logger.info("\n=== 反向传播 ===") + loss = outputs.mean() # 简单的损失函数 + loss.backward() + + # 7. 检查梯度 + logger.info("\n=== 梯度检查 ===") + + # 7.1 检查输入梯度 + for name, tensor in test_data.items(): + if tensor.requires_grad: + logger.info(f"\n{name} gradient:") + if tensor.grad is not None: + logger.info(f" grad norm: {tensor.grad.norm().item():.6f}") + logger.info(f" grad mean: {tensor.grad.mean().item():.6f}") + logger.info(f" grad std: {tensor.grad.std().item():.6f}") + else: + logger.info(" No gradient!") + + # 7.2 检查模型参数梯度 + logger.info("\n=== 模型参数梯度 ===") + for name, param in embedder.named_parameters(): + logger.info(f"\n{name}:") + if param.grad is not None: + logger.info(f" grad norm: {param.grad.norm().item():.6f}") + logger.info(f" grad mean: {param.grad.mean().item():.6f}") + logger.info(f" grad std: {param.grad.std().item():.6f}") + # 检查是否有任何梯度为NaN或inf + if torch.isnan(param.grad).any(): + logger.warning(" Contains NaN gradients!") + if torch.isinf(param.grad).any(): + logger.warning(" Contains Inf gradients!") + else: + logger.warning(" No gradient!") + + # 8. 特别检查transformer层 + logger.info("\n=== Transformer层检查 ===") + for i, layer in enumerate(embedder.net.layers): + logger.info(f"\nTransformer Layer {i}:") + # 检查自注意力层 + logger.info(" Self Attention:") + logger.info(f" in_proj_weight norm: {layer.self_attn.in_proj_weight.norm().item():.6f}") + logger.info(f" in_proj_bias norm: {layer.self_attn.in_proj_bias.norm().item():.6f}") + logger.info(f" out_proj.weight norm: {layer.self_attn.out_proj.weight.norm().item():.6f}") + logger.info(f" out_proj.bias norm: {layer.self_attn.out_proj.bias.norm().item():.6f}") + + if layer.self_attn.in_proj_weight.grad is not None: + logger.info(f" in_proj_weight grad norm: {layer.self_attn.in_proj_weight.grad.norm().item():.6f}") + else: + logger.warning(" in_proj_weight has no gradient!") + + # 检查LayerNorm层 + logger.info(" LayerNorm:") + logger.info(f" norm1.weight norm: {layer.norm1.weight.norm().item():.6f}") + logger.info(f" norm1.bias norm: {layer.norm1.bias.norm().item():.6f}") + logger.info(f" norm2.weight norm: {layer.norm2.weight.norm().item():.6f}") + logger.info(f" norm2.bias norm: {layer.norm2.bias.norm().item():.6f}") + if __name__ == "__main__": - main() \ No newline at end of file + test_brep_embedder() \ No newline at end of file