Browse Source

refctor: print_tensor_stats 移入log

main
mckay 3 months ago
parent
commit
5eef6c4290
  1. 36
      brep2sdf/networks/encoder.py
  2. 20
      brep2sdf/utils/logger.py

36
brep2sdf/networks/encoder.py

@ -248,37 +248,37 @@ class BRepFeatureEmbedder(nn.Module):
logger.info("\n=== 输入张量检查 ===")
for name, tensor in input_tensors.items():
print_tensor_stats(name, tensor)
logger.print_tensor_stats(name, tensor)
# 1. 处理顶点特征
vertex_embed = self.vertp_embed(vertex_pos[..., :3]) # [B, F, E, 2, embed_dim]
print_tensor_stats('vertex_embed', vertex_embed)
logger.print_tensor_stats('vertex_embed', vertex_embed)
vertex_embed = self.vertex_proj(vertex_embed) # [B, F, E, 2, embed_dim]
print_tensor_stats('vertex_embed(after proj)', vertex_embed)
logger.print_tensor_stats('vertex_embed(after proj)', vertex_embed)
vertex_embed = vertex_embed.mean(dim=3) # [B, F, E, embed_dim]
# 2. 处理边特征
edge_embeds = self.edgez_embed(edge_ncs) # [B, F, E, embed_dim]
print_tensor_stats('edge_embeds', edge_embeds)
logger.print_tensor_stats('edge_embeds', edge_embeds)
edge_p_embeds = self.edgep_embed(edge_pos) # [B, F, E, embed_dim]
print_tensor_stats('edge_p_embeds', edge_p_embeds)
logger.print_tensor_stats('edge_p_embeds', edge_p_embeds)
# 3. 处理面特征
surf_embeds = self.surfz_embed(surf_ncs) # [B, F, embed_dim]
print_tensor_stats('surf_embeds', surf_embeds)
logger.print_tensor_stats('surf_embeds', surf_embeds)
surf_p_embeds = self.surfp_embed(surf_pos) # [B, F, embed_dim]
print_tensor_stats('surf_p_embeds', surf_p_embeds)
logger.print_tensor_stats('surf_p_embeds', surf_p_embeds)
# 4. 组合特征
if self.use_cf:
# 组合边特征
edge_features = edge_embeds + edge_p_embeds + vertex_embed # [B, F, E, embed_dim]
edge_features = edge_features.reshape(B, F*E, -1) # [B, F*E, embed_dim]
print_tensor_stats('edge_features', edge_features)
logger.print_tensor_stats('edge_features', edge_features)
# 组合面特征
surf_features = surf_embeds + surf_p_embeds # [B, F, embed_dim]
print_tensor_stats('surf_features', surf_features)
logger.print_tensor_stats('surf_features', surf_features)
# 拼接所有特征
embeds = torch.cat([
@ -303,7 +303,7 @@ class BRepFeatureEmbedder(nn.Module):
logger.debug(f"embeds shape: {embeds.shape}")
# 6. Transformer处理
output = self.net(embeds.transpose(0, 1), src_key_padding_mask=mask)
print_tensor_stats('output', output)
logger.print_tensor_stats('output', output)
logger.debug(f"output shape: {output.shape}")
return output.transpose(0, 1) # [B, F*E+F, embed_dim]
@ -516,22 +516,6 @@ class BRepToSDF(nn.Module):
logger.error(f" query_points: {query_points.shape}")
raise
def print_tensor_stats(name: str, tensor: torch.Tensor):
"""打印张量的统计信息"""
logger.info(f"\n=== {name} 统计信息 ===")
logger.info(f" shape: {tensor.shape}")
logger.info(f" norm: {tensor.norm().item():.6f}")
logger.info(f" mean: {tensor.mean().item():.6f}")
logger.info(f" std: {tensor.std().item():.6f}")
logger.info(f" min: {tensor.min().item():.6f}")
logger.info(f" max: {tensor.max().item():.6f}")
logger.info(f" requires_grad: {tensor.requires_grad}")
if tensor.requires_grad:
if not tensor.grad_fn:
logger.warning(f"⚠️ {name} requires_grad=True 但没有梯度函数!")
else:
logger.warning(f"⚠️ {name} requires_grad=False!")
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1):
"""SDF损失函数"""
# 确保points需要梯度

20
brep2sdf/utils/logger.py

@ -5,8 +5,10 @@ import traceback
from datetime import datetime
from brep2sdf.config.default_config import get_default_config
from functools import wraps
import torch
import time
class ColoredFormatter(logging.Formatter):
"""自定义彩色格式化器"""
@ -176,6 +178,22 @@ class BRepLogger:
# 记录初始信息
self.logger.info("BRep Logger initialized")
self.logger.debug(f"Log file: {log_file}")
def print_tensor_stats(self, name: str, tensor: torch.Tensor):
"""打印张量的统计信息"""
logger.info(f"\n=== {name} 统计信息 ===")
logger.info(f" shape: {tensor.shape}")
logger.info(f" norm: {tensor.norm().item():.6f}")
logger.info(f" mean: {tensor.mean().item():.6f}")
logger.info(f" std: {tensor.std().item():.6f}")
logger.info(f" min: {tensor.min().item():.6f}")
logger.info(f" max: {tensor.max().item():.6f}")
logger.info(f" requires_grad: {tensor.requires_grad}")
if tensor.requires_grad:
if not tensor.grad_fn:
logger.warning(f"⚠️ {name} requires_grad=True 但没有梯度函数!")
else:
logger.warning(f"⚠️ {name} requires_grad=False!")
def timeit(func):
"""计时装饰器"""
@ -194,4 +212,4 @@ def setup_logger(config=None):
return BRepLogger(config)
# 使用默认配置创建全局logger实例
logger = setup_logger()
logger = setup_logger()

Loading…
Cancel
Save