diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 7f6407d..e05fa7c 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -251,13 +251,20 @@ class Trainer: gt_sdf = _gt_sdf[start_idx:end_idx] # SDF真值 normals = _normals[start_idx:end_idx] if args.use_normal else None # 法线 + nonmnfld_pnts, psdf = self.sampler.get_norm_points(mnfld_pnts, normals, 0.1) + + # --- 准备模型输入,启用梯度 --- mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 + nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 # --- 前向传播 --- mnfld_pred = self.model.forward_background( mnfld_pnts ) + nonmnfld_pred = self.model.forward_background( + nonmnfld_pnts + ) @@ -274,11 +281,14 @@ class Trainer: #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") logger.gpu_memory_stats("计算损失前") - loss, loss_details = self.loss_manager.compute_loss_stage1( + loss, loss_details = self.loss_manager.compute_loss( mnfld_pnts, + nonmnfld_pnts, normals, # 传递检查过的 normals gt_sdf, mnfld_pred, + nonmnfld_pred, + psdf ) else: loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)