|
|
@ -123,10 +123,16 @@ class ReconstructionRunner: |
|
|
|
self.network.train() # 设置网络为训练模式 |
|
|
|
self.adjust_learning_rate(epoch) # 调整学习率 |
|
|
|
nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0), mnfld_sigma.unsqueeze(0)).squeeze() # 生成非流形点 |
|
|
|
# nonmnfld_pnts: torch.Size([18432, 3]) |
|
|
|
#logger.info(f"mnfld_pnts: {mnfld_pnts.shape}") mnfld_pnts: torch.Size([16384, 3]) |
|
|
|
#logger.info(f"mnfld_sigma: {mnfld_sigma.shape}") mnfld_sigma: torch.Size([16384]) |
|
|
|
|
|
|
|
# forward pass |
|
|
|
mnfld_pred_all = self.network(mnfld_pnts) # 进行前向传播,计算流形点的预测值 |
|
|
|
nonmnfld_pred_all = self.network(nonmnfld_pnts) # 进行前向传播,计算非流形点的预测值 |
|
|
|
#logger.info(f"mnfld_pred_all: {mnfld_pred_all.shape}") |
|
|
|
#logger.info(f"nonmnfld_pred_all: {nonmnfld_pred_all.shape}") |
|
|
|
|
|
|
|
mnfld_pred = mnfld_pred_all[:,0] # 提取流形预测结果 |
|
|
|
nonmnfld_pred = nonmnfld_pred_all[:,0] # 提取非流形预测结果 |
|
|
|
loss = 0.0 # 初始化损失为 0 |
|
|
@ -166,6 +172,7 @@ class ReconstructionRunner: |
|
|
|
# last patch |
|
|
|
all_fi[(n_branch - 1) * n_patch_batch:, 0] = mnfld_pred_all[(n_branch - 1) * n_patch_batch:, n_branch] # 填充最后一个分支的流形预测值 |
|
|
|
|
|
|
|
#logger.info(f"all_fi: {all_fi.shape}") |
|
|
|
# manifold loss for patches |
|
|
|
mnfld_loss_patch = torch.zeros(1).cuda() # 初始化补丁流形损失 |
|
|
|
if not args.ab == 'patch': # 检查是否为补丁损失 |
|
|
|