From 8db42d1ef733277b8ff9f861e3055e6f5bde2428 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=90=9B=E6=B6=B5?= Date: Tue, 19 Nov 2024 00:38:56 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E5=85=A8=E5=B1=80con?= =?UTF-8?q?fig?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/config/default_config.py | 102 ++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 brep2sdf/config/default_config.py diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py new file mode 100644 index 0000000..fa473b2 --- /dev/null +++ b/brep2sdf/config/default_config.py @@ -0,0 +1,102 @@ +from dataclasses import dataclass +from typing import Tuple, Optional + +@dataclass +class ModelConfig: + """模型相关配置""" + brep_feature_dim: int = 48 + use_cf: bool = True + embed_dim: int = 768 + latent_dim: int = 256 + + # 点云采样配置 + num_surf_points: int = 16 # 每个面采样点数 + num_edge_points: int = 4 # 每条边采样点数 + +@dataclass +class DataConfig: + """数据相关配置""" + max_face: int = 70 + max_edge: int = 70 + bbox_scaled: float = 1.0 + + # 数据路径 + brep_dir: str = '/home/wch/brep2sdf/test_data/pkl' + sdf_dir: str = '/home/wch/brep2sdf/test_data/sdf' + valid_data_dir: str = '/home/wch/brep2sdf/test_data/result/pkl' + save_dir: str = 'checkpoints' + +@dataclass +class TrainConfig: + """训练相关配置""" + # 基本训练参数 + batch_size: int = 32 + num_workers: int = 4 + num_epochs: int = 100 + learning_rate: float = 1e-4 + min_lr: float = 1e-6 + weight_decay: float = 0.01 + + # 梯度和损失相关 + grad_weight: float = 0.1 + max_grad_norm: float = 1.0 + + # 学习率调度器参数 + lr_scheduler: str = 'cosine' # ['cosine', 'linear', 'step'] + warmup_epochs: int = 5 + + # 保存和验证 + save_freq: int = 10 # 每多少个epoch保存一次 + val_freq: int = 1 # 每多少个epoch验证一次 + +@dataclass +class LogConfig: + """日志相关配置""" + # wandb配置 + use_wandb: bool = True + project_name: str = 'brep2sdf' + run_name: str = 'training_run' + log_interval: int = 10 + + # 本地日志 + log_dir: str = 'logs' + log_level: str = 'INFO' + +@dataclass +class Config: + """总配置类""" + model: ModelConfig = ModelConfig() + data: DataConfig = DataConfig() + train: TrainConfig = TrainConfig() + log: LogConfig = LogConfig() + + def __post_init__(self): + """初始化后的处理""" + # 可以在这里添加配置验证逻辑 + pass + + @classmethod + def from_dict(cls, config_dict: dict) -> 'Config': + """从字典创建配置""" + model_config = ModelConfig(**config_dict.get('model', {})) + data_config = DataConfig(**config_dict.get('data', {})) + train_config = TrainConfig(**config_dict.get('train', {})) + log_config = LogConfig(**config_dict.get('log', {})) + + return cls( + model=model_config, + data=data_config, + train=train_config, + log=log_config + ) + +def get_default_config() -> Config: + """获取默认配置""" + return Config() + +def load_config(config_path: str) -> Config: + """从文件加载配置""" + import yaml + with open(config_path, 'r') as f: + config_dict = yaml.safe_load(f) + return Config.from_dict(config_dict) \ No newline at end of file