diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index ca121a9..fb16b6c 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union from brep2sdf.config.default_config import get_default_config +from brep2sdf.utils.logger import logger class ResConvBlock(nn.Module): """残差卷积块""" @@ -119,16 +120,24 @@ class Encoder1D(nn.Module): class BRepFeatureEmbedder(nn.Module): """B-rep特征嵌入器""" - def __init__(self, use_cf: bool = True): + def __init__(self, config=None): super().__init__() - # 获取配置 - self.config = get_default_config() - self.embed_dim = 768 - self.use_cf = use_cf + if config is None: + self.config = get_default_config() + else: + self.config = config + + self.num_surf_points = self.config.model.num_surf_points + self.num_edge_points = self.config.model.num_edge_points + self.embed_dim = self.config.model.embed_dim + self.use_cf = self.config.model.use_cf - # 使用配置中的采样点数 - self.num_surf_points = self.config.model.num_surf_points # 16 - self.num_edge_points = self.config.model.num_edge_points # 4 + # 打印初始化信息 + logger.info(f"BRepFeatureEmbedder config:") + logger.info(f" num_surf_points: {self.num_surf_points}") + logger.info(f" num_edge_points: {self.num_edge_points}") + logger.info(f" embed_dim: {self.embed_dim}") + logger.info(f" use_cf: {self.use_cf}") # Transformer编码器层 layer = nn.TransformerEncoderLayer( @@ -182,59 +191,93 @@ class BRepFeatureEmbedder(nn.Module): nn.Linear(self.embed_dim, self.embed_dim), ) - def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, mask=None): - """ - Args: - surf_z: 表面点云 [B, N, num_surf_points, 3] - edge_z: 边点云 [B, M, num_edge_points, 3] - surf_p: 表面点 [B, N, 6] - edge_p: 边点 [B, M, 6] - vert_p: 顶点点 [B, K, 6] - mask: 注意力掩码 - """ - # 获取批次大小和其他维度 - B = surf_z.size(0) - N = surf_z.size(1) - M = edge_z.size(1) - K = vert_p.size(1) - - # 重塑点云数据用于1D编码器 - surf_z = surf_z.reshape(B*N, self.num_surf_points, 3).transpose(1, 2) # [B*N, 3, num_surf_points] - edge_z = edge_z.reshape(B*M, self.num_edge_points, 3).transpose(1, 2) # [B*M, 3, num_edge_points] - - # 特征嵌入 - surf_embeds = self.surfz_embed(surf_z) # [B*N, embed_dim, num_points] - edge_embeds = self.edgez_embed(edge_z) # [B*M, embed_dim, num_points] + def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, data_class=None): + """B-rep特征嵌入器的前向传播 - # 全局池化得到每个面/边的特征 - surf_embeds = surf_embeds.mean(dim=-1) # [B*N, embed_dim] - edge_embeds = edge_embeds.mean(dim=-1) # [B*M, embed_dim] - - # 重塑回批次维度 - surf_embeds = surf_embeds.reshape(B, N, -1) # [B, N, embed_dim] - edge_embeds = edge_embeds.reshape(B, M, -1) # [B, M, embed_dim] + Args: + edge_ncs: 边归一化特征 [B, max_face, max_edge, num_edge_points, 3] + edge_pos: 边位置 [B, max_face, max_edge, 6] + edge_mask: 边掩码 [B, max_face, max_edge] + surf_ncs: 面归一化特征 [B, max_face, num_surf_points, 3] + surf_pos: 面位置 [B, max_face, 6] + vertex_pos: 顶点位置 [B, max_face, max_edge, 2, 3] - # 点嵌入 - surf_p_embeds = self.surfp_embed(surf_p) # [B, N, embed_dim] - edge_p_embeds = self.edgep_embed(edge_p) # [B, M, embed_dim] - vert_p_embeds = self.vertp_embed(vert_p) # [B, K, embed_dim] + Returns: + embeds: [B, max_face*(max_edge+1), embed_dim] + """ + B = self.config.train.batch_size + max_face = self.config.data.max_face + max_edge = self.config.data.max_edge - # 组合所有嵌入 - if self.use_cf: - embeds = torch.cat([ - surf_embeds + surf_p_embeds, - edge_embeds + edge_p_embeds, - vert_p_embeds - ], dim=1) # [B, N+M+K, embed_dim] - else: - embeds = torch.cat([ - surf_p_embeds, - edge_p_embeds, - vert_p_embeds - ], dim=1) # [B, N+M+K, embed_dim] + try: + # 1. 处理边特征 + # 重塑边点云以适应1D编码器 + edge_ncs = edge_ncs.reshape(B*max_face*max_edge, -1, 3).transpose(1, 2) # [B*max_face*max_edge, 3, num_edge_points] + edge_embeds = self.edgez_embed(edge_ncs) # [B*max_face*max_edge, embed_dim, num_edge_points] + edge_embeds = edge_embeds.mean(dim=-1) # [B*max_face*max_edge, embed_dim] + edge_embeds = edge_embeds.reshape(B, max_face, max_edge, -1) # [B, max_face, max_edge, embed_dim] + + # 2. 处理面特征 + surf_ncs = surf_ncs.reshape(B*max_face, -1, 3).transpose(1, 2) # [B*max_face, 3, num_surf_points] + surf_embeds = self.surfz_embed(surf_ncs) # [B*max_face, embed_dim, num_surf_points] + surf_embeds = surf_embeds.mean(dim=-1) # [B*max_face, embed_dim] + surf_embeds = surf_embeds.reshape(B, max_face, -1) # [B, max_face, embed_dim] + + # 3. 处理位置编码 + # 边位置编码 + edge_pos = edge_pos.reshape(B*max_face*max_edge, -1) # [B*max_face*max_edge, 6] + edge_p_embeds = self.edgep_embed(edge_pos) # [B*max_face*max_edge, embed_dim] + edge_p_embeds = edge_p_embeds.reshape(B, max_face, max_edge, -1) # [B, max_face, max_edge, embed_dim] + + # 面位置编码 + surf_p_embeds = self.surfp_embed(surf_pos) # [B, max_face, embed_dim] - output = self.transformer(embeds, src_key_padding_mask=mask) - return output + # 4. 组合特征 + if self.use_cf: + # 边特征 + edge_features = edge_embeds + edge_p_embeds # [B, max_face, max_edge, embed_dim] + edge_features = edge_features.reshape(B, max_face*max_edge, -1) # [B, max_face*max_edge, embed_dim] + + # 面特征 + surf_features = surf_embeds + surf_p_embeds # [B, max_face, embed_dim] + + # 组合所有特征 + embeds = torch.cat([ + edge_features, # [B, max_face*max_edge, embed_dim] + surf_features # [B, max_face, embed_dim] + ], dim=1) # [B, max_face*(max_edge+1), embed_dim] + else: + # 只使用位置编码 + edge_features = edge_p_embeds.reshape(B, max_face*max_edge, -1) # [B, max_face*max_edge, embed_dim] + embeds = torch.cat([ + edge_features, # [B, max_face*max_edge, embed_dim] + surf_p_embeds # [B, max_face, embed_dim] + ], dim=1) # [B, max_face*(max_edge+1), embed_dim] + + # 5. 处理掩码 + if edge_mask is not None: + # 扩展掩码以匹配特征维度 + edge_mask = edge_mask.reshape(B, -1) # [B, max_face*max_edge] + surf_mask = torch.ones(B, max_face, device=edge_mask.device, dtype=torch.bool) # [B, max_face] + mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, max_face*(max_edge+1)] + else: + mask = None + + # 6. Transformer处理 + output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask) + return output.transpose(0, 1) # 确保输出维度为 [B, seq_len, embed_dim] + + except Exception as e: + logger.error(f"Error in BRepFeatureEmbedder forward pass:") + logger.error(f" Error message: {str(e)}") + logger.error(f" Input shapes:") + logger.error(f" edge_ncs: {edge_ncs.shape}") + logger.error(f" edge_pos: {edge_pos.shape}") + logger.error(f" edge_mask: {edge_mask.shape}") + logger.error(f" surf_ncs: {surf_ncs.shape}") + logger.error(f" surf_pos: {surf_pos.shape}") + logger.error(f" vertex_pos: {vertex_pos.shape}") + raise class SDFTransformer(nn.Module): """SDF Transformer编码器""" @@ -296,7 +339,7 @@ class BRepToSDF(nn.Module): ) # 2. B-rep特征编码器 - self.feature_embedder = BRepFeatureEmbedder(use_cf=use_cf) + self.brep_embedder = BRepFeatureEmbedder() # 3. 特征融合Transformer self.transformer = SDFTransformer( @@ -307,45 +350,68 @@ class BRepToSDF(nn.Module): # 4. SDF预测头 self.sdf_head = SDFHead(embed_dim=embed_dim*2) - def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, query_points, mask=None): - """ + def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, query_points, data_class=None): + """B-rep到SDF的前向传播 + Args: - surf_z: 表面特征 [B, N, 48] - edge_z: 边特征 [B, M, 12] - surf_p: 表面点 [B, N, 6] - edge_p: 边点 [B, M, 6] - vert_p: 顶点点 [B, K, 6] - query_points: 查询点 [B, Q, 3] - mask: 注意力掩码 + edge_ncs: 边归一化特征 [B, max_face, max_edge, num_edge_points, 3] + edge_pos: 边位置 [B, max_face, max_edge, 6] + edge_mask: 边掩码 [B, max_face, max_edge] + surf_ncs: 面归一化特征 [B, max_face, num_surf_points, 3] + surf_pos: 面位置 [B, max_face, 6] + vertex_pos: 顶点位置 [B, max_face, max_edge, 2, 3] + query_points: 查询点 [B, num_queries, 3] + data_class: (可选) 类别标签 + Returns: - sdf: [B, Q, 1] + sdf: 预测的SDF值 [B, num_queries, 1] """ - B, Q, _ = query_points.shape - - # 1. B-rep特征嵌入 - brep_features = self.feature_embedder( - surf_z, edge_z, surf_p, edge_p, vert_p, mask - ) # [B, N+M+K, embed_dim] - - # 2. 查询点编码 - query_features = self.query_encoder(query_points) # [B, Q, embed_dim] - - # 3. 提取全局特征 - global_features = brep_features.mean(dim=1) # [B, embed_dim] - - # 4. 为每个查询点准备特征 - expanded_features = global_features.unsqueeze(1).expand(-1, Q, -1) # [B, Q, embed_dim] - - # 5. 连接查询点特征和全局特征 - combined_features = torch.cat([ - expanded_features, # [B, Q, embed_dim] - query_features # [B, Q, embed_dim] - ], dim=-1) # [B, Q, embed_dim*2] + B, Q = query_points.shape[:2] # B: batch_size, Q: num_queries - # 6. SDF预测 - sdf = self.sdf_head(combined_features) # [B, Q, 1] - - return sdf + try: + # 1. B-rep特征编码 + brep_features = self.brep_embedder( + edge_ncs=edge_ncs, # [B, max_face, max_edge, num_edge_points, 3] + edge_pos=edge_pos, # [B, max_face, max_edge, 6] + edge_mask=edge_mask, # [B, max_face, max_edge] + surf_ncs=surf_ncs, # [B, max_face, num_surf_points, 3] + surf_pos=surf_pos, # [B, max_face, 6] + vertex_pos=vertex_pos, # [B, max_face, max_edge, 2, 3] + data_class=data_class + ) # [B, max_face*(max_edge+1), embed_dim] + + # 2. 查询点编码 + query_features = self.query_encoder(query_points) # [B, Q, embed_dim] + + # 3. 提取全局特征 + global_features = brep_features.mean(dim=1) # [B, embed_dim] + + # 4. 为每个查询点准备特征 + expanded_features = global_features.unsqueeze(1).expand(-1, Q, -1) # [B, Q, embed_dim] + + # 5. 连接查询点特征和全局特征 + combined_features = torch.cat([ + expanded_features, # [B, Q, embed_dim] + query_features # [B, Q, embed_dim] + ], dim=-1) # [B, Q, embed_dim*2] + + # 6. SDF预测 + sdf = self.sdf_head(combined_features) # [B, Q, 1] + + return sdf + + except Exception as e: + logger.error(f"Error in BRepToSDF forward pass:") + logger.error(f" Error message: {str(e)}") + logger.error(f" Input shapes:") + logger.error(f" edge_ncs: {edge_ncs.shape}") + logger.error(f" edge_pos: {edge_pos.shape}") + logger.error(f" edge_mask: {edge_mask.shape}") + logger.error(f" surf_ncs: {surf_ncs.shape}") + logger.error(f" surf_pos: {surf_pos.shape}") + logger.error(f" vertex_pos: {vertex_pos.shape}") + logger.error(f" query_points: {query_points.shape}") + raise def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): """SDF损失函数""" @@ -384,47 +450,73 @@ def main(): num_verts = 8 # 顶点数保持固定 num_queries = 1000 # 查询点数保持固定 - # 生成示例数据 - surf_z = torch.randn( - batch_size, - num_surfs, - config.model.num_surf_points, # 16 + # 更新测试数据维度 + edge_ncs = torch.randn( + batch_size, + num_surfs, # max_face + num_edges, # max_edge + config.model.num_edge_points, + 3 + ) # [B, max_face, max_edge, num_edge_points, 3] + + edge_pos = torch.randn( + batch_size, + num_surfs, + num_edges, + 6 + ) # [B, max_face, max_edge, 6] + + edge_mask = torch.ones( + batch_size, + num_surfs, + num_edges, + dtype=torch.bool + ) # [B, max_face, max_edge] + + surf_ncs = torch.randn( + batch_size, + num_surfs, + config.model.num_surf_points, 3 - ) # [B, N, num_surf_points, 3] + ) # [B, max_face, num_surf_points, 3] + + surf_pos = torch.randn( + batch_size, + num_surfs, + 6 + ) # [B, max_face, 6] - edge_z = torch.randn( - batch_size, - num_edges, - config.model.num_edge_points, # 4 + vertex_pos = torch.randn( + batch_size, + num_surfs, + num_edges, + 2, 3 - ) # [B, M, num_edge_points, 3] + ) # [B, max_face, max_edge, 2, 3] - # 其他输入 - surf_p = torch.randn(batch_size, num_surfs, 6) - edge_p = torch.randn(batch_size, num_edges, 6) - vert_p = torch.randn(batch_size, num_verts, 6) query_points = torch.randn(batch_size, num_queries, 3) - # 前向传播 - sdf = model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) - - # 打印形状信息和配置信息 - print("\nConfiguration:") - print(f"Batch Size: {batch_size}") - print(f"Embed Dimension: {config.model.embed_dim}") - print(f"Surface Points: {config.model.num_surf_points}") - print(f"Edge Points: {config.model.num_edge_points}") - print(f"Max Faces: {config.data.max_face}") - print(f"Max Edges: {config.data.max_edge}") + # 更新前向传播调用 + sdf = model( + edge_ncs=edge_ncs, + edge_pos=edge_pos, + edge_mask=edge_mask, + surf_ncs=surf_ncs, + surf_pos=surf_pos, + vertex_pos=vertex_pos, + query_points=query_points + ) + # 更新打印信息 print("\nInput shapes:") - print(f"surf_z: {surf_z.shape}") # [32, 64, 16, 3] - print(f"edge_z: {edge_z.shape}") # [32, 64, 4, 3] - print(f"surf_p: {surf_p.shape}") # [32, 64, 6] - print(f"edge_p: {edge_p.shape}") # [32, 64, 6] - print(f"vert_p: {vert_p.shape}") # [32, 8, 6] - print(f"query_points: {query_points.shape}") # [32, 1000, 3] - print(f"\nOutput SDF shape: {sdf.shape}") # [32, 1000, 1] + print(f"edge_ncs: {edge_ncs.shape}") + print(f"edge_pos: {edge_pos.shape}") + print(f"edge_mask: {edge_mask.shape}") + print(f"surf_ncs: {surf_ncs.shape}") + print(f"surf_pos: {surf_pos.shape}") + print(f"vertex_pos: {vertex_pos.shape}") + print(f"query_points: {query_points.shape}") + print(f"\nOutput SDF shape: {sdf.shape}") if __name__ == "__main__": main() \ No newline at end of file