Browse Source

refactor: config配net

main
mckay 7 months ago
parent
commit
eff305e802
  1. 18
      brep2sdf/config/default_config.py
  2. 159
      brep2sdf/networks/encoder.py

18
brep2sdf/config/default_config.py

@ -4,15 +4,25 @@ from typing import Tuple, Optional
@dataclass @dataclass
class ModelConfig: class ModelConfig:
"""模型相关配置""" """模型相关配置"""
brep_feature_dim: int = 48 brep_feature_dim: int = 32
use_cf: bool = True use_cf: bool = True
embed_dim: int = 768 embed_dim: int = 384 # 3 的 倍数
latent_dim: int = 256 latent_dim: int = 64
# 点云采样配置 # 点云采样配置
num_surf_points: int = 16 # 每个面采样点数 num_surf_points: int = 16 # 每个面采样点数
num_edge_points: int = 4 # 每条边采样点数 num_edge_points: int = 4 # 每条边采样点数
# Transformer相关配置
num_transformer_layers: int = 6
num_attention_heads: int = 8
transformer_dim_feedforward: int = 512
transformer_dropout: float = 0.1
# 编码器配置
encoder_channels: Tuple[int] = (32, 64, 128)
encoder_layers_per_block: int = 1
@dataclass @dataclass
class DataConfig: class DataConfig:
"""数据相关配置""" """数据相关配置"""
@ -40,7 +50,7 @@ class DataConfig:
class TrainConfig: class TrainConfig:
"""训练相关配置""" """训练相关配置"""
# 基本训练参数 # 基本训练参数
batch_size: int = 1 batch_size: int = 8
num_workers: int = 4 num_workers: int = 4
num_epochs: int = 100 num_epochs: int = 100
learning_rate: float = 1e-4 learning_rate: float = 1e-4

159
brep2sdf/networks/encoder.py

