diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py new file mode 100644 index 0000000..fa473b2 --- /dev/null +++ b/brep2sdf/config/default_config.py @@ -0,0 +1,102 @@ +from dataclasses import dataclass +from typing import Tuple, Optional + +@dataclass +class ModelConfig: + """模型相关配置""" + brep_feature_dim: int = 48 + use_cf: bool = True + embed_dim: int = 768 + latent_dim: int = 256 + + # 点云采样配置 + num_surf_points: int = 16 # 每个面采样点数 + num_edge_points: int = 4 # 每条边采样点数 + +@dataclass +class DataConfig: + """数据相关配置""" + max_face: int = 70 + max_edge: int = 70 + bbox_scaled: float = 1.0 + + # 数据路径 + brep_dir: str = '/home/wch/brep2sdf/test_data/pkl' + sdf_dir: str = '/home/wch/brep2sdf/test_data/sdf' + valid_data_dir: str = '/home/wch/brep2sdf/test_data/result/pkl' + save_dir: str = 'checkpoints' + +@dataclass +class TrainConfig: + """训练相关配置""" + # 基本训练参数 + batch_size: int = 32 + num_workers: int = 4 + num_epochs: int = 100 + learning_rate: float = 1e-4 + min_lr: float = 1e-6 + weight_decay: float = 0.01 + + # 梯度和损失相关 + grad_weight: float = 0.1 + max_grad_norm: float = 1.0 + + # 学习率调度器参数 + lr_scheduler: str = 'cosine' # ['cosine', 'linear', 'step'] + warmup_epochs: int = 5 + + # 保存和验证 + save_freq: int = 10 # 每多少个epoch保存一次 + val_freq: int = 1 # 每多少个epoch验证一次 + +@dataclass +class LogConfig: + """日志相关配置""" + # wandb配置 + use_wandb: bool = True + project_name: str = 'brep2sdf' + run_name: str = 'training_run' + log_interval: int = 10 + + # 本地日志 + log_dir: str = 'logs' + log_level: str = 'INFO' + +@dataclass +class Config: + """总配置类""" + model: ModelConfig = ModelConfig() + data: DataConfig = DataConfig() + train: TrainConfig = TrainConfig() + log: LogConfig = LogConfig() + + def __post_init__(self): + """初始化后的处理""" + # 可以在这里添加配置验证逻辑 + pass + + @classmethod + def from_dict(cls, config_dict: dict) -> 'Config': + """从字典创建配置""" + model_config = ModelConfig(**config_dict.get('model', {})) + data_config = DataConfig(**config_dict.get('data', {})) + train_config = TrainConfig(**config_dict.get('train', {})) + log_config = LogConfig(**config_dict.get('log', {})) + + return cls( + model=model_config, + data=data_config, + train=train_config, + log=log_config + ) + +def get_default_config() -> Config: + """获取默认配置""" + return Config() + +def load_config(config_path: str) -> Config: + """从文件加载配置""" + import yaml + with open(config_path, 'r') as f: + config_dict = yaml.safe_load(f) + return Config.from_dict(config_dict) \ No newline at end of file