| 
						
						
							
								
							
						
						
					 | 
					@ -315,7 +315,7 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            start_idx = batch_idx * batch_size | 
					 | 
					 | 
					            start_idx = batch_idx * batch_size | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            end_idx = min((batch_idx + 1) * batch_size, num_points) | 
					 | 
					 | 
					            end_idx = min((batch_idx + 1) * batch_size, num_points) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            mnfld_pnts = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点 | 
					 | 
					 | 
					            mnfld_pnts = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0)).squeeze(0)  # 生成非流形点 | 
					 | 
					 | 
					             | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					            gt_sdf = self.sdf_data[start_idx:end_idx, -1].clone().detach()  # 取最后一列作为SDF真值 | 
					 | 
					 | 
					            gt_sdf = self.sdf_data[start_idx:end_idx, -1].clone().detach()  # 取最后一列作为SDF真值 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            normals = None | 
					 | 
					 | 
					            normals = None | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            if args.use_normal: | 
					 | 
					 | 
					            if args.use_normal: | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -323,7 +323,11 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    logger.error(f"Epoch {epoch}: --use-normal is specified, but sdf_data has only {self.sdf_data.shape[1]} columns.") | 
					 | 
					 | 
					                    logger.error(f"Epoch {epoch}: --use-normal is specified, but sdf_data has only {self.sdf_data.shape[1]} columns.") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    return float('inf') | 
					 | 
					 | 
					                    return float('inf') | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                normals = self.sdf_data[start_idx:end_idx, 3:6].clone().detach() # 取中间3列作为法线 | 
					 | 
					 | 
					                normals = self.sdf_data[start_idx:end_idx, 3:6].clone().detach() # 取中间3列作为法线 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                nonmnfld_pnts,psdf = self.sampler.get_norm_points(mnfld_pnts,normals)  # 生成非流形点 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.debug((mnfld_pnts,nonmnfld_pnts,psdf)) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0)).squeeze(0)  # 生成非流形点 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 执行检查 | 
					 | 
					 | 
					            # 执行检查 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            if  self.debug_mode: | 
					 | 
					 | 
					            if  self.debug_mode: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                if check_tensor(mnfld_pnts, "Input Points", epoch, step): return float('inf') | 
					 | 
					 | 
					                if check_tensor(mnfld_pnts, "Input Points", epoch, step): return float('inf') | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -372,7 +376,8 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                        normals, # 传递检查过的 normals | 
					 | 
					 | 
					                        normals, # 传递检查过的 normals | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                        gt_sdf, | 
					 | 
					 | 
					                        gt_sdf, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                        mnfld_pred, | 
					 | 
					 | 
					                        mnfld_pred, | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					                        nonmnfld_pred | 
					 | 
					 | 
					                        nonmnfld_pred, | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        psdf | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    ) | 
					 | 
					 | 
					                    ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                else: | 
					 | 
					 | 
					                else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) | 
					 | 
					 | 
					                    loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					
  |