You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

195 lines
8.5 KiB

import os
import sys
import time
project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
sys.path.append(project_dir)
os.chdir(project_dir)
import torch
from model.network import gradient
class LossManager:
def __init__(self, ablation, **condition_kwargs):
self.weights = {
"manifold": 1,
"feature_manifold": 1, # 原文里面和manifold的权重是一样的
"normals": 1,
"eikonal": 1,
"offsurface": 1,
"consistency": 1,
"correction": 1,
}
self.condition_kwargs = condition_kwargs
self.ablation = ablation # 消融实验用
def _get_condition_kwargs(self, key):
"""
获取条件参数, 期望
ab: 损失类型 【overall, patch, off, cons, cc, cor,】
siren: 是否使用SIREN
epoch: 当前epoch
baseline: 是否为baseline
"""
if key in self.condition_kwargs:
return self.condition_kwargs[key]
else:
raise ValueError(f"Key {key} not found in condition_kwargs")
def pre_process(self, mnfld_pnts, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last):
"""
预处理
"""
mnfld_pred_h = mnfld_pred_all[:,0] # 提取流形预测结果
nonmnfld_pred_h = nonmnfld_pred_all[:,0] # 提取非流形预测结果
mnfld_grad = gradient(mnfld_pnts, mnfld_pred_h) # 计算流形点的梯度
all_fi = torch.zeros([n_batchsize, 1], device = 'cuda') # 初始化所有流形预测值
for i in range(n_branch - 1):
all_fi[i * n_patch_batch : (i + 1) * n_patch_batch, 0] = mnfld_pred_all[i * n_patch_batch : (i + 1) * n_patch_batch, i + 1] # 填充流形预测值
# last patch
all_fi[(n_branch - 1) * n_patch_batch:, 0] = mnfld_pred_all[(n_branch - 1) * n_patch_batch:, n_branch] # 填充最后一个分支的流形预测值
return mnfld_pred_h, nonmnfld_pred_h, mnfld_grad, all_fi
def position_loss(self, outputs):
"""
计算流型损失的逻辑
:param outputs: 模型的输出
:return: 计算得到的流型损失值
"""
# 计算流型损失(这里使用均方误差作为示例)
manifold_loss = (outputs.abs()).mean() # 计算流型损失
return manifold_loss
def normals_loss(self, normals: torch.Tensor, mnfld_pnts: torch.Tensor, all_fi: torch.Tensor, patch_sup: bool = True) -> torch.Tensor:
"""
计算法线损失
:param normals: 法线
:param mnfld_pnts: 流型点
:param all_fi: 所有流型预测值
:param patch_sup: 是否支持补丁
:return: 计算得到的法线损失
"""
# NOTE 源代码 这里还有复杂逻辑
# 计算分支梯度
branch_grad = gradient(mnfld_pnts, all_fi[:, 0]) # 计算分支梯度
# 计算法线损失
normals_loss = (((branch_grad - normals).abs()).norm(2, dim=1)).mean() # 计算法线损失
return normals_loss # 返回加权后的法线损失
def eikonal_loss(self, nonmnfld_pnts, nonmnfld_pred):
"""
计算Eikonal损失
"""
grad_loss_h = torch.zeros(1).cuda() # 初始化 Eikonal 损失
single_nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred) # 计算非流形点的梯度
eikonal_loss = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean() # 计算 Eikonal 损失
return eikonal_loss
def offsurface_loss(self, nonmnfld_pnts, nonmnfld_pred):
"""
Eo
惩罚远离表面但是预测值接近0的点
"""
offsurface_loss = torch.zeros(1).cuda()
if not self.ablation == 'off':
offsurface_loss = torch.exp(-100.0 * torch.abs(nonmnfld_pred)).mean() # 计算离表面损失
return offsurface_loss
def consistency_loss(self, mnfld_pnts, mnfld_pred, all_fi):
"""
惩罚流形点预测值和非流形点预测值不一致的点
"""
mnfld_consistency_loss = torch.zeros(1).cuda()
if not (self.ablation == 'cons' or self.ablation == 'cc'):
mnfld_consistency_loss = (mnfld_pred - all_fi[:,0]).abs().mean() # 计算流形一致性损失
return mnfld_consistency_loss
def correction_loss(self, mnfld_pnts, mnfld_pred, all_fi, th_closeness = 1e-5, a_correction = 100):
"""
修正损失
"""
correction_loss = torch.zeros(1).cuda() # 初始化修正损失
if not (self.ablation == 'cor' or self.ablation == 'cc'):
mismatch_id = torch.abs(mnfld_pred - all_fi[:,0]) > th_closeness # 计算不匹配的 ID
if mismatch_id.sum() != 0: # 如果存在不匹配
correction_loss = (a_correction * torch.abs(mnfld_pred - all_fi[:,0])[mismatch_id]).mean() # 计算修正损失
return correction_loss
def compute_loss(self, mnfld_pnts, normals, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last):
"""
计算流型损失的逻辑
:param outputs: 模型的输出
:return: 计算得到的流型损失值
"""
mnfld_pred, nonmnfld_pred, mnfld_grad, all_fi = self.pre_process(mnfld_pnts, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last)
manifold_loss = torch.zeros(1).cuda()
# 计算流型损失(这里使用均方误差作为示例)
if not self.ablation == 'overall':
manifold_loss = (mnfld_pred.abs()).mean() # 计算流型损失
'''
if args.feature_sample: # 如果启用了特征采样
feature_indices = torch.randperm(args.all_feature_sample)[:args.num_feature_sample].cuda() # 随机选择特征点
feature_pnts = self.feature_data[feature_indices] # 获取特征点数据
feature_mask_pair = self.feature_data_mask_pair[feature_indices] # 获取特征掩码对
feature_pred_all = self.network(feature_pnts) # 进行前向传播,计算特征点的预测值
feature_pred = feature_pred_all[:,0] # 提取特征预测结果
feature_mnfld_loss = feature_pred.abs().mean() # 计算特征流形损失
loss = loss + weight_mnfld_h * feature_mnfld_loss # 将特征流形损失加权到总损失中
# patch loss:
feature_id_left = [list(range(args.num_feature_sample)), feature_mask_pair[:,0].tolist()] # 获取左侧特征 ID
feature_id_right = [list(range(args.num_feature_sample)), feature_mask_pair[:,1].tolist()] # 获取右侧特征 ID
feature_fis_left = feature_pred_all[feature_id_left] # 获取左侧特征预测值
feature_fis_right = feature_pred_all[feature_id_right] # 获取右侧特征预测值
feature_loss_patch = feature_fis_left.abs().mean() + feature_fis_right.abs().mean() # 计算补丁损失
loss += feature_loss_patch # 将补丁损失加权到总损失中
# consistency loss:
feature_loss_cons = (feature_fis_left - feature_pred).abs().mean() + (feature_fis_right - feature_pred).abs().mean() # 计算一致性损失
'''
manifold_loss_patch = torch.zeros(1).cuda()
if self.ablation == 'patch':
manifold_loss_patch = all_fi[:,0].abs().mean()
# 计算法线损失
normals_loss = self.normals_loss(normals, mnfld_pnts, all_fi, patch_sup=True)
# 计算Eikonal损失
eikonal_loss = self.eikonal_loss(nonmnfld_pnts, nonmnfld_pred_all)
# 计算离表面损失
offsurface_loss = self.offsurface_loss(nonmnfld_pnts, nonmnfld_pred_all)
# 计算一致性损失
consistency_loss = self.consistency_loss(mnfld_pnts, mnfld_pred, all_fi)
# 计算修正损失
correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi)
loss_details = {
"manifold": self.weights["manifold"] * manifold_loss,
"manifold_patch": manifold_loss_patch,
"normals": self.weights["normals"] * normals_loss,
"eikonal": self.weights["eikonal"] * eikonal_loss,
"offsurface": self.weights["offsurface"] * offsurface_loss,
"consistency": self.weights["consistency"] * consistency_loss,
"correction": self.weights["correction"] * correction_loss,
}
# 计算总损失
total_loss = sum(loss_details.values())
return total_loss, loss_details