From 596a7b6565b21bce77c22b01310d82c4a630796d Mon Sep 17 00:00:00 2001 From: mckay Date: Thu, 28 Nov 2024 22:24:02 +0800 Subject: [PATCH] =?UTF-8?q?tofix=EF=BC=9A=20=E6=A6=82=E7=8E=87=E6=80=A7?= =?UTF-8?q?=E6=9C=89=E6=A2=AF=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/networks/encoder.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index eb3e503..50becc4 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -251,7 +251,6 @@ class BRepFeatureEmbedder(nn.Module): vertex_embed = vertex_embed.mean(dim=3) # [B, F, E, embed_dim] # 2. 处理边特征 - logger.info(f"edge_ncs shape: {edge_ncs.shape}") edge_embeds = self.edgez_embed(edge_ncs) # [B, F, E, embed_dim] edge_p_embeds = self.edgep_embed(edge_pos) # [B, F, E, embed_dim] @@ -441,6 +440,13 @@ class BRepToSDF(nn.Module): """ B, Q = query_points.shape[:2] # B: batch_size, Q: num_queries + #logger.info(f"query_points requires_grad: {query_points.requires_grad}") + #logger.info(f"edge_ncs requires_grad: {edge_ncs.requires_grad}") + #logger.info(f"edge_pos requires_grad: {edge_pos.requires_grad}") + #logger.info(f"edge_mask requires_grad: {edge_mask.requires_grad}") + #logger.info(f"surf_ncs requires_grad: {surf_ncs.requires_grad}") + #logger.info(f"surf_pos requires_grad: {surf_pos.requires_grad}") + try: # 确保query_points需要梯度 if not query_points.requires_grad: @@ -544,13 +550,13 @@ def main(): # 生成测试数据 test_data = { - 'edge_ncs': torch.randn(batch_size, max_face, max_edge, num_edge_points, 3), - 'edge_pos': torch.randn(batch_size, max_face, max_edge, 6), + 'edge_ncs': torch.randn(batch_size, max_face, max_edge, num_edge_points, 3, requires_grad=True), + 'edge_pos': torch.randn(batch_size, max_face, max_edge, 6, requires_grad=True), 'edge_mask': torch.ones(batch_size, max_face, max_edge, dtype=torch.bool), - 'surf_ncs': torch.randn(batch_size, max_face, num_surf_points, 3), - 'surf_pos': torch.randn(batch_size, max_face, 6), - 'vertex_pos': torch.randn(batch_size, max_face, max_edge, 2, 3), - 'query_points': torch.randn(batch_size, 1000, 3) # 1000个查询点 + 'surf_ncs': torch.randn(batch_size, max_face, num_surf_points, 3, requires_grad=True), + 'surf_pos': torch.randn(batch_size, max_face, 6, requires_grad=True), + 'vertex_pos': torch.randn(batch_size, max_face, max_edge, 2, 3, requires_grad=True), + 'query_points': torch.randn(batch_size, 1000, 3, requires_grad=True) # 1000个查询点 } # 打印输入数据形状