Browse Source

fix:使用size()而不是shape来获取维度

main
mckay 4 months ago
parent
commit
a2cb2c6e8a
  1. 12
      brep2sdf/networks/encoder.py

12
brep2sdf/networks/encoder.py

@ -192,13 +192,15 @@ class BRepFeatureEmbedder(nn.Module):
vert_p: 顶点点 [B, K, 6]
mask: 注意力掩码
"""
B, N, _, _ = surf_z.shape
_, M, _, _ = edge_z.shape
_, K, _ = vert_p.shape
# 获取批次大小和其他维度
B = surf_z.size(0)
N = surf_z.size(1)
M = edge_z.size(1)
K = vert_p.size(1)
# 重塑点云数据用于1D编码器
surf_z = surf_z.reshape(B*N, 3, self.num_surf_points) # [B*N, 3, num_surf_points]
edge_z = edge_z.reshape(B*M, 3, self.num_edge_points) # [B*M, 3, num_edge_points]
surf_z = surf_z.reshape(B*N, self.num_surf_points, 3).transpose(1, 2) # [B*N, 3, num_surf_points]
edge_z = edge_z.reshape(B*M, self.num_edge_points, 3).transpose(1, 2) # [B*M, 3, num_edge_points]
# 特征嵌入
surf_embeds = self.surfz_embed(surf_z) # [B*N, embed_dim, num_points]

Loading…
Cancel
Save