diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py index ac31944..693df34 100644 --- a/brep2sdf/config/default_config.py +++ b/brep2sdf/config/default_config.py @@ -24,7 +24,17 @@ class DataConfig: brep_dir: str = '/home/wch/brep2sdf/test_data/pkl' sdf_dir: str = '/home/wch/brep2sdf/test_data/sdf' valid_data_dir: str = '/home/wch/brep2sdf/test_data/result/pkl' - save_dir: str = 'checkpoints' + + # 保存路径 + save_dir: str = 'checkpoints' # 模型保存基础目录 + model_save_dir: str = 'checkpoints/models' # 模型文件保存目录 + log_save_dir: str = 'checkpoints/logs' # 日志文件保存目录 + result_save_dir: str = 'checkpoints/results' # 结果保存目录 + + # 文件命名 + model_name: str = 'brep2sdf' # 模型名称,用于文件命名 + checkpoint_format: str = '{model_name}_epoch_{epoch:03d}.pth' # 检查点文件名格式 + best_model_name: str = '{model_name}_best.pth' # 最佳模型文件名格式 @dataclass class TrainConfig: diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 8a1ca00..da8484e 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -6,50 +6,25 @@ from torch.utils.data import DataLoader from brep2sdf.data.data import BRepSDFDataset from brep2sdf.networks.encoder import BRepToSDF, sdf_loss from brep2sdf.utils.logger import logger +from brep2sdf.config.default_config import get_default_config, load_config import wandb def main(): - # 使用字典存储配置参数 - config = { - # 数据路径 - 'brep_dir': '/home/wch/brep2sdf/test_data/pkl', - 'sdf_dir': '/home/wch/brep2sdf/test_data/sdf', - 'valid_data_dir': '/home/wch/brep2sdf/test_data/result/pkl', - 'save_dir': 'checkpoints', - - # 训练参数 - 'batch_size': 32, - 'num_workers': 4, - 'num_epochs': 100, - 'learning_rate': 1e-4, - 'min_lr': 1e-6, - 'weight_decay': 0.01, - 'grad_weight': 0.1, - 'max_grad_norm': 1.0, - - # 模型参数 - 'brep_feature_dim': 48, - 'use_cf': True, - 'embed_dim': 768, - 'latent_dim': 256, - - # wandb参数 - 'use_wandb': True, - 'project_name': 'brep2sdf', - 'run_name': 'training_run', - 'log_interval': 10 - } + # 获取配置 + config = get_default_config() - # 创建保存目录 - os.makedirs(config['save_dir'], exist_ok=True) + # 创建所有保存目录 + os.makedirs(config.data.model_save_dir, exist_ok=True) + os.makedirs(config.data.log_save_dir, exist_ok=True) + os.makedirs(config.data.result_save_dir, exist_ok=True) - # 初始化wandb (添加超时设置和离线模式) - if config['use_wandb']: + # 初始化wandb + if config.log.use_wandb: try: wandb.init( - project=config['project_name'], - name=config['run_name'], - config=config, + project=config.log.project_name, + name=config.log.run_name, + config=config.__dict__, settings=wandb.Settings( init_timeout=180, # 增加超时时间 _disable_stats=True, # 禁用统计 @@ -60,7 +35,7 @@ def main(): logger.info("Wandb initialized in offline mode") except Exception as e: logger.warning(f"Failed to initialize wandb: {str(e)}") - config['use_wandb'] = False # 禁用wandb + config.log.use_wandb = False logger.warning("Continuing without wandb logging") # 初始化训练器并开始训练 @@ -75,15 +50,15 @@ class Trainer: # 初始化数据集 self.train_dataset = BRepSDFDataset( - brep_dir=config['brep_dir'], - sdf_dir=config['sdf_dir'], - valid_data_dir=config['valid_data_dir'], + brep_dir=config.data.brep_dir, + sdf_dir=config.data.sdf_dir, + valid_data_dir=config.data.valid_data_dir, split='train' ) self.val_dataset = BRepSDFDataset( - brep_dir=config['brep_dir'], - sdf_dir=config['sdf_dir'], - valid_data_dir=config['valid_data_dir'], + brep_dir=config.data.brep_dir, + sdf_dir=config.data.sdf_dir, + valid_data_dir=config.data.valid_data_dir, split='val' ) @@ -93,37 +68,37 @@ class Trainer: # 初始化数据加载器 self.train_loader = DataLoader( self.train_dataset, - batch_size=config['batch_size'], + batch_size=config.train.batch_size, shuffle=True, - num_workers=config['num_workers'] + num_workers=config.train.num_workers ) self.val_loader = DataLoader( self.val_dataset, - batch_size=config['batch_size'], + batch_size=config.train.batch_size, shuffle=False, - num_workers=config['num_workers'] + num_workers=config.train.num_workers ) # 初始化模型 self.model = BRepToSDF( - brep_feature_dim=config['brep_feature_dim'], - use_cf=config['use_cf'], - embed_dim=config['embed_dim'], - latent_dim=config['latent_dim'] + brep_feature_dim=config.model.brep_feature_dim, + use_cf=config.model.use_cf, + embed_dim=config.model.embed_dim, + latent_dim=config.model.latent_dim ).to(self.device) # 初始化优化器 self.optimizer = optim.AdamW( self.model.parameters(), - lr=config['learning_rate'], - weight_decay=config['weight_decay'] + lr=config.train.learning_rate, + weight_decay=config.train.weight_decay ) # 学习率调度器 self.scheduler = optim.lr_scheduler.CosineAnnealingLR( self.optimizer, - T_max=config['num_epochs'], - eta_min=config['min_lr'] + T_max=config.train.num_epochs, + eta_min=config.train.min_lr ) def train_epoch(self, epoch): @@ -148,7 +123,7 @@ class Trainer: pred_sdf, gt_sdf, query_points, - grad_weight=self.config['grad_weight'] + grad_weight=self.config.train.grad_weight ) # 反向传播 @@ -158,19 +133,19 @@ class Trainer: # 梯度裁剪 torch.nn.utils.clip_grad_norm_( self.model.parameters(), - self.config['max_grad_norm'] + self.config.train.max_grad_norm ) self.optimizer.step() total_loss += loss.item() # 打印训练进度 - if (batch_idx + 1) % self.config['log_interval'] == 0: + if (batch_idx + 1) % self.config.log.log_interval == 0: logger.info(f'Train Epoch: {epoch} [{batch_idx+1}/{len(self.train_loader)}]\t' f'Loss: {loss.item():.6f}') # 记录到wandb - if self.config['use_wandb'] and (batch_idx + 1) % self.config['log_interval'] == 0: + if self.config.log.use_wandb and (batch_idx + 1) % self.config.log.log_interval == 0: wandb.log({ 'batch_loss': loss.item(), 'batch': batch_idx, @@ -203,7 +178,7 @@ class Trainer: pred_sdf, gt_sdf, query_points, - grad_weight=self.config['grad_weight'] + grad_weight=self.config.train.grad_weight ) total_loss += loss.item() @@ -211,7 +186,7 @@ class Trainer: avg_loss = total_loss / len(self.val_loader) logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}') - if self.config['use_wandb']: + if self.config.log.use_wandb: wandb.log({ 'val_loss': avg_loss, 'epoch': epoch @@ -222,35 +197,68 @@ class Trainer: best_val_loss = float('inf') logger.info("Starting training...") - for epoch in range(1, self.config['num_epochs'] + 1): + for epoch in range(1, self.config.train.num_epochs + 1): train_loss = self.train_epoch(epoch) - val_loss = self.validate(epoch) - self.scheduler.step() - # 保存最佳模型 - if val_loss < best_val_loss: - best_val_loss = val_loss - model_path = os.path.join(self.config['save_dir'], 'best_model.pth') + # 定期验证 + if epoch % self.config.train.val_freq == 0: + val_loss = self.validate(epoch) + + # 保存最佳模型 + if val_loss < best_val_loss: + best_val_loss = val_loss + best_model_path = os.path.join( + self.config.data.model_save_dir, + self.config.data.best_model_name.format( + model_name=self.config.data.model_name + ) + ) + torch.save({ + 'epoch': epoch, + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'val_loss': val_loss, + }, best_model_path) + logger.info(f'Saved best model with val_loss: {val_loss:.6f}') + + # 定期保存检查点 + if epoch % self.config.train.save_freq == 0: + checkpoint_path = os.path.join( + self.config.data.model_save_dir, + self.config.data.checkpoint_format.format( + model_name=self.config.data.model_name, + epoch=epoch + ) + ) torch.save({ 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), - 'val_loss': val_loss, - }, model_path) - logger.info(f'Saved best model with val_loss: {val_loss:.6f}') + 'scheduler_state_dict': self.scheduler.state_dict(), + 'train_loss': train_loss, + 'val_loss': val_loss if epoch % self.config.train.val_freq == 0 else None, + }, checkpoint_path) + logger.info(f'Saved checkpoint at epoch {epoch}') + + self.scheduler.step() # 记录训练信息 - logger.info(f'Epoch: {epoch}\tTrain Loss: {train_loss:.6f}\t' - f'Val Loss: {val_loss:.6f}\tLR: {self.scheduler.get_last_lr()[0]:.6f}') + log_info = f'Epoch: {epoch}\tTrain Loss: {train_loss:.6f}\t' + if epoch % self.config.train.val_freq == 0: + log_info += f'Val Loss: {val_loss:.6f}\t' + log_info += f'LR: {self.scheduler.get_last_lr()[0]:.6f}' + logger.info(log_info) # 记录到wandb - if self.config['use_wandb']: - wandb.log({ + if self.config.log.use_wandb: + log_dict = { 'train_loss': train_loss, - 'val_loss': val_loss, 'learning_rate': self.scheduler.get_last_lr()[0], 'epoch': epoch - }) + } + if epoch % self.config.train.val_freq == 0: + log_dict['val_loss'] = val_loss + wandb.log(log_dict) if __name__ == '__main__': main() \ No newline at end of file