Browse Source

fix: can train

main
mckay 7 months ago
parent
commit
229c0aeffc
  1. 48
      brep2sdf/networks/encoder.py

48
brep2sdf/networks/encoder.py

@ -197,14 +197,14 @@ class BRepFeatureEmbedder(nn.Module):
self.surfz_embed = Encoder1D( self.surfz_embed = Encoder1D(
in_channels=3, in_channels=3,
out_channels=self.embed_dim, out_channels=self.embed_dim,
block_out_channels=(64, 128, 256), block_out_channels=(64, 128, self.embed_dim),
layers_per_block=2 layers_per_block=2
) )
self.edgez_embed = Encoder1D( self.edgez_embed = Encoder1D(
in_channels=3, in_channels=3,
out_channels=self.embed_dim, out_channels=self.embed_dim,
block_out_channels=(64, 128, 256), block_out_channels=(64, 128, self.embed_dim),
layers_per_block=2 layers_per_block=2
) )
@ -410,6 +410,11 @@ class BRepToSDF(nn.Module):
B, Q = query_points.shape[:2] # B: batch_size, Q: num_queries B, Q = query_points.shape[:2] # B: batch_size, Q: num_queries
try: try:
# 确保query_points需要梯度
if not query_points.requires_grad:
query_points = query_points.detach().requires_grad_(True)
# 1. B-rep特征编码 # 1. B-rep特征编码
brep_features = self.brep_embedder( brep_features = self.brep_embedder(
edge_ncs=edge_ncs, # [B, max_face, max_edge, num_edge_points, 3] edge_ncs=edge_ncs, # [B, max_face, max_edge, num_edge_points, 3]
@ -438,6 +443,10 @@ class BRepToSDF(nn.Module):
# 6. SDF预测 # 6. SDF预测
sdf = self.sdf_head(combined_features) # [B, Q, 1] sdf = self.sdf_head(combined_features) # [B, Q, 1]
if not sdf.requires_grad:
logger.warning("SDF output does not require grad!")
return sdf return sdf
@ -456,19 +465,34 @@ class BRepToSDF(nn.Module):
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1):
"""SDF损失函数""" """SDF损失函数"""
# 确保points需要梯度
if not points.requires_grad:
points = points.detach().requires_grad_(True)
# L1损失 # L1损失
l1_loss = F.l1_loss(pred_sdf, gt_sdf) l1_loss = F.l1_loss(pred_sdf, gt_sdf)
# 梯度约束损失 try:
grad = torch.autograd.grad( # 梯度约束损失
pred_sdf.sum(), grad = torch.autograd.grad(
points, pred_sdf.sum(),
create_graph=True points,
)[0] create_graph=True,
grad_constraint = F.mse_loss( retain_graph=True,
torch.norm(grad, dim=-1), allow_unused=True
torch.ones_like(pred_sdf.squeeze(-1)) )[0]
)
if grad is not None:
grad_constraint = F.mse_loss(
torch.norm(grad, dim=-1),
torch.ones_like(pred_sdf.squeeze(-1))
)
else:
grad_constraint = torch.tensor(0.0, device=pred_sdf.device)
except Exception as e:
logger.warning(f"Gradient computation failed: {str(e)}")
grad_constraint = torch.tensor(0.0, device=pred_sdf.device)
return l1_loss + grad_weight * grad_constraint return l1_loss + grad_weight * grad_constraint

Loading…
Cancel
Save