From 5eef6c4290e37435311912d2e24d5fde11ead88c Mon Sep 17 00:00:00 2001 From: mckay Date: Tue, 3 Dec 2024 20:41:53 +0800 Subject: [PATCH] =?UTF-8?q?refctor:=20print=5Ftensor=5Fstats=20=E7=A7=BB?= =?UTF-8?q?=E5=85=A5log?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/networks/encoder.py | 36 ++++++++++-------------------------- brep2sdf/utils/logger.py | 20 +++++++++++++++++++- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index c7645e0..b4fdb91 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -248,37 +248,37 @@ class BRepFeatureEmbedder(nn.Module): logger.info("\n=== 输入张量检查 ===") for name, tensor in input_tensors.items(): - print_tensor_stats(name, tensor) + logger.print_tensor_stats(name, tensor) # 1. 处理顶点特征 vertex_embed = self.vertp_embed(vertex_pos[..., :3]) # [B, F, E, 2, embed_dim] - print_tensor_stats('vertex_embed', vertex_embed) + logger.print_tensor_stats('vertex_embed', vertex_embed) vertex_embed = self.vertex_proj(vertex_embed) # [B, F, E, 2, embed_dim] - print_tensor_stats('vertex_embed(after proj)', vertex_embed) + logger.print_tensor_stats('vertex_embed(after proj)', vertex_embed) vertex_embed = vertex_embed.mean(dim=3) # [B, F, E, embed_dim] # 2. 处理边特征 edge_embeds = self.edgez_embed(edge_ncs) # [B, F, E, embed_dim] - print_tensor_stats('edge_embeds', edge_embeds) + logger.print_tensor_stats('edge_embeds', edge_embeds) edge_p_embeds = self.edgep_embed(edge_pos) # [B, F, E, embed_dim] - print_tensor_stats('edge_p_embeds', edge_p_embeds) + logger.print_tensor_stats('edge_p_embeds', edge_p_embeds) # 3. 处理面特征 surf_embeds = self.surfz_embed(surf_ncs) # [B, F, embed_dim] - print_tensor_stats('surf_embeds', surf_embeds) + logger.print_tensor_stats('surf_embeds', surf_embeds) surf_p_embeds = self.surfp_embed(surf_pos) # [B, F, embed_dim] - print_tensor_stats('surf_p_embeds', surf_p_embeds) + logger.print_tensor_stats('surf_p_embeds', surf_p_embeds) # 4. 组合特征 if self.use_cf: # 组合边特征 edge_features = edge_embeds + edge_p_embeds + vertex_embed # [B, F, E, embed_dim] edge_features = edge_features.reshape(B, F*E, -1) # [B, F*E, embed_dim] - print_tensor_stats('edge_features', edge_features) + logger.print_tensor_stats('edge_features', edge_features) # 组合面特征 surf_features = surf_embeds + surf_p_embeds # [B, F, embed_dim] - print_tensor_stats('surf_features', surf_features) + logger.print_tensor_stats('surf_features', surf_features) # 拼接所有特征 embeds = torch.cat([ @@ -303,7 +303,7 @@ class BRepFeatureEmbedder(nn.Module): logger.debug(f"embeds shape: {embeds.shape}") # 6. Transformer处理 output = self.net(embeds.transpose(0, 1), src_key_padding_mask=mask) - print_tensor_stats('output', output) + logger.print_tensor_stats('output', output) logger.debug(f"output shape: {output.shape}") return output.transpose(0, 1) # [B, F*E+F, embed_dim] @@ -516,22 +516,6 @@ class BRepToSDF(nn.Module): logger.error(f" query_points: {query_points.shape}") raise -def print_tensor_stats(name: str, tensor: torch.Tensor): - """打印张量的统计信息""" - logger.info(f"\n=== {name} 统计信息 ===") - logger.info(f" shape: {tensor.shape}") - logger.info(f" norm: {tensor.norm().item():.6f}") - logger.info(f" mean: {tensor.mean().item():.6f}") - logger.info(f" std: {tensor.std().item():.6f}") - logger.info(f" min: {tensor.min().item():.6f}") - logger.info(f" max: {tensor.max().item():.6f}") - logger.info(f" requires_grad: {tensor.requires_grad}") - if tensor.requires_grad: - if not tensor.grad_fn: - logger.warning(f"⚠️ {name} requires_grad=True 但没有梯度函数!") - else: - logger.warning(f"⚠️ {name} requires_grad=False!") - def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): """SDF损失函数""" # 确保points需要梯度 diff --git a/brep2sdf/utils/logger.py b/brep2sdf/utils/logger.py index cc51235..52848a2 100644 --- a/brep2sdf/utils/logger.py +++ b/brep2sdf/utils/logger.py @@ -5,8 +5,10 @@ import traceback from datetime import datetime from brep2sdf.config.default_config import get_default_config from functools import wraps +import torch import time + class ColoredFormatter(logging.Formatter): """自定义彩色格式化器""" @@ -176,6 +178,22 @@ class BRepLogger: # 记录初始信息 self.logger.info("BRep Logger initialized") self.logger.debug(f"Log file: {log_file}") + + def print_tensor_stats(self, name: str, tensor: torch.Tensor): + """打印张量的统计信息""" + logger.info(f"\n=== {name} 统计信息 ===") + logger.info(f" shape: {tensor.shape}") + logger.info(f" norm: {tensor.norm().item():.6f}") + logger.info(f" mean: {tensor.mean().item():.6f}") + logger.info(f" std: {tensor.std().item():.6f}") + logger.info(f" min: {tensor.min().item():.6f}") + logger.info(f" max: {tensor.max().item():.6f}") + logger.info(f" requires_grad: {tensor.requires_grad}") + if tensor.requires_grad: + if not tensor.grad_fn: + logger.warning(f"⚠️ {name} requires_grad=True 但没有梯度函数!") + else: + logger.warning(f"⚠️ {name} requires_grad=False!") def timeit(func): """计时装饰器""" @@ -194,4 +212,4 @@ def setup_logger(config=None): return BRepLogger(config) # 使用默认配置创建全局logger实例 -logger = setup_logger() \ No newline at end of file +logger = setup_logger()