|
|
@ -315,27 +315,29 @@ class SDFHead(nn.Module): |
|
|
|
return self.mlp(x) |
|
|
|
|
|
|
|
class BRepToSDF(nn.Module): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
brep_feature_dim: int = 48, |
|
|
|
use_cf: bool = True, |
|
|
|
embed_dim: int = 768, |
|
|
|
latent_dim: int = 256 |
|
|
|
): |
|
|
|
def __init__(self, config=None): |
|
|
|
super().__init__() |
|
|
|
# 获取配置 |
|
|
|
self.config = get_default_config() |
|
|
|
self.embed_dim = embed_dim |
|
|
|
if config is None: |
|
|
|
self.config = get_default_config() |
|
|
|
else: |
|
|
|
self.config = config |
|
|
|
|
|
|
|
# 从配置中读取参数 |
|
|
|
self.embed_dim = self.config.model.embed_dim |
|
|
|
self.brep_feature_dim = self.config.model.brep_feature_dim |
|
|
|
self.latent_dim = self.config.model.latent_dim |
|
|
|
self.use_cf = self.config.model.use_cf |
|
|
|
|
|
|
|
# 1. 查询点编码器 |
|
|
|
self.query_encoder = nn.Sequential( |
|
|
|
nn.Linear(3, embed_dim//4), |
|
|
|
nn.LayerNorm(embed_dim//4), |
|
|
|
nn.Linear(3, self.embed_dim//4), |
|
|
|
nn.LayerNorm(self.embed_dim//4), |
|
|
|
nn.ReLU(), |
|
|
|
nn.Linear(embed_dim//4, embed_dim//2), |
|
|
|
nn.LayerNorm(embed_dim//2), |
|
|
|
nn.Linear(self.embed_dim//4, self.embed_dim//2), |
|
|
|
nn.LayerNorm(self.embed_dim//2), |
|
|
|
nn.ReLU(), |
|
|
|
nn.Linear(embed_dim//2, embed_dim) |
|
|
|
nn.Linear(self.embed_dim//2, self.embed_dim) |
|
|
|
) |
|
|
|
|
|
|
|
# 2. B-rep特征编码器 |
|
|
@ -343,12 +345,12 @@ class BRepToSDF(nn.Module): |
|
|
|
|
|
|
|
# 3. 特征融合Transformer |
|
|
|
self.transformer = SDFTransformer( |
|
|
|
embed_dim=embed_dim, |
|
|
|
num_layers=6 |
|
|
|
embed_dim=self.embed_dim, |
|
|
|
num_layers=6 # 这个参数也可以移到配置文件中 |
|
|
|
) |
|
|
|
|
|
|
|
# 4. SDF预测头 |
|
|
|
self.sdf_head = SDFHead(embed_dim=embed_dim*2) |
|
|
|
self.sdf_head = SDFHead(embed_dim=self.embed_dim*2) |
|
|
|
|
|
|
|
def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, query_points, data_class=None): |
|
|
|
"""B-rep到SDF的前向传播 |
|
|
@ -435,88 +437,47 @@ def main(): |
|
|
|
# 获取配置 |
|
|
|
config = get_default_config() |
|
|
|
|
|
|
|
# 从配置初始化模型 |
|
|
|
model = BRepToSDF( |
|
|
|
brep_feature_dim=config.model.brep_feature_dim, # 48 |
|
|
|
use_cf=config.model.use_cf, # True |
|
|
|
embed_dim=config.model.embed_dim, # 768 |
|
|
|
latent_dim=config.model.latent_dim # 256 |
|
|
|
) |
|
|
|
|
|
|
|
# 从配置获取数据参数 |
|
|
|
batch_size = config.train.batch_size # 32 |
|
|
|
num_surfs = config.data.max_face # 64 |
|
|
|
num_edges = config.data.max_edge # 64 |
|
|
|
num_verts = 8 # 顶点数保持固定 |
|
|
|
num_queries = 1000 # 查询点数保持固定 |
|
|
|
|
|
|
|
# 更新测试数据维度 |
|
|
|
edge_ncs = torch.randn( |
|
|
|
batch_size, |
|
|
|
num_surfs, # max_face |
|
|
|
num_edges, # max_edge |
|
|
|
config.model.num_edge_points, |
|
|
|
3 |
|
|
|
) # [B, max_face, max_edge, num_edge_points, 3] |
|
|
|
|
|
|
|
edge_pos = torch.randn( |
|
|
|
batch_size, |
|
|
|
num_surfs, |
|
|
|
num_edges, |
|
|
|
6 |
|
|
|
) # [B, max_face, max_edge, 6] |
|
|
|
|
|
|
|
edge_mask = torch.ones( |
|
|
|
batch_size, |
|
|
|
num_surfs, |
|
|
|
num_edges, |
|
|
|
dtype=torch.bool |
|
|
|
) # [B, max_face, max_edge] |
|
|
|
# 初始化模型 |
|
|
|
model = BRepToSDF(config=config) |
|
|
|
|
|
|
|
surf_ncs = torch.randn( |
|
|
|
batch_size, |
|
|
|
num_surfs, |
|
|
|
config.model.num_surf_points, |
|
|
|
3 |
|
|
|
) # [B, max_face, num_surf_points, 3] |
|
|
|
# 从配置获取参数 |
|
|
|
batch_size = config.train.batch_size |
|
|
|
max_face = config.data.max_face |
|
|
|
max_edge = config.data.max_edge |
|
|
|
num_surf_points = config.model.num_surf_points |
|
|
|
num_edge_points = config.model.num_edge_points |
|
|
|
|
|
|
|
surf_pos = torch.randn( |
|
|
|
batch_size, |
|
|
|
num_surfs, |
|
|
|
6 |
|
|
|
) # [B, max_face, 6] |
|
|
|
# 生成测试数据 |
|
|
|
test_data = { |
|
|
|
'edge_ncs': torch.randn(batch_size, max_face, max_edge, num_edge_points, 3), |
|
|
|
'edge_pos': torch.randn(batch_size, max_face, max_edge, 6), |
|
|
|
'edge_mask': torch.ones(batch_size, max_face, max_edge, dtype=torch.bool), |
|
|
|
'surf_ncs': torch.randn(batch_size, max_face, num_surf_points, 3), |
|
|
|
'surf_pos': torch.randn(batch_size, max_face, 6), |
|
|
|
'vertex_pos': torch.randn(batch_size, max_face, max_edge, 2, 3), |
|
|
|
'query_points': torch.randn(batch_size, 1000, 3) # 1000个查询点 |
|
|
|
} |
|
|
|
|
|
|
|
vertex_pos = torch.randn( |
|
|
|
batch_size, |
|
|
|
num_surfs, |
|
|
|
num_edges, |
|
|
|
2, |
|
|
|
3 |
|
|
|
) # [B, max_face, max_edge, 2, 3] |
|
|
|
# 打印输入数据形状 |
|
|
|
logger.info("Input shapes:") |
|
|
|
for name, tensor in test_data.items(): |
|
|
|
logger.info(f" {name}: {tensor.shape}") |
|
|
|
|
|
|
|
query_points = torch.randn(batch_size, num_queries, 3) |
|
|
|
|
|
|
|
# 更新前向传播调用 |
|
|
|
sdf = model( |
|
|
|
edge_ncs=edge_ncs, |
|
|
|
edge_pos=edge_pos, |
|
|
|
edge_mask=edge_mask, |
|
|
|
surf_ncs=surf_ncs, |
|
|
|
surf_pos=surf_pos, |
|
|
|
vertex_pos=vertex_pos, |
|
|
|
query_points=query_points |
|
|
|
) |
|
|
|
|
|
|
|
# 更新打印信息 |
|
|
|
print("\nInput shapes:") |
|
|
|
print(f"edge_ncs: {edge_ncs.shape}") |
|
|
|
print(f"edge_pos: {edge_pos.shape}") |
|
|
|
print(f"edge_mask: {edge_mask.shape}") |
|
|
|
print(f"surf_ncs: {surf_ncs.shape}") |
|
|
|
print(f"surf_pos: {surf_pos.shape}") |
|
|
|
print(f"vertex_pos: {vertex_pos.shape}") |
|
|
|
print(f"query_points: {query_points.shape}") |
|
|
|
print(f"\nOutput SDF shape: {sdf.shape}") |
|
|
|
# 前向传播 |
|
|
|
try: |
|
|
|
sdf = model(**test_data) |
|
|
|
logger.info(f"\nOutput SDF shape: {sdf.shape}") |
|
|
|
|
|
|
|
# 计算模型参数量 |
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
logger.info(f"\nModel statistics:") |
|
|
|
logger.info(f" Total parameters: {total_params:,}") |
|
|
|
logger.info(f" Trainable parameters: {trainable_params:,}") |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
logger.error(f"Error during forward pass: {str(e)}") |
|
|
|
raise |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
main() |