diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index 21279bc..889d61d 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -119,38 +119,7 @@ class BRepToSDF(nn.Module): logger.error(f" query_points: {query_points.shape}") raise -def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): - """SDF损失函数""" - # 确保points需要梯度 - if not points.requires_grad: - points = points.detach().requires_grad_(True) - - # L1损失 - l1_loss = F.l1_loss(pred_sdf, gt_sdf) - - try: - # 梯度约束损失 - grad = torch.autograd.grad( - pred_sdf.sum(), - points, - create_graph=True, - retain_graph=True, - allow_unused=True - )[0] - - if grad is not None: - grad_constraint = F.mse_loss( - torch.norm(grad, dim=-1), - torch.ones_like(pred_sdf.squeeze(-1)) - ) - else: - grad_constraint = torch.tensor(0.0, device=pred_sdf.device) - - except Exception as e: - logger.warning(f"Gradient computation failed: {str(e)}") - grad_constraint = torch.tensor(0.0, device=pred_sdf.device) - - return l1_loss + grad_weight * grad_constraint + def main(): # 获取配置