import os import sys import time project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) sys.path.append(project_dir) os.chdir(project_dir) import torch from model.network import gradient class LossManager: def __init__(self, ablation, **condition_kwargs): self.weights = { "manifold": 1, "feature_manifold": 1, # 原文里面和manifold的权重是一样的 "normals": 1, "eikonal": 1, "offsurface": 1, "consistency": 1, "correction": 1, } self.condition_kwargs = condition_kwargs self.ablation = ablation # 消融实验用 def _get_condition_kwargs(self, key): """ 获取条件参数, 期望 ab: 损失类型 【overall, patch, off, cons, cc, cor,】 siren: 是否使用SIREN epoch: 当前epoch baseline: 是否为baseline """ if key in self.condition_kwargs: return self.condition_kwargs[key] else: raise ValueError(f"Key {key} not found in condition_kwargs") def pre_process(self, mnfld_pnts, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last): """ 预处理 """ mnfld_pred_h = mnfld_pred_all[:,0] # 提取流形预测结果 nonmnfld_pred_h = nonmnfld_pred_all[:,0] # 提取非流形预测结果 mnfld_grad = gradient(mnfld_pnts, mnfld_pred_h) # 计算流形点的梯度 all_fi = torch.zeros([n_batchsize, 1], device = 'cuda') # 初始化所有流形预测值 for i in range(n_branch - 1): all_fi[i * n_patch_batch : (i + 1) * n_patch_batch, 0] = mnfld_pred_all[i * n_patch_batch : (i + 1) * n_patch_batch, i + 1] # 填充流形预测值 # last patch all_fi[(n_branch - 1) * n_patch_batch:, 0] = mnfld_pred_all[(n_branch - 1) * n_patch_batch:, n_branch] # 填充最后一个分支的流形预测值 return mnfld_pred_h, nonmnfld_pred_h, mnfld_grad, all_fi def position_loss(self, outputs): """ 计算流型损失的逻辑 :param outputs: 模型的输出 :return: 计算得到的流型损失值 """ # 计算流型损失(这里使用均方误差作为示例) manifold_loss = (outputs.abs()).mean() # 计算流型损失 return manifold_loss def normals_loss(self, normals: torch.Tensor, mnfld_pnts: torch.Tensor, all_fi: torch.Tensor, patch_sup: bool = True) -> torch.Tensor: """ 计算法线损失 :param normals: 法线 :param mnfld_pnts: 流型点 :param all_fi: 所有流型预测值 :param patch_sup: 是否支持补丁 :return: 计算得到的法线损失 """ # NOTE 源代码 这里还有复杂逻辑 # 计算分支梯度 branch_grad = gradient(mnfld_pnts, all_fi[:, 0]) # 计算分支梯度 # 计算法线损失 normals_loss = (((branch_grad - normals).abs()).norm(2, dim=1)).mean() # 计算法线损失 return normals_loss # 返回加权后的法线损失 def eikonal_loss(self, nonmnfld_pnts, nonmnfld_pred): """ 计算Eikonal损失 """ grad_loss_h = torch.zeros(1).cuda() # 初始化 Eikonal 损失 single_nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred) # 计算非流形点的梯度 eikonal_loss = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean() # 计算 Eikonal 损失 return eikonal_loss def offsurface_loss(self, nonmnfld_pnts, nonmnfld_pred): """ Eo 惩罚远离表面但是预测值接近0的点 """ offsurface_loss = torch.zeros(1).cuda() if not self.ablation == 'off': offsurface_loss = torch.exp(-100.0 * torch.abs(nonmnfld_pred)).mean() # 计算离表面损失 return offsurface_loss def consistency_loss(self, mnfld_pnts, mnfld_pred, all_fi): """ 惩罚流形点预测值和非流形点预测值不一致的点 """ mnfld_consistency_loss = torch.zeros(1).cuda() if not (self.ablation == 'cons' or self.ablation == 'cc'): mnfld_consistency_loss = (mnfld_pred - all_fi[:,0]).abs().mean() # 计算流形一致性损失 return mnfld_consistency_loss def correction_loss(self, mnfld_pnts, mnfld_pred, all_fi, th_closeness = 1e-5, a_correction = 100): """ 修正损失 """ correction_loss = torch.zeros(1).cuda() # 初始化修正损失 if not (self.ablation == 'cor' or self.ablation == 'cc'): mismatch_id = torch.abs(mnfld_pred - all_fi[:,0]) > th_closeness # 计算不匹配的 ID if mismatch_id.sum() != 0: # 如果存在不匹配 correction_loss = (a_correction * torch.abs(mnfld_pred - all_fi[:,0])[mismatch_id]).mean() # 计算修正损失 return correction_loss def compute_loss(self, mnfld_pnts, normals, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last): """ 计算流型损失的逻辑 :param outputs: 模型的输出 :return: 计算得到的流型损失值 """ mnfld_pred, nonmnfld_pred, mnfld_grad, all_fi = self.pre_process(mnfld_pnts, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last) manifold_loss = torch.zeros(1).cuda() # 计算流型损失(这里使用均方误差作为示例) if not self.ablation == 'overall': manifold_loss = (mnfld_pred.abs()).mean() # 计算流型损失 ''' if args.feature_sample: # 如果启用了特征采样 feature_indices = torch.randperm(args.all_feature_sample)[:args.num_feature_sample].cuda() # 随机选择特征点 feature_pnts = self.feature_data[feature_indices] # 获取特征点数据 feature_mask_pair = self.feature_data_mask_pair[feature_indices] # 获取特征掩码对 feature_pred_all = self.network(feature_pnts) # 进行前向传播,计算特征点的预测值 feature_pred = feature_pred_all[:,0] # 提取特征预测结果 feature_mnfld_loss = feature_pred.abs().mean() # 计算特征流形损失 loss = loss + weight_mnfld_h * feature_mnfld_loss # 将特征流形损失加权到总损失中 # patch loss: feature_id_left = [list(range(args.num_feature_sample)), feature_mask_pair[:,0].tolist()] # 获取左侧特征 ID feature_id_right = [list(range(args.num_feature_sample)), feature_mask_pair[:,1].tolist()] # 获取右侧特征 ID feature_fis_left = feature_pred_all[feature_id_left] # 获取左侧特征预测值 feature_fis_right = feature_pred_all[feature_id_right] # 获取右侧特征预测值 feature_loss_patch = feature_fis_left.abs().mean() + feature_fis_right.abs().mean() # 计算补丁损失 loss += feature_loss_patch # 将补丁损失加权到总损失中 # consistency loss: feature_loss_cons = (feature_fis_left - feature_pred).abs().mean() + (feature_fis_right - feature_pred).abs().mean() # 计算一致性损失 ''' manifold_loss_patch = torch.zeros(1).cuda() if self.ablation == 'patch': manifold_loss_patch = all_fi[:,0].abs().mean() # 计算法线损失 normals_loss = self.normals_loss(normals, mnfld_pnts, all_fi, patch_sup=True) # 计算Eikonal损失 eikonal_loss = self.eikonal_loss(nonmnfld_pnts, nonmnfld_pred_all) # 计算离表面损失 offsurface_loss = self.offsurface_loss(nonmnfld_pnts, nonmnfld_pred_all) # 计算一致性损失 consistency_loss = self.consistency_loss(mnfld_pnts, mnfld_pred, all_fi) # 计算修正损失 correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi) loss_details = { "manifold": self.weights["manifold"] * manifold_loss, "manifold_patch": manifold_loss_patch, "normals": self.weights["normals"] * normals_loss, "eikonal": self.weights["eikonal"] * eikonal_loss, "offsurface": self.weights["offsurface"] * offsurface_loss, "consistency": self.weights["consistency"] * consistency_loss, "correction": self.weights["correction"] * correction_loss, } # 计算总损失 total_loss = sum(loss_details.values()) return total_loss, loss_details