From 8f232a2103cd12a62568b65d0eb6b810d587fd3c Mon Sep 17 00:00:00 2001 From: mckay Date: Sat, 23 Nov 2024 20:29:53 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E5=9C=A8networks=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E9=87=8C=E9=9D=A2=E7=A7=BB=E9=99=A4loss=E7=9B=B8=E5=85=B3?= =?UTF-8?q?=E7=9A=84=E4=B8=9C=E8=A5=BF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/networks/network.py | 33 +-------------------------------- 1 file changed, 1 insertion(+), 32 deletions(-) diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index 21279bc..889d61d 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -119,38 +119,7 @@ class BRepToSDF(nn.Module): logger.error(f" query_points: {query_points.shape}") raise -def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): - """SDF损失函数""" - # 确保points需要梯度 - if not points.requires_grad: - points = points.detach().requires_grad_(True) - - # L1损失 - l1_loss = F.l1_loss(pred_sdf, gt_sdf) - - try: - # 梯度约束损失 - grad = torch.autograd.grad( - pred_sdf.sum(), - points, - create_graph=True, - retain_graph=True, - allow_unused=True - )[0] - - if grad is not None: - grad_constraint = F.mse_loss( - torch.norm(grad, dim=-1), - torch.ones_like(pred_sdf.squeeze(-1)) - ) - else: - grad_constraint = torch.tensor(0.0, device=pred_sdf.device) - - except Exception as e: - logger.warning(f"Gradient computation failed: {str(e)}") - grad_constraint = torch.tensor(0.0, device=pred_sdf.device) - - return l1_loss + grad_weight * grad_constraint + def main(): # 获取配置