Browse Source

refactor: 在networks文件里面移除loss相关的东西

main
mckay 4 months ago
parent
commit
8f232a2103
  1. 33
      brep2sdf/networks/network.py

33
brep2sdf/networks/network.py

@ -119,38 +119,7 @@ class BRepToSDF(nn.Module):
logger.error(f" query_points: {query_points.shape}")
raise
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1):
"""SDF损失函数"""
# 确保points需要梯度
if not points.requires_grad:
points = points.detach().requires_grad_(True)
# L1损失
l1_loss = F.l1_loss(pred_sdf, gt_sdf)
try:
# 梯度约束损失
grad = torch.autograd.grad(
pred_sdf.sum(),
points,
create_graph=True,
retain_graph=True,
allow_unused=True
)[0]
if grad is not None:
grad_constraint = F.mse_loss(
torch.norm(grad, dim=-1),
torch.ones_like(pred_sdf.squeeze(-1))
)
else:
grad_constraint = torch.tensor(0.0, device=pred_sdf.device)
except Exception as e:
logger.warning(f"Gradient computation failed: {str(e)}")
grad_constraint = torch.tensor(0.0, device=pred_sdf.device)
return l1_loss + grad_weight * grad_constraint
def main():
# 获取配置

Loading…
Cancel
Save