From d266c37a049c7834c197f7fa4ce85b58e4b21cb9 Mon Sep 17 00:00:00 2001 From: mckay Date: Tue, 20 May 2025 22:53:52 +0800 Subject: [PATCH] =?UTF-8?q?stage1=20=E4=BD=BF=E7=94=A8=E5=AE=8C=E6=95=B4lo?= =?UTF-8?q?ss?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/train.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 7f6407d..e05fa7c 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -251,13 +251,20 @@ class Trainer: gt_sdf = _gt_sdf[start_idx:end_idx] # SDF真值 normals = _normals[start_idx:end_idx] if args.use_normal else None # 法线 + nonmnfld_pnts, psdf = self.sampler.get_norm_points(mnfld_pnts, normals, 0.1) + + # --- 准备模型输入,启用梯度 --- mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 + nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 # --- 前向传播 --- mnfld_pred = self.model.forward_background( mnfld_pnts ) + nonmnfld_pred = self.model.forward_background( + nonmnfld_pnts + ) @@ -274,11 +281,14 @@ class Trainer: #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") logger.gpu_memory_stats("计算损失前") - loss, loss_details = self.loss_manager.compute_loss_stage1( + loss, loss_details = self.loss_manager.compute_loss( mnfld_pnts, + nonmnfld_pnts, normals, # 传递检查过的 normals gt_sdf, mnfld_pred, + nonmnfld_pred, + psdf ) else: loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)