Browse Source

feat: 增加全局config

main
王琛涵 4 months ago
parent
commit
8db42d1ef7
  1. 102
      brep2sdf/config/default_config.py

102
brep2sdf/config/default_config.py

@ -0,0 +1,102 @@
from dataclasses import dataclass
from typing import Tuple, Optional
@dataclass
class ModelConfig:
"""模型相关配置"""
brep_feature_dim: int = 48
use_cf: bool = True
embed_dim: int = 768
latent_dim: int = 256
# 点云采样配置
num_surf_points: int = 16 # 每个面采样点数
num_edge_points: int = 4 # 每条边采样点数
@dataclass
class DataConfig:
"""数据相关配置"""
max_face: int = 70
max_edge: int = 70
bbox_scaled: float = 1.0
# 数据路径
brep_dir: str = '/home/wch/brep2sdf/test_data/pkl'
sdf_dir: str = '/home/wch/brep2sdf/test_data/sdf'
valid_data_dir: str = '/home/wch/brep2sdf/test_data/result/pkl'
save_dir: str = 'checkpoints'
@dataclass
class TrainConfig:
"""训练相关配置"""
# 基本训练参数
batch_size: int = 32
num_workers: int = 4
num_epochs: int = 100
learning_rate: float = 1e-4
min_lr: float = 1e-6
weight_decay: float = 0.01
# 梯度和损失相关
grad_weight: float = 0.1
max_grad_norm: float = 1.0
# 学习率调度器参数
lr_scheduler: str = 'cosine' # ['cosine', 'linear', 'step']
warmup_epochs: int = 5
# 保存和验证
save_freq: int = 10 # 每多少个epoch保存一次
val_freq: int = 1 # 每多少个epoch验证一次
@dataclass
class LogConfig:
"""日志相关配置"""
# wandb配置
use_wandb: bool = True
project_name: str = 'brep2sdf'
run_name: str = 'training_run'
log_interval: int = 10
# 本地日志
log_dir: str = 'logs'
log_level: str = 'INFO'
@dataclass
class Config:
"""总配置类"""
model: ModelConfig = ModelConfig()
data: DataConfig = DataConfig()
train: TrainConfig = TrainConfig()
log: LogConfig = LogConfig()
def __post_init__(self):
"""初始化后的处理"""
# 可以在这里添加配置验证逻辑
pass
@classmethod
def from_dict(cls, config_dict: dict) -> 'Config':
"""从字典创建配置"""
model_config = ModelConfig(**config_dict.get('model', {}))
data_config = DataConfig(**config_dict.get('data', {}))
train_config = TrainConfig(**config_dict.get('train', {}))
log_config = LogConfig(**config_dict.get('log', {}))
return cls(
model=model_config,
data=data_config,
train=train_config,
log=log_config
)
def get_default_config() -> Config:
"""获取默认配置"""
return Config()
def load_config(config_path: str) -> Config:
"""从文件加载配置"""
import yaml
with open(config_path, 'r') as f:
config_dict = yaml.safe_load(f)
return Config.from_dict(config_dict)
Loading…
Cancel
Save