Browse Source

feat: optimaize loss, by torch.norm and mead

main
mckay 3 months ago
parent
commit
e6e1aa5008
  1. 73
      brep2sdf/networks/loss.py
  2. 1
      brep2sdf/networks/network.py

73
brep2sdf/networks/loss.py

@ -2,41 +2,66 @@ import torch
import torch.nn as nn
from brep2sdf.config.default_config import get_default_config
from brep2sdf.utils.logger import logger
class Brep2SDFLoss:
"""解释Brep2SDF的loss设计原理"""
def __init__(self, batch_size:float, enforce_minmax: bool=True, clamping_distance: float = 0.1):
class Brep2SDFLoss(nn.Module):
def __init__(self, batch_size:float=1, enforce_minmax: bool=True, clamping_distance: float=0.1,
grad_weight=0.1, warmup_epochs=10):
super().__init__()
self.batch_size = batch_size
self.l1_loss = nn.L1Loss(reduction='sum')
self.enforce_minmax = enforce_minmax
self.minT = -clamping_distance
self.maxT = clamping_distance
def __call__(self, pred_sdf, gt_sdf):
"""使类可直接调用"""
return self.forward(pred_sdf, gt_sdf)
def forward(self, pred_sdf, gt_sdf):
"""
pred_sdf: 预测的SDF值
gt_sdf: 真实的SDF值
latent_vecs: 形状编码, 来自 brep
epoch: 当前训练轮次
self.grad_weight = grad_weight
self.warmup_epochs = warmup_epochs
self.l1_loss = nn.L1Loss(reduction='mean')
def forward(self, pred_sdf, gt_sdf, points=None, epoch=None):
"""计算SDF预测的损失
Args:
pred_sdf: 预测的SDF值 [B, N, 1]
gt_sdf: 真实的SDF值 [B, N, 1]
points: 查询点坐标 [B, N, 3]用于计算梯度损失
epoch: 当前训练轮次用于损失权重调整
"""
# 1. 对SDF值进行clamp
if self.enforce_minmax:
pred_sdf = torch.clamp(pred_sdf, self.minT, self.maxT)
gt_sdf = torch.clamp(gt_sdf, self.minT, self.maxT)
# 1. L1 Loss的优势
# - 对异常值更鲁棒
# - 能更好地保持表面细节
base_loss = self.l1_loss(pred_sdf, gt_sdf) / pred_sdf.shape[0]
# 2. 计算基础L1损失
l1_loss = self.l1_loss(pred_sdf, gt_sdf)
# 3. 计算梯度损失(如果提供了points)
grad_loss = 0
if points is not None:
grad_loss = self.gradient_loss(pred_sdf, points)
# 4. 根据epoch调整权重
if epoch is not None:
weight = min(1.0, epoch / self.warmup_epochs)
else:
weight = 1.0
# 5. 组合损失
total_loss = l1_loss + self.grad_weight * grad_loss * weight
return total_loss
def gradient_loss(self, pred_sdf, points):
"""计算SDF梯度损失"""
grad_pred = torch.autograd.grad(
pred_sdf.sum(),
points,
create_graph=True
)[0]
grad_norm = torch.norm(grad_pred, dim=-1)
grad_loss = torch.mean((grad_norm - 1.0) ** 2)
return base_loss / self.batch_size
return grad_loss
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1):

1
brep2sdf/networks/network.py

@ -251,6 +251,7 @@ def train(model, config, num_epochs=10):
# 初始化损失函数
clamping_distance = config.train.clamping_distance
criterion = Brep2SDFLoss(
batch_size=config.train.batch_size,
enforce_minmax= (clamping_distance > 0),
clamping_distance= clamping_distance
)

Loading…
Cancel
Save