|
|
@ -76,7 +76,7 @@ class UNetMidBlock1D(nn.Module): |
|
|
|
x = attn(x) |
|
|
|
return x |
|
|
|
|
|
|
|
class Encoder1D(nn.Module): |
|
|
|
class _Encoder1D(nn.Module): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
in_channels: int = 3, |
|
|
@ -114,47 +114,40 @@ class Encoder1D(nn.Module): |
|
|
|
x = self.mid_block(x) |
|
|
|
x = self.conv_out(x) |
|
|
|
return x |
|
|
|
|
|
|
|
class Encoder1D_(nn.Module): |
|
|
|
"""一维编码器""" |
|
|
|
|
|
|
|
class Encoder1D(nn.Module): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
in_channels: int = 3, |
|
|
|
out_channels: int = 256, |
|
|
|
block_out_channels: Tuple[int] = (64, 128, 256), |
|
|
|
layers_per_block: int = 2, |
|
|
|
num_points: int = 2, # 输入点数 |
|
|
|
in_channels: int = 3, # xyz坐标 |
|
|
|
out_dim: int = 6, # 展平后维度 (2*3) |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.conv_in = nn.Conv1d(in_channels, block_out_channels[0], 3, padding=1) |
|
|
|
|
|
|
|
self.down_blocks = nn.ModuleList([]) |
|
|
|
in_ch = block_out_channels[0] |
|
|
|
for out_ch in block_out_channels: |
|
|
|
block = [] |
|
|
|
for _ in range(layers_per_block): |
|
|
|
block.append(ResConvBlock(in_ch, out_ch, out_ch)) |
|
|
|
in_ch = out_ch |
|
|
|
if out_ch != block_out_channels[-1]: |
|
|
|
block.append(nn.AvgPool1d(2)) |
|
|
|
self.down_blocks.append(nn.Sequential(*block)) |
|
|
|
|
|
|
|
self.mid_block = UNetMidBlock1D( |
|
|
|
in_channels=block_out_channels[-1], |
|
|
|
mid_channels=block_out_channels[-1], |
|
|
|
) |
|
|
|
self.num_points = num_points |
|
|
|
self.in_channels = in_channels |
|
|
|
self.out_dim = out_dim |
|
|
|
|
|
|
|
self.conv_out = nn.Sequential( |
|
|
|
nn.GroupNorm(32, block_out_channels[-1]), |
|
|
|
nn.SiLU(), |
|
|
|
nn.Conv1d(block_out_channels[-1], out_channels, 3, padding=1), |
|
|
|
) |
|
|
|
|
|
|
|
assert num_points * in_channels == out_dim, \ |
|
|
|
f"Flatten dimension {out_dim} must equal num_points({num_points}) * in_channels({in_channels})" |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
x = self.conv_in(x) |
|
|
|
for block in self.down_blocks: |
|
|
|
x = block(x) |
|
|
|
x = self.mid_block(x) |
|
|
|
x = self.conv_out(x) |
|
|
|
""" |
|
|
|
Args: |
|
|
|
x: [B, F, E, num_points, channels] |
|
|
|
Returns: |
|
|
|
x: [B, F, E, out_dim] |
|
|
|
""" |
|
|
|
orig_shape = x.shape[:-2] # [B, F, E] |
|
|
|
|
|
|
|
# 重塑以处理所有批次 |
|
|
|
x = x.reshape(-1, self.num_points, self.in_channels) # [B*F*E, num_points, channels] |
|
|
|
|
|
|
|
# 展平点和通道维度 |
|
|
|
x = x.reshape(x.size(0), -1) # [B*F*E, num_points*channels] |
|
|
|
|
|
|
|
# 恢复批次维度 |
|
|
|
x = x.reshape(*orig_shape, self.out_dim) # [B, F, E, out_dim] |
|
|
|
|
|
|
|
return x |
|
|
|
|
|
|
|
class BRepFeatureEmbedder(nn.Module): |
|
|
@ -177,62 +170,37 @@ class BRepFeatureEmbedder(nn.Module): |
|
|
|
logger.info(f" num_edge_points: {self.num_edge_points}") |
|
|
|
logger.info(f" embed_dim: {self.embed_dim}") |
|
|
|
logger.info(f" use_cf: {self.use_cf}") |
|
|
|
|
|
|
|
# Transformer编码器层 |
|
|
|
layer = nn.TransformerEncoderLayer( |
|
|
|
d_model=self.embed_dim, |
|
|
|
nhead=8, # 从12减少到8,使每个head的维度更大 |
|
|
|
norm_first=True, # 改为True,先进行归一化 |
|
|
|
dim_feedforward=self.embed_dim * 4, # 增大FFN维度 |
|
|
|
dropout=0.1, |
|
|
|
activation=F.gelu # 使用GELU激活函数 |
|
|
|
) |
|
|
|
|
|
|
|
# 添加初始化方法 |
|
|
|
def _init_weights(module): |
|
|
|
if isinstance(module, nn.Linear): |
|
|
|
torch.nn.init.xavier_uniform_(module.weight, gain=1/math.sqrt(2)) |
|
|
|
if module.bias is not None: |
|
|
|
torch.nn.init.zeros_(module.bias) |
|
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
|
torch.nn.init.ones_(module.weight) |
|
|
|
torch.nn.init.zeros_(module.bias) |
|
|
|
|
|
|
|
self.transformer = nn.TransformerEncoder( |
|
|
|
layer, |
|
|
|
num_layers=6, # 从12减少到6层 |
|
|
|
norm=nn.LayerNorm(self.embed_dim), |
|
|
|
enable_nested_tensor=False |
|
|
|
) |
|
|
|
|
|
|
|
# 应用初始化 |
|
|
|
self.transformer.apply(_init_weights) |
|
|
|
|
|
|
|
# 添加位置编码 |
|
|
|
self.pos_embedding = nn.Parameter(torch.randn(1, 1000, self.embed_dim) * 0.02) # 最大序列长度设为1000 |
|
|
|
|
|
|
|
surf_out_dim = 3*self.num_surf_points # 面特征展平后维度,8*3=24 |
|
|
|
edge_out_dim = 3*self.num_edge_points # 边特征展平后维度,2*3=6 |
|
|
|
self.surfz_encoder = Encoder1D(num_points=self.num_surf_points,in_channels=3,out_dim=surf_out_dim) |
|
|
|
self.edgez_encoder = Encoder1D(num_points=self.num_edge_points,in_channels=3,out_dim=edge_out_dim) |
|
|
|
|
|
|
|
# 修改为处理[num_points, 3]形状的输入 |
|
|
|
self.surfz_embed = Encoder1D( |
|
|
|
in_channels=3, |
|
|
|
out_channels=self.embed_dim, |
|
|
|
block_out_channels=(64, 128, self.embed_dim), |
|
|
|
layers_per_block=2 |
|
|
|
) |
|
|
|
|
|
|
|
self.edgez_embed = Encoder1D( |
|
|
|
in_channels=3, |
|
|
|
out_channels=self.embed_dim, |
|
|
|
block_out_channels=(64, 128, self.embed_dim), |
|
|
|
layers_per_block=2 |
|
|
|
self.surfz_embed = nn.Sequential( |
|
|
|
self.surfz_encoder, # [B, 16, 3] -> [B, 48] |
|
|
|
nn.Linear(surf_out_dim, self.embed_dim), # [B, 48] -> [B, embed_dim] |
|
|
|
nn.LayerNorm(self.embed_dim), |
|
|
|
nn.SiLU(), |
|
|
|
nn.Linear(self.embed_dim, self.embed_dim), |
|
|
|
) |
|
|
|
|
|
|
|
# 其他嵌入层保持不变 |
|
|
|
|
|
|
|
self.surfp_embed = nn.Sequential( |
|
|
|
nn.Linear(6, self.embed_dim), |
|
|
|
nn.LayerNorm(self.embed_dim), |
|
|
|
nn.SiLU(), |
|
|
|
nn.Linear(self.embed_dim, self.embed_dim), |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.edgez_embed = nn.Sequential( |
|
|
|
self.edgez_encoder, |
|
|
|
nn.Linear(edge_out_dim, self.embed_dim), |
|
|
|
nn.LayerNorm(self.embed_dim), |
|
|
|
nn.SiLU(), |
|
|
|
nn.Linear(self.embed_dim, self.embed_dim), |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.edgep_embed = nn.Sequential( |
|
|
|
nn.Linear(6, self.embed_dim), |
|
|
@ -243,86 +211,87 @@ class BRepFeatureEmbedder(nn.Module): |
|
|
|
|
|
|
|
# 修改vertp_embed的结构 |
|
|
|
self.vertp_embed = nn.Sequential( |
|
|
|
nn.Linear(3, self.embed_dim // 2), |
|
|
|
nn.LayerNorm(self.embed_dim // 2), |
|
|
|
nn.ReLU(), |
|
|
|
nn.Linear(self.embed_dim // 2, self.embed_dim) |
|
|
|
nn.Linear(3, self.embed_dim), |
|
|
|
nn.LayerNorm(self.embed_dim), |
|
|
|
nn.SiLU(), |
|
|
|
nn.Linear(self.embed_dim, self.embed_dim) |
|
|
|
) |
|
|
|
|
|
|
|
# 添加一个额外的投影层 |
|
|
|
self.vertex_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
|
|
|
|
|
# 添加 transformer 初始化 |
|
|
|
self.transformer = nn.TransformerEncoder( |
|
|
|
encoder_layer=nn.TransformerEncoderLayer( |
|
|
|
d_model=self.embed_dim, |
|
|
|
nhead=8, # 注意力头数,通常是embed_dim的因子 |
|
|
|
dim_feedforward=4*self.embed_dim, # 前馈网络维度,通常是embed_dim的4倍 |
|
|
|
dropout=0.1, |
|
|
|
activation='gelu', |
|
|
|
batch_first=False # 因为我们用了transpose(0,1) |
|
|
|
), |
|
|
|
num_layers=6 # transformer层数 |
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, **kwargs): |
|
|
|
B, F, E, _, _ = edge_ncs.shape |
|
|
|
|
|
|
|
# 处理顶点特征 |
|
|
|
vertex_features = vertex_pos.view(B*F*E*2, -1) # 展平顶点坐标 |
|
|
|
vertex_embed = self.vertp_embed(vertex_features) |
|
|
|
vertex_embed = self.vertex_proj(vertex_embed) # 添加投影 |
|
|
|
vertex_embed = vertex_embed.view(B, F, E, 2, -1) # 恢复形状 |
|
|
|
|
|
|
|
# 确保顶点特征参与后续计算 |
|
|
|
edge_features = torch.cat([ |
|
|
|
self.edgep_embed(edge_pos.view(B*F*E, -1)).view(B, F, E, -1), |
|
|
|
vertex_embed.mean(dim=3) # 将顶点特征平均池化 |
|
|
|
], dim=-1) |
|
|
|
|
|
|
|
# 1. 处理边特征 |
|
|
|
# 重塑边点云以适应1D编码器 |
|
|
|
edge_ncs = edge_ncs.reshape(B*F*E, -1, 3).transpose(1, 2) # [B*max_face*max_edge, 3, num_edge_points] |
|
|
|
edge_embeds = self.edgez_embed(edge_ncs) # [B*max_face*max_edge, embed_dim, num_edge_points] |
|
|
|
edge_embeds = edge_embeds.mean(dim=-1) # [B*max_face*max_edge, embed_dim] |
|
|
|
edge_embeds = edge_embeds.reshape(B, F, E, -1) # [B, max_face, max_edge, embed_dim] |
|
|
|
|
|
|
|
# 2. 处理面特征 |
|
|
|
surf_ncs = surf_ncs.reshape(B*F, -1, 3).transpose(1, 2) # [B*max_face, 3, num_surf_points] |
|
|
|
surf_embeds = self.surfz_embed(surf_ncs) # [B*max_face, embed_dim, num_surf_points] |
|
|
|
surf_embeds = surf_embeds.mean(dim=-1) # [B*max_face, embed_dim] |
|
|
|
surf_embeds = surf_embeds.reshape(B, F, -1) # [B, max_face, embed_dim] |
|
|
|
|
|
|
|
# 3. 处理位置编码 |
|
|
|
# 边位置编码 |
|
|
|
edge_pos = edge_pos.reshape(B*F*E, -1) # [B*max_face*max_edge, 6] |
|
|
|
edge_p_embeds = self.edgep_embed(edge_pos) # [B*max_face*max_edge, embed_dim] |
|
|
|
edge_p_embeds = edge_p_embeds.reshape(B, F, E, -1) # [B, max_face, max_edge, embed_dim] |
|
|
|
|
|
|
|
# 面位置编码 |
|
|
|
surf_p_embeds = self.surfp_embed(surf_pos) # [B, max_face, embed_dim] |
|
|
|
""" |
|
|
|
Args: |
|
|
|
edge_ncs: [B, F, E, num_edge_points, 3] - 边点云 |
|
|
|
edge_pos: [B, F, E, 6] - 边位置编码 |
|
|
|
edge_mask: [B, F, E] - 边掩码 |
|
|
|
surf_ncs: [B, F, num_surf_points, 3] - 面点云 |
|
|
|
surf_pos: [B, F, 6] - 面位置编码 |
|
|
|
vertex_pos: [B, F, E, 2, 3] - 顶点位置 |
|
|
|
""" |
|
|
|
B, F, E = edge_pos.shape[:3] |
|
|
|
|
|
|
|
# 1. 处理顶点特征 |
|
|
|
vertex_embed = self.vertp_embed(vertex_pos[..., :3]) # [B, F, E, 2, embed_dim] |
|
|
|
vertex_embed = self.vertex_proj(vertex_embed) # [B, F, E, 2, embed_dim] |
|
|
|
vertex_embed = vertex_embed.mean(dim=3) # [B, F, E, embed_dim] |
|
|
|
|
|
|
|
# 2. 处理边特征 |
|
|
|
logger.info(f"edge_ncs shape: {edge_ncs.shape}") |
|
|
|
edge_embeds = self.edgez_embed(edge_ncs) # [B, F, E, embed_dim] |
|
|
|
edge_p_embeds = self.edgep_embed(edge_pos) # [B, F, E, embed_dim] |
|
|
|
|
|
|
|
# 3. 处理面特征 |
|
|
|
surf_embeds = self.surfz_embed(surf_ncs) # [B, F, embed_dim] |
|
|
|
surf_p_embeds = self.surfp_embed(surf_pos) # [B, F, embed_dim] |
|
|
|
|
|
|
|
# 4. 组合特征 |
|
|
|
if self.use_cf: |
|
|
|
# 边特征 |
|
|
|
edge_features = edge_embeds + edge_p_embeds # [B, max_face, max_edge, embed_dim] |
|
|
|
edge_features = edge_features.reshape(B, F*E, -1) # [B, max_face*max_edge, embed_dim] |
|
|
|
# 组合边特征 |
|
|
|
edge_features = edge_embeds + edge_p_embeds + vertex_embed # [B, F, E, embed_dim] |
|
|
|
edge_features = edge_features.reshape(B, F*E, -1) # [B, F*E, embed_dim] |
|
|
|
|
|
|
|
# 面特征 |
|
|
|
surf_features = surf_embeds + surf_p_embeds # [B, max_face, embed_dim] |
|
|
|
# 组合面特征 |
|
|
|
surf_features = surf_embeds + surf_p_embeds # [B, F, embed_dim] |
|
|
|
|
|
|
|
# 组合所有特征 |
|
|
|
# 拼接所有特征 |
|
|
|
embeds = torch.cat([ |
|
|
|
edge_features, # [B, max_face*max_edge, embed_dim] |
|
|
|
surf_features # [B, max_face, embed_dim] |
|
|
|
], dim=1) # [B, max_face*(max_edge+1), embed_dim] |
|
|
|
edge_features, # [B, F*E, embed_dim] |
|
|
|
surf_features # [B, F, embed_dim] |
|
|
|
], dim=1) # [B, F*E+F, embed_dim] |
|
|
|
else: |
|
|
|
# 只使用位置编码 |
|
|
|
edge_features = edge_p_embeds.reshape(B, F*E, -1) # [B, max_face*max_edge, embed_dim] |
|
|
|
edge_features = edge_p_embeds.reshape(B, F*E, -1) # [B, F*E, embed_dim] |
|
|
|
embeds = torch.cat([ |
|
|
|
edge_features, # [B, max_face*max_edge, embed_dim] |
|
|
|
surf_p_embeds # [B, max_face, embed_dim] |
|
|
|
], dim=1) # [B, max_face*(max_edge+1), embed_dim] |
|
|
|
edge_features, # [B, F*E, embed_dim] |
|
|
|
surf_p_embeds # [B, F, embed_dim] |
|
|
|
], dim=1) # [B, F*E+F, embed_dim] |
|
|
|
|
|
|
|
# 5. 处理掩码 |
|
|
|
if edge_mask is not None: |
|
|
|
# 扩展掩码以匹配特征维度 |
|
|
|
edge_mask = edge_mask.reshape(B, -1) # [B, max_face*max_edge] |
|
|
|
surf_mask = torch.ones(B, F, device=edge_mask.device, dtype=torch.bool) # [B, max_face] |
|
|
|
mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, max_face*(max_edge+1)] |
|
|
|
edge_mask = edge_mask.reshape(B, -1) # [B, F*E] |
|
|
|
surf_mask = torch.ones(B, F, device=edge_mask.device, dtype=torch.bool) # [B, F] |
|
|
|
mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, F*E+F] |
|
|
|
else: |
|
|
|
mask = None |
|
|
|
|
|
|
|
# 6. Transformer处理 |
|
|
|
output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask) |
|
|
|
return output.transpose(0, 1) # 确保输出维度为 [B, seq_len, embed_dim] |
|
|
|
return output.transpose(0, 1) # [B, F*E+F, embed_dim] |
|
|
|
|
|
|
|
class SDFTransformer(nn.Module): |
|
|
|
"""SDF Transformer编码器""" |
|
|
|