import math import torch import torch.nn as nn import torch.nn.functional as F from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union class ResConvBlock(nn.Module): """残差卷积块""" def __init__(self, in_channels: int, mid_channels: int, out_channels: int): super().__init__() self.conv1 = nn.Conv1d(in_channels, mid_channels, 3, padding=1) self.norm1 = nn.GroupNorm(32, mid_channels) self.conv2 = nn.Conv1d(mid_channels, out_channels, 3, padding=1) self.norm2 = nn.GroupNorm(32, out_channels) self.act = nn.SiLU() self.conv_shortcut = None if in_channels != out_channels: self.conv_shortcut = nn.Conv1d(in_channels, out_channels, 1) def forward(self, x): residual = x x = self.conv1(x) x = self.norm1(x) x = self.act(x) x = self.conv2(x) x = self.norm2(x) if self.conv_shortcut is not None: residual = self.conv_shortcut(residual) return self.act(x + residual) class SelfAttention1d(nn.Module): """一维自注意力层""" def __init__(self, channels: int, num_head_channels: int): super().__init__() self.num_heads = channels // num_head_channels self.scale = num_head_channels ** -0.5 self.qkv = nn.Conv1d(channels, channels * 3, 1) self.proj = nn.Conv1d(channels, channels, 1) def forward(self, x): b, c, l = x.shape qkv = self.qkv(x).reshape(b, 3, self.num_heads, c // self.num_heads, l) q, k, v = qkv.unbind(1) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) x = (attn @ v).reshape(b, c, l) x = self.proj(x) return x class UNetMidBlock1D(nn.Module): """U-Net中间块""" def __init__(self, in_channels: int, mid_channels: int): super().__init__() self.resnets = nn.ModuleList([ ResConvBlock(in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, in_channels), ]) self.attentions = nn.ModuleList([ SelfAttention1d(mid_channels, mid_channels // 32) for _ in range(3) ]) def forward(self, x): for attn, resnet in zip(self.attentions, self.resnets): x = resnet(x) x = attn(x) return x class Encoder1D(nn.Module): """一维编码器""" def __init__( self, in_channels: int = 3, out_channels: int = 256, block_out_channels: Tuple[int] = (64, 128, 256), layers_per_block: int = 2, ): super().__init__() self.conv_in = nn.Conv1d(in_channels, block_out_channels[0], 3, padding=1) self.down_blocks = nn.ModuleList([]) in_ch = block_out_channels[0] for out_ch in block_out_channels: block = [] for _ in range(layers_per_block): block.append(ResConvBlock(in_ch, out_ch, out_ch)) in_ch = out_ch if out_ch != block_out_channels[-1]: block.append(nn.AvgPool1d(2)) self.down_blocks.append(nn.Sequential(*block)) self.mid_block = UNetMidBlock1D( in_channels=block_out_channels[-1], mid_channels=block_out_channels[-1], ) self.conv_out = nn.Sequential( nn.GroupNorm(32, block_out_channels[-1]), nn.SiLU(), nn.Conv1d(block_out_channels[-1], out_channels, 3, padding=1), ) def forward(self, x): x = self.conv_in(x) for block in self.down_blocks: x = block(x) x = self.mid_block(x) x = self.conv_out(x) return x class BRepFeatureEmbedder(nn.Module): """B-rep特征嵌入器""" def __init__(self, use_cf: bool = True): super().__init__() self.embed_dim = 768 self.use_cf = use_cf layer = nn.TransformerEncoderLayer( d_model=self.embed_dim, nhead=12, norm_first=False, dim_feedforward=1024, dropout=0.1 ) self.transformer = nn.TransformerEncoder( layer, num_layers=12, norm=nn.LayerNorm(self.embed_dim), enable_nested_tensor=False # 添加这个参数 ) self.surfz_embed = nn.Sequential( nn.Linear(3*16, self.embed_dim), nn.LayerNorm(self.embed_dim), nn.SiLU(), nn.Linear(self.embed_dim, self.embed_dim), ) self.edgez_embed = nn.Sequential( nn.Linear(3*4, self.embed_dim), nn.LayerNorm(self.embed_dim), nn.SiLU(), nn.Linear(self.embed_dim, self.embed_dim), ) self.surfp_embed = nn.Sequential( nn.Linear(6, self.embed_dim), nn.LayerNorm(self.embed_dim), nn.SiLU(), nn.Linear(self.embed_dim, self.embed_dim), ) self.edgep_embed = nn.Sequential( nn.Linear(6, self.embed_dim), nn.LayerNorm(self.embed_dim), nn.SiLU(), nn.Linear(self.embed_dim, self.embed_dim), ) self.vertp_embed = nn.Sequential( nn.Linear(6, self.embed_dim), nn.LayerNorm(self.embed_dim), nn.SiLU(), nn.Linear(self.embed_dim, self.embed_dim), ) def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, mask=None): # 特征嵌入 surf_embeds = self.surfz_embed(surf_z) edge_embeds = self.edgez_embed(edge_z) # 点嵌入 surf_p_embeds = self.surfp_embed(surf_p) edge_p_embeds = self.edgep_embed(edge_p) vert_p_embeds = self.vertp_embed(vert_p) # 组合所有嵌入 if self.use_cf: embeds = torch.cat([ surf_embeds + surf_p_embeds, edge_embeds + edge_p_embeds, vert_p_embeds ], dim=1) else: embeds = torch.cat([ surf_p_embeds, edge_p_embeds, vert_p_embeds ], dim=1) output = self.transformer(embeds, src_key_padding_mask=mask) return output class SDFTransformer(nn.Module): """SDF Transformer编码器""" def __init__(self, embed_dim: int = 768, num_layers: int = 6): super().__init__() layer = nn.TransformerEncoderLayer( d_model=embed_dim, nhead=8, dim_feedforward=1024, dropout=0.1, batch_first=True, norm_first=False # 修改这里:设置为False ) self.transformer = nn.TransformerEncoder(layer, num_layers) def forward(self, x, mask=None): return self.transformer(x, src_key_padding_mask=mask) class SDFHead(nn.Module): """SDF预测头""" def __init__(self, embed_dim: int = 768*2): super().__init__() self.mlp = nn.Sequential( nn.Linear(embed_dim, embed_dim//2), nn.LayerNorm(embed_dim//2), nn.ReLU(), nn.Linear(embed_dim//2, embed_dim//4), nn.LayerNorm(embed_dim//4), nn.ReLU(), nn.Linear(embed_dim//4, 1), nn.Tanh() ) def forward(self, x): return self.mlp(x) class BRepToSDF(nn.Module): def __init__( self, brep_feature_dim: int = 48, use_cf: bool = True, embed_dim: int = 768, latent_dim: int = 256 ): super().__init__() self.embed_dim = embed_dim # 1. 查询点编码器 self.query_encoder = nn.Sequential( nn.Linear(3, embed_dim//4), nn.LayerNorm(embed_dim//4), nn.ReLU(), nn.Linear(embed_dim//4, embed_dim//2), nn.LayerNorm(embed_dim//2), nn.ReLU(), nn.Linear(embed_dim//2, embed_dim) ) # 2. B-rep特征编码器 self.feature_embedder = BRepFeatureEmbedder(use_cf=use_cf) # 3. 特征融合Transformer self.transformer = SDFTransformer( embed_dim=embed_dim, num_layers=6 ) # 4. SDF预测头 self.sdf_head = SDFHead(embed_dim=embed_dim*2) def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, query_points, mask=None): """ Args: surf_z: 表面特征 [B, N, 48] edge_z: 边特征 [B, M, 12] surf_p: 表面点 [B, N, 6] edge_p: 边点 [B, M, 6] vert_p: 顶点点 [B, K, 6] query_points: 查询点 [B, Q, 3] mask: 注意力掩码 Returns: sdf: [B, Q, 1] """ B, Q, _ = query_points.shape # 1. B-rep特征嵌入 brep_features = self.feature_embedder( surf_z, edge_z, surf_p, edge_p, vert_p, mask ) # [B, N+M+K, embed_dim] # 2. 查询点编码 query_features = self.query_encoder(query_points) # [B, Q, embed_dim] # 3. 提取全局特征 global_features = brep_features.mean(dim=1) # [B, embed_dim] # 4. 为每个查询点准备特征 expanded_features = global_features.unsqueeze(1).expand(-1, Q, -1) # [B, Q, embed_dim] # 5. 连接查询点特征和全局特征 combined_features = torch.cat([ expanded_features, # [B, Q, embed_dim] query_features # [B, Q, embed_dim] ], dim=-1) # [B, Q, embed_dim*2] # 6. SDF预测 sdf = self.sdf_head(combined_features) # [B, Q, 1] return sdf def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): """SDF损失函数""" # L1损失 l1_loss = F.l1_loss(pred_sdf, gt_sdf) # 梯度约束损失 grad = torch.autograd.grad( pred_sdf.sum(), points, create_graph=True )[0] grad_constraint = F.mse_loss( torch.norm(grad, dim=-1), torch.ones_like(pred_sdf.squeeze(-1)) ) return l1_loss + grad_weight * grad_constraint def main(): # 初始化模型 model = BRepToSDF( brep_feature_dim=48, use_cf=True, embed_dim=768, latent_dim=256 ) # 示例输入 batch_size = 4 num_surfs = 10 num_edges = 20 num_verts = 8 num_queries = 1000 # 生成示例数据 surf_z = torch.randn(batch_size, num_surfs, 48) edge_z = torch.randn(batch_size, num_edges, 12) surf_p = torch.randn(batch_size, num_surfs, 6) edge_p = torch.randn(batch_size, num_edges, 6) vert_p = torch.randn(batch_size, num_verts, 6) query_points = torch.randn(batch_size, num_queries, 3) # 前向传播 sdf = model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) print(f"Output SDF shape: {sdf.shape}") if __name__ == "__main__": main()