|
|
@ -2,41 +2,66 @@ import torch |
|
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
from brep2sdf.config.default_config import get_default_config |
|
|
|
from brep2sdf.utils.logger import logger |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Brep2SDFLoss: |
|
|
|
"""解释Brep2SDF的loss设计原理""" |
|
|
|
def __init__(self, batch_size:float, enforce_minmax: bool=True, clamping_distance: float = 0.1): |
|
|
|
class Brep2SDFLoss(nn.Module): |
|
|
|
def __init__(self, batch_size:float=1, enforce_minmax: bool=True, clamping_distance: float=0.1, |
|
|
|
grad_weight=0.1, warmup_epochs=10): |
|
|
|
super().__init__() |
|
|
|
self.batch_size = batch_size |
|
|
|
self.l1_loss = nn.L1Loss(reduction='sum') |
|
|
|
self.enforce_minmax = enforce_minmax |
|
|
|
self.minT = -clamping_distance |
|
|
|
self.maxT = clamping_distance |
|
|
|
|
|
|
|
def __call__(self, pred_sdf, gt_sdf): |
|
|
|
"""使类可直接调用""" |
|
|
|
return self.forward(pred_sdf, gt_sdf) |
|
|
|
|
|
|
|
def forward(self, pred_sdf, gt_sdf): |
|
|
|
""" |
|
|
|
pred_sdf: 预测的SDF值 |
|
|
|
gt_sdf: 真实的SDF值 |
|
|
|
latent_vecs: 形状编码, 来自 brep |
|
|
|
epoch: 当前训练轮次 |
|
|
|
self.grad_weight = grad_weight |
|
|
|
self.warmup_epochs = warmup_epochs |
|
|
|
|
|
|
|
self.l1_loss = nn.L1Loss(reduction='mean') |
|
|
|
|
|
|
|
def forward(self, pred_sdf, gt_sdf, points=None, epoch=None): |
|
|
|
"""计算SDF预测的损失 |
|
|
|
|
|
|
|
Args: |
|
|
|
pred_sdf: 预测的SDF值 [B, N, 1] |
|
|
|
gt_sdf: 真实的SDF值 [B, N, 1] |
|
|
|
points: 查询点坐标 [B, N, 3],用于计算梯度损失 |
|
|
|
epoch: 当前训练轮次,用于损失权重调整 |
|
|
|
""" |
|
|
|
|
|
|
|
# 1. 对SDF值进行clamp |
|
|
|
if self.enforce_minmax: |
|
|
|
pred_sdf = torch.clamp(pred_sdf, self.minT, self.maxT) |
|
|
|
gt_sdf = torch.clamp(gt_sdf, self.minT, self.maxT) |
|
|
|
|
|
|
|
# 1. L1 Loss的优势 |
|
|
|
# - 对异常值更鲁棒 |
|
|
|
# - 能更好地保持表面细节 |
|
|
|
base_loss = self.l1_loss(pred_sdf, gt_sdf) / pred_sdf.shape[0] |
|
|
|
|
|
|
|
# 2. 计算基础L1损失 |
|
|
|
l1_loss = self.l1_loss(pred_sdf, gt_sdf) |
|
|
|
|
|
|
|
# 3. 计算梯度损失(如果提供了points) |
|
|
|
grad_loss = 0 |
|
|
|
if points is not None: |
|
|
|
grad_loss = self.gradient_loss(pred_sdf, points) |
|
|
|
|
|
|
|
# 4. 根据epoch调整权重 |
|
|
|
if epoch is not None: |
|
|
|
weight = min(1.0, epoch / self.warmup_epochs) |
|
|
|
else: |
|
|
|
weight = 1.0 |
|
|
|
|
|
|
|
# 5. 组合损失 |
|
|
|
total_loss = l1_loss + self.grad_weight * grad_loss * weight |
|
|
|
|
|
|
|
return total_loss |
|
|
|
|
|
|
|
def gradient_loss(self, pred_sdf, points): |
|
|
|
"""计算SDF梯度损失""" |
|
|
|
grad_pred = torch.autograd.grad( |
|
|
|
pred_sdf.sum(), |
|
|
|
points, |
|
|
|
create_graph=True |
|
|
|
)[0] |
|
|
|
|
|
|
|
grad_norm = torch.norm(grad_pred, dim=-1) |
|
|
|
grad_loss = torch.mean((grad_norm - 1.0) ** 2) |
|
|
|
|
|
|
|
return base_loss / self.batch_size |
|
|
|
return grad_loss |
|
|
|
|
|
|
|
|
|
|
|
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): |
|
|
|