diff --git a/brep2sdf/networks/loss.py b/brep2sdf/networks/loss.py index 7b4b394..eaebe9d 100644 --- a/brep2sdf/networks/loss.py +++ b/brep2sdf/networks/loss.py @@ -4,74 +4,6 @@ from .network import gradient from brep2sdf.config.default_config import get_default_config from brep2sdf.utils.logger import logger -class Brep2SDFLoss(nn.Module): - def __init__(self, batch_size:float=1, enforce_minmax: bool=True, clamping_distance: float=0.1, - grad_weight=0.1, warmup_epochs=10): - super().__init__() - self.batch_size = batch_size - self.enforce_minmax = enforce_minmax - self.minT = -clamping_distance - self.maxT = clamping_distance - self.grad_weight = grad_weight - self.warmup_epochs = warmup_epochs - - self.l1_loss = nn.L1Loss(reduction='mean') - self.mse_loss = nn.MSELoss(reduction='mean') - - def forward(self, pred_sdf, gt_sdf, points=None, epoch=None): - """计算SDF预测的损失 - - Args: - pred_sdf: 预测的SDF值 [B, N, 1] - gt_sdf: 真实的SDF值 [B, N, 1] - points: 查询点坐标 [B, N, 3],用于计算梯度损失 - epoch: 当前训练轮次,用于损失权重调整 - """ - # 1. 对SDF值进行clamp - if self.enforce_minmax: - pred_sdf = torch.clamp(pred_sdf, self.minT, self.maxT) - gt_sdf = torch.clamp(gt_sdf, self.minT, self.maxT) - - # 2. 计算基础L1损失 - l1_loss = self.l1_loss(pred_sdf, gt_sdf) - #mse_loss = self.mse_loss(pred_sdf, gt_sdf) - - # 3. 计算梯度损失(如果提供了points) - grad_loss = 0 - if points is not None: - grad_loss = self.gradient_loss(pred_sdf, points) - - # 4. 根据epoch调整权重 - if epoch is not None: - weight = min(1.0, epoch / self.warmup_epochs) - else: - weight = 1.0 - - # 5. 组合损失 - total_loss = l1_loss + self.grad_weight * grad_loss * weight - - return total_loss - - def gradient_loss(self, pred_sdf, points): - """计算SDF梯度损失""" - grad_pred = torch.autograd.grad( - pred_sdf.sum(), - points, - create_graph=True - )[0] - - grad_norm = torch.norm(grad_pred, dim=-1) - grad_loss = torch.mean((grad_norm - 1.0) ** 2) - - return grad_loss - - - - - - - - class LossManager: def __init__(self, ablation, **condition_kwargs):