You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
104 lines
2.9 KiB
104 lines
2.9 KiB
from dataclasses import dataclass
|
|
from typing import Tuple, Optional
|
|
|
|
@dataclass
|
|
class ModelConfig:
|
|
"""模型相关配置"""
|
|
brep_feature_dim: int = 48
|
|
use_cf: bool = True
|
|
embed_dim: int = 768
|
|
latent_dim: int = 256
|
|
|
|
# 点云采样配置
|
|
num_surf_points: int = 16 # 每个面采样点数
|
|
num_edge_points: int = 4 # 每条边采样点数
|
|
|
|
@dataclass
|
|
class DataConfig:
|
|
"""数据相关配置"""
|
|
max_face: int = 64
|
|
max_edge: int = 64
|
|
bbox_scaled: float = 1.0
|
|
|
|
# 数据路径
|
|
brep_dir: str = '/home/wch/brep2sdf/test_data/pkl'
|
|
sdf_dir: str = '/home/wch/brep2sdf/test_data/sdf'
|
|
valid_data_dir: str = '/home/wch/brep2sdf/test_data/result/pkl'
|
|
save_dir: str = 'checkpoints'
|
|
|
|
@dataclass
|
|
class TrainConfig:
|
|
"""训练相关配置"""
|
|
# 基本训练参数
|
|
batch_size: int = 32
|
|
num_workers: int = 4
|
|
num_epochs: int = 100
|
|
learning_rate: float = 1e-4
|
|
min_lr: float = 1e-6
|
|
weight_decay: float = 0.01
|
|
|
|
# 梯度和损失相关
|
|
grad_weight: float = 0.1
|
|
max_grad_norm: float = 1.0
|
|
|
|
# 学习率调度器参数
|
|
lr_scheduler: str = 'cosine' # ['cosine', 'linear', 'step']
|
|
warmup_epochs: int = 5
|
|
|
|
# 保存和验证
|
|
save_freq: int = 10 # 每多少个epoch保存一次
|
|
val_freq: int = 1 # 每多少个epoch验证一次
|
|
|
|
@dataclass
|
|
class LogConfig:
|
|
"""日志相关配置"""
|
|
# wandb配置
|
|
use_wandb: bool = True
|
|
project_name: str = 'brep2sdf'
|
|
run_name: str = 'training_run'
|
|
log_interval: int = 10
|
|
|
|
# 本地日志
|
|
log_dir: str = 'logs' # 日志保存目录
|
|
log_level: str = 'INFO' # 日志级别
|
|
console_level: str = 'INFO' # 控制台日志级别
|
|
file_level: str = 'DEBUG' # 文件日志级别
|
|
|
|
@dataclass
|
|
class Config:
|
|
"""总配置类"""
|
|
model: ModelConfig = ModelConfig()
|
|
data: DataConfig = DataConfig()
|
|
train: TrainConfig = TrainConfig()
|
|
log: LogConfig = LogConfig()
|
|
|
|
def __post_init__(self):
|
|
"""初始化后的处理"""
|
|
# 可以在这里添加配置验证逻辑
|
|
pass
|
|
|
|
@classmethod
|
|
def from_dict(cls, config_dict: dict) -> 'Config':
|
|
"""从字典创建配置"""
|
|
model_config = ModelConfig(**config_dict.get('model', {}))
|
|
data_config = DataConfig(**config_dict.get('data', {}))
|
|
train_config = TrainConfig(**config_dict.get('train', {}))
|
|
log_config = LogConfig(**config_dict.get('log', {}))
|
|
|
|
return cls(
|
|
model=model_config,
|
|
data=data_config,
|
|
train=train_config,
|
|
log=log_config
|
|
)
|
|
|
|
def get_default_config() -> Config:
|
|
"""获取默认配置"""
|
|
return Config()
|
|
|
|
def load_config(config_path: str) -> Config:
|
|
"""从文件加载配置"""
|
|
import yaml
|
|
with open(config_path, 'r') as f:
|
|
config_dict = yaml.safe_load(f)
|
|
return Config.from_dict(config_dict)
|