From e6e1aa5008174af10cb98a7f0efe156c9f3e39cb Mon Sep 17 00:00:00 2001 From: mckay Date: Fri, 13 Dec 2024 23:08:55 +0800 Subject: [PATCH] feat: optimaize loss, by torch.norm and mead --- brep2sdf/networks/loss.py | 73 ++++++++++++++++++++++++------------ brep2sdf/networks/network.py | 1 + 2 files changed, 50 insertions(+), 24 deletions(-) diff --git a/brep2sdf/networks/loss.py b/brep2sdf/networks/loss.py index d43e1ed..540488c 100644 --- a/brep2sdf/networks/loss.py +++ b/brep2sdf/networks/loss.py @@ -2,41 +2,66 @@ import torch import torch.nn as nn from brep2sdf.config.default_config import get_default_config +from brep2sdf.utils.logger import logger - - -class Brep2SDFLoss: - """解释Brep2SDF的loss设计原理""" - def __init__(self, batch_size:float, enforce_minmax: bool=True, clamping_distance: float = 0.1): +class Brep2SDFLoss(nn.Module): + def __init__(self, batch_size:float=1, enforce_minmax: bool=True, clamping_distance: float=0.1, + grad_weight=0.1, warmup_epochs=10): + super().__init__() self.batch_size = batch_size - self.l1_loss = nn.L1Loss(reduction='sum') self.enforce_minmax = enforce_minmax self.minT = -clamping_distance self.maxT = clamping_distance - - def __call__(self, pred_sdf, gt_sdf): - """使类可直接调用""" - return self.forward(pred_sdf, gt_sdf) - - def forward(self, pred_sdf, gt_sdf): - """ - pred_sdf: 预测的SDF值 - gt_sdf: 真实的SDF值 - latent_vecs: 形状编码, 来自 brep - epoch: 当前训练轮次 + self.grad_weight = grad_weight + self.warmup_epochs = warmup_epochs + + self.l1_loss = nn.L1Loss(reduction='mean') + + def forward(self, pred_sdf, gt_sdf, points=None, epoch=None): + """计算SDF预测的损失 + + Args: + pred_sdf: 预测的SDF值 [B, N, 1] + gt_sdf: 真实的SDF值 [B, N, 1] + points: 查询点坐标 [B, N, 3],用于计算梯度损失 + epoch: 当前训练轮次,用于损失权重调整 """ - + # 1. 对SDF值进行clamp if self.enforce_minmax: pred_sdf = torch.clamp(pred_sdf, self.minT, self.maxT) gt_sdf = torch.clamp(gt_sdf, self.minT, self.maxT) - - # 1. L1 Loss的优势 - # - 对异常值更鲁棒 - # - 能更好地保持表面细节 - base_loss = self.l1_loss(pred_sdf, gt_sdf) / pred_sdf.shape[0] + # 2. 计算基础L1损失 + l1_loss = self.l1_loss(pred_sdf, gt_sdf) + + # 3. 计算梯度损失(如果提供了points) + grad_loss = 0 + if points is not None: + grad_loss = self.gradient_loss(pred_sdf, points) + + # 4. 根据epoch调整权重 + if epoch is not None: + weight = min(1.0, epoch / self.warmup_epochs) + else: + weight = 1.0 + + # 5. 组合损失 + total_loss = l1_loss + self.grad_weight * grad_loss * weight + + return total_loss + + def gradient_loss(self, pred_sdf, points): + """计算SDF梯度损失""" + grad_pred = torch.autograd.grad( + pred_sdf.sum(), + points, + create_graph=True + )[0] + + grad_norm = torch.norm(grad_pred, dim=-1) + grad_loss = torch.mean((grad_norm - 1.0) ** 2) - return base_loss / self.batch_size + return grad_loss def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index 5c36c91..ebdd60d 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -251,6 +251,7 @@ def train(model, config, num_epochs=10): # 初始化损失函数 clamping_distance = config.train.clamping_distance criterion = Brep2SDFLoss( + batch_size=config.train.batch_size, enforce_minmax= (clamping_distance > 0), clamping_distance= clamping_distance )