|
|
@ -197,14 +197,14 @@ class BRepFeatureEmbedder(nn.Module): |
|
|
|
self.surfz_embed = Encoder1D( |
|
|
|
in_channels=3, |
|
|
|
out_channels=self.embed_dim, |
|
|
|
block_out_channels=(64, 128, 256), |
|
|
|
block_out_channels=(64, 128, self.embed_dim), |
|
|
|
layers_per_block=2 |
|
|
|
) |
|
|
|
|
|
|
|
self.edgez_embed = Encoder1D( |
|
|
|
in_channels=3, |
|
|
|
out_channels=self.embed_dim, |
|
|
|
block_out_channels=(64, 128, 256), |
|
|
|
block_out_channels=(64, 128, self.embed_dim), |
|
|
|
layers_per_block=2 |
|
|
|
) |
|
|
|
|
|
|
@ -410,6 +410,11 @@ class BRepToSDF(nn.Module): |
|
|
|
B, Q = query_points.shape[:2] # B: batch_size, Q: num_queries |
|
|
|
|
|
|
|
try: |
|
|
|
# 确保query_points需要梯度 |
|
|
|
if not query_points.requires_grad: |
|
|
|
query_points = query_points.detach().requires_grad_(True) |
|
|
|
|
|
|
|
|
|
|
|
# 1. B-rep特征编码 |
|
|
|
brep_features = self.brep_embedder( |
|
|
|
edge_ncs=edge_ncs, # [B, max_face, max_edge, num_edge_points, 3] |
|
|
@ -438,6 +443,10 @@ class BRepToSDF(nn.Module): |
|
|
|
|
|
|
|
# 6. SDF预测 |
|
|
|
sdf = self.sdf_head(combined_features) # [B, Q, 1] |
|
|
|
|
|
|
|
if not sdf.requires_grad: |
|
|
|
logger.warning("SDF output does not require grad!") |
|
|
|
|
|
|
|
|
|
|
|
return sdf |
|
|
|
|
|
|
@ -456,19 +465,34 @@ class BRepToSDF(nn.Module): |
|
|
|
|
|
|
|
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): |
|
|
|
"""SDF损失函数""" |
|
|
|
# 确保points需要梯度 |
|
|
|
if not points.requires_grad: |
|
|
|
points = points.detach().requires_grad_(True) |
|
|
|
|
|
|
|
# L1损失 |
|
|
|
l1_loss = F.l1_loss(pred_sdf, gt_sdf) |
|
|
|
|
|
|
|
# 梯度约束损失 |
|
|
|
grad = torch.autograd.grad( |
|
|
|
pred_sdf.sum(), |
|
|
|
points, |
|
|
|
create_graph=True |
|
|
|
)[0] |
|
|
|
grad_constraint = F.mse_loss( |
|
|
|
torch.norm(grad, dim=-1), |
|
|
|
torch.ones_like(pred_sdf.squeeze(-1)) |
|
|
|
) |
|
|
|
try: |
|
|
|
# 梯度约束损失 |
|
|
|
grad = torch.autograd.grad( |
|
|
|
pred_sdf.sum(), |
|
|
|
points, |
|
|
|
create_graph=True, |
|
|
|
retain_graph=True, |
|
|
|
allow_unused=True |
|
|
|
)[0] |
|
|
|
|
|
|
|
if grad is not None: |
|
|
|
grad_constraint = F.mse_loss( |
|
|
|
torch.norm(grad, dim=-1), |
|
|
|
torch.ones_like(pred_sdf.squeeze(-1)) |
|
|
|
) |
|
|
|
else: |
|
|
|
grad_constraint = torch.tensor(0.0, device=pred_sdf.device) |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
logger.warning(f"Gradient computation failed: {str(e)}") |
|
|
|
grad_constraint = torch.tensor(0.0, device=pred_sdf.device) |
|
|
|
|
|
|
|
return l1_loss + grad_weight * grad_constraint |
|
|
|
|
|
|
|