from dataclasses import dataclass from typing import Tuple, Optional @dataclass class ModelConfig: """模型相关配置""" brep_feature_dim: int = 16 use_cf: bool = True embed_dim: int = 768 # 3 的 倍数 latent_dim: int = 16 # 点云采样配置 num_surf_points: int = 8 # 每个面采样点数 num_edge_points: int = 2 # 每条边采样点数 # Transformer相关配置 num_transformer_layers: int = 4 num_attention_heads: int = 6 transformer_dim_feedforward: int = 512 transformer_dropout: float = 0.1 # 编码器配置 encoder_channels: Tuple[int] = (32, 64, 128) encoder_layers_per_block: int = 1 @dataclass class DataConfig: """数据相关配置""" max_face: int = 32 max_edge: int = 128 num_query_points: int = 32*32*32 # 限制查询点数量,sdf 采样点数 本来是 128*128*128 ,在data load时随机采样 bbox_scaled: float = 1.0 # pre_process origin_brep_dir: str = '/mnt/mynewdisk/dataset/furniture/step/furniture_dataset_step/' # 数据路径 brep_dir: str = '/home/wch/brep2sdf/test_data/pkl' sdf_dir: str = '/home/wch/brep2sdf/test_data/sdf' valid_data_dir: str = '/home/wch/brep2sdf/test_data/result/pkl' # 保存路径 save_dir: str = '/home/wch/brep2sdf/checkpoints' # 模型保存基础目录 model_save_dir: str = '/home/wch/brep2sdf/checkpoints/models' # 模型文件保存目录 result_save_dir: str = '/home/wch/brep2sdf/checkpoints/results' # 结果保存目录 # 文件命名 model_name: str = 'brep2sdf' # 模型名称,用于文件命名 checkpoint_format: str = '{model_name}_epoch_{epoch:03d}.pth' # 检查点文件名格式 best_model_name: str = '{model_name}_best.pth' # 最佳模型文件名格式 @dataclass class TrainConfig: """训练相关配置""" # 基本训练参数 batch_size: int = 8 num_workers: int = 4 num_epochs: int = 100 learning_rate: float = 1 min_lr: float = 1e-1 weight_decay: float = 0.01 # 梯度和损失相关 max_grad_norm: float = 1.0 clamping_distance: float = 0.1 # 学习率调度器参数 lr_scheduler: str = 'cosine' # ['cosine', 'linear', 'step'] warmup_epochs: int = 5 # 保存和验证 save_freq: int = 10 # 每多少个epoch保存一次 val_freq: int = 1 # 每多少个epoch验证一次 @dataclass class TestConfig: vis_freq: int = 100 @dataclass class LogConfig: """日志相关配置""" # wandb配置 use_wandb: bool = True project_name: str = 'brep2sdf' run_name: str = 'training_run' log_interval: int = 10 # 本地日志 log_dir: str = '/home/wch/brep2sdf/logs' # 日志保存目录 log_level: str = 'INFO' # 日志级别 console_level: str = 'INFO' # 控制台日志级别 file_level: str = 'DEBUG' # 文件日志级别 @dataclass class Config: """总配置类""" model: ModelConfig = ModelConfig() data: DataConfig = DataConfig() train: TrainConfig = TrainConfig() test: TestConfig = TestConfig() log: LogConfig = LogConfig() def __post_init__(self): """初始化后的处理""" # 可以在这里添加配置验证逻辑 pass @classmethod def from_dict(cls, config_dict: dict) -> 'Config': """从字典创建配置""" model_config = ModelConfig(**config_dict.get('model', {})) data_config = DataConfig(**config_dict.get('data', {})) train_config = TrainConfig(**config_dict.get('train', {})) test_config = TestConfig(**config_dict.get('test', {})) log_config = LogConfig(**config_dict.get('log', {})) return cls( model=model_config, data=data_config, train=train_config, test=test_config, log=log_config ) def get_default_config() -> Config: """获取默认配置""" return Config() def load_config(config_path: str) -> Config: """从文件加载配置""" import yaml with open(config_path, 'r') as f: config_dict = yaml.safe_load(f) return Config.from_dict(config_dict)