diff --git a/code/conversion/loss.py b/code/conversion/loss.py index c5bc997..8bf045b 100644 --- a/code/conversion/loss.py +++ b/code/conversion/loss.py @@ -38,6 +38,30 @@ class LossManager: return self.normals_lambda * normals_loss # 返回加权后的法线损失 + def eikonal_loss(self, nonmnfld_pnts, nonmnfld_pred_all): + """ + 计算Eikonal损失 + """ + grad_loss_h = torch.zeros(1).cuda() # 初始化 Eikonal 损失 + single_nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred_all[:,0]) # 计算非流形点的梯度 + eikonal_loss = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean() # 计算 Eikonal 损失 + return eikonal_loss + + def offsurface_loss(self, nonmnfld_pnts, nonmnfld_pred_all): + """ + Eo + 惩罚远离表面但是预测值接近0的点 + """ + offsurface_loss = torch.exp(-100.0 * torch.abs(nonmnfld_pred_all[:,0])).mean() # 计算离表面损失 + return offsurface_loss + + def consistency_loss(self, mnfld_pnts, mnfld_pred_all, all_fi): + """ + 惩罚流形点预测值和非流形点预测值不一致的点 + """ + mnfld_consistency_loss = (mnfld_pred - all_fi[:,0]).abs().mean() # 计算流形一致性损失 + return mnfld_consistency_loss + def compute_loss(self, outputs): """ 计算流型损失的逻辑