Browse Source

fix: can train

main
mckay 4 months ago
parent
commit
229c0aeffc
  1. 48
      brep2sdf/networks/encoder.py

48
brep2sdf/networks/encoder.py

@ -197,14 +197,14 @@ class BRepFeatureEmbedder(nn.Module):
self.surfz_embed = Encoder1D(
in_channels=3,
out_channels=self.embed_dim,
block_out_channels=(64, 128, 256),
block_out_channels=(64, 128, self.embed_dim),
layers_per_block=2
)
self.edgez_embed = Encoder1D(
in_channels=3,
out_channels=self.embed_dim,
block_out_channels=(64, 128, 256),
block_out_channels=(64, 128, self.embed_dim),
layers_per_block=2
)
@ -410,6 +410,11 @@ class BRepToSDF(nn.Module):
B, Q = query_points.shape[:2] # B: batch_size, Q: num_queries
try:
# 确保query_points需要梯度
if not query_points.requires_grad:
query_points = query_points.detach().requires_grad_(True)
# 1. B-rep特征编码
brep_features = self.brep_embedder(
edge_ncs=edge_ncs, # [B, max_face, max_edge, num_edge_points, 3]
@ -438,6 +443,10 @@ class BRepToSDF(nn.Module):
# 6. SDF预测
sdf = self.sdf_head(combined_features) # [B, Q, 1]
if not sdf.requires_grad:
logger.warning("SDF output does not require grad!")
return sdf
@ -456,19 +465,34 @@ class BRepToSDF(nn.Module):
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1):
"""SDF损失函数"""
# 确保points需要梯度
if not points.requires_grad:
points = points.detach().requires_grad_(True)
# L1损失
l1_loss = F.l1_loss(pred_sdf, gt_sdf)
# 梯度约束损失
grad = torch.autograd.grad(
pred_sdf.sum(),
points,
create_graph=True
)[0]
grad_constraint = F.mse_loss(
torch.norm(grad, dim=-1),
torch.ones_like(pred_sdf.squeeze(-1))
)
try:
# 梯度约束损失
grad = torch.autograd.grad(
pred_sdf.sum(),
points,
create_graph=True,
retain_graph=True,
allow_unused=True
)[0]
if grad is not None:
grad_constraint = F.mse_loss(
torch.norm(grad, dim=-1),
torch.ones_like(pred_sdf.squeeze(-1))
)
else:
grad_constraint = torch.tensor(0.0, device=pred_sdf.device)
except Exception as e:
logger.warning(f"Gradient computation failed: {str(e)}")
grad_constraint = torch.tensor(0.0, device=pred_sdf.device)
return l1_loss + grad_weight * grad_constraint

Loading…
Cancel
Save