You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
73 lines
2.2 KiB
73 lines
2.2 KiB
import torch
|
|
import torch.nn as nn
|
|
|
|
from brep2sdf.config.default_config import get_default_config
|
|
|
|
|
|
|
|
class Brep2SDFLoss:
|
|
"""解释Brep2SDF的loss设计原理"""
|
|
def __init__(self, batch_size:float, enforce_minmax: bool=True, clamping_distance: float = 0.1):
|
|
self.batch_size = batch_size
|
|
self.l1_loss = nn.L1Loss(reduction='sum')
|
|
self.enforce_minmax = enforce_minmax
|
|
self.minT = -clamping_distance
|
|
self.maxT = clamping_distance
|
|
|
|
def __call__(self, pred_sdf, gt_sdf):
|
|
"""使类可直接调用"""
|
|
return self.forward(pred_sdf, gt_sdf)
|
|
|
|
def forward(self, pred_sdf, gt_sdf):
|
|
"""
|
|
pred_sdf: 预测的SDF值
|
|
gt_sdf: 真实的SDF值
|
|
latent_vecs: 形状编码, 来自 brep
|
|
epoch: 当前训练轮次
|
|
"""
|
|
|
|
if self.enforce_minmax:
|
|
pred_sdf = torch.clamp(pred_sdf, self.minT, self.maxT)
|
|
gt_sdf = torch.clamp(gt_sdf, self.minT, self.maxT)
|
|
|
|
# 1. L1 Loss的优势
|
|
# - 对异常值更鲁棒
|
|
# - 能更好地保持表面细节
|
|
base_loss = self.l1_loss(pred_sdf, gt_sdf) / pred_sdf.shape[0]
|
|
|
|
|
|
return base_loss / self.batch_size
|
|
|
|
|
|
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1):
|
|
"""SDF损失函数"""
|
|
# 确保points需要梯度
|
|
if not points.requires_grad:
|
|
points = points.detach().requires_grad_(True)
|
|
|
|
# L1损失
|
|
l1_loss = F.l1_loss(pred_sdf, gt_sdf)
|
|
|
|
try:
|
|
# 梯度约束损失
|
|
grad = torch.autograd.grad(
|
|
pred_sdf.sum(),
|
|
points,
|
|
create_graph=True,
|
|
retain_graph=True,
|
|
allow_unused=True
|
|
)[0]
|
|
|
|
if grad is not None:
|
|
grad_constraint = F.mse_loss(
|
|
torch.norm(grad, dim=-1),
|
|
torch.ones_like(pred_sdf.squeeze(-1))
|
|
)
|
|
else:
|
|
grad_constraint = torch.tensor(0.0, device=pred_sdf.device)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Gradient computation failed: {str(e)}")
|
|
grad_constraint = torch.tensor(0.0, device=pred_sdf.device)
|
|
|
|
return l1_loss + grad_weight * grad_constraint
|