You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
522 lines
19 KiB
522 lines
19 KiB
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from dataclasses import dataclass
|
|
from typing import Dict, Optional, Tuple, Union
|
|
from brep2sdf.config.default_config import get_default_config
|
|
from brep2sdf.utils.logger import logger
|
|
|
|
class ResConvBlock(nn.Module):
|
|
"""残差卷积块"""
|
|
def __init__(self, in_channels: int, mid_channels: int, out_channels: int):
|
|
super().__init__()
|
|
self.conv1 = nn.Conv1d(in_channels, mid_channels, 3, padding=1)
|
|
self.norm1 = nn.GroupNorm(32, mid_channels)
|
|
self.conv2 = nn.Conv1d(mid_channels, out_channels, 3, padding=1)
|
|
self.norm2 = nn.GroupNorm(32, out_channels)
|
|
self.act = nn.SiLU()
|
|
|
|
self.conv_shortcut = None
|
|
if in_channels != out_channels:
|
|
self.conv_shortcut = nn.Conv1d(in_channels, out_channels, 1)
|
|
|
|
def forward(self, x):
|
|
residual = x
|
|
x = self.conv1(x)
|
|
x = self.norm1(x)
|
|
x = self.act(x)
|
|
x = self.conv2(x)
|
|
x = self.norm2(x)
|
|
|
|
if self.conv_shortcut is not None:
|
|
residual = self.conv_shortcut(residual)
|
|
|
|
return self.act(x + residual)
|
|
|
|
class SelfAttention1d(nn.Module):
|
|
"""一维自注意力层"""
|
|
def __init__(self, channels: int, num_head_channels: int):
|
|
super().__init__()
|
|
self.num_heads = channels // num_head_channels
|
|
self.scale = num_head_channels ** -0.5
|
|
|
|
self.qkv = nn.Conv1d(channels, channels * 3, 1)
|
|
self.proj = nn.Conv1d(channels, channels, 1)
|
|
|
|
def forward(self, x):
|
|
b, c, l = x.shape
|
|
qkv = self.qkv(x).reshape(b, 3, self.num_heads, c // self.num_heads, l)
|
|
q, k, v = qkv.unbind(1)
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
x = (attn @ v).reshape(b, c, l)
|
|
x = self.proj(x)
|
|
return x
|
|
|
|
class UNetMidBlock1D(nn.Module):
|
|
"""U-Net中间块"""
|
|
def __init__(self, in_channels: int, mid_channels: int):
|
|
super().__init__()
|
|
self.resnets = nn.ModuleList([
|
|
ResConvBlock(in_channels, mid_channels, mid_channels),
|
|
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
|
ResConvBlock(mid_channels, mid_channels, in_channels),
|
|
])
|
|
self.attentions = nn.ModuleList([
|
|
SelfAttention1d(mid_channels, mid_channels // 32)
|
|
for _ in range(3)
|
|
])
|
|
|
|
def forward(self, x):
|
|
for attn, resnet in zip(self.attentions, self.resnets):
|
|
x = resnet(x)
|
|
x = attn(x)
|
|
return x
|
|
|
|
class Encoder1D(nn.Module):
|
|
"""一维编码器"""
|
|
def __init__(
|
|
self,
|
|
in_channels: int = 3,
|
|
out_channels: int = 256,
|
|
block_out_channels: Tuple[int] = (64, 128, 256),
|
|
layers_per_block: int = 2,
|
|
):
|
|
super().__init__()
|
|
self.conv_in = nn.Conv1d(in_channels, block_out_channels[0], 3, padding=1)
|
|
|
|
self.down_blocks = nn.ModuleList([])
|
|
in_ch = block_out_channels[0]
|
|
for out_ch in block_out_channels:
|
|
block = []
|
|
for _ in range(layers_per_block):
|
|
block.append(ResConvBlock(in_ch, out_ch, out_ch))
|
|
in_ch = out_ch
|
|
if out_ch != block_out_channels[-1]:
|
|
block.append(nn.AvgPool1d(2))
|
|
self.down_blocks.append(nn.Sequential(*block))
|
|
|
|
self.mid_block = UNetMidBlock1D(
|
|
in_channels=block_out_channels[-1],
|
|
mid_channels=block_out_channels[-1],
|
|
)
|
|
|
|
self.conv_out = nn.Sequential(
|
|
nn.GroupNorm(32, block_out_channels[-1]),
|
|
nn.SiLU(),
|
|
nn.Conv1d(block_out_channels[-1], out_channels, 3, padding=1),
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.conv_in(x)
|
|
for block in self.down_blocks:
|
|
x = block(x)
|
|
x = self.mid_block(x)
|
|
x = self.conv_out(x)
|
|
return x
|
|
|
|
class BRepFeatureEmbedder(nn.Module):
|
|
"""B-rep特征嵌入器"""
|
|
def __init__(self, config=None):
|
|
super().__init__()
|
|
if config is None:
|
|
self.config = get_default_config()
|
|
else:
|
|
self.config = config
|
|
|
|
self.num_surf_points = self.config.model.num_surf_points
|
|
self.num_edge_points = self.config.model.num_edge_points
|
|
self.embed_dim = self.config.model.embed_dim
|
|
self.use_cf = self.config.model.use_cf
|
|
|
|
# 打印初始化信息
|
|
logger.info(f"BRepFeatureEmbedder config:")
|
|
logger.info(f" num_surf_points: {self.num_surf_points}")
|
|
logger.info(f" num_edge_points: {self.num_edge_points}")
|
|
logger.info(f" embed_dim: {self.embed_dim}")
|
|
logger.info(f" use_cf: {self.use_cf}")
|
|
|
|
# Transformer编码器层
|
|
layer = nn.TransformerEncoderLayer(
|
|
d_model=self.embed_dim,
|
|
nhead=12,
|
|
norm_first=False,
|
|
dim_feedforward=1024,
|
|
dropout=0.1
|
|
)
|
|
self.transformer = nn.TransformerEncoder(
|
|
layer,
|
|
num_layers=12,
|
|
norm=nn.LayerNorm(self.embed_dim),
|
|
enable_nested_tensor=False
|
|
)
|
|
|
|
# 修改为处理[num_points, 3]形状的输入
|
|
self.surfz_embed = Encoder1D(
|
|
in_channels=3,
|
|
out_channels=self.embed_dim,
|
|
block_out_channels=(64, 128, 256),
|
|
layers_per_block=2
|
|
)
|
|
|
|
self.edgez_embed = Encoder1D(
|
|
in_channels=3,
|
|
out_channels=self.embed_dim,
|
|
block_out_channels=(64, 128, 256),
|
|
layers_per_block=2
|
|
)
|
|
|
|
# 其他嵌入层保持不变
|
|
self.surfp_embed = nn.Sequential(
|
|
nn.Linear(6, self.embed_dim),
|
|
nn.LayerNorm(self.embed_dim),
|
|
nn.SiLU(),
|
|
nn.Linear(self.embed_dim, self.embed_dim),
|
|
)
|
|
|
|
self.edgep_embed = nn.Sequential(
|
|
nn.Linear(6, self.embed_dim),
|
|
nn.LayerNorm(self.embed_dim),
|
|
nn.SiLU(),
|
|
nn.Linear(self.embed_dim, self.embed_dim),
|
|
)
|
|
|
|
self.vertp_embed = nn.Sequential(
|
|
nn.Linear(6, self.embed_dim),
|
|
nn.LayerNorm(self.embed_dim),
|
|
nn.SiLU(),
|
|
nn.Linear(self.embed_dim, self.embed_dim),
|
|
)
|
|
|
|
def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, data_class=None):
|
|
"""B-rep特征嵌入器的前向传播
|
|
|
|
Args:
|
|
edge_ncs: 边归一化特征 [B, max_face, max_edge, num_edge_points, 3]
|
|
edge_pos: 边位置 [B, max_face, max_edge, 6]
|
|
edge_mask: 边掩码 [B, max_face, max_edge]
|
|
surf_ncs: 面归一化特征 [B, max_face, num_surf_points, 3]
|
|
surf_pos: 面位置 [B, max_face, 6]
|
|
vertex_pos: 顶点位置 [B, max_face, max_edge, 2, 3]
|
|
|
|
Returns:
|
|
embeds: [B, max_face*(max_edge+1), embed_dim]
|
|
"""
|
|
B = self.config.train.batch_size
|
|
max_face = self.config.data.max_face
|
|
max_edge = self.config.data.max_edge
|
|
|
|
try:
|
|
# 1. 处理边特征
|
|
# 重塑边点云以适应1D编码器
|
|
edge_ncs = edge_ncs.reshape(B*max_face*max_edge, -1, 3).transpose(1, 2) # [B*max_face*max_edge, 3, num_edge_points]
|
|
edge_embeds = self.edgez_embed(edge_ncs) # [B*max_face*max_edge, embed_dim, num_edge_points]
|
|
edge_embeds = edge_embeds.mean(dim=-1) # [B*max_face*max_edge, embed_dim]
|
|
edge_embeds = edge_embeds.reshape(B, max_face, max_edge, -1) # [B, max_face, max_edge, embed_dim]
|
|
|
|
# 2. 处理面特征
|
|
surf_ncs = surf_ncs.reshape(B*max_face, -1, 3).transpose(1, 2) # [B*max_face, 3, num_surf_points]
|
|
surf_embeds = self.surfz_embed(surf_ncs) # [B*max_face, embed_dim, num_surf_points]
|
|
surf_embeds = surf_embeds.mean(dim=-1) # [B*max_face, embed_dim]
|
|
surf_embeds = surf_embeds.reshape(B, max_face, -1) # [B, max_face, embed_dim]
|
|
|
|
# 3. 处理位置编码
|
|
# 边位置编码
|
|
edge_pos = edge_pos.reshape(B*max_face*max_edge, -1) # [B*max_face*max_edge, 6]
|
|
edge_p_embeds = self.edgep_embed(edge_pos) # [B*max_face*max_edge, embed_dim]
|
|
edge_p_embeds = edge_p_embeds.reshape(B, max_face, max_edge, -1) # [B, max_face, max_edge, embed_dim]
|
|
|
|
# 面位置编码
|
|
surf_p_embeds = self.surfp_embed(surf_pos) # [B, max_face, embed_dim]
|
|
|
|
# 4. 组合特征
|
|
if self.use_cf:
|
|
# 边特征
|
|
edge_features = edge_embeds + edge_p_embeds # [B, max_face, max_edge, embed_dim]
|
|
edge_features = edge_features.reshape(B, max_face*max_edge, -1) # [B, max_face*max_edge, embed_dim]
|
|
|
|
# 面特征
|
|
surf_features = surf_embeds + surf_p_embeds # [B, max_face, embed_dim]
|
|
|
|
# 组合所有特征
|
|
embeds = torch.cat([
|
|
edge_features, # [B, max_face*max_edge, embed_dim]
|
|
surf_features # [B, max_face, embed_dim]
|
|
], dim=1) # [B, max_face*(max_edge+1), embed_dim]
|
|
else:
|
|
# 只使用位置编码
|
|
edge_features = edge_p_embeds.reshape(B, max_face*max_edge, -1) # [B, max_face*max_edge, embed_dim]
|
|
embeds = torch.cat([
|
|
edge_features, # [B, max_face*max_edge, embed_dim]
|
|
surf_p_embeds # [B, max_face, embed_dim]
|
|
], dim=1) # [B, max_face*(max_edge+1), embed_dim]
|
|
|
|
# 5. 处理掩码
|
|
if edge_mask is not None:
|
|
# 扩展掩码以匹配特征维度
|
|
edge_mask = edge_mask.reshape(B, -1) # [B, max_face*max_edge]
|
|
surf_mask = torch.ones(B, max_face, device=edge_mask.device, dtype=torch.bool) # [B, max_face]
|
|
mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, max_face*(max_edge+1)]
|
|
else:
|
|
mask = None
|
|
|
|
# 6. Transformer处理
|
|
output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask)
|
|
return output.transpose(0, 1) # 确保输出维度为 [B, seq_len, embed_dim]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in BRepFeatureEmbedder forward pass:")
|
|
logger.error(f" Error message: {str(e)}")
|
|
logger.error(f" Input shapes:")
|
|
logger.error(f" edge_ncs: {edge_ncs.shape}")
|
|
logger.error(f" edge_pos: {edge_pos.shape}")
|
|
logger.error(f" edge_mask: {edge_mask.shape}")
|
|
logger.error(f" surf_ncs: {surf_ncs.shape}")
|
|
logger.error(f" surf_pos: {surf_pos.shape}")
|
|
logger.error(f" vertex_pos: {vertex_pos.shape}")
|
|
raise
|
|
|
|
class SDFTransformer(nn.Module):
|
|
"""SDF Transformer编码器"""
|
|
def __init__(self, embed_dim: int = 768, num_layers: int = 6):
|
|
super().__init__()
|
|
layer = nn.TransformerEncoderLayer(
|
|
d_model=embed_dim,
|
|
nhead=8,
|
|
dim_feedforward=1024,
|
|
dropout=0.1,
|
|
batch_first=True,
|
|
norm_first=False # 修改这里:设置为False
|
|
)
|
|
self.transformer = nn.TransformerEncoder(layer, num_layers)
|
|
|
|
def forward(self, x, mask=None):
|
|
return self.transformer(x, src_key_padding_mask=mask)
|
|
|
|
class SDFHead(nn.Module):
|
|
"""SDF预测头"""
|
|
def __init__(self, embed_dim: int = 768*2):
|
|
super().__init__()
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(embed_dim, embed_dim//2),
|
|
nn.LayerNorm(embed_dim//2),
|
|
nn.ReLU(),
|
|
nn.Linear(embed_dim//2, embed_dim//4),
|
|
nn.LayerNorm(embed_dim//4),
|
|
nn.ReLU(),
|
|
nn.Linear(embed_dim//4, 1),
|
|
nn.Tanh()
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.mlp(x)
|
|
|
|
class BRepToSDF(nn.Module):
|
|
def __init__(
|
|
self,
|
|
brep_feature_dim: int = 48,
|
|
use_cf: bool = True,
|
|
embed_dim: int = 768,
|
|
latent_dim: int = 256
|
|
):
|
|
super().__init__()
|
|
# 获取配置
|
|
self.config = get_default_config()
|
|
self.embed_dim = embed_dim
|
|
|
|
# 1. 查询点编码器
|
|
self.query_encoder = nn.Sequential(
|
|
nn.Linear(3, embed_dim//4),
|
|
nn.LayerNorm(embed_dim//4),
|
|
nn.ReLU(),
|
|
nn.Linear(embed_dim//4, embed_dim//2),
|
|
nn.LayerNorm(embed_dim//2),
|
|
nn.ReLU(),
|
|
nn.Linear(embed_dim//2, embed_dim)
|
|
)
|
|
|
|
# 2. B-rep特征编码器
|
|
self.brep_embedder = BRepFeatureEmbedder()
|
|
|
|
# 3. 特征融合Transformer
|
|
self.transformer = SDFTransformer(
|
|
embed_dim=embed_dim,
|
|
num_layers=6
|
|
)
|
|
|
|
# 4. SDF预测头
|
|
self.sdf_head = SDFHead(embed_dim=embed_dim*2)
|
|
|
|
def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, query_points, data_class=None):
|
|
"""B-rep到SDF的前向传播
|
|
|
|
Args:
|
|
edge_ncs: 边归一化特征 [B, max_face, max_edge, num_edge_points, 3]
|
|
edge_pos: 边位置 [B, max_face, max_edge, 6]
|
|
edge_mask: 边掩码 [B, max_face, max_edge]
|
|
surf_ncs: 面归一化特征 [B, max_face, num_surf_points, 3]
|
|
surf_pos: 面位置 [B, max_face, 6]
|
|
vertex_pos: 顶点位置 [B, max_face, max_edge, 2, 3]
|
|
query_points: 查询点 [B, num_queries, 3]
|
|
data_class: (可选) 类别标签
|
|
|
|
Returns:
|
|
sdf: 预测的SDF值 [B, num_queries, 1]
|
|
"""
|
|
B, Q = query_points.shape[:2] # B: batch_size, Q: num_queries
|
|
|
|
try:
|
|
# 1. B-rep特征编码
|
|
brep_features = self.brep_embedder(
|
|
edge_ncs=edge_ncs, # [B, max_face, max_edge, num_edge_points, 3]
|
|
edge_pos=edge_pos, # [B, max_face, max_edge, 6]
|
|
edge_mask=edge_mask, # [B, max_face, max_edge]
|
|
surf_ncs=surf_ncs, # [B, max_face, num_surf_points, 3]
|
|
surf_pos=surf_pos, # [B, max_face, 6]
|
|
vertex_pos=vertex_pos, # [B, max_face, max_edge, 2, 3]
|
|
data_class=data_class
|
|
) # [B, max_face*(max_edge+1), embed_dim]
|
|
|
|
# 2. 查询点编码
|
|
query_features = self.query_encoder(query_points) # [B, Q, embed_dim]
|
|
|
|
# 3. 提取全局特征
|
|
global_features = brep_features.mean(dim=1) # [B, embed_dim]
|
|
|
|
# 4. 为每个查询点准备特征
|
|
expanded_features = global_features.unsqueeze(1).expand(-1, Q, -1) # [B, Q, embed_dim]
|
|
|
|
# 5. 连接查询点特征和全局特征
|
|
combined_features = torch.cat([
|
|
expanded_features, # [B, Q, embed_dim]
|
|
query_features # [B, Q, embed_dim]
|
|
], dim=-1) # [B, Q, embed_dim*2]
|
|
|
|
# 6. SDF预测
|
|
sdf = self.sdf_head(combined_features) # [B, Q, 1]
|
|
|
|
return sdf
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in BRepToSDF forward pass:")
|
|
logger.error(f" Error message: {str(e)}")
|
|
logger.error(f" Input shapes:")
|
|
logger.error(f" edge_ncs: {edge_ncs.shape}")
|
|
logger.error(f" edge_pos: {edge_pos.shape}")
|
|
logger.error(f" edge_mask: {edge_mask.shape}")
|
|
logger.error(f" surf_ncs: {surf_ncs.shape}")
|
|
logger.error(f" surf_pos: {surf_pos.shape}")
|
|
logger.error(f" vertex_pos: {vertex_pos.shape}")
|
|
logger.error(f" query_points: {query_points.shape}")
|
|
raise
|
|
|
|
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1):
|
|
"""SDF损失函数"""
|
|
# L1损失
|
|
l1_loss = F.l1_loss(pred_sdf, gt_sdf)
|
|
|
|
# 梯度约束损失
|
|
grad = torch.autograd.grad(
|
|
pred_sdf.sum(),
|
|
points,
|
|
create_graph=True
|
|
)[0]
|
|
grad_constraint = F.mse_loss(
|
|
torch.norm(grad, dim=-1),
|
|
torch.ones_like(pred_sdf.squeeze(-1))
|
|
)
|
|
|
|
return l1_loss + grad_weight * grad_constraint
|
|
|
|
def main():
|
|
# 获取配置
|
|
config = get_default_config()
|
|
|
|
# 从配置初始化模型
|
|
model = BRepToSDF(
|
|
brep_feature_dim=config.model.brep_feature_dim, # 48
|
|
use_cf=config.model.use_cf, # True
|
|
embed_dim=config.model.embed_dim, # 768
|
|
latent_dim=config.model.latent_dim # 256
|
|
)
|
|
|
|
# 从配置获取数据参数
|
|
batch_size = config.train.batch_size # 32
|
|
num_surfs = config.data.max_face # 64
|
|
num_edges = config.data.max_edge # 64
|
|
num_verts = 8 # 顶点数保持固定
|
|
num_queries = 1000 # 查询点数保持固定
|
|
|
|
# 更新测试数据维度
|
|
edge_ncs = torch.randn(
|
|
batch_size,
|
|
num_surfs, # max_face
|
|
num_edges, # max_edge
|
|
config.model.num_edge_points,
|
|
3
|
|
) # [B, max_face, max_edge, num_edge_points, 3]
|
|
|
|
edge_pos = torch.randn(
|
|
batch_size,
|
|
num_surfs,
|
|
num_edges,
|
|
6
|
|
) # [B, max_face, max_edge, 6]
|
|
|
|
edge_mask = torch.ones(
|
|
batch_size,
|
|
num_surfs,
|
|
num_edges,
|
|
dtype=torch.bool
|
|
) # [B, max_face, max_edge]
|
|
|
|
surf_ncs = torch.randn(
|
|
batch_size,
|
|
num_surfs,
|
|
config.model.num_surf_points,
|
|
3
|
|
) # [B, max_face, num_surf_points, 3]
|
|
|
|
surf_pos = torch.randn(
|
|
batch_size,
|
|
num_surfs,
|
|
6
|
|
) # [B, max_face, 6]
|
|
|
|
vertex_pos = torch.randn(
|
|
batch_size,
|
|
num_surfs,
|
|
num_edges,
|
|
2,
|
|
3
|
|
) # [B, max_face, max_edge, 2, 3]
|
|
|
|
query_points = torch.randn(batch_size, num_queries, 3)
|
|
|
|
# 更新前向传播调用
|
|
sdf = model(
|
|
edge_ncs=edge_ncs,
|
|
edge_pos=edge_pos,
|
|
edge_mask=edge_mask,
|
|
surf_ncs=surf_ncs,
|
|
surf_pos=surf_pos,
|
|
vertex_pos=vertex_pos,
|
|
query_points=query_points
|
|
)
|
|
|
|
# 更新打印信息
|
|
print("\nInput shapes:")
|
|
print(f"edge_ncs: {edge_ncs.shape}")
|
|
print(f"edge_pos: {edge_pos.shape}")
|
|
print(f"edge_mask: {edge_mask.shape}")
|
|
print(f"surf_ncs: {surf_ncs.shape}")
|
|
print(f"surf_pos: {surf_pos.shape}")
|
|
print(f"vertex_pos: {vertex_pos.shape}")
|
|
print(f"query_points: {query_points.shape}")
|
|
print(f"\nOutput SDF shape: {sdf.shape}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|