3 changed files with 94 additions and 13 deletions
			
			
		@ -0,0 +1,70 @@ | 
				
			|||
import torch | 
				
			|||
import torch.nn as nn | 
				
			|||
 | 
				
			|||
from brep2sdf.config.default_config import get_default_config | 
				
			|||
 | 
				
			|||
class Brep2SDFLoss: | 
				
			|||
    """解释Brep2SDF的loss设计原理""" | 
				
			|||
    def __init__(self, enforce_minmax: bool=True, clamping_distance: float = 0.1): | 
				
			|||
        self.l1_loss = nn.L1Loss(reduction='sum') | 
				
			|||
        self.enforce_minmax = enforce_minmax | 
				
			|||
        self.minT = -clamping_distance | 
				
			|||
        self.maxT = clamping_distance | 
				
			|||
 | 
				
			|||
    def __call__(self, pred_sdf, gt_sdf): | 
				
			|||
        """使类可直接调用""" | 
				
			|||
        return self.forward(pred_sdf, gt_sdf) | 
				
			|||
           | 
				
			|||
    def forward(self, pred_sdf, gt_sdf): | 
				
			|||
        """ | 
				
			|||
        pred_sdf: 预测的SDF值 | 
				
			|||
        gt_sdf: 真实的SDF值 | 
				
			|||
        latent_vecs: 形状编码, 来自 brep | 
				
			|||
        epoch: 当前训练轮次 | 
				
			|||
        """ | 
				
			|||
 | 
				
			|||
        if self.enforce_minmax: | 
				
			|||
            pred_sdf = torch.clamp(pred_sdf, self.minT, self.maxT) | 
				
			|||
            gt_sdf = torch.clamp(gt_sdf, self.minT, self.maxT) | 
				
			|||
 | 
				
			|||
        # 1. L1 Loss的优势 | 
				
			|||
        # - 对异常值更鲁棒 | 
				
			|||
        # - 能更好地保持表面细节 | 
				
			|||
        base_loss = self.l1_loss(pred_sdf, gt_sdf) / pred_sdf.shape[0] | 
				
			|||
         | 
				
			|||
         | 
				
			|||
        return base_loss | 
				
			|||
 | 
				
			|||
 | 
				
			|||
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): | 
				
			|||
    """SDF损失函数""" | 
				
			|||
    # 确保points需要梯度 | 
				
			|||
    if not points.requires_grad: | 
				
			|||
        points = points.detach().requires_grad_(True) | 
				
			|||
     | 
				
			|||
    # L1损失 | 
				
			|||
    l1_loss = F.l1_loss(pred_sdf, gt_sdf) | 
				
			|||
     | 
				
			|||
    try: | 
				
			|||
        # 梯度约束损失 | 
				
			|||
        grad = torch.autograd.grad( | 
				
			|||
            pred_sdf.sum(),  | 
				
			|||
            points, | 
				
			|||
            create_graph=True, | 
				
			|||
            retain_graph=True, | 
				
			|||
            allow_unused=True | 
				
			|||
        )[0] | 
				
			|||
         | 
				
			|||
        if grad is not None: | 
				
			|||
            grad_constraint = F.mse_loss( | 
				
			|||
                torch.norm(grad, dim=-1), | 
				
			|||
                torch.ones_like(pred_sdf.squeeze(-1)) | 
				
			|||
            ) | 
				
			|||
        else: | 
				
			|||
            grad_constraint = torch.tensor(0.0, device=pred_sdf.device) | 
				
			|||
             | 
				
			|||
    except Exception as e: | 
				
			|||
        logger.warning(f"Gradient computation failed: {str(e)}") | 
				
			|||
        grad_constraint = torch.tensor(0.0, device=pred_sdf.device) | 
				
			|||
     | 
				
			|||
    return l1_loss + grad_weight * grad_constraint | 
				
			|||
					Loading…
					
					
				
		Reference in new issue