| 
						
						
							
								
							
						
						
					 | 
					@ -3,9 +3,12 @@ import torch.nn as nn | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from brep2sdf.config.default_config import get_default_config | 
					 | 
					 | 
					from brep2sdf.config.default_config import get_default_config | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					class Brep2SDFLoss: | 
					 | 
					 | 
					class Brep2SDFLoss: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    """解释Brep2SDF的loss设计原理""" | 
					 | 
					 | 
					    """解释Brep2SDF的loss设计原理""" | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    def __init__(self, enforce_minmax: bool=True, clamping_distance: float = 0.1): | 
					 | 
					 | 
					    def __init__(self, batch_size:float, enforce_minmax: bool=True, clamping_distance: float = 0.1): | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.batch_size = batch_size | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.l1_loss = nn.L1Loss(reduction='sum') | 
					 | 
					 | 
					        self.l1_loss = nn.L1Loss(reduction='sum') | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.enforce_minmax = enforce_minmax | 
					 | 
					 | 
					        self.enforce_minmax = enforce_minmax | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.minT = -clamping_distance | 
					 | 
					 | 
					        self.minT = -clamping_distance | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -33,7 +36,7 @@ class Brep2SDFLoss: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        base_loss = self.l1_loss(pred_sdf, gt_sdf) / pred_sdf.shape[0] | 
					 | 
					 | 
					        base_loss = self.l1_loss(pred_sdf, gt_sdf) / pred_sdf.shape[0] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        return base_loss | 
					 | 
					 | 
					        return base_loss / self.batch_size | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): | 
					 | 
					 | 
					def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					
  |