Browse Source

fix: net dim问题。num*3,变成[num,3],保留空间特征

main
王琛涵 4 months ago
parent
commit
40bb28cfe7
  1. 123
      brep2sdf/networks/encoder.py

123
brep2sdf/networks/encoder.py

@ -130,6 +130,7 @@ class BRepFeatureEmbedder(nn.Module):
self.num_surf_points = self.config.model.num_surf_points # 16
self.num_edge_points = self.config.model.num_edge_points # 4
# Transformer编码器层
layer = nn.TransformerEncoderLayer(
d_model=self.embed_dim,
nhead=12,
@ -144,21 +145,22 @@ class BRepFeatureEmbedder(nn.Module):
enable_nested_tensor=False
)
# 修改输入维度以匹配采样点数
self.surfz_embed = nn.Sequential(
nn.Linear(3 * self.num_surf_points, self.embed_dim), # 3 * 16
nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim),
# 修改为处理[num_points, 3]形状的输入
self.surfz_embed = Encoder1D(
in_channels=3,
out_channels=self.embed_dim,
block_out_channels=(64, 128, 256),
layers_per_block=2
)
self.edgez_embed = nn.Sequential(
nn.Linear(3 * self.num_edge_points, self.embed_dim), # 3 * 4
nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim),
self.edgez_embed = Encoder1D(
in_channels=3,
out_channels=self.embed_dim,
block_out_channels=(64, 128, 256),
layers_per_block=2
)
# 其他嵌入层保持不变
self.surfp_embed = nn.Sequential(
nn.Linear(6, self.embed_dim),
nn.LayerNorm(self.embed_dim),
@ -183,21 +185,37 @@ class BRepFeatureEmbedder(nn.Module):
def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, mask=None):
"""
Args:
surf_z: 表面特征 [B, N, num_surf_points*3]
edge_z: 特征 [B, M, num_edge_points*3]
surf_z: 表面点云 [B, N, num_surf_points, 3]
edge_z: 点云 [B, M, num_edge_points, 3]
surf_p: 表面点 [B, N, 6]
edge_p: 边点 [B, M, 6]
vert_p: 顶点点 [B, K, 6]
mask: 注意力掩码
"""
B, N, _, _ = surf_z.shape
_, M, _, _ = edge_z.shape
_, K, _ = vert_p.shape
# 重塑点云数据用于1D编码器
surf_z = surf_z.reshape(B*N, 3, self.num_surf_points) # [B*N, 3, num_surf_points]
edge_z = edge_z.reshape(B*M, 3, self.num_edge_points) # [B*M, 3, num_edge_points]
# 特征嵌入
surf_embeds = self.surfz_embed(surf_z)
edge_embeds = self.edgez_embed(edge_z)
surf_embeds = self.surfz_embed(surf_z) # [B*N, embed_dim, num_points]
edge_embeds = self.edgez_embed(edge_z) # [B*M, embed_dim, num_points]
# 全局池化得到每个面/边的特征
surf_embeds = surf_embeds.mean(dim=-1) # [B*N, embed_dim]
edge_embeds = edge_embeds.mean(dim=-1) # [B*M, embed_dim]
# 重塑回批次维度
surf_embeds = surf_embeds.reshape(B, N, -1) # [B, N, embed_dim]
edge_embeds = edge_embeds.reshape(B, M, -1) # [B, M, embed_dim]
# 点嵌入
surf_p_embeds = self.surfp_embed(surf_p)
edge_p_embeds = self.edgep_embed(edge_p)
vert_p_embeds = self.vertp_embed(vert_p)
surf_p_embeds = self.surfp_embed(surf_p) # [B, N, embed_dim]
edge_p_embeds = self.edgep_embed(edge_p) # [B, M, embed_dim]
vert_p_embeds = self.vertp_embed(vert_p) # [B, K, embed_dim]
# 组合所有嵌入
if self.use_cf:
@ -205,13 +223,13 @@ class BRepFeatureEmbedder(nn.Module):
surf_embeds + surf_p_embeds,
edge_embeds + edge_p_embeds,
vert_p_embeds
], dim=1)
], dim=1) # [B, N+M+K, embed_dim]
else:
embeds = torch.cat([
surf_p_embeds,
edge_p_embeds,
vert_p_embeds
], dim=1)
], dim=1) # [B, N+M+K, embed_dim]
output = self.transformer(embeds, src_key_padding_mask=mask)
return output
@ -290,8 +308,8 @@ class BRepToSDF(nn.Module):
def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, query_points, mask=None):
"""
Args:
surf_z: 表面特征 [B, N, num_surf_points*3]
edge_z: 边特征 [B, M, num_edge_points*3]
surf_z: 表面特征 [B, N, 48]
edge_z: 边特征 [B, M, 12]
surf_p: 表面点 [B, N, 6]
edge_p: 边点 [B, M, 6]
vert_p: 顶点点 [B, K, 6]
@ -346,24 +364,40 @@ def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1):
return l1_loss + grad_weight * grad_constraint
def main():
# 初始化模型
# 获取配置
config = get_default_config()
# 从配置初始化模型
model = BRepToSDF(
brep_feature_dim=48,
use_cf=True,
embed_dim=768,
latent_dim=256
brep_feature_dim=config.model.brep_feature_dim, # 48
use_cf=config.model.use_cf, # True
embed_dim=config.model.embed_dim, # 768
latent_dim=config.model.latent_dim # 256
)
# 示例输入
batch_size = 4
num_surfs = 10
num_edges = 20
num_verts = 8
num_queries = 1000
# 从配置获取数据参数
batch_size = config.train.batch_size # 32
num_surfs = config.data.max_face # 64
num_edges = config.data.max_edge # 64
num_verts = 8 # 顶点数保持固定
num_queries = 1000 # 查询点数保持固定
# 生成示例数据
surf_z = torch.randn(batch_size, num_surfs, 48)
edge_z = torch.randn(batch_size, num_edges, 12)
surf_z = torch.randn(
batch_size,
num_surfs,
config.model.num_surf_points, # 16
3
) # [B, N, num_surf_points, 3]
edge_z = torch.randn(
batch_size,
num_edges,
config.model.num_edge_points, # 4
3
) # [B, M, num_edge_points, 3]
# 其他输入
surf_p = torch.randn(batch_size, num_surfs, 6)
edge_p = torch.randn(batch_size, num_edges, 6)
vert_p = torch.randn(batch_size, num_verts, 6)
@ -371,7 +405,24 @@ def main():
# 前向传播
sdf = model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points)
print(f"Output SDF shape: {sdf.shape}")
# 打印形状信息和配置信息
print("\nConfiguration:")
print(f"Batch Size: {batch_size}")
print(f"Embed Dimension: {config.model.embed_dim}")
print(f"Surface Points: {config.model.num_surf_points}")
print(f"Edge Points: {config.model.num_edge_points}")
print(f"Max Faces: {config.data.max_face}")
print(f"Max Edges: {config.data.max_edge}")
print("\nInput shapes:")
print(f"surf_z: {surf_z.shape}") # [32, 64, 16, 3]
print(f"edge_z: {edge_z.shape}") # [32, 64, 4, 3]
print(f"surf_p: {surf_p.shape}") # [32, 64, 6]
print(f"edge_p: {edge_p.shape}") # [32, 64, 6]
print(f"vert_p: {vert_p.shape}") # [32, 8, 6]
print(f"query_points: {query_points.shape}") # [32, 1000, 3]
print(f"\nOutput SDF shape: {sdf.shape}") # [32, 1000, 1]
if __name__ == "__main__":
main()
Loading…
Cancel
Save