You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

73 lines
2.2 KiB

import torch
import torch.nn as nn
from brep2sdf.config.default_config import get_default_config
class Brep2SDFLoss:
"""解释Brep2SDF的loss设计原理"""
def __init__(self, batch_size:float, enforce_minmax: bool=True, clamping_distance: float = 0.1):
self.batch_size = batch_size
self.l1_loss = nn.L1Loss(reduction='sum')
self.enforce_minmax = enforce_minmax
self.minT = -clamping_distance
self.maxT = clamping_distance
def __call__(self, pred_sdf, gt_sdf):
"""使类可直接调用"""
return self.forward(pred_sdf, gt_sdf)
def forward(self, pred_sdf, gt_sdf):
"""
pred_sdf: 预测的SDF值
gt_sdf: 真实的SDF值
latent_vecs: 形状编码, 来自 brep
epoch: 当前训练轮次
"""
if self.enforce_minmax:
pred_sdf = torch.clamp(pred_sdf, self.minT, self.maxT)
gt_sdf = torch.clamp(gt_sdf, self.minT, self.maxT)
# 1. L1 Loss的优势
# - 对异常值更鲁棒
# - 能更好地保持表面细节
base_loss = self.l1_loss(pred_sdf, gt_sdf) / pred_sdf.shape[0]
return base_loss / self.batch_size
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1):
"""SDF损失函数"""
# 确保points需要梯度
if not points.requires_grad:
points = points.detach().requires_grad_(True)
# L1损失
l1_loss = F.l1_loss(pred_sdf, gt_sdf)
try:
# 梯度约束损失
grad = torch.autograd.grad(
pred_sdf.sum(),
points,
create_graph=True,
retain_graph=True,
allow_unused=True
)[0]
if grad is not None:
grad_constraint = F.mse_loss(
torch.norm(grad, dim=-1),
torch.ones_like(pred_sdf.squeeze(-1))
)
else:
grad_constraint = torch.tensor(0.0, device=pred_sdf.device)
except Exception as e:
logger.warning(f"Gradient computation failed: {str(e)}")
grad_constraint = torch.tensor(0.0, device=pred_sdf.device)
return l1_loss + grad_weight * grad_constraint