From 229c0aeffc1d70915f6f7ea176662fbcde81d532 Mon Sep 17 00:00:00 2001 From: mckay Date: Sat, 23 Nov 2024 14:28:44 +0800 Subject: [PATCH] fix: can train --- brep2sdf/networks/encoder.py | 48 +++++++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index 38eb340..cda7ab0 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -197,14 +197,14 @@ class BRepFeatureEmbedder(nn.Module): self.surfz_embed = Encoder1D( in_channels=3, out_channels=self.embed_dim, - block_out_channels=(64, 128, 256), + block_out_channels=(64, 128, self.embed_dim), layers_per_block=2 ) self.edgez_embed = Encoder1D( in_channels=3, out_channels=self.embed_dim, - block_out_channels=(64, 128, 256), + block_out_channels=(64, 128, self.embed_dim), layers_per_block=2 ) @@ -410,6 +410,11 @@ class BRepToSDF(nn.Module): B, Q = query_points.shape[:2] # B: batch_size, Q: num_queries try: + # 确保query_points需要梯度 + if not query_points.requires_grad: + query_points = query_points.detach().requires_grad_(True) + + # 1. B-rep特征编码 brep_features = self.brep_embedder( edge_ncs=edge_ncs, # [B, max_face, max_edge, num_edge_points, 3] @@ -438,6 +443,10 @@ class BRepToSDF(nn.Module): # 6. SDF预测 sdf = self.sdf_head(combined_features) # [B, Q, 1] + + if not sdf.requires_grad: + logger.warning("SDF output does not require grad!") + return sdf @@ -456,19 +465,34 @@ class BRepToSDF(nn.Module): def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): """SDF损失函数""" + # 确保points需要梯度 + if not points.requires_grad: + points = points.detach().requires_grad_(True) + # L1损失 l1_loss = F.l1_loss(pred_sdf, gt_sdf) - # 梯度约束损失 - grad = torch.autograd.grad( - pred_sdf.sum(), - points, - create_graph=True - )[0] - grad_constraint = F.mse_loss( - torch.norm(grad, dim=-1), - torch.ones_like(pred_sdf.squeeze(-1)) - ) + try: + # 梯度约束损失 + grad = torch.autograd.grad( + pred_sdf.sum(), + points, + create_graph=True, + retain_graph=True, + allow_unused=True + )[0] + + if grad is not None: + grad_constraint = F.mse_loss( + torch.norm(grad, dim=-1), + torch.ones_like(pred_sdf.squeeze(-1)) + ) + else: + grad_constraint = torch.tensor(0.0, device=pred_sdf.device) + + except Exception as e: + logger.warning(f"Gradient computation failed: {str(e)}") + grad_constraint = torch.tensor(0.0, device=pred_sdf.device) return l1_loss + grad_weight * grad_constraint