1 changed files with 102 additions and 0 deletions
@ -0,0 +1,102 @@ |
|||
from dataclasses import dataclass |
|||
from typing import Tuple, Optional |
|||
|
|||
@dataclass |
|||
class ModelConfig: |
|||
"""模型相关配置""" |
|||
brep_feature_dim: int = 48 |
|||
use_cf: bool = True |
|||
embed_dim: int = 768 |
|||
latent_dim: int = 256 |
|||
|
|||
# 点云采样配置 |
|||
num_surf_points: int = 16 # 每个面采样点数 |
|||
num_edge_points: int = 4 # 每条边采样点数 |
|||
|
|||
@dataclass |
|||
class DataConfig: |
|||
"""数据相关配置""" |
|||
max_face: int = 70 |
|||
max_edge: int = 70 |
|||
bbox_scaled: float = 1.0 |
|||
|
|||
# 数据路径 |
|||
brep_dir: str = '/home/wch/brep2sdf/test_data/pkl' |
|||
sdf_dir: str = '/home/wch/brep2sdf/test_data/sdf' |
|||
valid_data_dir: str = '/home/wch/brep2sdf/test_data/result/pkl' |
|||
save_dir: str = 'checkpoints' |
|||
|
|||
@dataclass |
|||
class TrainConfig: |
|||
"""训练相关配置""" |
|||
# 基本训练参数 |
|||
batch_size: int = 32 |
|||
num_workers: int = 4 |
|||
num_epochs: int = 100 |
|||
learning_rate: float = 1e-4 |
|||
min_lr: float = 1e-6 |
|||
weight_decay: float = 0.01 |
|||
|
|||
# 梯度和损失相关 |
|||
grad_weight: float = 0.1 |
|||
max_grad_norm: float = 1.0 |
|||
|
|||
# 学习率调度器参数 |
|||
lr_scheduler: str = 'cosine' # ['cosine', 'linear', 'step'] |
|||
warmup_epochs: int = 5 |
|||
|
|||
# 保存和验证 |
|||
save_freq: int = 10 # 每多少个epoch保存一次 |
|||
val_freq: int = 1 # 每多少个epoch验证一次 |
|||
|
|||
@dataclass |
|||
class LogConfig: |
|||
"""日志相关配置""" |
|||
# wandb配置 |
|||
use_wandb: bool = True |
|||
project_name: str = 'brep2sdf' |
|||
run_name: str = 'training_run' |
|||
log_interval: int = 10 |
|||
|
|||
# 本地日志 |
|||
log_dir: str = 'logs' |
|||
log_level: str = 'INFO' |
|||
|
|||
@dataclass |
|||
class Config: |
|||
"""总配置类""" |
|||
model: ModelConfig = ModelConfig() |
|||
data: DataConfig = DataConfig() |
|||
train: TrainConfig = TrainConfig() |
|||
log: LogConfig = LogConfig() |
|||
|
|||
def __post_init__(self): |
|||
"""初始化后的处理""" |
|||
# 可以在这里添加配置验证逻辑 |
|||
pass |
|||
|
|||
@classmethod |
|||
def from_dict(cls, config_dict: dict) -> 'Config': |
|||
"""从字典创建配置""" |
|||
model_config = ModelConfig(**config_dict.get('model', {})) |
|||
data_config = DataConfig(**config_dict.get('data', {})) |
|||
train_config = TrainConfig(**config_dict.get('train', {})) |
|||
log_config = LogConfig(**config_dict.get('log', {})) |
|||
|
|||
return cls( |
|||
model=model_config, |
|||
data=data_config, |
|||
train=train_config, |
|||
log=log_config |
|||
) |
|||
|
|||
def get_default_config() -> Config: |
|||
"""获取默认配置""" |
|||
return Config() |
|||
|
|||
def load_config(config_path: str) -> Config: |
|||
"""从文件加载配置""" |
|||
import yaml |
|||
with open(config_path, 'r') as f: |
|||
config_dict = yaml.safe_load(f) |
|||
return Config.from_dict(config_dict) |
Loading…
Reference in new issue