|
|
@ -119,38 +119,7 @@ class BRepToSDF(nn.Module): |
|
|
|
logger.error(f" query_points: {query_points.shape}") |
|
|
|
raise |
|
|
|
|
|
|
|
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): |
|
|
|
"""SDF损失函数""" |
|
|
|
# 确保points需要梯度 |
|
|
|
if not points.requires_grad: |
|
|
|
points = points.detach().requires_grad_(True) |
|
|
|
|
|
|
|
# L1损失 |
|
|
|
l1_loss = F.l1_loss(pred_sdf, gt_sdf) |
|
|
|
|
|
|
|
try: |
|
|
|
# 梯度约束损失 |
|
|
|
grad = torch.autograd.grad( |
|
|
|
pred_sdf.sum(), |
|
|
|
points, |
|
|
|
create_graph=True, |
|
|
|
retain_graph=True, |
|
|
|
allow_unused=True |
|
|
|
)[0] |
|
|
|
|
|
|
|
if grad is not None: |
|
|
|
grad_constraint = F.mse_loss( |
|
|
|
torch.norm(grad, dim=-1), |
|
|
|
torch.ones_like(pred_sdf.squeeze(-1)) |
|
|
|
) |
|
|
|
else: |
|
|
|
grad_constraint = torch.tensor(0.0, device=pred_sdf.device) |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
logger.warning(f"Gradient computation failed: {str(e)}") |
|
|
|
grad_constraint = torch.tensor(0.0, device=pred_sdf.device) |
|
|
|
|
|
|
|
return l1_loss + grad_weight * grad_constraint |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
# 获取配置 |
|
|
|