You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

52 lines
1.6 KiB

import torch
class LossManager:
def __init__(self):
pass
def position_loss(self, outputs):
"""
计算流型损失的逻辑
:param outputs: 模型的输出
:return: 计算得到的流型损失值
"""
# 计算流型损失(这里使用均方误差作为示例)
manifold_loss = (outputs.abs()).mean() # 计算流型损失
return manifold_loss
def normals_loss(self, cur_data: torch.Tensor, mnfld_pnts: torch.Tensor, all_fi: torch.Tensor, patch_sup: bool) -> torch.Tensor:
"""
计算法线损失
:param cur_data: 当前数据,包含法线信息
:param mnfld_pnts: 流型点
:param all_fi: 所有流型预测值
:param patch_sup: 是否支持补丁
:return: 计算得到的法线损失
"""
# 提取法线
normals = cur_data[:, -self.d_in:]
# 计算分支梯度
branch_grad = gradient(mnfld_pnts, all_fi[:, 0]) # 计算分支梯度
# 计算法线损失
normals_loss = (((branch_grad - normals).abs()).norm(2, dim=1)).mean() # 计算法线损失
return self.normals_lambda * normals_loss # 返回加权后的法线损失
def compute_loss(self, outputs):
"""
计算流型损失的逻辑
:param outputs: 模型的输出
:return: 计算得到的流型损失值
"""
# 计算流型损失(这里使用均方误差作为示例)
manifold_loss = (outputs.abs()).mean() # 计算流型损失
return manifold_loss