| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -182,6 +182,49 @@ class LossManager: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return total_loss, loss_details | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def compute_loss_stage1(self,  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                mnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                normals, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                gt_sdfs, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                mnfld_pred, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                ): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        计算流型损失的逻辑 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        :param outputs: 模型的输出 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        :return: 计算得到的流型损失值      | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 强制类型转换确保一致性 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        normals = normals.to(torch.float32) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        mnfld_pred = mnfld_pred.to(torch.float32) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        gt_sdfs = gt_sdfs.to(torch.float32) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 计算流形损失  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        manifold_loss = self.position_loss(mnfld_pred, gt_sdfs) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 计算法线损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        normals_loss = self.normals_loss(normals, mnfld_pnts, mnfld_pred) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #logger.gpu_memory_stats("计算法线损失后") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 计算一致性损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #onsistency_loss = self.consistency_loss(mnfld_pnts, mnfld_pred, all_fi) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 计算修正损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 汇总损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        loss_details = { | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            "manifold": self.weights["manifold"] * manifold_loss, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            "normals": self.weights["normals"] * normals_loss | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        } | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 计算总损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        total_loss = sum(loss_details.values()) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return total_loss, loss_details | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def compute_loss_volume(self,  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                mnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                nonmnfld_pnts, | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |