| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -281,15 +281,24 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.gpu_memory_stats("计算损失前") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    loss, loss_details = self.loss_manager.compute_loss( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        mnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        nonmnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        normals, # 传递检查过的 normals | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        gt_sdf, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        mnfld_pred, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        nonmnfld_pred, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        psdf | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if args.only_zero_surface: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        loss, loss_details = self.loss_manager.compute_loss( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            mnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            nonmnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            normals, # 传递检查过的 normals | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            gt_sdf, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            mnfld_pred, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            nonmnfld_pred, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            psdf | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        )        | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        loss, loss_details = self.loss_manager.compute_loss_stage1( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            mnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            normals, # 传递检查过的 normals | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            gt_sdf, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            mnfld_pred | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        )               | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -696,15 +705,23 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    #logger.gpu_memory_stats("计算损失前") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    loss, loss_details = self.loss_manager.compute_loss( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        mnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        nonmnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        normals, # 传递检查过的 normals | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        gt_sdf, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        mnfld_pred, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        nonmnfld_pred, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        psdf | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    )                     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if args.only_zero_surface: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        loss, loss_details = self.loss_manager.compute_loss( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            mnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            nonmnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            normals, # 传递检查过的 normals | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            gt_sdf, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            mnfld_pred, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            nonmnfld_pred, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            psdf | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        )        | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        loss, loss_details = self.loss_manager.compute_loss_stage1( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            mnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            normals, # 传递检查过的 normals | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            gt_sdf, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            mnfld_pred | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        )                | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    #logger.gpu_memory_stats("计算损失后") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |