| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -251,13 +251,20 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            gt_sdf = _gt_sdf[start_idx:end_idx]  # SDF真值 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            normals = _normals[start_idx:end_idx] if args.use_normal else None  # 法线 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nonmnfld_pnts, psdf = self.sampler.get_norm_points(mnfld_pnts, normals, 0.1) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # --- 准备模型输入,启用梯度 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # --- 前向传播 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            mnfld_pred = self.model.forward_background( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                mnfld_pnts | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nonmnfld_pred = self.model.forward_background( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                nonmnfld_pnts | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -274,11 +281,14 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.gpu_memory_stats("计算损失前") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    loss, loss_details = self.loss_manager.compute_loss_stage1( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    loss, loss_details = self.loss_manager.compute_loss( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        mnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        nonmnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        normals, # 传递检查过的 normals | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        gt_sdf, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        mnfld_pred, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        nonmnfld_pred, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        psdf | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |