Browse Source

fix: loss 增加batch归一化

main
mckay 3 months ago
parent
commit
92bcce4c2a
  1. 7
      brep2sdf/networks/loss.py
  2. 1
      brep2sdf/train.py

7
brep2sdf/networks/loss.py

@ -3,9 +3,12 @@ import torch.nn as nn
from brep2sdf.config.default_config import get_default_config
class Brep2SDFLoss:
"""解释Brep2SDF的loss设计原理"""
def __init__(self, enforce_minmax: bool=True, clamping_distance: float = 0.1):
def __init__(self, batch_size:float, enforce_minmax: bool=True, clamping_distance: float = 0.1):
self.batch_size = batch_size
self.l1_loss = nn.L1Loss(reduction='sum')
self.enforce_minmax = enforce_minmax
self.minT = -clamping_distance
@ -33,7 +36,7 @@ class Brep2SDFLoss:
base_loss = self.l1_loss(pred_sdf, gt_sdf) / pred_sdf.shape[0]
return base_loss
return base_loss / self.batch_size
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1):

1
brep2sdf/train.py

@ -48,6 +48,7 @@ class Trainer:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
clamping_distance = self.config.train.clamping_distance
self.criterion = Brep2SDFLoss(
batch_size = config.train.batch_size,
enforce_minmax= (clamping_distance > 0),
clamping_distance= clamping_distance
)

Loading…
Cancel
Save