1 changed files with 102 additions and 0 deletions
			
			
		@ -0,0 +1,102 @@ | 
				
			|||
from dataclasses import dataclass | 
				
			|||
from typing import Tuple, Optional | 
				
			|||
 | 
				
			|||
@dataclass | 
				
			|||
class ModelConfig: | 
				
			|||
    """模型相关配置""" | 
				
			|||
    brep_feature_dim: int = 48 | 
				
			|||
    use_cf: bool = True | 
				
			|||
    embed_dim: int = 768 | 
				
			|||
    latent_dim: int = 256 | 
				
			|||
     | 
				
			|||
    # 点云采样配置 | 
				
			|||
    num_surf_points: int = 16  # 每个面采样点数 | 
				
			|||
    num_edge_points: int = 4   # 每条边采样点数 | 
				
			|||
 | 
				
			|||
@dataclass | 
				
			|||
class DataConfig: | 
				
			|||
    """数据相关配置""" | 
				
			|||
    max_face: int = 70 | 
				
			|||
    max_edge: int = 70 | 
				
			|||
    bbox_scaled: float = 1.0 | 
				
			|||
     | 
				
			|||
    # 数据路径 | 
				
			|||
    brep_dir: str = '/home/wch/brep2sdf/test_data/pkl' | 
				
			|||
    sdf_dir: str = '/home/wch/brep2sdf/test_data/sdf' | 
				
			|||
    valid_data_dir: str = '/home/wch/brep2sdf/test_data/result/pkl' | 
				
			|||
    save_dir: str = 'checkpoints' | 
				
			|||
 | 
				
			|||
@dataclass | 
				
			|||
class TrainConfig: | 
				
			|||
    """训练相关配置""" | 
				
			|||
    # 基本训练参数 | 
				
			|||
    batch_size: int = 32 | 
				
			|||
    num_workers: int = 4 | 
				
			|||
    num_epochs: int = 100 | 
				
			|||
    learning_rate: float = 1e-4 | 
				
			|||
    min_lr: float = 1e-6 | 
				
			|||
    weight_decay: float = 0.01 | 
				
			|||
     | 
				
			|||
    # 梯度和损失相关 | 
				
			|||
    grad_weight: float = 0.1 | 
				
			|||
    max_grad_norm: float = 1.0 | 
				
			|||
     | 
				
			|||
    # 学习率调度器参数 | 
				
			|||
    lr_scheduler: str = 'cosine'  # ['cosine', 'linear', 'step'] | 
				
			|||
    warmup_epochs: int = 5 | 
				
			|||
     | 
				
			|||
    # 保存和验证 | 
				
			|||
    save_freq: int = 10  # 每多少个epoch保存一次 | 
				
			|||
    val_freq: int = 1    # 每多少个epoch验证一次 | 
				
			|||
 | 
				
			|||
@dataclass | 
				
			|||
class LogConfig: | 
				
			|||
    """日志相关配置""" | 
				
			|||
    # wandb配置 | 
				
			|||
    use_wandb: bool = True | 
				
			|||
    project_name: str = 'brep2sdf' | 
				
			|||
    run_name: str = 'training_run' | 
				
			|||
    log_interval: int = 10 | 
				
			|||
     | 
				
			|||
    # 本地日志 | 
				
			|||
    log_dir: str = 'logs' | 
				
			|||
    log_level: str = 'INFO' | 
				
			|||
 | 
				
			|||
@dataclass | 
				
			|||
class Config: | 
				
			|||
    """总配置类""" | 
				
			|||
    model: ModelConfig = ModelConfig() | 
				
			|||
    data: DataConfig = DataConfig() | 
				
			|||
    train: TrainConfig = TrainConfig() | 
				
			|||
    log: LogConfig = LogConfig() | 
				
			|||
     | 
				
			|||
    def __post_init__(self): | 
				
			|||
        """初始化后的处理""" | 
				
			|||
        # 可以在这里添加配置验证逻辑 | 
				
			|||
        pass | 
				
			|||
 | 
				
			|||
    @classmethod | 
				
			|||
    def from_dict(cls, config_dict: dict) -> 'Config': | 
				
			|||
        """从字典创建配置""" | 
				
			|||
        model_config = ModelConfig(**config_dict.get('model', {})) | 
				
			|||
        data_config = DataConfig(**config_dict.get('data', {})) | 
				
			|||
        train_config = TrainConfig(**config_dict.get('train', {})) | 
				
			|||
        log_config = LogConfig(**config_dict.get('log', {})) | 
				
			|||
         | 
				
			|||
        return cls( | 
				
			|||
            model=model_config, | 
				
			|||
            data=data_config, | 
				
			|||
            train=train_config, | 
				
			|||
            log=log_config | 
				
			|||
        ) | 
				
			|||
 | 
				
			|||
def get_default_config() -> Config: | 
				
			|||
    """获取默认配置""" | 
				
			|||
    return Config() | 
				
			|||
 | 
				
			|||
def load_config(config_path: str) -> Config: | 
				
			|||
    """从文件加载配置""" | 
				
			|||
    import yaml | 
				
			|||
    with open(config_path, 'r') as f: | 
				
			|||
        config_dict = yaml.safe_load(f) | 
				
			|||
    return Config.from_dict(config_dict)  | 
				
			|||
					Loading…
					
					
				
		Reference in new issue