From 40bb28cfe7c8528531802d5685f65d2addbe539e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=90=9B=E6=B6=B5?= Date: Tue, 19 Nov 2024 01:40:01 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20net=20dim=E9=97=AE=E9=A2=98=E3=80=82num*?= =?UTF-8?q?3,=E5=8F=98=E6=88=90[num,3],=E4=BF=9D=E7=95=99=E7=A9=BA?= =?UTF-8?q?=E9=97=B4=E7=89=B9=E5=BE=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/networks/encoder.py | 123 +++++++++++++++++++++++++---------- 1 file changed, 87 insertions(+), 36 deletions(-) diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index d4e8726..66bdfab 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -130,6 +130,7 @@ class BRepFeatureEmbedder(nn.Module): self.num_surf_points = self.config.model.num_surf_points # 16 self.num_edge_points = self.config.model.num_edge_points # 4 + # Transformer编码器层 layer = nn.TransformerEncoderLayer( d_model=self.embed_dim, nhead=12, @@ -144,21 +145,22 @@ class BRepFeatureEmbedder(nn.Module): enable_nested_tensor=False ) - # 修改输入维度以匹配采样点数 - self.surfz_embed = nn.Sequential( - nn.Linear(3 * self.num_surf_points, self.embed_dim), # 3 * 16 - nn.LayerNorm(self.embed_dim), - nn.SiLU(), - nn.Linear(self.embed_dim, self.embed_dim), + # 修改为处理[num_points, 3]形状的输入 + self.surfz_embed = Encoder1D( + in_channels=3, + out_channels=self.embed_dim, + block_out_channels=(64, 128, 256), + layers_per_block=2 ) - self.edgez_embed = nn.Sequential( - nn.Linear(3 * self.num_edge_points, self.embed_dim), # 3 * 4 - nn.LayerNorm(self.embed_dim), - nn.SiLU(), - nn.Linear(self.embed_dim, self.embed_dim), + self.edgez_embed = Encoder1D( + in_channels=3, + out_channels=self.embed_dim, + block_out_channels=(64, 128, 256), + layers_per_block=2 ) + # 其他嵌入层保持不变 self.surfp_embed = nn.Sequential( nn.Linear(6, self.embed_dim), nn.LayerNorm(self.embed_dim), @@ -183,21 +185,37 @@ class BRepFeatureEmbedder(nn.Module): def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, mask=None): """ Args: - surf_z: 表面特征 [B, N, num_surf_points*3] - edge_z: 边特征 [B, M, num_edge_points*3] + surf_z: 表面点云 [B, N, num_surf_points, 3] + edge_z: 边点云 [B, M, num_edge_points, 3] surf_p: 表面点 [B, N, 6] edge_p: 边点 [B, M, 6] vert_p: 顶点点 [B, K, 6] mask: 注意力掩码 """ + B, N, _, _ = surf_z.shape + _, M, _, _ = edge_z.shape + _, K, _ = vert_p.shape + + # 重塑点云数据用于1D编码器 + surf_z = surf_z.reshape(B*N, 3, self.num_surf_points) # [B*N, 3, num_surf_points] + edge_z = edge_z.reshape(B*M, 3, self.num_edge_points) # [B*M, 3, num_edge_points] + # 特征嵌入 - surf_embeds = self.surfz_embed(surf_z) - edge_embeds = self.edgez_embed(edge_z) + surf_embeds = self.surfz_embed(surf_z) # [B*N, embed_dim, num_points] + edge_embeds = self.edgez_embed(edge_z) # [B*M, embed_dim, num_points] + + # 全局池化得到每个面/边的特征 + surf_embeds = surf_embeds.mean(dim=-1) # [B*N, embed_dim] + edge_embeds = edge_embeds.mean(dim=-1) # [B*M, embed_dim] + + # 重塑回批次维度 + surf_embeds = surf_embeds.reshape(B, N, -1) # [B, N, embed_dim] + edge_embeds = edge_embeds.reshape(B, M, -1) # [B, M, embed_dim] # 点嵌入 - surf_p_embeds = self.surfp_embed(surf_p) - edge_p_embeds = self.edgep_embed(edge_p) - vert_p_embeds = self.vertp_embed(vert_p) + surf_p_embeds = self.surfp_embed(surf_p) # [B, N, embed_dim] + edge_p_embeds = self.edgep_embed(edge_p) # [B, M, embed_dim] + vert_p_embeds = self.vertp_embed(vert_p) # [B, K, embed_dim] # 组合所有嵌入 if self.use_cf: @@ -205,13 +223,13 @@ class BRepFeatureEmbedder(nn.Module): surf_embeds + surf_p_embeds, edge_embeds + edge_p_embeds, vert_p_embeds - ], dim=1) + ], dim=1) # [B, N+M+K, embed_dim] else: embeds = torch.cat([ surf_p_embeds, edge_p_embeds, vert_p_embeds - ], dim=1) + ], dim=1) # [B, N+M+K, embed_dim] output = self.transformer(embeds, src_key_padding_mask=mask) return output @@ -290,8 +308,8 @@ class BRepToSDF(nn.Module): def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, query_points, mask=None): """ Args: - surf_z: 表面特征 [B, N, num_surf_points*3] - edge_z: 边特征 [B, M, num_edge_points*3] + surf_z: 表面特征 [B, N, 48] + edge_z: 边特征 [B, M, 12] surf_p: 表面点 [B, N, 6] edge_p: 边点 [B, M, 6] vert_p: 顶点点 [B, K, 6] @@ -346,24 +364,40 @@ def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): return l1_loss + grad_weight * grad_constraint def main(): - # 初始化模型 + # 获取配置 + config = get_default_config() + + # 从配置初始化模型 model = BRepToSDF( - brep_feature_dim=48, - use_cf=True, - embed_dim=768, - latent_dim=256 + brep_feature_dim=config.model.brep_feature_dim, # 48 + use_cf=config.model.use_cf, # True + embed_dim=config.model.embed_dim, # 768 + latent_dim=config.model.latent_dim # 256 ) - # 示例输入 - batch_size = 4 - num_surfs = 10 - num_edges = 20 - num_verts = 8 - num_queries = 1000 + # 从配置获取数据参数 + batch_size = config.train.batch_size # 32 + num_surfs = config.data.max_face # 64 + num_edges = config.data.max_edge # 64 + num_verts = 8 # 顶点数保持固定 + num_queries = 1000 # 查询点数保持固定 # 生成示例数据 - surf_z = torch.randn(batch_size, num_surfs, 48) - edge_z = torch.randn(batch_size, num_edges, 12) + surf_z = torch.randn( + batch_size, + num_surfs, + config.model.num_surf_points, # 16 + 3 + ) # [B, N, num_surf_points, 3] + + edge_z = torch.randn( + batch_size, + num_edges, + config.model.num_edge_points, # 4 + 3 + ) # [B, M, num_edge_points, 3] + + # 其他输入 surf_p = torch.randn(batch_size, num_surfs, 6) edge_p = torch.randn(batch_size, num_edges, 6) vert_p = torch.randn(batch_size, num_verts, 6) @@ -371,7 +405,24 @@ def main(): # 前向传播 sdf = model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) - print(f"Output SDF shape: {sdf.shape}") + + # 打印形状信息和配置信息 + print("\nConfiguration:") + print(f"Batch Size: {batch_size}") + print(f"Embed Dimension: {config.model.embed_dim}") + print(f"Surface Points: {config.model.num_surf_points}") + print(f"Edge Points: {config.model.num_edge_points}") + print(f"Max Faces: {config.data.max_face}") + print(f"Max Edges: {config.data.max_edge}") + + print("\nInput shapes:") + print(f"surf_z: {surf_z.shape}") # [32, 64, 16, 3] + print(f"edge_z: {edge_z.shape}") # [32, 64, 4, 3] + print(f"surf_p: {surf_p.shape}") # [32, 64, 6] + print(f"edge_p: {edge_p.shape}") # [32, 64, 6] + print(f"vert_p: {vert_p.shape}") # [32, 8, 6] + print(f"query_points: {query_points.shape}") # [32, 1000, 3] + print(f"\nOutput SDF shape: {sdf.shape}") # [32, 1000, 1] if __name__ == "__main__": main() \ No newline at end of file