1 changed files with 102 additions and 0 deletions
@ -0,0 +1,102 @@ |
|||||
|
from dataclasses import dataclass |
||||
|
from typing import Tuple, Optional |
||||
|
|
||||
|
@dataclass |
||||
|
class ModelConfig: |
||||
|
"""模型相关配置""" |
||||
|
brep_feature_dim: int = 48 |
||||
|
use_cf: bool = True |
||||
|
embed_dim: int = 768 |
||||
|
latent_dim: int = 256 |
||||
|
|
||||
|
# 点云采样配置 |
||||
|
num_surf_points: int = 16 # 每个面采样点数 |
||||
|
num_edge_points: int = 4 # 每条边采样点数 |
||||
|
|
||||
|
@dataclass |
||||
|
class DataConfig: |
||||
|
"""数据相关配置""" |
||||
|
max_face: int = 70 |
||||
|
max_edge: int = 70 |
||||
|
bbox_scaled: float = 1.0 |
||||
|
|
||||
|
# 数据路径 |
||||
|
brep_dir: str = '/home/wch/brep2sdf/test_data/pkl' |
||||
|
sdf_dir: str = '/home/wch/brep2sdf/test_data/sdf' |
||||
|
valid_data_dir: str = '/home/wch/brep2sdf/test_data/result/pkl' |
||||
|
save_dir: str = 'checkpoints' |
||||
|
|
||||
|
@dataclass |
||||
|
class TrainConfig: |
||||
|
"""训练相关配置""" |
||||
|
# 基本训练参数 |
||||
|
batch_size: int = 32 |
||||
|
num_workers: int = 4 |
||||
|
num_epochs: int = 100 |
||||
|
learning_rate: float = 1e-4 |
||||
|
min_lr: float = 1e-6 |
||||
|
weight_decay: float = 0.01 |
||||
|
|
||||
|
# 梯度和损失相关 |
||||
|
grad_weight: float = 0.1 |
||||
|
max_grad_norm: float = 1.0 |
||||
|
|
||||
|
# 学习率调度器参数 |
||||
|
lr_scheduler: str = 'cosine' # ['cosine', 'linear', 'step'] |
||||
|
warmup_epochs: int = 5 |
||||
|
|
||||
|
# 保存和验证 |
||||
|
save_freq: int = 10 # 每多少个epoch保存一次 |
||||
|
val_freq: int = 1 # 每多少个epoch验证一次 |
||||
|
|
||||
|
@dataclass |
||||
|
class LogConfig: |
||||
|
"""日志相关配置""" |
||||
|
# wandb配置 |
||||
|
use_wandb: bool = True |
||||
|
project_name: str = 'brep2sdf' |
||||
|
run_name: str = 'training_run' |
||||
|
log_interval: int = 10 |
||||
|
|
||||
|
# 本地日志 |
||||
|
log_dir: str = 'logs' |
||||
|
log_level: str = 'INFO' |
||||
|
|
||||
|
@dataclass |
||||
|
class Config: |
||||
|
"""总配置类""" |
||||
|
model: ModelConfig = ModelConfig() |
||||
|
data: DataConfig = DataConfig() |
||||
|
train: TrainConfig = TrainConfig() |
||||
|
log: LogConfig = LogConfig() |
||||
|
|
||||
|
def __post_init__(self): |
||||
|
"""初始化后的处理""" |
||||
|
# 可以在这里添加配置验证逻辑 |
||||
|
pass |
||||
|
|
||||
|
@classmethod |
||||
|
def from_dict(cls, config_dict: dict) -> 'Config': |
||||
|
"""从字典创建配置""" |
||||
|
model_config = ModelConfig(**config_dict.get('model', {})) |
||||
|
data_config = DataConfig(**config_dict.get('data', {})) |
||||
|
train_config = TrainConfig(**config_dict.get('train', {})) |
||||
|
log_config = LogConfig(**config_dict.get('log', {})) |
||||
|
|
||||
|
return cls( |
||||
|
model=model_config, |
||||
|
data=data_config, |
||||
|
train=train_config, |
||||
|
log=log_config |
||||
|
) |
||||
|
|
||||
|
def get_default_config() -> Config: |
||||
|
"""获取默认配置""" |
||||
|
return Config() |
||||
|
|
||||
|
def load_config(config_path: str) -> Config: |
||||
|
"""从文件加载配置""" |
||||
|
import yaml |
||||
|
with open(config_path, 'r') as f: |
||||
|
config_dict = yaml.safe_load(f) |
||||
|
return Config.from_dict(config_dict) |
Loading…
Reference in new issue