| 
						
						
							
								
							
						
						
					 | 
					@ -2,41 +2,66 @@ import torch | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					import torch.nn as nn | 
					 | 
					 | 
					import torch.nn as nn | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from brep2sdf.config.default_config import get_default_config | 
					 | 
					 | 
					from brep2sdf.config.default_config import get_default_config | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					from brep2sdf.utils.logger import logger | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					
 | 
					 | 
					 | 
					class Brep2SDFLoss(nn.Module): | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					
 | 
					 | 
					 | 
					    def __init__(self, batch_size:float=1, enforce_minmax: bool=True, clamping_distance: float=0.1, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					class Brep2SDFLoss: | 
					 | 
					 | 
					                 grad_weight=0.1, warmup_epochs=10): | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    """解释Brep2SDF的loss设计原理""" | 
					 | 
					 | 
					        super().__init__() | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					    def __init__(self, batch_size:float, enforce_minmax: bool=True, clamping_distance: float = 0.1): | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        self.batch_size = batch_size | 
					 | 
					 | 
					        self.batch_size = batch_size | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.l1_loss = nn.L1Loss(reduction='sum') | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.enforce_minmax = enforce_minmax | 
					 | 
					 | 
					        self.enforce_minmax = enforce_minmax | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.minT = -clamping_distance | 
					 | 
					 | 
					        self.minT = -clamping_distance | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.maxT = clamping_distance | 
					 | 
					 | 
					        self.maxT = clamping_distance | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.grad_weight = grad_weight | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.warmup_epochs = warmup_epochs | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    def __call__(self, pred_sdf, gt_sdf): | 
					 | 
					 | 
					        self.l1_loss = nn.L1Loss(reduction='mean') | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					        """使类可直接调用""" | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        return self.forward(pred_sdf, gt_sdf) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    def forward(self, pred_sdf, gt_sdf): | 
					 | 
					 | 
					    def forward(self, pred_sdf, gt_sdf, points=None, epoch=None): | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        """ | 
					 | 
					 | 
					        """计算SDF预测的损失 | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					        pred_sdf: 预测的SDF值 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        gt_sdf: 真实的SDF值 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        latent_vecs: 形状编码, 来自 brep | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        epoch: 当前训练轮次 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        """ | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        Args: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            pred_sdf: 预测的SDF值 [B, N, 1] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            gt_sdf: 真实的SDF值 [B, N, 1] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            points: 查询点坐标 [B, N, 3],用于计算梯度损失 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            epoch: 当前训练轮次,用于损失权重调整 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        """ | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 1. 对SDF值进行clamp | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        if self.enforce_minmax: | 
					 | 
					 | 
					        if self.enforce_minmax: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            pred_sdf = torch.clamp(pred_sdf, self.minT, self.maxT) | 
					 | 
					 | 
					            pred_sdf = torch.clamp(pred_sdf, self.minT, self.maxT) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            gt_sdf = torch.clamp(gt_sdf, self.minT, self.maxT) | 
					 | 
					 | 
					            gt_sdf = torch.clamp(gt_sdf, self.minT, self.maxT) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 1. L1 Loss的优势 | 
					 | 
					 | 
					        # 2. 计算基础L1损失 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # - 对异常值更鲁棒 | 
					 | 
					 | 
					        l1_loss = self.l1_loss(pred_sdf, gt_sdf) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # - 能更好地保持表面细节 | 
					 | 
					 | 
					         | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        base_loss = self.l1_loss(pred_sdf, gt_sdf) / pred_sdf.shape[0] | 
					 | 
					 | 
					        # 3. 计算梯度损失(如果提供了points) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        grad_loss = 0 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        if points is not None: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            grad_loss = self.gradient_loss(pred_sdf, points) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 4. 根据epoch调整权重 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        if epoch is not None: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            weight = min(1.0, epoch / self.warmup_epochs) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            weight = 1.0 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 5. 组合损失 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        total_loss = l1_loss + self.grad_weight * grad_loss * weight | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        return total_loss | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    def gradient_loss(self, pred_sdf, points): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        """计算SDF梯度损失""" | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        grad_pred = torch.autograd.grad( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            pred_sdf.sum(),  | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            points,  | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            create_graph=True | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        )[0] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        grad_norm = torch.norm(grad_pred, dim=-1) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        grad_loss = torch.mean((grad_norm - 1.0) ** 2) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        return base_loss / self.batch_size | 
					 | 
					 | 
					        return grad_loss | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): | 
					 | 
					 | 
					def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					
  |