From a2cb2c6e8ac2689cf4a94e9663115d18a40d73d2 Mon Sep 17 00:00:00 2001 From: mckay Date: Tue, 19 Nov 2024 02:28:25 +0800 Subject: [PATCH] =?UTF-8?q?fix:=E4=BD=BF=E7=94=A8size()=E8=80=8C=E4=B8=8D?= =?UTF-8?q?=E6=98=AFshape=E6=9D=A5=E8=8E=B7=E5=8F=96=E7=BB=B4=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/networks/encoder.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index 66bdfab..ca121a9 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -192,13 +192,15 @@ class BRepFeatureEmbedder(nn.Module): vert_p: 顶点点 [B, K, 6] mask: 注意力掩码 """ - B, N, _, _ = surf_z.shape - _, M, _, _ = edge_z.shape - _, K, _ = vert_p.shape + # 获取批次大小和其他维度 + B = surf_z.size(0) + N = surf_z.size(1) + M = edge_z.size(1) + K = vert_p.size(1) # 重塑点云数据用于1D编码器 - surf_z = surf_z.reshape(B*N, 3, self.num_surf_points) # [B*N, 3, num_surf_points] - edge_z = edge_z.reshape(B*M, 3, self.num_edge_points) # [B*M, 3, num_edge_points] + surf_z = surf_z.reshape(B*N, self.num_surf_points, 3).transpose(1, 2) # [B*N, 3, num_surf_points] + edge_z = edge_z.reshape(B*M, self.num_edge_points, 3).transpose(1, 2) # [B*M, 3, num_edge_points] # 特征嵌入 surf_embeds = self.surfz_embed(surf_z) # [B*N, embed_dim, num_points]