@ -315,27 +315,29 @@ class SDFHead(nn.Module):
return self.mlp(x) return self.mlp(x)
class BRepToSDF(nn.Module): class BRepToSDF(nn.Module):
def __init__( def __init__(self, config=None):
self,
brep_feature_dim: int = 48,
use_cf: bool = True,
embed_dim: int = 768,
latent_dim: int = 256
):
super().__init__() super().__init__()
# 获取配置 # 获取配置
self.config = get_default_config() if config is None:
self.embed_dim = embed_dim self.config = get_default_config()
else:
self.config = config
# 从配置中读取参数
self.embed_dim = self.config.model.embed_dim
self.brep_feature_dim = self.config.model.brep_feature_dim
self.latent_dim = self.config.model.latent_dim
self.use_cf = self.config.model.use_cf
# 1. 查询点编码器 # 1. 查询点编码器
self.query_encoder = nn.Sequential( self.query_encoder = nn.Sequential(
nn.Linear(3, embed_dim//4), nn.Linear(3, self.embed_dim//4),
nn.LayerNorm(embed_dim//4), nn.LayerNorm(self.embed_dim//4),
nn.ReLU(), nn.ReLU(),
nn.Linear(embed_dim//4, embed_dim//2), nn.Linear(self.embed_dim//4, self.embed_dim//2),
nn.LayerNorm(embed_dim//2), nn.LayerNorm(self.embed_dim//2),
nn.ReLU(), nn.ReLU(),
nn.Linear(embed_dim//2, embed_dim) nn.Linear(self.embed_dim//2, self.embed_dim)
) )
# 2. B-rep特征编码器 # 2. B-rep特征编码器
@ -343,12 +345,12 @@ class BRepToSDF(nn.Module):
# 3. 特征融合Transformer # 3. 特征融合Transformer
self.transformer = SDFTransformer( self.transformer = SDFTransformer(
embed_dim=embed_dim, embed_dim=self.embed_dim,
num_layers=6 num_layers=6 # 这个参数也可以移到配置文件中
) )
# 4. SDF预测头 # 4. SDF预测头
self.sdf_head = SDFHead(embed_dim=embed_dim*2) self.sdf_head = SDFHead(embed_dim=self.embed_dim*2)
def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, query_points, data_class=None): def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, query_points, data_class=None):
"""B-rep到SDF的前向传播 """B-rep到SDF的前向传播
@ -435,88 +437,47 @@ def main():
# 获取配置 # 获取配置
config = get_default_config() config = get_default_config()
# 从配置初始化模型 # 初始化模型
model = BRepToSDF( model = BRepToSDF(config=config)
brep_feature_dim=config.model.brep_feature_dim, # 48
use_cf=config.model.use_cf, # True # 从配置获取参数
embed_dim=config.model.embed_dim, # 768 batch_size = config.train.batch_size
latent_dim=config.model.latent_dim # 256 max_face = config.data.max_face
) max_edge = config.data.max_edge
num_surf_points = config.model.num_surf_points
# 从配置获取数据参数 num_edge_points = config.model.num_edge_points
batch_size = config.train.batch_size # 32
num_surfs = config.data.max_face # 64 # 生成测试数据
num_edges = config.data.max_edge # 64 test_data = {
num_verts = 8 # 顶点数保持固定 'edge_ncs': torch.randn(batch_size, max_face, max_edge, num_edge_points, 3),
num_queries = 1000 # 查询点数保持固定 'edge_pos': torch.randn(batch_size, max_face, max_edge, 6),
'edge_mask': torch.ones(batch_size, max_face, max_edge, dtype=torch.bool),
# 更新测试数据维度 'surf_ncs': torch.randn(batch_size, max_face, num_surf_points, 3),
edge_ncs = torch.randn( 'surf_pos': torch.randn(batch_size, max_face, 6),
batch_size, 'vertex_pos': torch.randn(batch_size, max_face, max_edge, 2, 3),
num_surfs, # max_face 'query_points': torch.randn(batch_size, 1000, 3) # 1000个查询点
num_edges, # max_edge }
config.model.num_edge_points,
3 # 打印输入数据形状
) # [B, max_face, max_edge, num_edge_points, 3] logger.info("Input shapes:")
for name, tensor in test_data.items():
edge_pos = torch.randn( logger.info(f" {name}: {tensor.shape}")
batch_size,
num_surfs, # 前向传播
num_edges, try:
6 sdf = model(**test_data)
) # [B, max_face, max_edge, 6] logger.info(f"\nOutput SDF shape: {sdf.shape}")
edge_mask = torch.ones( # 计算模型参数量
batch_size, total_params = sum(p.numel() for p in model.parameters())
num_surfs, trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
num_edges, logger.info(f"\nModel statistics:")
dtype=torch.bool logger.info(f" Total parameters: {total_params:,}")
) # [B, max_face, max_edge] logger.info(f" Trainable parameters: {trainable_params:,}")
surf_ncs = torch.randn( except Exception as e:
batch_size, logger.error(f"Error during forward pass: {str(e)}")
num_surfs, raise
config.model.num_surf_points,
3
) # [B, max_face, num_surf_points, 3]
surf_pos = torch.randn(
batch_size,
num_surfs,
6
) # [B, max_face, 6]
vertex_pos = torch.randn(
batch_size,
num_surfs,
num_edges,
2,
3
) # [B, max_face, max_edge, 2, 3]
query_points = torch.randn(batch_size, num_queries, 3)
# 更新前向传播调用
sdf = model(
edge_ncs=edge_ncs,
edge_pos=edge_pos,
edge_mask=edge_mask,
surf_ncs=surf_ncs,
surf_pos=surf_pos,
vertex_pos=vertex_pos,
query_points=query_points
)
# 更新打印信息
print("\nInput shapes:")
print(f"edge_ncs: {edge_ncs.shape}")
print(f"edge_pos: {edge_pos.shape}")
print(f"edge_mask: {edge_mask.shape}")
print(f"surf_ncs: {surf_ncs.shape}")
print(f"surf_pos: {surf_pos.shape}")
print(f"vertex_pos: {vertex_pos.shape}")
print(f"query_points: {query_points.shape}")
print(f"\nOutput SDF shape: {sdf.shape}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Loading…
Cancel
Save