From 92bcce4c2a7cf267b7c1989dc4336a386e9927bf Mon Sep 17 00:00:00 2001 From: mckay Date: Sat, 7 Dec 2024 13:01:24 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20loss=20=E5=A2=9E=E5=8A=A0batch=E5=BD=92?= =?UTF-8?q?=E4=B8=80=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/networks/loss.py | 7 +++++-- brep2sdf/train.py | 1 + 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/brep2sdf/networks/loss.py b/brep2sdf/networks/loss.py index adbe2cf..d43e1ed 100644 --- a/brep2sdf/networks/loss.py +++ b/brep2sdf/networks/loss.py @@ -3,9 +3,12 @@ import torch.nn as nn from brep2sdf.config.default_config import get_default_config + + class Brep2SDFLoss: """解释Brep2SDF的loss设计原理""" - def __init__(self, enforce_minmax: bool=True, clamping_distance: float = 0.1): + def __init__(self, batch_size:float, enforce_minmax: bool=True, clamping_distance: float = 0.1): + self.batch_size = batch_size self.l1_loss = nn.L1Loss(reduction='sum') self.enforce_minmax = enforce_minmax self.minT = -clamping_distance @@ -33,7 +36,7 @@ class Brep2SDFLoss: base_loss = self.l1_loss(pred_sdf, gt_sdf) / pred_sdf.shape[0] - return base_loss + return base_loss / self.batch_size def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 3a9630e..62ee8dd 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -48,6 +48,7 @@ class Trainer: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') clamping_distance = self.config.train.clamping_distance self.criterion = Brep2SDFLoss( + batch_size = config.train.batch_size, enforce_minmax= (clamping_distance > 0), clamping_distance= clamping_distance )