|
|
@ -281,15 +281,24 @@ class Trainer: |
|
|
|
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") |
|
|
|
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") |
|
|
|
logger.gpu_memory_stats("计算损失前") |
|
|
|
loss, loss_details = self.loss_manager.compute_loss( |
|
|
|
mnfld_pnts, |
|
|
|
nonmnfld_pnts, |
|
|
|
normals, # 传递检查过的 normals |
|
|
|
gt_sdf, |
|
|
|
mnfld_pred, |
|
|
|
nonmnfld_pred, |
|
|
|
psdf |
|
|
|
) |
|
|
|
|
|
|
|
if args.only_zero_surface: |
|
|
|
loss, loss_details = self.loss_manager.compute_loss( |
|
|
|
mnfld_pnts, |
|
|
|
nonmnfld_pnts, |
|
|
|
normals, # 传递检查过的 normals |
|
|
|
gt_sdf, |
|
|
|
mnfld_pred, |
|
|
|
nonmnfld_pred, |
|
|
|
psdf |
|
|
|
) |
|
|
|
else: |
|
|
|
loss, loss_details = self.loss_manager.compute_loss_stage1( |
|
|
|
mnfld_pnts, |
|
|
|
normals, # 传递检查过的 normals |
|
|
|
gt_sdf, |
|
|
|
mnfld_pred |
|
|
|
) |
|
|
|
else: |
|
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) |
|
|
|
|
|
|
@ -696,15 +705,23 @@ class Trainer: |
|
|
|
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") |
|
|
|
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") |
|
|
|
#logger.gpu_memory_stats("计算损失前") |
|
|
|
loss, loss_details = self.loss_manager.compute_loss( |
|
|
|
mnfld_pnts, |
|
|
|
nonmnfld_pnts, |
|
|
|
normals, # 传递检查过的 normals |
|
|
|
gt_sdf, |
|
|
|
mnfld_pred, |
|
|
|
nonmnfld_pred, |
|
|
|
psdf |
|
|
|
) |
|
|
|
if args.only_zero_surface: |
|
|
|
loss, loss_details = self.loss_manager.compute_loss( |
|
|
|
mnfld_pnts, |
|
|
|
nonmnfld_pnts, |
|
|
|
normals, # 传递检查过的 normals |
|
|
|
gt_sdf, |
|
|
|
mnfld_pred, |
|
|
|
nonmnfld_pred, |
|
|
|
psdf |
|
|
|
) |
|
|
|
else: |
|
|
|
loss, loss_details = self.loss_manager.compute_loss_stage1( |
|
|
|
mnfld_pnts, |
|
|
|
normals, # 传递检查过的 normals |
|
|
|
gt_sdf, |
|
|
|
mnfld_pred |
|
|
|
) |
|
|
|
#logger.gpu_memory_stats("计算损失后") |
|
|
|
else: |
|
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) |
|
|
|