You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

430 lines
14 KiB

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from brep2sdf.config.default_config import get_default_config
class ResConvBlock(nn.Module):
"""残差卷积块"""
def __init__(self, in_channels: int, mid_channels: int, out_channels: int):
super().__init__()
self.conv1 = nn.Conv1d(in_channels, mid_channels, 3, padding=1)
self.norm1 = nn.GroupNorm(32, mid_channels)
self.conv2 = nn.Conv1d(mid_channels, out_channels, 3, padding=1)
self.norm2 = nn.GroupNorm(32, out_channels)
self.act = nn.SiLU()
self.conv_shortcut = None
if in_channels != out_channels:
self.conv_shortcut = nn.Conv1d(in_channels, out_channels, 1)
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = self.act(x)
x = self.conv2(x)
x = self.norm2(x)
if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)
return self.act(x + residual)
class SelfAttention1d(nn.Module):
"""一维自注意力层"""
def __init__(self, channels: int, num_head_channels: int):
super().__init__()
self.num_heads = channels // num_head_channels
self.scale = num_head_channels ** -0.5
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.proj = nn.Conv1d(channels, channels, 1)
def forward(self, x):
b, c, l = x.shape
qkv = self.qkv(x).reshape(b, 3, self.num_heads, c // self.num_heads, l)
q, k, v = qkv.unbind(1)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).reshape(b, c, l)
x = self.proj(x)
return x
class UNetMidBlock1D(nn.Module):
"""U-Net中间块"""
def __init__(self, in_channels: int, mid_channels: int):
super().__init__()
self.resnets = nn.ModuleList([
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, in_channels),
])
self.attentions = nn.ModuleList([
SelfAttention1d(mid_channels, mid_channels // 32)
for _ in range(3)
])
def forward(self, x):
for attn, resnet in zip(self.attentions, self.resnets):
x = resnet(x)
x = attn(x)
return x
class Encoder1D(nn.Module):
"""一维编码器"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 256,
block_out_channels: Tuple[int] = (64, 128, 256),
layers_per_block: int = 2,
):
super().__init__()
self.conv_in = nn.Conv1d(in_channels, block_out_channels[0], 3, padding=1)
self.down_blocks = nn.ModuleList([])
in_ch = block_out_channels[0]
for out_ch in block_out_channels:
block = []
for _ in range(layers_per_block):
block.append(ResConvBlock(in_ch, out_ch, out_ch))
in_ch = out_ch
if out_ch != block_out_channels[-1]:
block.append(nn.AvgPool1d(2))
self.down_blocks.append(nn.Sequential(*block))
self.mid_block = UNetMidBlock1D(
in_channels=block_out_channels[-1],
mid_channels=block_out_channels[-1],
)
self.conv_out = nn.Sequential(
nn.GroupNorm(32, block_out_channels[-1]),
nn.SiLU(),
nn.Conv1d(block_out_channels[-1], out_channels, 3, padding=1),
)
def forward(self, x):
x = self.conv_in(x)
for block in self.down_blocks:
x = block(x)
x = self.mid_block(x)
x = self.conv_out(x)
return x
class BRepFeatureEmbedder(nn.Module):
"""B-rep特征嵌入器"""
def __init__(self, use_cf: bool = True):
super().__init__()
# 获取配置
self.config = get_default_config()
self.embed_dim = 768
self.use_cf = use_cf
# 使用配置中的采样点数
self.num_surf_points = self.config.model.num_surf_points # 16
self.num_edge_points = self.config.model.num_edge_points # 4
# Transformer编码器层
layer = nn.TransformerEncoderLayer(
d_model=self.embed_dim,
nhead=12,
norm_first=False,
dim_feedforward=1024,
dropout=0.1
)
self.transformer = nn.TransformerEncoder(
layer,
num_layers=12,
norm=nn.LayerNorm(self.embed_dim),
enable_nested_tensor=False
)
# 修改为处理[num_points, 3]形状的输入
self.surfz_embed = Encoder1D(
in_channels=3,
out_channels=self.embed_dim,
block_out_channels=(64, 128, 256),
layers_per_block=2
)
self.edgez_embed = Encoder1D(
in_channels=3,
out_channels=self.embed_dim,
block_out_channels=(64, 128, 256),
layers_per_block=2
)
# 其他嵌入层保持不变
self.surfp_embed = nn.Sequential(
nn.Linear(6, self.embed_dim),
nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim),
)
self.edgep_embed = nn.Sequential(
nn.Linear(6, self.embed_dim),
nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim),
)
self.vertp_embed = nn.Sequential(
nn.Linear(6, self.embed_dim),
nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim),
)
def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, mask=None):
"""
Args:
surf_z: 表面点云 [B, N, num_surf_points, 3]
edge_z: 边点云 [B, M, num_edge_points, 3]
surf_p: 表面点 [B, N, 6]
edge_p: 边点 [B, M, 6]
vert_p: 顶点点 [B, K, 6]
mask: 注意力掩码
"""
# 获取批次大小和其他维度
B = surf_z.size(0)
N = surf_z.size(1)
M = edge_z.size(1)
K = vert_p.size(1)
# 重塑点云数据用于1D编码器
surf_z = surf_z.reshape(B*N, self.num_surf_points, 3).transpose(1, 2) # [B*N, 3, num_surf_points]
edge_z = edge_z.reshape(B*M, self.num_edge_points, 3).transpose(1, 2) # [B*M, 3, num_edge_points]
# 特征嵌入
surf_embeds = self.surfz_embed(surf_z) # [B*N, embed_dim, num_points]
edge_embeds = self.edgez_embed(edge_z) # [B*M, embed_dim, num_points]
# 全局池化得到每个面/边的特征
surf_embeds = surf_embeds.mean(dim=-1) # [B*N, embed_dim]
edge_embeds = edge_embeds.mean(dim=-1) # [B*M, embed_dim]
# 重塑回批次维度
surf_embeds = surf_embeds.reshape(B, N, -1) # [B, N, embed_dim]
edge_embeds = edge_embeds.reshape(B, M, -1) # [B, M, embed_dim]
# 点嵌入
surf_p_embeds = self.surfp_embed(surf_p) # [B, N, embed_dim]
edge_p_embeds = self.edgep_embed(edge_p) # [B, M, embed_dim]
vert_p_embeds = self.vertp_embed(vert_p) # [B, K, embed_dim]
# 组合所有嵌入
if self.use_cf:
embeds = torch.cat([
surf_embeds + surf_p_embeds,
edge_embeds + edge_p_embeds,
vert_p_embeds
], dim=1) # [B, N+M+K, embed_dim]
else:
embeds = torch.cat([
surf_p_embeds,
edge_p_embeds,
vert_p_embeds
], dim=1) # [B, N+M+K, embed_dim]
output = self.transformer(embeds, src_key_padding_mask=mask)
return output
class SDFTransformer(nn.Module):
"""SDF Transformer编码器"""
def __init__(self, embed_dim: int = 768, num_layers: int = 6):
super().__init__()
layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=8,
dim_feedforward=1024,
dropout=0.1,
batch_first=True,
norm_first=False # 修改这里:设置为False
)
self.transformer = nn.TransformerEncoder(layer, num_layers)
def forward(self, x, mask=None):
return self.transformer(x, src_key_padding_mask=mask)
class SDFHead(nn.Module):
"""SDF预测头"""
def __init__(self, embed_dim: int = 768*2):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(embed_dim, embed_dim//2),
nn.LayerNorm(embed_dim//2),
nn.ReLU(),
nn.Linear(embed_dim//2, embed_dim//4),
nn.LayerNorm(embed_dim//4),
nn.ReLU(),
nn.Linear(embed_dim//4, 1),
nn.Tanh()
)
def forward(self, x):
return self.mlp(x)
class BRepToSDF(nn.Module):
def __init__(
self,
brep_feature_dim: int = 48,
use_cf: bool = True,
embed_dim: int = 768,
latent_dim: int = 256
):
super().__init__()
# 获取配置
self.config = get_default_config()
self.embed_dim = embed_dim
# 1. 查询点编码器
self.query_encoder = nn.Sequential(
nn.Linear(3, embed_dim//4),
nn.LayerNorm(embed_dim//4),
nn.ReLU(),
nn.Linear(embed_dim//4, embed_dim//2),
nn.LayerNorm(embed_dim//2),
nn.ReLU(),
nn.Linear(embed_dim//2, embed_dim)
)
# 2. B-rep特征编码器
self.feature_embedder = BRepFeatureEmbedder(use_cf=use_cf)
# 3. 特征融合Transformer
self.transformer = SDFTransformer(
embed_dim=embed_dim,
num_layers=6
)
# 4. SDF预测头
self.sdf_head = SDFHead(embed_dim=embed_dim*2)
def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, query_points, mask=None):
"""
Args:
surf_z: 表面特征 [B, N, 48]
edge_z: 边特征 [B, M, 12]
surf_p: 表面点 [B, N, 6]
edge_p: 边点 [B, M, 6]
vert_p: 顶点点 [B, K, 6]
query_points: 查询点 [B, Q, 3]
mask: 注意力掩码
Returns:
sdf: [B, Q, 1]
"""
B, Q, _ = query_points.shape
# 1. B-rep特征嵌入
brep_features = self.feature_embedder(
surf_z, edge_z, surf_p, edge_p, vert_p, mask
) # [B, N+M+K, embed_dim]
# 2. 查询点编码
query_features = self.query_encoder(query_points) # [B, Q, embed_dim]
# 3. 提取全局特征
global_features = brep_features.mean(dim=1) # [B, embed_dim]
# 4. 为每个查询点准备特征
expanded_features = global_features.unsqueeze(1).expand(-1, Q, -1) # [B, Q, embed_dim]
# 5. 连接查询点特征和全局特征
combined_features = torch.cat([
expanded_features, # [B, Q, embed_dim]
query_features # [B, Q, embed_dim]
], dim=-1) # [B, Q, embed_dim*2]
# 6. SDF预测
sdf = self.sdf_head(combined_features) # [B, Q, 1]
return sdf
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1):
"""SDF损失函数"""
# L1损失
l1_loss = F.l1_loss(pred_sdf, gt_sdf)
# 梯度约束损失
grad = torch.autograd.grad(
pred_sdf.sum(),
points,
create_graph=True
)[0]
grad_constraint = F.mse_loss(
torch.norm(grad, dim=-1),
torch.ones_like(pred_sdf.squeeze(-1))
)
return l1_loss + grad_weight * grad_constraint
def main():
# 获取配置
config = get_default_config()
# 从配置初始化模型
model = BRepToSDF(
brep_feature_dim=config.model.brep_feature_dim, # 48
use_cf=config.model.use_cf, # True
embed_dim=config.model.embed_dim, # 768
latent_dim=config.model.latent_dim # 256
)
# 从配置获取数据参数
batch_size = config.train.batch_size # 32
num_surfs = config.data.max_face # 64
num_edges = config.data.max_edge # 64
num_verts = 8 # 顶点数保持固定
num_queries = 1000 # 查询点数保持固定
# 生成示例数据
surf_z = torch.randn(
batch_size,
num_surfs,
config.model.num_surf_points, # 16
3
) # [B, N, num_surf_points, 3]
edge_z = torch.randn(
batch_size,
num_edges,
config.model.num_edge_points, # 4
3
) # [B, M, num_edge_points, 3]
# 其他输入
surf_p = torch.randn(batch_size, num_surfs, 6)
edge_p = torch.randn(batch_size, num_edges, 6)
vert_p = torch.randn(batch_size, num_verts, 6)
query_points = torch.randn(batch_size, num_queries, 3)
# 前向传播
sdf = model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points)
# 打印形状信息和配置信息
print("\nConfiguration:")
print(f"Batch Size: {batch_size}")
print(f"Embed Dimension: {config.model.embed_dim}")
print(f"Surface Points: {config.model.num_surf_points}")
print(f"Edge Points: {config.model.num_edge_points}")
print(f"Max Faces: {config.data.max_face}")
print(f"Max Edges: {config.data.max_edge}")
print("\nInput shapes:")
print(f"surf_z: {surf_z.shape}") # [32, 64, 16, 3]
print(f"edge_z: {edge_z.shape}") # [32, 64, 4, 3]
print(f"surf_p: {surf_p.shape}") # [32, 64, 6]
print(f"edge_p: {edge_p.shape}") # [32, 64, 6]
print(f"vert_p: {vert_p.shape}") # [32, 8, 6]
print(f"query_points: {query_points.shape}") # [32, 1000, 3]
print(f"\nOutput SDF shape: {sdf.shape}") # [32, 1000, 1]
if __name__ == "__main__":
main()