|
|
@ -182,6 +182,49 @@ class LossManager: |
|
|
|
|
|
|
|
return total_loss, loss_details |
|
|
|
|
|
|
|
def compute_loss_stage1(self, |
|
|
|
mnfld_pnts, |
|
|
|
normals, |
|
|
|
gt_sdfs, |
|
|
|
mnfld_pred, |
|
|
|
): |
|
|
|
""" |
|
|
|
计算流型损失的逻辑 |
|
|
|
|
|
|
|
:param outputs: 模型的输出 |
|
|
|
:return: 计算得到的流型损失值 |
|
|
|
""" |
|
|
|
# 强制类型转换确保一致性 |
|
|
|
normals = normals.to(torch.float32) |
|
|
|
mnfld_pred = mnfld_pred.to(torch.float32) |
|
|
|
gt_sdfs = gt_sdfs.to(torch.float32) |
|
|
|
|
|
|
|
# 计算流形损失 |
|
|
|
manifold_loss = self.position_loss(mnfld_pred, gt_sdfs) |
|
|
|
|
|
|
|
# 计算法线损失 |
|
|
|
normals_loss = self.normals_loss(normals, mnfld_pnts, mnfld_pred) |
|
|
|
#logger.gpu_memory_stats("计算法线损失后") |
|
|
|
|
|
|
|
|
|
|
|
# 计算一致性损失 |
|
|
|
#onsistency_loss = self.consistency_loss(mnfld_pnts, mnfld_pred, all_fi) |
|
|
|
|
|
|
|
# 计算修正损失 |
|
|
|
#correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi) |
|
|
|
|
|
|
|
|
|
|
|
# 汇总损失 |
|
|
|
loss_details = { |
|
|
|
"manifold": self.weights["manifold"] * manifold_loss, |
|
|
|
"normals": self.weights["normals"] * normals_loss |
|
|
|
} |
|
|
|
|
|
|
|
# 计算总损失 |
|
|
|
total_loss = sum(loss_details.values()) |
|
|
|
|
|
|
|
return total_loss, loss_details |
|
|
|
|
|
|
|
def compute_loss_volume(self, |
|
|
|
mnfld_pnts, |
|
|
|
nonmnfld_pnts, |
|
|
|