|
|
@ -130,6 +130,7 @@ class BRepFeatureEmbedder(nn.Module): |
|
|
|
self.num_surf_points = self.config.model.num_surf_points # 16 |
|
|
|
self.num_edge_points = self.config.model.num_edge_points # 4 |
|
|
|
|
|
|
|
# Transformer编码器层 |
|
|
|
layer = nn.TransformerEncoderLayer( |
|
|
|
d_model=self.embed_dim, |
|
|
|
nhead=12, |
|
|
@ -144,21 +145,22 @@ class BRepFeatureEmbedder(nn.Module): |
|
|
|
enable_nested_tensor=False |
|
|
|
) |
|
|
|
|
|
|
|
# 修改输入维度以匹配采样点数 |
|
|
|
self.surfz_embed = nn.Sequential( |
|
|
|
nn.Linear(3 * self.num_surf_points, self.embed_dim), # 3 * 16 |
|
|
|
nn.LayerNorm(self.embed_dim), |
|
|
|
nn.SiLU(), |
|
|
|
nn.Linear(self.embed_dim, self.embed_dim), |
|
|
|
# 修改为处理[num_points, 3]形状的输入 |
|
|
|
self.surfz_embed = Encoder1D( |
|
|
|
in_channels=3, |
|
|
|
out_channels=self.embed_dim, |
|
|
|
block_out_channels=(64, 128, 256), |
|
|
|
layers_per_block=2 |
|
|
|
) |
|
|
|
|
|
|
|
self.edgez_embed = nn.Sequential( |
|
|
|
nn.Linear(3 * self.num_edge_points, self.embed_dim), # 3 * 4 |
|
|
|
nn.LayerNorm(self.embed_dim), |
|
|
|
nn.SiLU(), |
|
|
|
nn.Linear(self.embed_dim, self.embed_dim), |
|
|
|
self.edgez_embed = Encoder1D( |
|
|
|
in_channels=3, |
|
|
|
out_channels=self.embed_dim, |
|
|
|
block_out_channels=(64, 128, 256), |
|
|
|
layers_per_block=2 |
|
|
|
) |
|
|
|
|
|
|
|
# 其他嵌入层保持不变 |
|
|
|
self.surfp_embed = nn.Sequential( |
|
|
|
nn.Linear(6, self.embed_dim), |
|
|
|
nn.LayerNorm(self.embed_dim), |
|
|
@ -183,21 +185,37 @@ class BRepFeatureEmbedder(nn.Module): |
|
|
|
def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, mask=None): |
|
|
|
""" |
|
|
|
Args: |
|
|
|
surf_z: 表面特征 [B, N, num_surf_points*3] |
|
|
|
edge_z: 边特征 [B, M, num_edge_points*3] |
|
|
|
surf_z: 表面点云 [B, N, num_surf_points, 3] |
|
|
|
edge_z: 边点云 [B, M, num_edge_points, 3] |
|
|
|
surf_p: 表面点 [B, N, 6] |
|
|
|
edge_p: 边点 [B, M, 6] |
|
|
|
vert_p: 顶点点 [B, K, 6] |
|
|
|
mask: 注意力掩码 |
|
|
|
""" |
|
|
|
B, N, _, _ = surf_z.shape |
|
|
|
_, M, _, _ = edge_z.shape |
|
|
|
_, K, _ = vert_p.shape |
|
|
|
|
|
|
|
# 重塑点云数据用于1D编码器 |
|
|
|
surf_z = surf_z.reshape(B*N, 3, self.num_surf_points) # [B*N, 3, num_surf_points] |
|
|
|
edge_z = edge_z.reshape(B*M, 3, self.num_edge_points) # [B*M, 3, num_edge_points] |
|
|
|
|
|
|
|
# 特征嵌入 |
|
|
|
surf_embeds = self.surfz_embed(surf_z) |
|
|
|
edge_embeds = self.edgez_embed(edge_z) |
|
|
|
surf_embeds = self.surfz_embed(surf_z) # [B*N, embed_dim, num_points] |
|
|
|
edge_embeds = self.edgez_embed(edge_z) # [B*M, embed_dim, num_points] |
|
|
|
|
|
|
|
# 全局池化得到每个面/边的特征 |
|
|
|
surf_embeds = surf_embeds.mean(dim=-1) # [B*N, embed_dim] |
|
|
|
edge_embeds = edge_embeds.mean(dim=-1) # [B*M, embed_dim] |
|
|
|
|
|
|
|
# 重塑回批次维度 |
|
|
|
surf_embeds = surf_embeds.reshape(B, N, -1) # [B, N, embed_dim] |
|
|
|
edge_embeds = edge_embeds.reshape(B, M, -1) # [B, M, embed_dim] |
|
|
|
|
|
|
|
# 点嵌入 |
|
|
|
surf_p_embeds = self.surfp_embed(surf_p) |
|
|
|
edge_p_embeds = self.edgep_embed(edge_p) |
|
|
|
vert_p_embeds = self.vertp_embed(vert_p) |
|
|
|
surf_p_embeds = self.surfp_embed(surf_p) # [B, N, embed_dim] |
|
|
|
edge_p_embeds = self.edgep_embed(edge_p) # [B, M, embed_dim] |
|
|
|
vert_p_embeds = self.vertp_embed(vert_p) # [B, K, embed_dim] |
|
|
|
|
|
|
|
# 组合所有嵌入 |
|
|
|
if self.use_cf: |
|
|
@ -205,13 +223,13 @@ class BRepFeatureEmbedder(nn.Module): |
|
|
|
surf_embeds + surf_p_embeds, |
|
|
|
edge_embeds + edge_p_embeds, |
|
|
|
vert_p_embeds |
|
|
|
], dim=1) |
|
|
|
], dim=1) # [B, N+M+K, embed_dim] |
|
|
|
else: |
|
|
|
embeds = torch.cat([ |
|
|
|
surf_p_embeds, |
|
|
|
edge_p_embeds, |
|
|
|
vert_p_embeds |
|
|
|
], dim=1) |
|
|
|
], dim=1) # [B, N+M+K, embed_dim] |
|
|
|
|
|
|
|
output = self.transformer(embeds, src_key_padding_mask=mask) |
|
|
|
return output |
|
|
@ -290,8 +308,8 @@ class BRepToSDF(nn.Module): |
|
|
|
def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, query_points, mask=None): |
|
|
|
""" |
|
|
|
Args: |
|
|
|
surf_z: 表面特征 [B, N, num_surf_points*3] |
|
|
|
edge_z: 边特征 [B, M, num_edge_points*3] |
|
|
|
surf_z: 表面特征 [B, N, 48] |
|
|
|
edge_z: 边特征 [B, M, 12] |
|
|
|
surf_p: 表面点 [B, N, 6] |
|
|
|
edge_p: 边点 [B, M, 6] |
|
|
|
vert_p: 顶点点 [B, K, 6] |
|
|
@ -346,24 +364,40 @@ def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): |
|
|
|
return l1_loss + grad_weight * grad_constraint |
|
|
|
|
|
|
|
def main(): |
|
|
|
# 初始化模型 |
|
|
|
# 获取配置 |
|
|
|
config = get_default_config() |
|
|
|
|
|
|
|
# 从配置初始化模型 |
|
|
|
model = BRepToSDF( |
|
|
|
brep_feature_dim=48, |
|
|
|
use_cf=True, |
|
|
|
embed_dim=768, |
|
|
|
latent_dim=256 |
|
|
|
brep_feature_dim=config.model.brep_feature_dim, # 48 |
|
|
|
use_cf=config.model.use_cf, # True |
|
|
|
embed_dim=config.model.embed_dim, # 768 |
|
|
|
latent_dim=config.model.latent_dim # 256 |
|
|
|
) |
|
|
|
|
|
|
|
# 示例输入 |
|
|
|
batch_size = 4 |
|
|
|
num_surfs = 10 |
|
|
|
num_edges = 20 |
|
|
|
num_verts = 8 |
|
|
|
num_queries = 1000 |
|
|
|
# 从配置获取数据参数 |
|
|
|
batch_size = config.train.batch_size # 32 |
|
|
|
num_surfs = config.data.max_face # 64 |
|
|
|
num_edges = config.data.max_edge # 64 |
|
|
|
num_verts = 8 # 顶点数保持固定 |
|
|
|
num_queries = 1000 # 查询点数保持固定 |
|
|
|
|
|
|
|
# 生成示例数据 |
|
|
|
surf_z = torch.randn(batch_size, num_surfs, 48) |
|
|
|
edge_z = torch.randn(batch_size, num_edges, 12) |
|
|
|
surf_z = torch.randn( |
|
|
|
batch_size, |
|
|
|
num_surfs, |
|
|
|
config.model.num_surf_points, # 16 |
|
|
|
3 |
|
|
|
) # [B, N, num_surf_points, 3] |
|
|
|
|
|
|
|
edge_z = torch.randn( |
|
|
|
batch_size, |
|
|
|
num_edges, |
|
|
|
config.model.num_edge_points, # 4 |
|
|
|
3 |
|
|
|
) # [B, M, num_edge_points, 3] |
|
|
|
|
|
|
|
# 其他输入 |
|
|
|
surf_p = torch.randn(batch_size, num_surfs, 6) |
|
|
|
edge_p = torch.randn(batch_size, num_edges, 6) |
|
|
|
vert_p = torch.randn(batch_size, num_verts, 6) |
|
|
@ -371,7 +405,24 @@ def main(): |
|
|
|
|
|
|
|
# 前向传播 |
|
|
|
sdf = model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) |
|
|
|
print(f"Output SDF shape: {sdf.shape}") |
|
|
|
|
|
|
|
# 打印形状信息和配置信息 |
|
|
|
print("\nConfiguration:") |
|
|
|
print(f"Batch Size: {batch_size}") |
|
|
|
print(f"Embed Dimension: {config.model.embed_dim}") |
|
|
|
print(f"Surface Points: {config.model.num_surf_points}") |
|
|
|
print(f"Edge Points: {config.model.num_edge_points}") |
|
|
|
print(f"Max Faces: {config.data.max_face}") |
|
|
|
print(f"Max Edges: {config.data.max_edge}") |
|
|
|
|
|
|
|
print("\nInput shapes:") |
|
|
|
print(f"surf_z: {surf_z.shape}") # [32, 64, 16, 3] |
|
|
|
print(f"edge_z: {edge_z.shape}") # [32, 64, 4, 3] |
|
|
|
print(f"surf_p: {surf_p.shape}") # [32, 64, 6] |
|
|
|
print(f"edge_p: {edge_p.shape}") # [32, 64, 6] |
|
|
|
print(f"vert_p: {vert_p.shape}") # [32, 8, 6] |
|
|
|
print(f"query_points: {query_points.shape}") # [32, 1000, 3] |
|
|
|
print(f"\nOutput SDF shape: {sdf.shape}") # [32, 1000, 1] |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
main() |