diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index a5e978e..eb3e503 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -76,7 +76,7 @@ class UNetMidBlock1D(nn.Module): x = attn(x) return x -class Encoder1D(nn.Module): +class _Encoder1D(nn.Module): def __init__( self, in_channels: int = 3, @@ -114,47 +114,40 @@ class Encoder1D(nn.Module): x = self.mid_block(x) x = self.conv_out(x) return x - -class Encoder1D_(nn.Module): - """一维编码器""" + +class Encoder1D(nn.Module): def __init__( self, - in_channels: int = 3, - out_channels: int = 256, - block_out_channels: Tuple[int] = (64, 128, 256), - layers_per_block: int = 2, + num_points: int = 2, # 输入点数 + in_channels: int = 3, # xyz坐标 + out_dim: int = 6, # 展平后维度 (2*3) ): super().__init__() - self.conv_in = nn.Conv1d(in_channels, block_out_channels[0], 3, padding=1) - - self.down_blocks = nn.ModuleList([]) - in_ch = block_out_channels[0] - for out_ch in block_out_channels: - block = [] - for _ in range(layers_per_block): - block.append(ResConvBlock(in_ch, out_ch, out_ch)) - in_ch = out_ch - if out_ch != block_out_channels[-1]: - block.append(nn.AvgPool1d(2)) - self.down_blocks.append(nn.Sequential(*block)) - - self.mid_block = UNetMidBlock1D( - in_channels=block_out_channels[-1], - mid_channels=block_out_channels[-1], - ) + self.num_points = num_points + self.in_channels = in_channels + self.out_dim = out_dim - self.conv_out = nn.Sequential( - nn.GroupNorm(32, block_out_channels[-1]), - nn.SiLU(), - nn.Conv1d(block_out_channels[-1], out_channels, 3, padding=1), - ) - + assert num_points * in_channels == out_dim, \ + f"Flatten dimension {out_dim} must equal num_points({num_points}) * in_channels({in_channels})" + def forward(self, x): - x = self.conv_in(x) - for block in self.down_blocks: - x = block(x) - x = self.mid_block(x) - x = self.conv_out(x) + """ + Args: + x: [B, F, E, num_points, channels] + Returns: + x: [B, F, E, out_dim] + """ + orig_shape = x.shape[:-2] # [B, F, E] + + # 重塑以处理所有批次 + x = x.reshape(-1, self.num_points, self.in_channels) # [B*F*E, num_points, channels] + + # 展平点和通道维度 + x = x.reshape(x.size(0), -1) # [B*F*E, num_points*channels] + + # 恢复批次维度 + x = x.reshape(*orig_shape, self.out_dim) # [B, F, E, out_dim] + return x class BRepFeatureEmbedder(nn.Module): @@ -177,62 +170,37 @@ class BRepFeatureEmbedder(nn.Module): logger.info(f" num_edge_points: {self.num_edge_points}") logger.info(f" embed_dim: {self.embed_dim}") logger.info(f" use_cf: {self.use_cf}") - - # Transformer编码器层 - layer = nn.TransformerEncoderLayer( - d_model=self.embed_dim, - nhead=8, # 从12减少到8,使每个head的维度更大 - norm_first=True, # 改为True,先进行归一化 - dim_feedforward=self.embed_dim * 4, # 增大FFN维度 - dropout=0.1, - activation=F.gelu # 使用GELU激活函数 - ) - - # 添加初始化方法 - def _init_weights(module): - if isinstance(module, nn.Linear): - torch.nn.init.xavier_uniform_(module.weight, gain=1/math.sqrt(2)) - if module.bias is not None: - torch.nn.init.zeros_(module.bias) - elif isinstance(module, nn.LayerNorm): - torch.nn.init.ones_(module.weight) - torch.nn.init.zeros_(module.bias) - - self.transformer = nn.TransformerEncoder( - layer, - num_layers=6, # 从12减少到6层 - norm=nn.LayerNorm(self.embed_dim), - enable_nested_tensor=False - ) - - # 应用初始化 - self.transformer.apply(_init_weights) - - # 添加位置编码 - self.pos_embedding = nn.Parameter(torch.randn(1, 1000, self.embed_dim) * 0.02) # 最大序列长度设为1000 + + surf_out_dim = 3*self.num_surf_points # 面特征展平后维度,8*3=24 + edge_out_dim = 3*self.num_edge_points # 边特征展平后维度,2*3=6 + self.surfz_encoder = Encoder1D(num_points=self.num_surf_points,in_channels=3,out_dim=surf_out_dim) + self.edgez_encoder = Encoder1D(num_points=self.num_edge_points,in_channels=3,out_dim=edge_out_dim) # 修改为处理[num_points, 3]形状的输入 - self.surfz_embed = Encoder1D( - in_channels=3, - out_channels=self.embed_dim, - block_out_channels=(64, 128, self.embed_dim), - layers_per_block=2 - ) - - self.edgez_embed = Encoder1D( - in_channels=3, - out_channels=self.embed_dim, - block_out_channels=(64, 128, self.embed_dim), - layers_per_block=2 + self.surfz_embed = nn.Sequential( + self.surfz_encoder, # [B, 16, 3] -> [B, 48] + nn.Linear(surf_out_dim, self.embed_dim), # [B, 48] -> [B, embed_dim] + nn.LayerNorm(self.embed_dim), + nn.SiLU(), + nn.Linear(self.embed_dim, self.embed_dim), ) - - # 其他嵌入层保持不变 + self.surfp_embed = nn.Sequential( nn.Linear(6, self.embed_dim), nn.LayerNorm(self.embed_dim), nn.SiLU(), nn.Linear(self.embed_dim, self.embed_dim), ) + + + self.edgez_embed = nn.Sequential( + self.edgez_encoder, + nn.Linear(edge_out_dim, self.embed_dim), + nn.LayerNorm(self.embed_dim), + nn.SiLU(), + nn.Linear(self.embed_dim, self.embed_dim), + ) + self.edgep_embed = nn.Sequential( nn.Linear(6, self.embed_dim), @@ -243,86 +211,87 @@ class BRepFeatureEmbedder(nn.Module): # 修改vertp_embed的结构 self.vertp_embed = nn.Sequential( - nn.Linear(3, self.embed_dim // 2), - nn.LayerNorm(self.embed_dim // 2), - nn.ReLU(), - nn.Linear(self.embed_dim // 2, self.embed_dim) + nn.Linear(3, self.embed_dim), + nn.LayerNorm(self.embed_dim), + nn.SiLU(), + nn.Linear(self.embed_dim, self.embed_dim) ) # 添加一个额外的投影层 self.vertex_proj = nn.Linear(self.embed_dim, self.embed_dim) + + # 添加 transformer 初始化 + self.transformer = nn.TransformerEncoder( + encoder_layer=nn.TransformerEncoderLayer( + d_model=self.embed_dim, + nhead=8, # 注意力头数,通常是embed_dim的因子 + dim_feedforward=4*self.embed_dim, # 前馈网络维度,通常是embed_dim的4倍 + dropout=0.1, + activation='gelu', + batch_first=False # 因为我们用了transpose(0,1) + ), + num_layers=6 # transformer层数 + ) def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, **kwargs): - B, F, E, _, _ = edge_ncs.shape - - # 处理顶点特征 - vertex_features = vertex_pos.view(B*F*E*2, -1) # 展平顶点坐标 - vertex_embed = self.vertp_embed(vertex_features) - vertex_embed = self.vertex_proj(vertex_embed) # 添加投影 - vertex_embed = vertex_embed.view(B, F, E, 2, -1) # 恢复形状 - - # 确保顶点特征参与后续计算 - edge_features = torch.cat([ - self.edgep_embed(edge_pos.view(B*F*E, -1)).view(B, F, E, -1), - vertex_embed.mean(dim=3) # 将顶点特征平均池化 - ], dim=-1) - - # 1. 处理边特征 - # 重塑边点云以适应1D编码器 - edge_ncs = edge_ncs.reshape(B*F*E, -1, 3).transpose(1, 2) # [B*max_face*max_edge, 3, num_edge_points] - edge_embeds = self.edgez_embed(edge_ncs) # [B*max_face*max_edge, embed_dim, num_edge_points] - edge_embeds = edge_embeds.mean(dim=-1) # [B*max_face*max_edge, embed_dim] - edge_embeds = edge_embeds.reshape(B, F, E, -1) # [B, max_face, max_edge, embed_dim] - - # 2. 处理面特征 - surf_ncs = surf_ncs.reshape(B*F, -1, 3).transpose(1, 2) # [B*max_face, 3, num_surf_points] - surf_embeds = self.surfz_embed(surf_ncs) # [B*max_face, embed_dim, num_surf_points] - surf_embeds = surf_embeds.mean(dim=-1) # [B*max_face, embed_dim] - surf_embeds = surf_embeds.reshape(B, F, -1) # [B, max_face, embed_dim] - - # 3. 处理位置编码 - # 边位置编码 - edge_pos = edge_pos.reshape(B*F*E, -1) # [B*max_face*max_edge, 6] - edge_p_embeds = self.edgep_embed(edge_pos) # [B*max_face*max_edge, embed_dim] - edge_p_embeds = edge_p_embeds.reshape(B, F, E, -1) # [B, max_face, max_edge, embed_dim] - - # 面位置编码 - surf_p_embeds = self.surfp_embed(surf_pos) # [B, max_face, embed_dim] + """ + Args: + edge_ncs: [B, F, E, num_edge_points, 3] - 边点云 + edge_pos: [B, F, E, 6] - 边位置编码 + edge_mask: [B, F, E] - 边掩码 + surf_ncs: [B, F, num_surf_points, 3] - 面点云 + surf_pos: [B, F, 6] - 面位置编码 + vertex_pos: [B, F, E, 2, 3] - 顶点位置 + """ + B, F, E = edge_pos.shape[:3] + + # 1. 处理顶点特征 + vertex_embed = self.vertp_embed(vertex_pos[..., :3]) # [B, F, E, 2, embed_dim] + vertex_embed = self.vertex_proj(vertex_embed) # [B, F, E, 2, embed_dim] + vertex_embed = vertex_embed.mean(dim=3) # [B, F, E, embed_dim] + + # 2. 处理边特征 + logger.info(f"edge_ncs shape: {edge_ncs.shape}") + edge_embeds = self.edgez_embed(edge_ncs) # [B, F, E, embed_dim] + edge_p_embeds = self.edgep_embed(edge_pos) # [B, F, E, embed_dim] + + # 3. 处理面特征 + surf_embeds = self.surfz_embed(surf_ncs) # [B, F, embed_dim] + surf_p_embeds = self.surfp_embed(surf_pos) # [B, F, embed_dim] # 4. 组合特征 if self.use_cf: - # 边特征 - edge_features = edge_embeds + edge_p_embeds # [B, max_face, max_edge, embed_dim] - edge_features = edge_features.reshape(B, F*E, -1) # [B, max_face*max_edge, embed_dim] + # 组合边特征 + edge_features = edge_embeds + edge_p_embeds + vertex_embed # [B, F, E, embed_dim] + edge_features = edge_features.reshape(B, F*E, -1) # [B, F*E, embed_dim] - # 面特征 - surf_features = surf_embeds + surf_p_embeds # [B, max_face, embed_dim] + # 组合面特征 + surf_features = surf_embeds + surf_p_embeds # [B, F, embed_dim] - # 组合所有特征 + # 拼接所有特征 embeds = torch.cat([ - edge_features, # [B, max_face*max_edge, embed_dim] - surf_features # [B, max_face, embed_dim] - ], dim=1) # [B, max_face*(max_edge+1), embed_dim] + edge_features, # [B, F*E, embed_dim] + surf_features # [B, F, embed_dim] + ], dim=1) # [B, F*E+F, embed_dim] else: # 只使用位置编码 - edge_features = edge_p_embeds.reshape(B, F*E, -1) # [B, max_face*max_edge, embed_dim] + edge_features = edge_p_embeds.reshape(B, F*E, -1) # [B, F*E, embed_dim] embeds = torch.cat([ - edge_features, # [B, max_face*max_edge, embed_dim] - surf_p_embeds # [B, max_face, embed_dim] - ], dim=1) # [B, max_face*(max_edge+1), embed_dim] + edge_features, # [B, F*E, embed_dim] + surf_p_embeds # [B, F, embed_dim] + ], dim=1) # [B, F*E+F, embed_dim] # 5. 处理掩码 if edge_mask is not None: - # 扩展掩码以匹配特征维度 - edge_mask = edge_mask.reshape(B, -1) # [B, max_face*max_edge] - surf_mask = torch.ones(B, F, device=edge_mask.device, dtype=torch.bool) # [B, max_face] - mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, max_face*(max_edge+1)] + edge_mask = edge_mask.reshape(B, -1) # [B, F*E] + surf_mask = torch.ones(B, F, device=edge_mask.device, dtype=torch.bool) # [B, F] + mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, F*E+F] else: mask = None # 6. Transformer处理 output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask) - return output.transpose(0, 1) # 确保输出维度为 [B, seq_len, embed_dim] + return output.transpose(0, 1) # [B, F*E+F, embed_dim] class SDFTransformer(nn.Module): """SDF Transformer编码器"""