|
|
@ -180,19 +180,37 @@ class BRepFeatureEmbedder(nn.Module): |
|
|
|
|
|
|
|
# Transformer编码器层 |
|
|
|
layer = nn.TransformerEncoderLayer( |
|
|
|
d_model=self.embed_dim, |
|
|
|
nhead=12, |
|
|
|
norm_first=False, |
|
|
|
dim_feedforward=1024, |
|
|
|
dropout=0.1 |
|
|
|
d_model=self.embed_dim, |
|
|
|
nhead=8, # 从12减少到8,使每个head的维度更大 |
|
|
|
norm_first=True, # 改为True,先进行归一化 |
|
|
|
dim_feedforward=self.embed_dim * 4, # 增大FFN维度 |
|
|
|
dropout=0.1, |
|
|
|
activation=F.gelu # 使用GELU激活函数 |
|
|
|
) |
|
|
|
|
|
|
|
# 添加初始化方法 |
|
|
|
def _init_weights(module): |
|
|
|
if isinstance(module, nn.Linear): |
|
|
|
torch.nn.init.xavier_uniform_(module.weight, gain=1/math.sqrt(2)) |
|
|
|
if module.bias is not None: |
|
|
|
torch.nn.init.zeros_(module.bias) |
|
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
|
torch.nn.init.ones_(module.weight) |
|
|
|
torch.nn.init.zeros_(module.bias) |
|
|
|
|
|
|
|
self.transformer = nn.TransformerEncoder( |
|
|
|
layer, |
|
|
|
num_layers=12, |
|
|
|
layer, |
|
|
|
num_layers=6, # 从12减少到6层 |
|
|
|
norm=nn.LayerNorm(self.embed_dim), |
|
|
|
enable_nested_tensor=False |
|
|
|
) |
|
|
|
|
|
|
|
# 应用初始化 |
|
|
|
self.transformer.apply(_init_weights) |
|
|
|
|
|
|
|
# 添加位置编码 |
|
|
|
self.pos_embedding = nn.Parameter(torch.randn(1, 1000, self.embed_dim) * 0.02) # 最大序列长度设为1000 |
|
|
|
|
|
|
|
# 修改为处理[num_points, 3]形状的输入 |
|
|
|
self.surfz_embed = Encoder1D( |
|
|
|
in_channels=3, |
|
|
@ -223,117 +241,162 @@ class BRepFeatureEmbedder(nn.Module): |
|
|
|
nn.Linear(self.embed_dim, self.embed_dim), |
|
|
|
) |
|
|
|
|
|
|
|
# 修改vertp_embed的结构 |
|
|
|
self.vertp_embed = nn.Sequential( |
|
|
|
nn.Linear(6, self.embed_dim), |
|
|
|
nn.LayerNorm(self.embed_dim), |
|
|
|
nn.SiLU(), |
|
|
|
nn.Linear(self.embed_dim, self.embed_dim), |
|
|
|
nn.Linear(3, self.embed_dim // 2), |
|
|
|
nn.LayerNorm(self.embed_dim // 2), |
|
|
|
nn.ReLU(), |
|
|
|
nn.Linear(self.embed_dim // 2, self.embed_dim) |
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, data_class=None): |
|
|
|
"""B-rep特征嵌入器的前向传播 |
|
|
|
|
|
|
|
Args: |
|
|
|
edge_ncs: 边归一化特征 [B, max_face, max_edge, num_edge_points, 3] |
|
|
|
edge_pos: 边位置 [B, max_face, max_edge, 6] |
|
|
|
edge_mask: 边掩码 [B, max_face, max_edge] |
|
|
|
surf_ncs: 面归一化特征 [B, max_face, num_surf_points, 3] |
|
|
|
surf_pos: 面位置 [B, max_face, 6] |
|
|
|
vertex_pos: 顶点位置 [B, max_face, max_edge, 2, 3] |
|
|
|
# 添加一个额外的投影层 |
|
|
|
self.vertex_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
|
|
|
|
|
Returns: |
|
|
|
embeds: [B, max_face*(max_edge+1), embed_dim] |
|
|
|
""" |
|
|
|
B = self.config.train.batch_size |
|
|
|
max_face = self.config.data.max_face |
|
|
|
max_edge = self.config.data.max_edge |
|
|
|
def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, **kwargs): |
|
|
|
B, F, E, _, _ = edge_ncs.shape |
|
|
|
|
|
|
|
try: |
|
|
|
# 1. 处理边特征 |
|
|
|
# 重塑边点云以适应1D编码器 |
|
|
|
edge_ncs = edge_ncs.reshape(B*max_face*max_edge, -1, 3).transpose(1, 2) # [B*max_face*max_edge, 3, num_edge_points] |
|
|
|
edge_embeds = self.edgez_embed(edge_ncs) # [B*max_face*max_edge, embed_dim, num_edge_points] |
|
|
|
edge_embeds = edge_embeds.mean(dim=-1) # [B*max_face*max_edge, embed_dim] |
|
|
|
edge_embeds = edge_embeds.reshape(B, max_face, max_edge, -1) # [B, max_face, max_edge, embed_dim] |
|
|
|
|
|
|
|
# 2. 处理面特征 |
|
|
|
surf_ncs = surf_ncs.reshape(B*max_face, -1, 3).transpose(1, 2) # [B*max_face, 3, num_surf_points] |
|
|
|
surf_embeds = self.surfz_embed(surf_ncs) # [B*max_face, embed_dim, num_surf_points] |
|
|
|
surf_embeds = surf_embeds.mean(dim=-1) # [B*max_face, embed_dim] |
|
|
|
surf_embeds = surf_embeds.reshape(B, max_face, -1) # [B, max_face, embed_dim] |
|
|
|
|
|
|
|
# 3. 处理位置编码 |
|
|
|
# 边位置编码 |
|
|
|
edge_pos = edge_pos.reshape(B*max_face*max_edge, -1) # [B*max_face*max_edge, 6] |
|
|
|
edge_p_embeds = self.edgep_embed(edge_pos) # [B*max_face*max_edge, embed_dim] |
|
|
|
edge_p_embeds = edge_p_embeds.reshape(B, max_face, max_edge, -1) # [B, max_face, max_edge, embed_dim] |
|
|
|
|
|
|
|
# 面位置编码 |
|
|
|
surf_p_embeds = self.surfp_embed(surf_pos) # [B, max_face, embed_dim] |
|
|
|
|
|
|
|
# 4. 组合特征 |
|
|
|
if self.use_cf: |
|
|
|
# 边特征 |
|
|
|
edge_features = edge_embeds + edge_p_embeds # [B, max_face, max_edge, embed_dim] |
|
|
|
edge_features = edge_features.reshape(B, max_face*max_edge, -1) # [B, max_face*max_edge, embed_dim] |
|
|
|
|
|
|
|
# 面特征 |
|
|
|
surf_features = surf_embeds + surf_p_embeds # [B, max_face, embed_dim] |
|
|
|
|
|
|
|
# 组合所有特征 |
|
|
|
embeds = torch.cat([ |
|
|
|
edge_features, # [B, max_face*max_edge, embed_dim] |
|
|
|
surf_features # [B, max_face, embed_dim] |
|
|
|
], dim=1) # [B, max_face*(max_edge+1), embed_dim] |
|
|
|
else: |
|
|
|
# 只使用位置编码 |
|
|
|
edge_features = edge_p_embeds.reshape(B, max_face*max_edge, -1) # [B, max_face*max_edge, embed_dim] |
|
|
|
embeds = torch.cat([ |
|
|
|
edge_features, # [B, max_face*max_edge, embed_dim] |
|
|
|
surf_p_embeds # [B, max_face, embed_dim] |
|
|
|
], dim=1) # [B, max_face*(max_edge+1), embed_dim] |
|
|
|
|
|
|
|
# 5. 处理掩码 |
|
|
|
if edge_mask is not None: |
|
|
|
# 扩展掩码以匹配特征维度 |
|
|
|
edge_mask = edge_mask.reshape(B, -1) # [B, max_face*max_edge] |
|
|
|
surf_mask = torch.ones(B, max_face, device=edge_mask.device, dtype=torch.bool) # [B, max_face] |
|
|
|
mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, max_face*(max_edge+1)] |
|
|
|
else: |
|
|
|
mask = None |
|
|
|
# 处理顶点特征 |
|
|
|
vertex_features = vertex_pos.view(B*F*E*2, -1) # 展平顶点坐标 |
|
|
|
vertex_embed = self.vertp_embed(vertex_features) |
|
|
|
vertex_embed = self.vertex_proj(vertex_embed) # 添加投影 |
|
|
|
vertex_embed = vertex_embed.view(B, F, E, 2, -1) # 恢复形状 |
|
|
|
|
|
|
|
# 确保顶点特征参与后续计算 |
|
|
|
edge_features = torch.cat([ |
|
|
|
self.edgep_embed(edge_pos.view(B*F*E, -1)).view(B, F, E, -1), |
|
|
|
vertex_embed.mean(dim=3) # 将顶点特征平均池化 |
|
|
|
], dim=-1) |
|
|
|
|
|
|
|
# 1. 处理边特征 |
|
|
|
# 重塑边点云以适应1D编码器 |
|
|
|
edge_ncs = edge_ncs.reshape(B*F*E, -1, 3).transpose(1, 2) # [B*max_face*max_edge, 3, num_edge_points] |
|
|
|
edge_embeds = self.edgez_embed(edge_ncs) # [B*max_face*max_edge, embed_dim, num_edge_points] |
|
|
|
edge_embeds = edge_embeds.mean(dim=-1) # [B*max_face*max_edge, embed_dim] |
|
|
|
edge_embeds = edge_embeds.reshape(B, F, E, -1) # [B, max_face, max_edge, embed_dim] |
|
|
|
|
|
|
|
# 2. 处理面特征 |
|
|
|
surf_ncs = surf_ncs.reshape(B*F, -1, 3).transpose(1, 2) # [B*max_face, 3, num_surf_points] |
|
|
|
surf_embeds = self.surfz_embed(surf_ncs) # [B*max_face, embed_dim, num_surf_points] |
|
|
|
surf_embeds = surf_embeds.mean(dim=-1) # [B*max_face, embed_dim] |
|
|
|
surf_embeds = surf_embeds.reshape(B, F, -1) # [B, max_face, embed_dim] |
|
|
|
|
|
|
|
# 3. 处理位置编码 |
|
|
|
# 边位置编码 |
|
|
|
edge_pos = edge_pos.reshape(B*F*E, -1) # [B*max_face*max_edge, 6] |
|
|
|
edge_p_embeds = self.edgep_embed(edge_pos) # [B*max_face*max_edge, embed_dim] |
|
|
|
edge_p_embeds = edge_p_embeds.reshape(B, F, E, -1) # [B, max_face, max_edge, embed_dim] |
|
|
|
|
|
|
|
# 面位置编码 |
|
|
|
surf_p_embeds = self.surfp_embed(surf_pos) # [B, max_face, embed_dim] |
|
|
|
|
|
|
|
# 4. 组合特征 |
|
|
|
if self.use_cf: |
|
|
|
# 边特征 |
|
|
|
edge_features = edge_embeds + edge_p_embeds # [B, max_face, max_edge, embed_dim] |
|
|
|
edge_features = edge_features.reshape(B, F*E, -1) # [B, max_face*max_edge, embed_dim] |
|
|
|
|
|
|
|
# 6. Transformer处理 |
|
|
|
output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask) |
|
|
|
return output.transpose(0, 1) # 确保输出维度为 [B, seq_len, embed_dim] |
|
|
|
# 面特征 |
|
|
|
surf_features = surf_embeds + surf_p_embeds # [B, max_face, embed_dim] |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
logger.error(f"Error in BRepFeatureEmbedder forward pass:") |
|
|
|
logger.error(f" Error message: {str(e)}") |
|
|
|
logger.error(f" Input shapes:") |
|
|
|
logger.error(f" edge_ncs: {edge_ncs.shape}") |
|
|
|
logger.error(f" edge_pos: {edge_pos.shape}") |
|
|
|
logger.error(f" edge_mask: {edge_mask.shape}") |
|
|
|
logger.error(f" surf_ncs: {surf_ncs.shape}") |
|
|
|
logger.error(f" surf_pos: {surf_pos.shape}") |
|
|
|
logger.error(f" vertex_pos: {vertex_pos.shape}") |
|
|
|
raise |
|
|
|
# 组合所有特征 |
|
|
|
embeds = torch.cat([ |
|
|
|
edge_features, # [B, max_face*max_edge, embed_dim] |
|
|
|
surf_features # [B, max_face, embed_dim] |
|
|
|
], dim=1) # [B, max_face*(max_edge+1), embed_dim] |
|
|
|
else: |
|
|
|
# 只使用位置编码 |
|
|
|
edge_features = edge_p_embeds.reshape(B, F*E, -1) # [B, max_face*max_edge, embed_dim] |
|
|
|
embeds = torch.cat([ |
|
|
|
edge_features, # [B, max_face*max_edge, embed_dim] |
|
|
|
surf_p_embeds # [B, max_face, embed_dim] |
|
|
|
], dim=1) # [B, max_face*(max_edge+1), embed_dim] |
|
|
|
|
|
|
|
# 5. 处理掩码 |
|
|
|
if edge_mask is not None: |
|
|
|
# 扩展掩码以匹配特征维度 |
|
|
|
edge_mask = edge_mask.reshape(B, -1) # [B, max_face*max_edge] |
|
|
|
surf_mask = torch.ones(B, F, device=edge_mask.device, dtype=torch.bool) # [B, max_face] |
|
|
|
mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, max_face*(max_edge+1)] |
|
|
|
else: |
|
|
|
mask = None |
|
|
|
|
|
|
|
# 6. Transformer处理 |
|
|
|
output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask) |
|
|
|
return output.transpose(0, 1) # 确保输出维度为 [B, seq_len, embed_dim] |
|
|
|
|
|
|
|
class SDFTransformer(nn.Module): |
|
|
|
"""SDF Transformer编码器""" |
|
|
|
def __init__(self, embed_dim: int = 768, num_layers: int = 6): |
|
|
|
def __init__(self, embed_dim: int = 192, num_layers: int = 6): # 改为192以匹配BRepFeatureEmbedder |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
# 1. 添加位置编码 |
|
|
|
self.pos_embedding = nn.Parameter(torch.randn(1, 1000, embed_dim) * 0.02) |
|
|
|
|
|
|
|
# 2. 修改Transformer层配置 |
|
|
|
layer = nn.TransformerEncoderLayer( |
|
|
|
d_model=embed_dim, |
|
|
|
nhead=8, |
|
|
|
dim_feedforward=1024, |
|
|
|
nhead=4, # 减少头数,使每个头的维度更大 |
|
|
|
dim_feedforward=embed_dim * 2, # 减小FFN维度 |
|
|
|
dropout=0.1, |
|
|
|
batch_first=True, |
|
|
|
norm_first=False # 修改这里:设置为False |
|
|
|
norm_first=True, # 使用Pre-LN结构 |
|
|
|
activation=F.gelu |
|
|
|
) |
|
|
|
|
|
|
|
# 3. 添加梯度缩放因子 |
|
|
|
self.attention_scale = math.sqrt(embed_dim) |
|
|
|
|
|
|
|
# 4. 自定义初始化 |
|
|
|
def _init_weights(module): |
|
|
|
if isinstance(module, nn.Linear): |
|
|
|
# 使用较小的初始化范围 |
|
|
|
nn.init.xavier_uniform_(module.weight, gain=0.1) |
|
|
|
if module.bias is not None: |
|
|
|
nn.init.zeros_(module.bias) |
|
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
|
nn.init.ones_(module.weight) |
|
|
|
nn.init.zeros_(module.bias) |
|
|
|
elif isinstance(module, nn.MultiheadAttention): |
|
|
|
# 特别处理注意力层 |
|
|
|
in_proj_weight = module.in_proj_weight |
|
|
|
out_proj_weight = module.out_proj.weight |
|
|
|
|
|
|
|
nn.init.xavier_uniform_(in_proj_weight, gain=0.1) |
|
|
|
nn.init.xavier_uniform_(out_proj_weight, gain=0.1) |
|
|
|
|
|
|
|
if module.in_proj_bias is not None: |
|
|
|
nn.init.zeros_(module.in_proj_bias) |
|
|
|
if module.out_proj.bias is not None: |
|
|
|
nn.init.zeros_(module.out_proj.bias) |
|
|
|
|
|
|
|
self.transformer = nn.TransformerEncoder( |
|
|
|
layer, |
|
|
|
num_layers, |
|
|
|
norm=nn.LayerNorm(embed_dim) |
|
|
|
) |
|
|
|
self.transformer = nn.TransformerEncoder(layer, num_layers) |
|
|
|
|
|
|
|
# 应用初始化 |
|
|
|
self.transformer.apply(_init_weights) |
|
|
|
|
|
|
|
# 5. 添加残差缩放 |
|
|
|
self.residual_scale = nn.Parameter(torch.ones(1) * 0.1) |
|
|
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
|
return self.transformer(x, src_key_padding_mask=mask) |
|
|
|
# 添加位置编码 |
|
|
|
seq_len = x.size(1) |
|
|
|
x = x + self.pos_embedding[:, :seq_len, :] |
|
|
|
|
|
|
|
# 缩放注意力分数 |
|
|
|
x = x * self.attention_scale |
|
|
|
|
|
|
|
# 前向传播 |
|
|
|
for layer in self.transformer.layers: |
|
|
|
# 手动添加残差连接和缩放 |
|
|
|
identity = x |
|
|
|
x = layer(x, src_key_padding_mask=mask) |
|
|
|
x = identity + x * self.residual_scale |
|
|
|
|
|
|
|
return self.transformer.norm(x) |
|
|
|
|
|
|
|
class SDFHead(nn.Module): |
|
|
|
"""SDF预测头""" |
|
|
|