You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
76 lines
2.6 KiB
76 lines
2.6 KiB
import torch
|
|
|
|
|
|
class LossManager:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def position_loss(self, outputs):
|
|
"""
|
|
计算流型损失的逻辑
|
|
|
|
:param outputs: 模型的输出
|
|
:return: 计算得到的流型损失值
|
|
"""
|
|
|
|
# 计算流型损失(这里使用均方误差作为示例)
|
|
manifold_loss = (outputs.abs()).mean() # 计算流型损失
|
|
return manifold_loss
|
|
|
|
def normals_loss(self, cur_data: torch.Tensor, mnfld_pnts: torch.Tensor, all_fi: torch.Tensor, patch_sup: bool) -> torch.Tensor:
|
|
"""
|
|
计算法线损失
|
|
|
|
:param cur_data: 当前数据,包含法线信息
|
|
:param mnfld_pnts: 流型点
|
|
:param all_fi: 所有流型预测值
|
|
:param patch_sup: 是否支持补丁
|
|
:return: 计算得到的法线损失
|
|
"""
|
|
# 提取法线
|
|
normals = cur_data[:, -self.d_in:]
|
|
|
|
# 计算分支梯度
|
|
branch_grad = gradient(mnfld_pnts, all_fi[:, 0]) # 计算分支梯度
|
|
|
|
# 计算法线损失
|
|
normals_loss = (((branch_grad - normals).abs()).norm(2, dim=1)).mean() # 计算法线损失
|
|
|
|
return self.normals_lambda * normals_loss # 返回加权后的法线损失
|
|
|
|
def eikonal_loss(self, nonmnfld_pnts, nonmnfld_pred_all):
|
|
"""
|
|
计算Eikonal损失
|
|
"""
|
|
grad_loss_h = torch.zeros(1).cuda() # 初始化 Eikonal 损失
|
|
single_nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred_all[:,0]) # 计算非流形点的梯度
|
|
eikonal_loss = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean() # 计算 Eikonal 损失
|
|
return eikonal_loss
|
|
|
|
def offsurface_loss(self, nonmnfld_pnts, nonmnfld_pred_all):
|
|
"""
|
|
Eo
|
|
惩罚远离表面但是预测值接近0的点
|
|
"""
|
|
offsurface_loss = torch.exp(-100.0 * torch.abs(nonmnfld_pred_all[:,0])).mean() # 计算离表面损失
|
|
return offsurface_loss
|
|
|
|
def consistency_loss(self, mnfld_pnts, mnfld_pred_all, all_fi):
|
|
"""
|
|
惩罚流形点预测值和非流形点预测值不一致的点
|
|
"""
|
|
mnfld_consistency_loss = (mnfld_pred - all_fi[:,0]).abs().mean() # 计算流形一致性损失
|
|
return mnfld_consistency_loss
|
|
|
|
def compute_loss(self, outputs):
|
|
"""
|
|
计算流型损失的逻辑
|
|
|
|
:param outputs: 模型的输出
|
|
:return: 计算得到的流型损失值
|
|
"""
|
|
|
|
# 计算流型损失(这里使用均方误差作为示例)
|
|
manifold_loss = (outputs.abs()).mean() # 计算流型损失
|
|
return manifold_loss
|
|
|
|
|