Browse Source

refactor: Add detailed comments and improve code readability in run.py training method

- Added comprehensive inline comments explaining each step of the training process in the `run_nhrepnet_training` method
- Improved code structure by adding descriptive comments for variable initializations and key computational steps
- Enhanced code readability by breaking down complex operations with clear explanatory comments
- Maintained existing functionality while providing better code documentation
NH-Rep
mckay 4 months ago
parent
commit
285aaf86dd
  1. 346
      code/conversion/run.py

346
code/conversion/run.py

@ -47,209 +47,209 @@ logger.info(f"project_dir: {project_dir}")
class ReconstructionRunner:
def run_nhrepnet_training(self):
print("running")
self.data = self.data.cuda()
self.data.requires_grad_()
feature_mask_cpu = self.feature_mask.numpy()
self.feature_mask = self.feature_mask.cuda()
n_branch = int(torch.max(self.feature_mask).item())
n_batchsize = self.points_batch
n_patch_batch = n_batchsize // n_branch
n_patch_last = n_batchsize - n_patch_batch * (n_branch - 1)
patch_sup = True
weight_mnfld_h = 1
weight_mnfld_cs = 1
weight_correction = 1
a_correction = 100
patch_id = []
patch_id_n = []
print("running") # 输出训练开始的提示信息
self.data = self.data.cuda() # 将数据移动到 GPU 上
self.data.requires_grad_() # 设置数据以便计算梯度
feature_mask_cpu = self.feature_mask.numpy() # 将特征掩码转换为 NumPy 数组
self.feature_mask = self.feature_mask.cuda() # 将特征掩码移动到 GPU 上
n_branch = int(torch.max(self.feature_mask).item()) # 计算分支数量
n_batchsize = self.points_batch # 设置批次大小
n_patch_batch = n_batchsize // n_branch # 计算每个分支的补丁批次大小
n_patch_last = n_batchsize - n_patch_batch * (n_branch - 1) # 计算最后一个分支的补丁大小
patch_sup = True # 设置补丁支持标志
weight_mnfld_h = 1 # 初始化流形损失权重
weight_mnfld_cs = 1 # 初始化流形一致性损失权重
weight_correction = 1 # 初始化修正损失权重
a_correction = 100 # 初始化修正损失的系数
patch_id = [] # 初始化补丁 ID 列表
patch_id_n = [] # 初始化补丁数量列表
for i in range(n_branch):
patch_id = patch_id + [np.where(feature_mask_cpu == i + 1)[0]]
patch_id_n = patch_id_n + [patch_id[i].shape[0]]
if self.eval:
print("evaluating epoch: {0}".format(self.startepoch))
my_path = os.path.join(self.cur_exp_dir, 'evaluation', str(self.startepoch))
utils.mkdir_ifnotexists(os.path.join(self.cur_exp_dir, 'evaluation'))
utils.mkdir_ifnotexists(my_path)
patch_id = patch_id + [np.where(feature_mask_cpu == i + 1)[0]] # 找到每个分支的补丁 ID
patch_id_n = patch_id_n + [patch_id[i].shape[0]] # 记录每个补丁的数量
if self.eval: # 检查是否处于评估模式
print("evaluating epoch: {0}".format(self.startepoch)) # 输出当前评估的轮次
my_path = os.path.join(self.cur_exp_dir, 'evaluation', str(self.startepoch)) # 创建评估结果保存路径
utils.mkdir_ifnotexists(os.path.join(self.cur_exp_dir, 'evaluation')) # 确保评估目录存在
utils.mkdir_ifnotexists(my_path) # 确保当前评估路径存在
for i in range(1):
self.network.flag_output = i + 1
self.plot_shapes(epoch=self.startepoch, path=my_path, file_suffix = "_" + str(i), with_cuts = True)
self.network.flag_output = 0
return
print("training begin")
if args.summary == True:
writer = SummaryWriter(os.path.join("summary", self.foldername))
self.network.flag_output = i + 1 # 设置网络输出标志
self.plot_shapes(epoch=self.startepoch, path=my_path, file_suffix = "_" + str(i), with_cuts = True) # 绘制评估结果
self.network.flag_output = 0 # 将输出标志重置为 0
return # 结束方法
print("training begin") # 输出训练开始的提示信息
if args.summary == True: # 如果启用了摘要记录
writer = SummaryWriter(os.path.join("summary", self.foldername)) # 创建一个 SummaryWriter 实例
# branch mask is predefined
branch_mask = torch.zeros(n_branch, n_batchsize).cuda()
single_branch_mask_gt = torch.zeros(n_batchsize, n_branch).cuda()
single_branch_mask_id = torch.zeros([n_batchsize], dtype = torch.long).cuda()
branch_mask = torch.zeros(n_branch, n_batchsize).cuda() # 初始化分支掩码
single_branch_mask_gt = torch.zeros(n_batchsize, n_branch).cuda() # 初始化单分支掩码
single_branch_mask_id = torch.zeros([n_batchsize], dtype = torch.long).cuda() # 初始化单分支 ID
for i in range(n_branch - 1):
branch_mask[i, i * n_patch_batch : (i + 1) * n_patch_batch] = 1.0
single_branch_mask_gt[i * n_patch_batch : (i + 1) * n_patch_batch, i] = 1.0
single_branch_mask_id[i * n_patch_batch : (i + 1) * n_patch_batch] = i
branch_mask[i, i * n_patch_batch : (i + 1) * n_patch_batch] = 1.0 # 设置当前分支的补丁掩码
single_branch_mask_gt[i * n_patch_batch : (i + 1) * n_patch_batch, i] = 1.0 # 设置单分支的真实掩码
single_branch_mask_id[i * n_patch_batch : (i + 1) * n_patch_batch] = i # 设置单分支 ID
#last patch
branch_mask[n_branch - 1, (n_branch - 1) * n_patch_batch:] = 1.0
single_branch_mask_gt[(n_branch - 1) * n_patch_batch:, (n_branch - 1)] = 1.0
single_branch_mask_id[(n_branch - 1) * n_patch_batch:] = (n_branch - 1)
# last patch
branch_mask[n_branch - 1, (n_branch - 1) * n_patch_batch:] = 1.0 # 设置最后一个分支的补丁掩码
single_branch_mask_gt[(n_branch - 1) * n_patch_batch:, (n_branch - 1)] = 1.0 # 设置最后一个分支的真实掩码
single_branch_mask_id[(n_branch - 1) * n_patch_batch:] = (n_branch - 1) # 设置最后一个分支 ID
for epoch in range(self.startepoch, self.nepochs + 1):
indices = torch.empty(0,dtype=torch.int64).cuda()
for epoch in range(self.startepoch, self.nepochs + 1): # 开始训练循环
indices = torch.empty(0,dtype=torch.int64).cuda() # 初始化索引张量
for i in range(n_branch - 1):
indices_nonfeature = torch.tensor(patch_id[i][np.random.choice(patch_id_n[i], n_patch_batch, True)]).cuda()
indices = torch.cat((indices, indices_nonfeature), 0)
#last patch
indices_nonfeature = torch.tensor(patch_id[n_branch - 1][np.random.choice(patch_id_n[n_branch - 1], n_patch_last, True)]).cuda()
indices = torch.cat((indices, indices_nonfeature), 0)
indices_nonfeature = torch.tensor(patch_id[i][np.random.choice(patch_id_n[i], n_patch_batch, True)]).cuda() # 随机选择非特征补丁的索引
indices = torch.cat((indices, indices_nonfeature), 0) # 将索引添加到总索引中
# last patch
indices_nonfeature = torch.tensor(patch_id[n_branch - 1][np.random.choice(patch_id_n[n_branch - 1], n_patch_last, True)]).cuda() # 处理最后一个补丁的索引
indices = torch.cat((indices, indices_nonfeature), 0) # 将最后一个补丁的索引添加到总索引中
cur_data = self.data[indices]
mnfld_pnts = cur_data[:, :self.d_in] #n_indices x 3
mnfld_sigma = self.local_sigma[indices] #noise points
cur_data = self.data[indices] # 根据索引获取当前数据
mnfld_pnts = cur_data[:, :self.d_in] # 提取流形点
mnfld_sigma = self.local_sigma[indices] # 提取噪声点
if epoch % self.conf.get_int('train.checkpoint_frequency') == 0:
if epoch % self.conf.get_int('train.checkpoint_frequency') == 0: # 每隔一定轮次保存检查点
self.save_checkpoints(epoch)
if epoch % self.conf.get_int('train.plot_frequency') == 0:
print('plot validation epoch: ', epoch)
if epoch % self.conf.get_int('train.plot_frequency') == 0: # 每隔一定轮次绘制验证结果
print('plot validation epoch: ', epoch) # 输出当前绘制的轮次
for i in range(n_branch + 1):
self.network.flag_output = i + 1
self.plot_shapes(epoch, file_suffix = "_" + str(i), with_cuts = False)
self.network.flag_output = 0
self.network.flag_output = i + 1 # 设置网络输出标志
self.plot_shapes(epoch, file_suffix = "_" + str(i), with_cuts = False) # 绘制形状
self.network.flag_output = 0 # 将输出标志重置为 0
self.network.train()
self.adjust_learning_rate(epoch)
nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0), mnfld_sigma.unsqueeze(0)).squeeze()
self.network.train() # 设置网络为训练模式
self.adjust_learning_rate(epoch) # 调整学习率
nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0), mnfld_sigma.unsqueeze(0)).squeeze() # 生成非流形点
# forward pass
mnfld_pred_all = self.network(mnfld_pnts)
nonmnfld_pred_all = self.network(nonmnfld_pnts)
mnfld_pred = mnfld_pred_all[:,0]
nonmnfld_pred = nonmnfld_pred_all[:,0]
loss = 0.0
mnfld_grad = gradient(mnfld_pnts, mnfld_pred)
mnfld_pred_all = self.network(mnfld_pnts) # 进行前向传播,计算流形点的预测值
nonmnfld_pred_all = self.network(nonmnfld_pnts) # 进行前向传播,计算非流形点的预测值
mnfld_pred = mnfld_pred_all[:,0] # 提取流形预测结果
nonmnfld_pred = nonmnfld_pred_all[:,0] # 提取非流形预测结果
loss = 0.0 # 初始化损失为 0
mnfld_grad = gradient(mnfld_pnts, mnfld_pred) # 计算流形点的梯度
# manifold loss
mnfld_loss = torch.zeros(1).cuda()
if not args.ab == 'overall':
mnfld_loss = (mnfld_pred.abs()).mean()
loss = loss + weight_mnfld_h * mnfld_loss
#feature sample
if args.feature_sample:
feature_indices = torch.randperm(args.all_feature_sample)[:args.num_feature_sample].cuda()
feature_pnts = self.feature_data[feature_indices]
feature_mask_pair = self.feature_data_mask_pair[feature_indices]
feature_pred_all = self.network(feature_pnts)
feature_pred = feature_pred_all[:,0]
feature_mnfld_loss = feature_pred.abs().mean()
loss = loss + weight_mnfld_h * feature_mnfld_loss #|h|
mnfld_loss = torch.zeros(1).cuda() # 初始化流形损失
if not args.ab == 'overall': # 检查是否为整体损失
mnfld_loss = (mnfld_pred.abs()).mean() # 计算流形损失
loss = loss + weight_mnfld_h * mnfld_loss # 将流形损失加权到总损失中
# feature sample
if args.feature_sample: # 如果启用了特征采样
feature_indices = torch.randperm(args.all_feature_sample)[:args.num_feature_sample].cuda() # 随机选择特征点
feature_pnts = self.feature_data[feature_indices] # 获取特征点数据
feature_mask_pair = self.feature_data_mask_pair[feature_indices] # 获取特征掩码对
feature_pred_all = self.network(feature_pnts) # 进行前向传播,计算特征点的预测值
feature_pred = feature_pred_all[:,0] # 提取特征预测结果
feature_mnfld_loss = feature_pred.abs().mean() # 计算特征流形损失
loss = loss + weight_mnfld_h * feature_mnfld_loss # 将特征流形损失加权到总损失中
#patch loss:
feature_id_left = [list(range(args.num_feature_sample)), feature_mask_pair[:,0].tolist()]
feature_id_right = [list(range(args.num_feature_sample)), feature_mask_pair[:,1].tolist()]
feature_fis_left = feature_pred_all[feature_id_left]
feature_fis_right = feature_pred_all[feature_id_right]
feature_loss_patch = feature_fis_left.abs().mean() + feature_fis_right.abs().mean()
loss += feature_loss_patch
#consistency loss:
feature_loss_cons = (feature_fis_left - feature_pred).abs().mean() + (feature_fis_right - feature_pred).abs().mean()
loss += weight_mnfld_cs * feature_loss_cons
all_fi = torch.zeros([n_batchsize, 1], device = 'cuda')
# patch loss:
feature_id_left = [list(range(args.num_feature_sample)), feature_mask_pair[:,0].tolist()] # 获取左侧特征 ID
feature_id_right = [list(range(args.num_feature_sample)), feature_mask_pair[:,1].tolist()] # 获取右侧特征 ID
feature_fis_left = feature_pred_all[feature_id_left] # 获取左侧特征预测值
feature_fis_right = feature_pred_all[feature_id_right] # 获取右侧特征预测值
feature_loss_patch = feature_fis_left.abs().mean() + feature_fis_right.abs().mean() # 计算补丁损失
loss += feature_loss_patch # 将补丁损失加权到总损失中
# consistency loss:
feature_loss_cons = (feature_fis_left - feature_pred).abs().mean() + (feature_fis_right - feature_pred).abs().mean() # 计算一致性损失
loss += weight_mnfld_cs * feature_loss_cons # 将一致性损失加权到总损失中
all_fi = torch.zeros([n_batchsize, 1], device = 'cuda') # 初始化所有流形预测值
for i in range(n_branch - 1):
all_fi[i * n_patch_batch : (i + 1) * n_patch_batch, 0] = mnfld_pred_all[i * n_patch_batch : (i + 1) * n_patch_batch, i + 1]
#last patch
all_fi[(n_branch - 1) * n_patch_batch:, 0] = mnfld_pred_all[(n_branch - 1) * n_patch_batch:, n_branch]
all_fi[i * n_patch_batch : (i + 1) * n_patch_batch, 0] = mnfld_pred_all[i * n_patch_batch : (i + 1) * n_patch_batch, i + 1] # 填充流形预测值
# last patch
all_fi[(n_branch - 1) * n_patch_batch:, 0] = mnfld_pred_all[(n_branch - 1) * n_patch_batch:, n_branch] # 填充最后一个分支的流形预测值
# manifold loss for patches
mnfld_loss_patch = torch.zeros(1).cuda()
if not args.ab == 'patch':
if patch_sup:
mnfld_loss_patch = all_fi[:,0].abs().mean()
loss = loss + mnfld_loss_patch
#correction loss
correction_loss = torch.zeros(1).cuda()
if not (args.ab == 'cor' or args.ab == 'cc') and epoch > 10000 and not args.baseline:
mismatch_id = torch.abs(mnfld_pred - all_fi[:,0]) > args.th_closeness
if mismatch_id.sum() != 0:
correction_loss = (a_correction * torch.abs(mnfld_pred - all_fi[:,0])[mismatch_id]).mean()
loss = loss + weight_correction * correction_loss
#off surface_loss
offsurface_loss = torch.zeros(1).cuda()
if not args.ab == 'off':
offsurface_loss = torch.exp(-100.0 * torch.abs(nonmnfld_pred[n_batchsize:])).mean()
loss = loss + offsurface_loss
#manifold consistency loss
mnfld_consistency_loss = torch.zeros(1).cuda()
if not (args.ab == 'cons' or args.ab == 'cc'):
mnfld_consistency_loss = (mnfld_pred - all_fi[:,0]).abs().mean()
loss = loss + weight_mnfld_cs * mnfld_consistency_loss
#eikonal loss for h
grad_loss_h = torch.zeros(1).cuda()
single_nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred_all[:,0])
grad_loss_h = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean()
loss = loss + self.grad_lambda * grad_loss_h
mnfld_loss_patch = torch.zeros(1).cuda() # 初始化补丁流形损失
if not args.ab == 'patch': # 检查是否为补丁损失
if patch_sup: # 如果支持补丁
mnfld_loss_patch = all_fi[:,0].abs().mean() # 计算补丁流形损失
loss = loss + mnfld_loss_patch # 将补丁流形损失加权到总损失中
# correction loss
correction_loss = torch.zeros(1).cuda() # 初始化修正损失
if not (args.ab == 'cor' or args.ab == 'cc') and epoch > 10000 and not args.baseline: # 检查修正损失的条件
mismatch_id = torch.abs(mnfld_pred - all_fi[:,0]) > args.th_closeness # 计算不匹配的 ID
if mismatch_id.sum() != 0: # 如果存在不匹配
correction_loss = (a_correction * torch.abs(mnfld_pred - all_fi[:,0])[mismatch_id]).mean() # 计算修正损失
loss = loss + weight_correction * correction_loss # 将修正损失加权到总损失中
# off surface loss
offsurface_loss = torch.zeros(1).cuda() # 初始化离表面损失
if not args.ab == 'off': # 检查是否为离表面损失
offsurface_loss = torch.exp(-100.0 * torch.abs(nonmnfld_pred[n_batchsize:])).mean() # 计算离表面损失
loss = loss + offsurface_loss # 将离表面损失加权到总损失中
# manifold consistency loss
mnfld_consistency_loss = torch.zeros(1).cuda() # 初始化流形一致性损失
if not (args.ab == 'cons' or args.ab == 'cc'): # 检查是否为一致性损失
mnfld_consistency_loss = (mnfld_pred - all_fi[:,0]).abs().mean() # 计算流形一致性损失
loss = loss + weight_mnfld_cs * mnfld_consistency_loss # 将一致性损失加权到总损失中
# eikonal loss for h
grad_loss_h = torch.zeros(1).cuda() # 初始化 Eikonal 损失
single_nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred_all[:,0]) # 计算非流形点的梯度
grad_loss_h = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean() # 计算 Eikonal 损失
loss = loss + self.grad_lambda * grad_loss_h # 将 Eikonal 损失加权到总损失中
# normals loss
normals_loss_h = torch.zeros(1).cuda()
normals_loss = torch.zeros(1).cuda()
normal_consistency_loss = torch.zeros(1).cuda()
if not args.siren:
if not args.ab == 'normal' and self.with_normals:
#all normals
normals = cur_data[:, -self.d_in:]
if patch_sup:
branch_grad = gradient(mnfld_pnts, all_fi[:,0])
normals_loss = (((branch_grad - normals).abs()).norm(2, dim=1)).mean()
loss = loss + self.normals_lambda * normals_loss
#only supervised, not used for loss computation
mnfld_grad = gradient(mnfld_pnts, mnfld_pred_all[:, 0])
normal_consistency_loss = (mnfld_grad - branch_grad).abs().norm(2, dim=1).mean()
normals_loss_h = torch.zeros(1).cuda() # 初始化法线损失
normals_loss = torch.zeros(1).cuda() # 初始化法线损失
normal_consistency_loss = torch.zeros(1).cuda() # 初始化法线一致性损失
if not args.siren: # 检查是否使用 SIREN
if not args.ab == 'normal' and self.with_normals: # 检查法线损失的条件
# all normals
normals = cur_data[:, -self.d_in:] # 提取法线
if patch_sup: # 如果支持补丁
branch_grad = gradient(mnfld_pnts, all_fi[:,0]) # 计算分支梯度
normals_loss = (((branch_grad - normals).abs()).norm(2, dim=1)).mean() # 计算法线损失
loss = loss + self.normals_lambda * normals_loss # 将法线损失加权到总损失中
# only supervised, not used for loss computation
mnfld_grad = gradient(mnfld_pnts, mnfld_pred_all[:, 0]) # 计算流形梯度
normal_consistency_loss = (mnfld_grad - branch_grad).abs().norm(2, dim=1).mean() # 计算法线一致性损失
else:
single_nonmnfld_grad = gradient(mnfld_pnts, all_fi[:,0])
normals_loss_h = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean()
loss = loss + self.normals_lambda * normals_loss_h
single_nonmnfld_grad = gradient(mnfld_pnts, all_fi[:,0]) # 计算非流形点的梯度
normals_loss_h = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean() # 计算法线损失
loss = loss + self.normals_lambda * normals_loss_h # 将法线损失加权到总损失中
else:
#compute consine normal
normals = cur_data[:, -self.d_in:]
normals_loss_h = (1 - F.cosine_similarity(mnfld_grad, normals, dim=-1)).mean()
loss = loss + self.normals_lambda * normals_loss_h
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
#tensorboard
if args.summary == True and epoch % 100 == 0:
writer.add_scalar('Loss/Total loss', loss.item(), epoch)
writer.add_scalar('Loss/Manifold loss h', mnfld_loss.item(), epoch)
writer.add_scalar('Loss/Manifold patch loss', mnfld_loss_patch.item(), epoch)
writer.add_scalar('Loss/Manifold cons loss', mnfld_consistency_loss.item(), epoch)
writer.add_scalar('Loss/Grad loss h',self.grad_lambda * grad_loss_h.item(), epoch)
writer.add_scalar('Loss/Normal loss all', self.normals_lambda * normals_loss.item(), epoch)
writer.add_scalar('Loss/Normal cs loss', self.normals_lambda * normal_consistency_loss.item(), epoch)
writer.add_scalar('Loss/Assignment loss', correction_loss.item(), epoch)
writer.add_scalar('Loss/Offsurface loss', offsurface_loss.item(), epoch)
if epoch % self.conf.get_int('train.status_frequency') == 0:
# compute cosine normal
normals = cur_data[:, -self.d_in:] # 提取法线
normals_loss_h = (1 - F.cosine_similarity(mnfld_grad, normals, dim=-1)).mean() # 计算法线的余弦相似度损失
loss = loss + self.normals_lambda * normals_loss_h # 将法线损失加权到总损失中
self.optimizer.zero_grad() # 清零优化器的梯度
loss.backward() # 反向传播计算梯度
self.optimizer.step() # 更新模型参数
# tensorboard
if args.summary == True and epoch % 100 == 0: # 每 100 轮记录损失到 TensorBoard
writer.add_scalar('Loss/Total loss', loss.item(), epoch) # 记录总损失
writer.add_scalar('Loss/Manifold loss h', mnfld_loss.item(), epoch) # 记录流形损失
writer.add_scalar('Loss/Manifold patch loss', mnfld_loss_patch.item(), epoch) # 记录补丁流形损失
writer.add_scalar('Loss/Manifold cons loss', mnfld_consistency_loss.item(), epoch) # 记录流形一致性损失
writer.add_scalar('Loss/Grad loss h',self.grad_lambda * grad_loss_h.item(), epoch) # 记录 Eikonal 损失
writer.add_scalar('Loss/Normal loss all', self.normals_lambda * normals_loss.item(), epoch) # 记录法线损失
writer.add_scalar('Loss/Normal cs loss', self.normals_lambda * normal_consistency_loss.item(), epoch) # 记录法线一致性损失
writer.add_scalar('Loss/Assignment loss', correction_loss.item(), epoch) # 记录修正损失
writer.add_scalar('Loss/Offsurface loss', offsurface_loss.item(), epoch) # 记录离表面损失
if epoch % self.conf.get_int('train.status_frequency') == 0: # 每隔一定轮次记录训练状态
logger.info('Train Epoch: [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}\t Manifold loss: {:.6f}'
'\tManifold patch loss: {:.6f}\t grad loss h: {:.6f}\t normals loss all: {:.6f}\t normals loss h: {:.6f}\t Manifold consistency loss: {:.6f}\tCorrection loss: {:.6f}\t Offsurface loss: {:.6f}'.format(
epoch, self.nepochs, 100. * epoch / self.nepochs,
loss.item(), mnfld_loss.item(), mnfld_loss_patch.item(), grad_loss_h.item(), normals_loss.item(), normals_loss_h.item(), mnfld_consistency_loss.item(), correction_loss.item(), offsurface_loss.item()))
if args.feature_sample:
logger.info('feature mnfld loss: {} patch loss: {} cons loss: {}'.format(feature_mnfld_loss.item(), feature_loss_patch.item(), feature_loss_cons.item()))
if args.feature_sample: # 如果启用了特征采样
logger.info('feature mnfld loss: {} patch loss: {} cons loss: {}'.format(feature_mnfld_loss.item(), feature_loss_patch.item(), feature_loss_cons.item())) # 记录特征损失信息
self.tracing()
self.tracing() # 调用 tracing 方法,可能用于记录或保存训练过程中的某些信息
def tracing(self):
#network definition

Loading…
Cancel
Save