You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
190 lines
8.4 KiB
190 lines
8.4 KiB
import os
|
|
import sys
|
|
import time
|
|
|
|
project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
|
|
sys.path.append(project_dir)
|
|
os.chdir(project_dir)
|
|
|
|
import torch
|
|
from model.network import gradient
|
|
|
|
class LossManager:
|
|
def __init__(self, ablation, **condition_kwargs):
|
|
self.weights = {
|
|
"manifold": 1,
|
|
"feature_manifold": 1, # 原文里面和manifold的权重是一样的
|
|
"normals": 1,
|
|
"eikonal": 1,
|
|
"offsurface": 1,
|
|
"consistency": 1,
|
|
"correction": 1,
|
|
}
|
|
self.condition_kwargs = condition_kwargs
|
|
self.ablation = ablation # 消融实验用
|
|
|
|
def _get_condition_kwargs(self, key):
|
|
"""
|
|
获取条件参数, 期望
|
|
ab: 损失类型 【overall, patch, off, cons, cc, cor,】
|
|
siren: 是否使用SIREN
|
|
epoch: 当前epoch
|
|
baseline: 是否为baseline
|
|
"""
|
|
if key in self.condition_kwargs:
|
|
return self.condition_kwargs[key]
|
|
else:
|
|
raise ValueError(f"Key {key} not found in condition_kwargs")
|
|
|
|
|
|
def pre_process(self, mnfld_pnts, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last):
|
|
"""
|
|
预处理
|
|
"""
|
|
mnfld_pred_h = mnfld_pred_all[:,0] # 提取流形预测结果
|
|
nonmnfld_pred_h = nonmnfld_pred_all[:,0] # 提取非流形预测结果
|
|
mnfld_grad = gradient(mnfld_pnts, mnfld_pred_h) # 计算流形点的梯度
|
|
|
|
all_fi = torch.zeros([n_batchsize, 1], device = 'cuda') # 初始化所有流形预测值
|
|
for i in range(n_branch - 1):
|
|
all_fi[i * n_patch_batch : (i + 1) * n_patch_batch, 0] = mnfld_pred_all[i * n_patch_batch : (i + 1) * n_patch_batch, i + 1] # 填充流形预测值
|
|
# last patch
|
|
all_fi[(n_branch - 1) * n_patch_batch:, 0] = mnfld_pred_all[(n_branch - 1) * n_patch_batch:, n_branch] # 填充最后一个分支的流形预测值
|
|
|
|
return mnfld_pred_h, nonmnfld_pred_h, mnfld_grad, all_fi
|
|
|
|
def position_loss(self, outputs):
|
|
"""
|
|
计算流型损失的逻辑
|
|
|
|
:param outputs: 模型的输出
|
|
:return: 计算得到的流型损失值
|
|
"""
|
|
|
|
# 计算流型损失(这里使用均方误差作为示例)
|
|
manifold_loss = (outputs.abs()).mean() # 计算流型损失
|
|
return manifold_loss
|
|
|
|
def normals_loss(self, normals: torch.Tensor, mnfld_pnts: torch.Tensor, all_fi: torch.Tensor, patch_sup: bool = True) -> torch.Tensor:
|
|
"""
|
|
计算法线损失
|
|
|
|
:param normals: 法线
|
|
:param mnfld_pnts: 流型点
|
|
:param all_fi: 所有流型预测值
|
|
:param patch_sup: 是否支持补丁
|
|
:return: 计算得到的法线损失
|
|
"""
|
|
# NOTE 源代码 这里还有复杂逻辑
|
|
# 计算分支梯度
|
|
branch_grad = gradient(mnfld_pnts, all_fi[:, 0]) # 计算分支梯度
|
|
|
|
# 计算法线损失
|
|
normals_loss = (((branch_grad - normals).abs()).norm(2, dim=1)).mean() # 计算法线损失
|
|
|
|
return normals_loss # 返回加权后的法线损失
|
|
|
|
def eikonal_loss(self, nonmnfld_pnts, nonmnfld_pred):
|
|
"""
|
|
计算Eikonal损失
|
|
"""
|
|
grad_loss_h = torch.zeros(1).cuda() # 初始化 Eikonal 损失
|
|
single_nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred) # 计算非流形点的梯度
|
|
eikonal_loss = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean() # 计算 Eikonal 损失
|
|
return eikonal_loss
|
|
|
|
def offsurface_loss(self, nonmnfld_pnts, nonmnfld_pred):
|
|
"""
|
|
Eo
|
|
惩罚远离表面但是预测值接近0的点
|
|
"""
|
|
offsurface_loss = torch.zeros(1).cuda()
|
|
if not self.ablation == 'off':
|
|
offsurface_loss = torch.exp(-100.0 * torch.abs(nonmnfld_pred)).mean() # 计算离表面损失
|
|
return offsurface_loss
|
|
|
|
def consistency_loss(self, mnfld_pnts, mnfld_pred, all_fi):
|
|
"""
|
|
惩罚流形点预测值和非流形点预测值不一致的点
|
|
"""
|
|
mnfld_consistency_loss = torch.zeros(1).cuda()
|
|
if not (self.ablation == 'cons' or self.ablation == 'cc'):
|
|
mnfld_consistency_loss = (mnfld_pred - all_fi[:,0]).abs().mean() # 计算流形一致性损失
|
|
return mnfld_consistency_loss
|
|
|
|
def correction_loss(self, mnfld_pnts, mnfld_pred, all_fi, th_closeness = 1e-5, a_correction = 100):
|
|
"""
|
|
修正损失
|
|
"""
|
|
correction_loss = torch.zeros(1).cuda() # 初始化修正损失
|
|
if not (self.ablation == 'cor' or self.ablation == 'cc'):
|
|
mismatch_id = torch.abs(mnfld_pred - all_fi[:,0]) > th_closeness # 计算不匹配的 ID
|
|
if mismatch_id.sum() != 0: # 如果存在不匹配
|
|
correction_loss = (a_correction * torch.abs(mnfld_pred - all_fi[:,0])[mismatch_id]).mean() # 计算修正损失
|
|
return correction_loss
|
|
|
|
def compute_loss(self, mnfld_pnts, normals, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last):
|
|
"""
|
|
计算流型损失的逻辑
|
|
|
|
:param outputs: 模型的输出
|
|
:return: 计算得到的流型损失值
|
|
"""
|
|
mnfld_pred, nonmnfld_pred, mnfld_grad, all_fi = self.pre_process(mnfld_pnts, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last)
|
|
manifold_loss = torch.zeros(1).cuda()
|
|
# 计算流型损失(这里使用均方误差作为示例)
|
|
if not self.ablation == 'overall':
|
|
manifold_loss = (mnfld_pred.abs()).mean() # 计算流型损失
|
|
'''
|
|
if args.feature_sample: # 如果启用了特征采样
|
|
feature_indices = torch.randperm(args.all_feature_sample)[:args.num_feature_sample].cuda() # 随机选择特征点
|
|
feature_pnts = self.feature_data[feature_indices] # 获取特征点数据
|
|
feature_mask_pair = self.feature_data_mask_pair[feature_indices] # 获取特征掩码对
|
|
feature_pred_all = self.network(feature_pnts) # 进行前向传播,计算特征点的预测值
|
|
feature_pred = feature_pred_all[:,0] # 提取特征预测结果
|
|
feature_mnfld_loss = feature_pred.abs().mean() # 计算特征流形损失
|
|
loss = loss + weight_mnfld_h * feature_mnfld_loss # 将特征流形损失加权到总损失中
|
|
|
|
# patch loss:
|
|
feature_id_left = [list(range(args.num_feature_sample)), feature_mask_pair[:,0].tolist()] # 获取左侧特征 ID
|
|
feature_id_right = [list(range(args.num_feature_sample)), feature_mask_pair[:,1].tolist()] # 获取右侧特征 ID
|
|
feature_fis_left = feature_pred_all[feature_id_left] # 获取左侧特征预测值
|
|
feature_fis_right = feature_pred_all[feature_id_right] # 获取右侧特征预测值
|
|
feature_loss_patch = feature_fis_left.abs().mean() + feature_fis_right.abs().mean() # 计算补丁损失
|
|
loss += feature_loss_patch # 将补丁损失加权到总损失中
|
|
|
|
# consistency loss:
|
|
feature_loss_cons = (feature_fis_left - feature_pred).abs().mean() + (feature_fis_right - feature_pred).abs().mean() # 计算一致性损失
|
|
'''
|
|
manifold_loss_patch = torch.zeros(1).cuda()
|
|
if self.ablation == 'patch':
|
|
manifold_loss_patch = all_fi[:,0].abs().mean()
|
|
|
|
# 计算法线损失
|
|
normals_loss = self.normals_loss(normals, mnfld_pnts, all_fi, patch_sup=True)
|
|
|
|
# 计算Eikonal损失
|
|
eikonal_loss = self.eikonal_loss(nonmnfld_pnts, nonmnfld_pred_all)
|
|
|
|
# 计算离表面损失
|
|
offsurface_loss = self.offsurface_loss(nonmnfld_pnts, nonmnfld_pred_all)
|
|
|
|
# 计算一致性损失
|
|
consistency_loss = self.consistency_loss(mnfld_pnts, mnfld_pred, all_fi)
|
|
|
|
# 计算修正损失
|
|
correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi)
|
|
|
|
# 计算总损失
|
|
total_loss = (self.weights["manifold"] * manifold_loss + \
|
|
#self.weights["feature_manifold"] * feature_manifold_loss + \
|
|
manifold_loss_patch + \
|
|
self.weights["normals"] * normals_loss + \
|
|
self.weights["eikonal"] * eikonal_loss + \
|
|
self.weights["offsurface"] * offsurface_loss + \
|
|
self.weights["consistency"] * consistency_loss + \
|
|
self.weights["correction"] * correction_loss)
|
|
return total_loss
|
|
|
|
|
|
|
|
|