Browse Source

loss 分支train

final
mckay 2 weeks ago
parent
commit
ff2cf97431
  1. 53
      brep2sdf/train.py

53
brep2sdf/train.py

@ -281,15 +281,24 @@ class Trainer:
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss")
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss")
logger.gpu_memory_stats("计算损失前") logger.gpu_memory_stats("计算损失前")
loss, loss_details = self.loss_manager.compute_loss(
mnfld_pnts, if args.only_zero_surface:
nonmnfld_pnts, loss, loss_details = self.loss_manager.compute_loss(
normals, # 传递检查过的 normals mnfld_pnts,
gt_sdf, nonmnfld_pnts,
mnfld_pred, normals, # 传递检查过的 normals
nonmnfld_pred, gt_sdf,
psdf mnfld_pred,
) nonmnfld_pred,
psdf
)
else:
loss, loss_details = self.loss_manager.compute_loss_stage1(
mnfld_pnts,
normals, # 传递检查过的 normals
gt_sdf,
mnfld_pred
)
else: else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
@ -696,15 +705,23 @@ class Trainer:
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss")
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss")
#logger.gpu_memory_stats("计算损失前") #logger.gpu_memory_stats("计算损失前")
loss, loss_details = self.loss_manager.compute_loss( if args.only_zero_surface:
mnfld_pnts, loss, loss_details = self.loss_manager.compute_loss(
nonmnfld_pnts, mnfld_pnts,
normals, # 传递检查过的 normals nonmnfld_pnts,
gt_sdf, normals, # 传递检查过的 normals
mnfld_pred, gt_sdf,
nonmnfld_pred, mnfld_pred,
psdf nonmnfld_pred,
) psdf
)
else:
loss, loss_details = self.loss_manager.compute_loss_stage1(
mnfld_pnts,
normals, # 传递检查过的 normals
gt_sdf,
mnfld_pred
)
#logger.gpu_memory_stats("计算损失后") #logger.gpu_memory_stats("计算损失后")
else: else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)

Loading…
Cancel
Save