diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py index 12c0a18..f5a0132 100644 --- a/brep2sdf/config/default_config.py +++ b/brep2sdf/config/default_config.py @@ -4,14 +4,24 @@ from typing import Tuple, Optional @dataclass class ModelConfig: """模型相关配置""" - brep_feature_dim: int = 48 + brep_feature_dim: int = 32 use_cf: bool = True - embed_dim: int = 768 - latent_dim: int = 256 + embed_dim: int = 384 # 3 的 倍数 + latent_dim: int = 64 # 点云采样配置 num_surf_points: int = 16 # 每个面采样点数 num_edge_points: int = 4 # 每条边采样点数 + + # Transformer相关配置 + num_transformer_layers: int = 6 + num_attention_heads: int = 8 + transformer_dim_feedforward: int = 512 + transformer_dropout: float = 0.1 + + # 编码器配置 + encoder_channels: Tuple[int] = (32, 64, 128) + encoder_layers_per_block: int = 1 @dataclass class DataConfig: @@ -40,7 +50,7 @@ class DataConfig: class TrainConfig: """训练相关配置""" # 基本训练参数 - batch_size: int = 1 + batch_size: int = 8 num_workers: int = 4 num_epochs: int = 100 learning_rate: float = 1e-4 diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index fb16b6c..2f336a3 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -315,27 +315,29 @@ class SDFHead(nn.Module): return self.mlp(x) class BRepToSDF(nn.Module): - def __init__( - self, - brep_feature_dim: int = 48, - use_cf: bool = True, - embed_dim: int = 768, - latent_dim: int = 256 - ): + def __init__(self, config=None): super().__init__() # 获取配置 - self.config = get_default_config() - self.embed_dim = embed_dim + if config is None: + self.config = get_default_config() + else: + self.config = config + + # 从配置中读取参数 + self.embed_dim = self.config.model.embed_dim + self.brep_feature_dim = self.config.model.brep_feature_dim + self.latent_dim = self.config.model.latent_dim + self.use_cf = self.config.model.use_cf # 1. 查询点编码器 self.query_encoder = nn.Sequential( - nn.Linear(3, embed_dim//4), - nn.LayerNorm(embed_dim//4), + nn.Linear(3, self.embed_dim//4), + nn.LayerNorm(self.embed_dim//4), nn.ReLU(), - nn.Linear(embed_dim//4, embed_dim//2), - nn.LayerNorm(embed_dim//2), + nn.Linear(self.embed_dim//4, self.embed_dim//2), + nn.LayerNorm(self.embed_dim//2), nn.ReLU(), - nn.Linear(embed_dim//2, embed_dim) + nn.Linear(self.embed_dim//2, self.embed_dim) ) # 2. B-rep特征编码器 @@ -343,12 +345,12 @@ class BRepToSDF(nn.Module): # 3. 特征融合Transformer self.transformer = SDFTransformer( - embed_dim=embed_dim, - num_layers=6 + embed_dim=self.embed_dim, + num_layers=6 # 这个参数也可以移到配置文件中 ) # 4. SDF预测头 - self.sdf_head = SDFHead(embed_dim=embed_dim*2) + self.sdf_head = SDFHead(embed_dim=self.embed_dim*2) def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, query_points, data_class=None): """B-rep到SDF的前向传播 @@ -435,88 +437,47 @@ def main(): # 获取配置 config = get_default_config() - # 从配置初始化模型 - model = BRepToSDF( - brep_feature_dim=config.model.brep_feature_dim, # 48 - use_cf=config.model.use_cf, # True - embed_dim=config.model.embed_dim, # 768 - latent_dim=config.model.latent_dim # 256 - ) - - # 从配置获取数据参数 - batch_size = config.train.batch_size # 32 - num_surfs = config.data.max_face # 64 - num_edges = config.data.max_edge # 64 - num_verts = 8 # 顶点数保持固定 - num_queries = 1000 # 查询点数保持固定 - - # 更新测试数据维度 - edge_ncs = torch.randn( - batch_size, - num_surfs, # max_face - num_edges, # max_edge - config.model.num_edge_points, - 3 - ) # [B, max_face, max_edge, num_edge_points, 3] - - edge_pos = torch.randn( - batch_size, - num_surfs, - num_edges, - 6 - ) # [B, max_face, max_edge, 6] - - edge_mask = torch.ones( - batch_size, - num_surfs, - num_edges, - dtype=torch.bool - ) # [B, max_face, max_edge] + # 初始化模型 + model = BRepToSDF(config=config) - surf_ncs = torch.randn( - batch_size, - num_surfs, - config.model.num_surf_points, - 3 - ) # [B, max_face, num_surf_points, 3] + # 从配置获取参数 + batch_size = config.train.batch_size + max_face = config.data.max_face + max_edge = config.data.max_edge + num_surf_points = config.model.num_surf_points + num_edge_points = config.model.num_edge_points - surf_pos = torch.randn( - batch_size, - num_surfs, - 6 - ) # [B, max_face, 6] + # 生成测试数据 + test_data = { + 'edge_ncs': torch.randn(batch_size, max_face, max_edge, num_edge_points, 3), + 'edge_pos': torch.randn(batch_size, max_face, max_edge, 6), + 'edge_mask': torch.ones(batch_size, max_face, max_edge, dtype=torch.bool), + 'surf_ncs': torch.randn(batch_size, max_face, num_surf_points, 3), + 'surf_pos': torch.randn(batch_size, max_face, 6), + 'vertex_pos': torch.randn(batch_size, max_face, max_edge, 2, 3), + 'query_points': torch.randn(batch_size, 1000, 3) # 1000个查询点 + } - vertex_pos = torch.randn( - batch_size, - num_surfs, - num_edges, - 2, - 3 - ) # [B, max_face, max_edge, 2, 3] + # 打印输入数据形状 + logger.info("Input shapes:") + for name, tensor in test_data.items(): + logger.info(f" {name}: {tensor.shape}") - query_points = torch.randn(batch_size, num_queries, 3) - - # 更新前向传播调用 - sdf = model( - edge_ncs=edge_ncs, - edge_pos=edge_pos, - edge_mask=edge_mask, - surf_ncs=surf_ncs, - surf_pos=surf_pos, - vertex_pos=vertex_pos, - query_points=query_points - ) - - # 更新打印信息 - print("\nInput shapes:") - print(f"edge_ncs: {edge_ncs.shape}") - print(f"edge_pos: {edge_pos.shape}") - print(f"edge_mask: {edge_mask.shape}") - print(f"surf_ncs: {surf_ncs.shape}") - print(f"surf_pos: {surf_pos.shape}") - print(f"vertex_pos: {vertex_pos.shape}") - print(f"query_points: {query_points.shape}") - print(f"\nOutput SDF shape: {sdf.shape}") + # 前向传播 + try: + sdf = model(**test_data) + logger.info(f"\nOutput SDF shape: {sdf.shape}") + + # 计算模型参数量 + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info(f"\nModel statistics:") + logger.info(f" Total parameters: {total_params:,}") + logger.info(f" Trainable parameters: {trainable_params:,}") + + except Exception as e: + logger.error(f"Error during forward pass: {str(e)}") + raise if __name__ == "__main__": main() \ No newline at end of file