|
|
@ -3,9 +3,12 @@ import torch.nn as nn |
|
|
|
|
|
|
|
from brep2sdf.config.default_config import get_default_config |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Brep2SDFLoss: |
|
|
|
"""解释Brep2SDF的loss设计原理""" |
|
|
|
def __init__(self, enforce_minmax: bool=True, clamping_distance: float = 0.1): |
|
|
|
def __init__(self, batch_size:float, enforce_minmax: bool=True, clamping_distance: float = 0.1): |
|
|
|
self.batch_size = batch_size |
|
|
|
self.l1_loss = nn.L1Loss(reduction='sum') |
|
|
|
self.enforce_minmax = enforce_minmax |
|
|
|
self.minT = -clamping_distance |
|
|
@ -33,7 +36,7 @@ class Brep2SDFLoss: |
|
|
|
base_loss = self.l1_loss(pred_sdf, gt_sdf) / pred_sdf.shape[0] |
|
|
|
|
|
|
|
|
|
|
|
return base_loss |
|
|
|
return base_loss / self.batch_size |
|
|
|
|
|
|
|
|
|
|
|
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): |
|
|
|