Browse Source

stage1 使用完整loss

final
mckay 3 weeks ago
parent
commit
d266c37a04
  1. 12
      brep2sdf/train.py

12
brep2sdf/train.py

@ -251,13 +251,20 @@ class Trainer:
gt_sdf = _gt_sdf[start_idx:end_idx] # SDF真值
normals = _normals[start_idx:end_idx] if args.use_normal else None # 法线
nonmnfld_pnts, psdf = self.sampler.get_norm_points(mnfld_pnts, normals, 0.1)
# --- 准备模型输入,启用梯度 ---
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
# --- 前向传播 ---
mnfld_pred = self.model.forward_background(
mnfld_pnts
)
nonmnfld_pred = self.model.forward_background(
nonmnfld_pnts
)
@ -274,11 +281,14 @@ class Trainer:
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss")
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss")
logger.gpu_memory_stats("计算损失前")
loss, loss_details = self.loss_manager.compute_loss_stage1(
loss, loss_details = self.loss_manager.compute_loss(
mnfld_pnts,
nonmnfld_pnts,
normals, # 传递检查过的 normals
gt_sdf,
mnfld_pred,
nonmnfld_pred,
psdf
)
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)

Loading…
Cancel
Save