| 
						
						
							
								
							
						
						
					 | 
					@ -123,10 +123,16 @@ class ReconstructionRunner: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            self.network.train()  # 设置网络为训练模式 | 
					 | 
					 | 
					            self.network.train()  # 设置网络为训练模式 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            self.adjust_learning_rate(epoch)  # 调整学习率 | 
					 | 
					 | 
					            self.adjust_learning_rate(epoch)  # 调整学习率 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0), mnfld_sigma.unsqueeze(0)).squeeze()  # 生成非流形点 | 
					 | 
					 | 
					            nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0), mnfld_sigma.unsqueeze(0)).squeeze()  # 生成非流形点 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            # nonmnfld_pnts: torch.Size([18432, 3]) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            #logger.info(f"mnfld_pnts: {mnfld_pnts.shape}") mnfld_pnts: torch.Size([16384, 3]) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            #logger.info(f"mnfld_sigma: {mnfld_sigma.shape}") mnfld_sigma: torch.Size([16384]) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # forward pass | 
					 | 
					 | 
					            # forward pass | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            mnfld_pred_all = self.network(mnfld_pnts)  # 进行前向传播,计算流形点的预测值 | 
					 | 
					 | 
					            mnfld_pred_all = self.network(mnfld_pnts)  # 进行前向传播,计算流形点的预测值 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            nonmnfld_pred_all = self.network(nonmnfld_pnts)  # 进行前向传播,计算非流形点的预测值 | 
					 | 
					 | 
					            nonmnfld_pred_all = self.network(nonmnfld_pnts)  # 进行前向传播,计算非流形点的预测值 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            #logger.info(f"mnfld_pred_all: {mnfld_pred_all.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            #logger.info(f"nonmnfld_pred_all: {nonmnfld_pred_all.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            mnfld_pred = mnfld_pred_all[:,0]  # 提取流形预测结果 | 
					 | 
					 | 
					            mnfld_pred = mnfld_pred_all[:,0]  # 提取流形预测结果 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            nonmnfld_pred = nonmnfld_pred_all[:,0]  # 提取非流形预测结果 | 
					 | 
					 | 
					            nonmnfld_pred = nonmnfld_pred_all[:,0]  # 提取非流形预测结果 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            loss = 0.0  # 初始化损失为 0 | 
					 | 
					 | 
					            loss = 0.0  # 初始化损失为 0 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -166,6 +172,7 @@ class ReconstructionRunner: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # last patch | 
					 | 
					 | 
					            # last patch | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            all_fi[(n_branch - 1) * n_patch_batch:, 0] = mnfld_pred_all[(n_branch - 1) * n_patch_batch:, n_branch]  # 填充最后一个分支的流形预测值 | 
					 | 
					 | 
					            all_fi[(n_branch - 1) * n_patch_batch:, 0] = mnfld_pred_all[(n_branch - 1) * n_patch_batch:, n_branch]  # 填充最后一个分支的流形预测值 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            #logger.info(f"all_fi: {all_fi.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # manifold loss for patches | 
					 | 
					 | 
					            # manifold loss for patches | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            mnfld_loss_patch = torch.zeros(1).cuda()  # 初始化补丁流形损失 | 
					 | 
					 | 
					            mnfld_loss_patch = torch.zeros(1).cuda()  # 初始化补丁流形损失 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            if not args.ab == 'patch':  # 检查是否为补丁损失 | 
					 | 
					 | 
					            if not args.ab == 'patch':  # 检查是否为补丁损失 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					
  |