|
|
@ -246,39 +246,39 @@ class BRepFeatureEmbedder(nn.Module): |
|
|
|
'vertex_pos': vertex_pos |
|
|
|
} |
|
|
|
|
|
|
|
logger.info("\n=== 输入张量检查 ===") |
|
|
|
for name, tensor in input_tensors.items(): |
|
|
|
logger.print_tensor_stats(name, tensor) |
|
|
|
#logger.info("\n=== 输入张量检查 ===") |
|
|
|
#for name, tensor in input_tensors.items(): |
|
|
|
#logger.print_tensor_stats(name, tensor) |
|
|
|
|
|
|
|
# 1. 处理顶点特征 |
|
|
|
vertex_embed = self.vertp_embed(vertex_pos[..., :3]) # [B, F, E, 2, embed_dim] |
|
|
|
logger.print_tensor_stats('vertex_embed', vertex_embed) |
|
|
|
# logger.print_tensor_stats('vertex_embed', vertex_embed) |
|
|
|
vertex_embed = self.vertex_proj(vertex_embed) # [B, F, E, 2, embed_dim] |
|
|
|
logger.print_tensor_stats('vertex_embed(after proj)', vertex_embed) |
|
|
|
# logger.print_tensor_stats('vertex_embed(after proj)', vertex_embed) |
|
|
|
vertex_embed = vertex_embed.mean(dim=3) # [B, F, E, embed_dim] |
|
|
|
|
|
|
|
# 2. 处理边特征 |
|
|
|
edge_embeds = self.edgez_embed(edge_ncs) # [B, F, E, embed_dim] |
|
|
|
logger.print_tensor_stats('edge_embeds', edge_embeds) |
|
|
|
# logger.print_tensor_stats('edge_embeds', edge_embeds) |
|
|
|
edge_p_embeds = self.edgep_embed(edge_pos) # [B, F, E, embed_dim] |
|
|
|
logger.print_tensor_stats('edge_p_embeds', edge_p_embeds) |
|
|
|
# logger.print_tensor_stats('edge_p_embeds', edge_p_embeds) |
|
|
|
|
|
|
|
# 3. 处理面特征 |
|
|
|
surf_embeds = self.surfz_embed(surf_ncs) # [B, F, embed_dim] |
|
|
|
logger.print_tensor_stats('surf_embeds', surf_embeds) |
|
|
|
# logger.print_tensor_stats('surf_embeds', surf_embeds) |
|
|
|
surf_p_embeds = self.surfp_embed(surf_pos) # [B, F, embed_dim] |
|
|
|
logger.print_tensor_stats('surf_p_embeds', surf_p_embeds) |
|
|
|
# logger.print_tensor_stats('surf_p_embeds', surf_p_embeds) |
|
|
|
|
|
|
|
# 4. 组合特征 |
|
|
|
if self.use_cf: |
|
|
|
# 组合边特征 |
|
|
|
edge_features = edge_embeds + edge_p_embeds + vertex_embed # [B, F, E, embed_dim] |
|
|
|
edge_features = edge_features.reshape(B, F*E, -1) # [B, F*E, embed_dim] |
|
|
|
logger.print_tensor_stats('edge_features', edge_features) |
|
|
|
# logger.print_tensor_stats('edge_features', edge_features) |
|
|
|
|
|
|
|
# 组合面特征 |
|
|
|
surf_features = surf_embeds + surf_p_embeds # [B, F, embed_dim] |
|
|
|
logger.print_tensor_stats('surf_features', surf_features) |
|
|
|
# logger.print_tensor_stats('surf_features', surf_features) |
|
|
|
|
|
|
|
# 拼接所有特征 |
|
|
|
embeds = torch.cat([ |
|
|
@ -300,11 +300,11 @@ class BRepFeatureEmbedder(nn.Module): |
|
|
|
mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, F*E+F] |
|
|
|
else: |
|
|
|
mask = None |
|
|
|
logger.debug(f"embeds shape: {embeds.shape}") |
|
|
|
#logger.debug(f"embeds shape: {embeds.shape}") |
|
|
|
# 6. Transformer处理 |
|
|
|
output = self.net(embeds.transpose(0, 1), src_key_padding_mask=mask) |
|
|
|
logger.print_tensor_stats('output', output) |
|
|
|
logger.debug(f"output shape: {output.shape}") |
|
|
|
# logger.print_tensor_stats('output', output) |
|
|
|
#logger.debug(f"output shape: {output.shape}") |
|
|
|
return output.transpose(0, 1) # [B, F*E+F, embed_dim] |
|
|
|
|
|
|
|
class SDFTransformer(nn.Module): |
|
|
|