|
|
@ -5,6 +5,7 @@ import torch.nn.functional as F |
|
|
|
from dataclasses import dataclass |
|
|
|
from typing import Dict, Optional, Tuple, Union |
|
|
|
from brep2sdf.config.default_config import get_default_config |
|
|
|
from brep2sdf.utils.logger import logger |
|
|
|
|
|
|
|
class ResConvBlock(nn.Module): |
|
|
|
"""残差卷积块""" |
|
|
@ -119,16 +120,24 @@ class Encoder1D(nn.Module): |
|
|
|
|
|
|
|
class BRepFeatureEmbedder(nn.Module): |
|
|
|
"""B-rep特征嵌入器""" |
|
|
|
def __init__(self, use_cf: bool = True): |
|
|
|
def __init__(self, config=None): |
|
|
|
super().__init__() |
|
|
|
# 获取配置 |
|
|
|
self.config = get_default_config() |
|
|
|
self.embed_dim = 768 |
|
|
|
self.use_cf = use_cf |
|
|
|
if config is None: |
|
|
|
self.config = get_default_config() |
|
|
|
else: |
|
|
|
self.config = config |
|
|
|
|
|
|
|
self.num_surf_points = self.config.model.num_surf_points |
|
|
|
self.num_edge_points = self.config.model.num_edge_points |
|
|
|
self.embed_dim = self.config.model.embed_dim |
|
|
|
self.use_cf = self.config.model.use_cf |
|
|
|
|
|
|
|
# 使用配置中的采样点数 |
|
|
|
self.num_surf_points = self.config.model.num_surf_points # 16 |
|
|
|
self.num_edge_points = self.config.model.num_edge_points # 4 |
|
|
|
# 打印初始化信息 |
|
|
|
logger.info(f"BRepFeatureEmbedder config:") |
|
|
|
logger.info(f" num_surf_points: {self.num_surf_points}") |
|
|
|
logger.info(f" num_edge_points: {self.num_edge_points}") |
|
|
|
logger.info(f" embed_dim: {self.embed_dim}") |
|
|
|
logger.info(f" use_cf: {self.use_cf}") |
|
|
|
|
|
|
|
# Transformer编码器层 |
|
|
|
layer = nn.TransformerEncoderLayer( |
|
|
@ -182,59 +191,93 @@ class BRepFeatureEmbedder(nn.Module): |
|
|
|
nn.Linear(self.embed_dim, self.embed_dim), |
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, mask=None): |
|
|
|
""" |
|
|
|
Args: |
|
|
|
surf_z: 表面点云 [B, N, num_surf_points, 3] |
|
|
|
edge_z: 边点云 [B, M, num_edge_points, 3] |
|
|
|
surf_p: 表面点 [B, N, 6] |
|
|
|
edge_p: 边点 [B, M, 6] |
|
|
|
vert_p: 顶点点 [B, K, 6] |
|
|
|
mask: 注意力掩码 |
|
|
|
""" |
|
|
|
# 获取批次大小和其他维度 |
|
|
|
B = surf_z.size(0) |
|
|
|
N = surf_z.size(1) |
|
|
|
M = edge_z.size(1) |
|
|
|
K = vert_p.size(1) |
|
|
|
|
|
|
|
# 重塑点云数据用于1D编码器 |
|
|
|
surf_z = surf_z.reshape(B*N, self.num_surf_points, 3).transpose(1, 2) # [B*N, 3, num_surf_points] |
|
|
|
edge_z = edge_z.reshape(B*M, self.num_edge_points, 3).transpose(1, 2) # [B*M, 3, num_edge_points] |
|
|
|
|
|
|
|
# 特征嵌入 |
|
|
|
surf_embeds = self.surfz_embed(surf_z) # [B*N, embed_dim, num_points] |
|
|
|
edge_embeds = self.edgez_embed(edge_z) # [B*M, embed_dim, num_points] |
|
|
|
def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, data_class=None): |
|
|
|
"""B-rep特征嵌入器的前向传播 |
|
|
|
|
|
|
|
# 全局池化得到每个面/边的特征 |
|
|
|
surf_embeds = surf_embeds.mean(dim=-1) # [B*N, embed_dim] |
|
|
|
edge_embeds = edge_embeds.mean(dim=-1) # [B*M, embed_dim] |
|
|
|
|
|
|
|
# 重塑回批次维度 |
|
|
|
surf_embeds = surf_embeds.reshape(B, N, -1) # [B, N, embed_dim] |
|
|
|
edge_embeds = edge_embeds.reshape(B, M, -1) # [B, M, embed_dim] |
|
|
|
Args: |
|
|
|
edge_ncs: 边归一化特征 [B, max_face, max_edge, num_edge_points, 3] |
|
|
|
edge_pos: 边位置 [B, max_face, max_edge, 6] |
|
|
|
edge_mask: 边掩码 [B, max_face, max_edge] |
|
|
|
surf_ncs: 面归一化特征 [B, max_face, num_surf_points, 3] |
|
|
|
surf_pos: 面位置 [B, max_face, 6] |
|
|
|
vertex_pos: 顶点位置 [B, max_face, max_edge, 2, 3] |
|
|
|
|
|
|
|
# 点嵌入 |
|
|
|
surf_p_embeds = self.surfp_embed(surf_p) # [B, N, embed_dim] |
|
|
|
edge_p_embeds = self.edgep_embed(edge_p) # [B, M, embed_dim] |
|
|
|
vert_p_embeds = self.vertp_embed(vert_p) # [B, K, embed_dim] |
|
|
|
Returns: |
|
|
|
embeds: [B, max_face*(max_edge+1), embed_dim] |
|
|
|
""" |
|
|
|
B = self.config.train.batch_size |
|
|
|
max_face = self.config.data.max_face |
|
|
|
max_edge = self.config.data.max_edge |
|
|
|
|
|
|
|
# 组合所有嵌入 |
|
|
|
if self.use_cf: |
|
|
|
embeds = torch.cat([ |
|
|
|
surf_embeds + surf_p_embeds, |
|
|
|
edge_embeds + edge_p_embeds, |
|
|
|
vert_p_embeds |
|
|
|
], dim=1) # [B, N+M+K, embed_dim] |
|
|
|
else: |
|
|
|
embeds = torch.cat([ |
|
|
|
surf_p_embeds, |
|
|
|
edge_p_embeds, |
|
|
|
vert_p_embeds |
|
|
|
], dim=1) # [B, N+M+K, embed_dim] |
|
|
|
try: |
|
|
|
# 1. 处理边特征 |
|
|
|
# 重塑边点云以适应1D编码器 |
|
|
|
edge_ncs = edge_ncs.reshape(B*max_face*max_edge, -1, 3).transpose(1, 2) # [B*max_face*max_edge, 3, num_edge_points] |
|
|
|
edge_embeds = self.edgez_embed(edge_ncs) # [B*max_face*max_edge, embed_dim, num_edge_points] |
|
|
|
edge_embeds = edge_embeds.mean(dim=-1) # [B*max_face*max_edge, embed_dim] |
|
|
|
edge_embeds = edge_embeds.reshape(B, max_face, max_edge, -1) # [B, max_face, max_edge, embed_dim] |
|
|
|
|
|
|
|
# 2. 处理面特征 |
|
|
|
surf_ncs = surf_ncs.reshape(B*max_face, -1, 3).transpose(1, 2) # [B*max_face, 3, num_surf_points] |
|
|
|
surf_embeds = self.surfz_embed(surf_ncs) # [B*max_face, embed_dim, num_surf_points] |
|
|
|
surf_embeds = surf_embeds.mean(dim=-1) # [B*max_face, embed_dim] |
|
|
|
surf_embeds = surf_embeds.reshape(B, max_face, -1) # [B, max_face, embed_dim] |
|
|
|
|
|
|
|
# 3. 处理位置编码 |
|
|
|
# 边位置编码 |
|
|
|
edge_pos = edge_pos.reshape(B*max_face*max_edge, -1) # [B*max_face*max_edge, 6] |
|
|
|
edge_p_embeds = self.edgep_embed(edge_pos) # [B*max_face*max_edge, embed_dim] |
|
|
|
edge_p_embeds = edge_p_embeds.reshape(B, max_face, max_edge, -1) # [B, max_face, max_edge, embed_dim] |
|
|
|
|
|
|
|
# 面位置编码 |
|
|
|
surf_p_embeds = self.surfp_embed(surf_pos) # [B, max_face, embed_dim] |
|
|
|
|
|
|
|
output = self.transformer(embeds, src_key_padding_mask=mask) |
|
|
|
return output |
|
|
|
# 4. 组合特征 |
|
|
|
if self.use_cf: |
|
|
|
# 边特征 |
|
|
|
edge_features = edge_embeds + edge_p_embeds # [B, max_face, max_edge, embed_dim] |
|
|
|
edge_features = edge_features.reshape(B, max_face*max_edge, -1) # [B, max_face*max_edge, embed_dim] |
|
|
|
|
|
|
|
# 面特征 |
|
|
|
surf_features = surf_embeds + surf_p_embeds # [B, max_face, embed_dim] |
|
|
|
|
|
|
|
# 组合所有特征 |
|
|
|
embeds = torch.cat([ |
|
|
|
edge_features, # [B, max_face*max_edge, embed_dim] |
|
|
|
surf_features # [B, max_face, embed_dim] |
|
|
|
], dim=1) # [B, max_face*(max_edge+1), embed_dim] |
|
|
|
else: |
|
|
|
# 只使用位置编码 |
|
|
|
edge_features = edge_p_embeds.reshape(B, max_face*max_edge, -1) # [B, max_face*max_edge, embed_dim] |
|
|
|
embeds = torch.cat([ |
|
|
|
edge_features, # [B, max_face*max_edge, embed_dim] |
|
|
|
surf_p_embeds # [B, max_face, embed_dim] |
|
|
|
], dim=1) # [B, max_face*(max_edge+1), embed_dim] |
|
|
|
|
|
|
|
# 5. 处理掩码 |
|
|
|
if edge_mask is not None: |
|
|
|
# 扩展掩码以匹配特征维度 |
|
|
|
edge_mask = edge_mask.reshape(B, -1) # [B, max_face*max_edge] |
|
|
|
surf_mask = torch.ones(B, max_face, device=edge_mask.device, dtype=torch.bool) # [B, max_face] |
|
|
|
mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, max_face*(max_edge+1)] |
|
|
|
else: |
|
|
|
mask = None |
|
|
|
|
|
|
|
# 6. Transformer处理 |
|
|
|
output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask) |
|
|
|
return output.transpose(0, 1) # 确保输出维度为 [B, seq_len, embed_dim] |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
logger.error(f"Error in BRepFeatureEmbedder forward pass:") |
|
|
|
logger.error(f" Error message: {str(e)}") |
|
|
|
logger.error(f" Input shapes:") |
|
|
|
logger.error(f" edge_ncs: {edge_ncs.shape}") |
|
|
|
logger.error(f" edge_pos: {edge_pos.shape}") |
|
|
|
logger.error(f" edge_mask: {edge_mask.shape}") |
|
|
|
logger.error(f" surf_ncs: {surf_ncs.shape}") |
|
|
|
logger.error(f" surf_pos: {surf_pos.shape}") |
|
|
|
logger.error(f" vertex_pos: {vertex_pos.shape}") |
|
|
|
raise |
|
|
|
|
|
|
|
class SDFTransformer(nn.Module): |
|
|
|
"""SDF Transformer编码器""" |
|
|
@ -296,7 +339,7 @@ class BRepToSDF(nn.Module): |
|
|
|
) |
|
|
|
|
|
|
|
# 2. B-rep特征编码器 |
|
|
|
self.feature_embedder = BRepFeatureEmbedder(use_cf=use_cf) |
|
|
|
self.brep_embedder = BRepFeatureEmbedder() |
|
|
|
|
|
|
|
# 3. 特征融合Transformer |
|
|
|
self.transformer = SDFTransformer( |
|
|
@ -307,45 +350,68 @@ class BRepToSDF(nn.Module): |
|
|
|
# 4. SDF预测头 |
|
|
|
self.sdf_head = SDFHead(embed_dim=embed_dim*2) |
|
|
|
|
|
|
|
def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, query_points, mask=None): |
|
|
|
""" |
|
|
|
def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, query_points, data_class=None): |
|
|
|
"""B-rep到SDF的前向传播 |
|
|
|
|
|
|
|
Args: |
|
|
|
surf_z: 表面特征 [B, N, 48] |
|
|
|
edge_z: 边特征 [B, M, 12] |
|
|
|
surf_p: 表面点 [B, N, 6] |
|
|
|
edge_p: 边点 [B, M, 6] |
|
|
|
vert_p: 顶点点 [B, K, 6] |
|
|
|
query_points: 查询点 [B, Q, 3] |
|
|
|
mask: 注意力掩码 |
|
|
|
edge_ncs: 边归一化特征 [B, max_face, max_edge, num_edge_points, 3] |
|
|
|
edge_pos: 边位置 [B, max_face, max_edge, 6] |
|
|
|
edge_mask: 边掩码 [B, max_face, max_edge] |
|
|
|
surf_ncs: 面归一化特征 [B, max_face, num_surf_points, 3] |
|
|
|
surf_pos: 面位置 [B, max_face, 6] |
|
|
|
vertex_pos: 顶点位置 [B, max_face, max_edge, 2, 3] |
|
|
|
query_points: 查询点 [B, num_queries, 3] |
|
|
|
data_class: (可选) 类别标签 |
|
|
|
|
|
|
|
Returns: |
|
|
|
sdf: [B, Q, 1] |
|
|
|
sdf: 预测的SDF值 [B, num_queries, 1] |
|
|
|
""" |
|
|
|
B, Q, _ = query_points.shape |
|
|
|
|
|
|
|
# 1. B-rep特征嵌入 |
|
|
|
brep_features = self.feature_embedder( |
|
|
|
surf_z, edge_z, surf_p, edge_p, vert_p, mask |
|
|
|
) # [B, N+M+K, embed_dim] |
|
|
|
|
|
|
|
# 2. 查询点编码 |
|
|
|
query_features = self.query_encoder(query_points) # [B, Q, embed_dim] |
|
|
|
|
|
|
|
# 3. 提取全局特征 |
|
|
|
global_features = brep_features.mean(dim=1) # [B, embed_dim] |
|
|
|
|
|
|
|
# 4. 为每个查询点准备特征 |
|
|
|
expanded_features = global_features.unsqueeze(1).expand(-1, Q, -1) # [B, Q, embed_dim] |
|
|
|
|
|
|
|
# 5. 连接查询点特征和全局特征 |
|
|
|
combined_features = torch.cat([ |
|
|
|
expanded_features, # [B, Q, embed_dim] |
|
|
|
query_features # [B, Q, embed_dim] |
|
|
|
], dim=-1) # [B, Q, embed_dim*2] |
|
|
|
B, Q = query_points.shape[:2] # B: batch_size, Q: num_queries |
|
|
|
|
|
|
|
# 6. SDF预测 |
|
|
|
sdf = self.sdf_head(combined_features) # [B, Q, 1] |
|
|
|
|
|
|
|
return sdf |
|
|
|
try: |
|
|
|
# 1. B-rep特征编码 |
|
|
|
brep_features = self.brep_embedder( |
|
|
|
edge_ncs=edge_ncs, # [B, max_face, max_edge, num_edge_points, 3] |
|
|
|
edge_pos=edge_pos, # [B, max_face, max_edge, 6] |
|
|
|
edge_mask=edge_mask, # [B, max_face, max_edge] |
|
|
|
surf_ncs=surf_ncs, # [B, max_face, num_surf_points, 3] |
|
|
|
surf_pos=surf_pos, # [B, max_face, 6] |
|
|
|
vertex_pos=vertex_pos, # [B, max_face, max_edge, 2, 3] |
|
|
|
data_class=data_class |
|
|
|
) # [B, max_face*(max_edge+1), embed_dim] |
|
|
|
|
|
|
|
# 2. 查询点编码 |
|
|
|
query_features = self.query_encoder(query_points) # [B, Q, embed_dim] |
|
|
|
|
|
|
|
# 3. 提取全局特征 |
|
|
|
global_features = brep_features.mean(dim=1) # [B, embed_dim] |
|
|
|
|
|
|
|
# 4. 为每个查询点准备特征 |
|
|
|
expanded_features = global_features.unsqueeze(1).expand(-1, Q, -1) # [B, Q, embed_dim] |
|
|
|
|
|
|
|
# 5. 连接查询点特征和全局特征 |
|
|
|
combined_features = torch.cat([ |
|
|
|
expanded_features, # [B, Q, embed_dim] |
|
|
|
query_features # [B, Q, embed_dim] |
|
|
|
], dim=-1) # [B, Q, embed_dim*2] |
|
|
|
|
|
|
|
# 6. SDF预测 |
|
|
|
sdf = self.sdf_head(combined_features) # [B, Q, 1] |
|
|
|
|
|
|
|
return sdf |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
logger.error(f"Error in BRepToSDF forward pass:") |
|
|
|
logger.error(f" Error message: {str(e)}") |
|
|
|
logger.error(f" Input shapes:") |
|
|
|
logger.error(f" edge_ncs: {edge_ncs.shape}") |
|
|
|
logger.error(f" edge_pos: {edge_pos.shape}") |
|
|
|
logger.error(f" edge_mask: {edge_mask.shape}") |
|
|
|
logger.error(f" surf_ncs: {surf_ncs.shape}") |
|
|
|
logger.error(f" surf_pos: {surf_pos.shape}") |
|
|
|
logger.error(f" vertex_pos: {vertex_pos.shape}") |
|
|
|
logger.error(f" query_points: {query_points.shape}") |
|
|
|
raise |
|
|
|
|
|
|
|
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): |
|
|
|
"""SDF损失函数""" |
|
|
@ -384,47 +450,73 @@ def main(): |
|
|
|
num_verts = 8 # 顶点数保持固定 |
|
|
|
num_queries = 1000 # 查询点数保持固定 |
|
|
|
|
|
|
|
# 生成示例数据 |
|
|
|
surf_z = torch.randn( |
|
|
|
batch_size, |
|
|
|
num_surfs, |
|
|
|
config.model.num_surf_points, # 16 |
|
|
|
# 更新测试数据维度 |
|
|
|
edge_ncs = torch.randn( |
|
|
|
batch_size, |
|
|
|
num_surfs, # max_face |
|
|
|
num_edges, # max_edge |
|
|
|
config.model.num_edge_points, |
|
|
|
3 |
|
|
|
) # [B, max_face, max_edge, num_edge_points, 3] |
|
|
|
|
|
|
|
edge_pos = torch.randn( |
|
|
|
batch_size, |
|
|
|
num_surfs, |
|
|
|
num_edges, |
|
|
|
6 |
|
|
|
) # [B, max_face, max_edge, 6] |
|
|
|
|
|
|
|
edge_mask = torch.ones( |
|
|
|
batch_size, |
|
|
|
num_surfs, |
|
|
|
num_edges, |
|
|
|
dtype=torch.bool |
|
|
|
) # [B, max_face, max_edge] |
|
|
|
|
|
|
|
surf_ncs = torch.randn( |
|
|
|
batch_size, |
|
|
|
num_surfs, |
|
|
|
config.model.num_surf_points, |
|
|
|
3 |
|
|
|
) # [B, N, num_surf_points, 3] |
|
|
|
) # [B, max_face, num_surf_points, 3] |
|
|
|
|
|
|
|
surf_pos = torch.randn( |
|
|
|
batch_size, |
|
|
|
num_surfs, |
|
|
|
6 |
|
|
|
) # [B, max_face, 6] |
|
|
|
|
|
|
|
edge_z = torch.randn( |
|
|
|
batch_size, |
|
|
|
num_edges, |
|
|
|
config.model.num_edge_points, # 4 |
|
|
|
vertex_pos = torch.randn( |
|
|
|
batch_size, |
|
|
|
num_surfs, |
|
|
|
num_edges, |
|
|
|
2, |
|
|
|
3 |
|
|
|
) # [B, M, num_edge_points, 3] |
|
|
|
) # [B, max_face, max_edge, 2, 3] |
|
|
|
|
|
|
|
# 其他输入 |
|
|
|
surf_p = torch.randn(batch_size, num_surfs, 6) |
|
|
|
edge_p = torch.randn(batch_size, num_edges, 6) |
|
|
|
vert_p = torch.randn(batch_size, num_verts, 6) |
|
|
|
query_points = torch.randn(batch_size, num_queries, 3) |
|
|
|
|
|
|
|
# 前向传播 |
|
|
|
sdf = model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) |
|
|
|
|
|
|
|
# 打印形状信息和配置信息 |
|
|
|
print("\nConfiguration:") |
|
|
|
print(f"Batch Size: {batch_size}") |
|
|
|
print(f"Embed Dimension: {config.model.embed_dim}") |
|
|
|
print(f"Surface Points: {config.model.num_surf_points}") |
|
|
|
print(f"Edge Points: {config.model.num_edge_points}") |
|
|
|
print(f"Max Faces: {config.data.max_face}") |
|
|
|
print(f"Max Edges: {config.data.max_edge}") |
|
|
|
# 更新前向传播调用 |
|
|
|
sdf = model( |
|
|
|
edge_ncs=edge_ncs, |
|
|
|
edge_pos=edge_pos, |
|
|
|
edge_mask=edge_mask, |
|
|
|
surf_ncs=surf_ncs, |
|
|
|
surf_pos=surf_pos, |
|
|
|
vertex_pos=vertex_pos, |
|
|
|
query_points=query_points |
|
|
|
) |
|
|
|
|
|
|
|
# 更新打印信息 |
|
|
|
print("\nInput shapes:") |
|
|
|
print(f"surf_z: {surf_z.shape}") # [32, 64, 16, 3] |
|
|
|
print(f"edge_z: {edge_z.shape}") # [32, 64, 4, 3] |
|
|
|
print(f"surf_p: {surf_p.shape}") # [32, 64, 6] |
|
|
|
print(f"edge_p: {edge_p.shape}") # [32, 64, 6] |
|
|
|
print(f"vert_p: {vert_p.shape}") # [32, 8, 6] |
|
|
|
print(f"query_points: {query_points.shape}") # [32, 1000, 3] |
|
|
|
print(f"\nOutput SDF shape: {sdf.shape}") # [32, 1000, 1] |
|
|
|
print(f"edge_ncs: {edge_ncs.shape}") |
|
|
|
print(f"edge_pos: {edge_pos.shape}") |
|
|
|
print(f"edge_mask: {edge_mask.shape}") |
|
|
|
print(f"surf_ncs: {surf_ncs.shape}") |
|
|
|
print(f"surf_pos: {surf_pos.shape}") |
|
|
|
print(f"vertex_pos: {vertex_pos.shape}") |
|
|
|
print(f"query_points: {query_points.shape}") |
|
|
|
print(f"\nOutput SDF shape: {sdf.shape}") |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
main() |