From ff2cf974317ede3f8c03ff7733e61fce10fa9b14 Mon Sep 17 00:00:00 2001 From: mckay Date: Fri, 23 May 2025 16:23:52 +0800 Subject: [PATCH] =?UTF-8?q?loss=20=E5=88=86=E6=94=AFtrain?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/train.py | 53 +++++++++++++++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 18 deletions(-) diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 2bd2aa5..dd6a281 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -281,15 +281,24 @@ class Trainer: #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") logger.gpu_memory_stats("计算损失前") - loss, loss_details = self.loss_manager.compute_loss( - mnfld_pnts, - nonmnfld_pnts, - normals, # 传递检查过的 normals - gt_sdf, - mnfld_pred, - nonmnfld_pred, - psdf - ) + + if args.only_zero_surface: + loss, loss_details = self.loss_manager.compute_loss( + mnfld_pnts, + nonmnfld_pnts, + normals, # 传递检查过的 normals + gt_sdf, + mnfld_pred, + nonmnfld_pred, + psdf + ) + else: + loss, loss_details = self.loss_manager.compute_loss_stage1( + mnfld_pnts, + normals, # 传递检查过的 normals + gt_sdf, + mnfld_pred + ) else: loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) @@ -696,15 +705,23 @@ class Trainer: #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") #logger.gpu_memory_stats("计算损失前") - loss, loss_details = self.loss_manager.compute_loss( - mnfld_pnts, - nonmnfld_pnts, - normals, # 传递检查过的 normals - gt_sdf, - mnfld_pred, - nonmnfld_pred, - psdf - ) + if args.only_zero_surface: + loss, loss_details = self.loss_manager.compute_loss( + mnfld_pnts, + nonmnfld_pnts, + normals, # 传递检查过的 normals + gt_sdf, + mnfld_pred, + nonmnfld_pred, + psdf + ) + else: + loss, loss_details = self.loss_manager.compute_loss_stage1( + mnfld_pnts, + normals, # 传递检查过的 normals + gt_sdf, + mnfld_pred + ) #logger.gpu_memory_stats("计算损失后") else: loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)