From 7f2049c4bde793c908dbcb4db78df5a30b12e7ac Mon Sep 17 00:00:00 2001 From: mckay Date: Thu, 27 Mar 2025 18:02:11 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8A=A0=E4=BF=9D=E5=AD=98checkpoint?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/config/default_config.py | 35 +++++++++++++------- brep2sdf/train.py | 53 +++++++++++++++++++++++++++---- 2 files changed, 71 insertions(+), 17 deletions(-) diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py index d7ac5b4..372f2eb 100644 --- a/brep2sdf/config/default_config.py +++ b/brep2sdf/config/default_config.py @@ -38,20 +38,24 @@ class DataConfig: # 数据路径 pkl_path = "/home/wch/brep2sdf/test_data/pkl/train/bathtub_0004.pkl" sdf_path = "/home/wch/brep2sdf/test_data/sdf/train/bathtub_0004.npz" - brep_dir: str = '/home/wch/brep2sdf/test_data/pkl' - sdf_dir: str = '/home/wch/brep2sdf/test_data/sdf' - valid_data_dir: str = '/home/wch/brep2sdf/test_data/result/pkl' + # 基础路径配置 + base_data_dir: str = '/home/wch/brep2sdf/test_data' # 基础数据目录 - # 保存路径 - save_dir: str = '/home/wch/brep2sdf/checkpoints' # 模型保存基础目录 - model_save_dir: str = '/home/wch/brep2sdf/checkpoints/models' # 模型文件保存目录 - result_save_dir: str = '/home/wch/brep2sdf/checkpoints/results' # 结果保存目录 - # 文件命名 - model_name: str = 'brep2sdf' # 模型名称,用于文件命名 - checkpoint_format: str = '{model_name}_epoch_{epoch:03d}.pth' # 检查点文件名格式 - best_model_name: str = '{model_name}_best.pth' # 最佳模型文件名格式 + @property + def brep_dir(self) -> str: + return os.path.join(self.base_data_dir, 'pkl') + + @property + def sdf_dir(self) -> str: + return os.path.join(self.base_data_dir, 'sdf') + + @property + def valid_data_dir(self) -> str: + return os.path.join(self.base_data_dir, 'result/pkl') + + @dataclass class TrainConfig: @@ -77,6 +81,15 @@ class TrainConfig: save_freq: int = 10 # 每多少个epoch保存一次 val_freq: int = 1 # 每多少个epoch验证一次 + # 保存路径 + checkpoint_dir: str = '/home/wch/brep2sdf/checkpoints' # 结果保存目录 + + # 文件命名 + model_name: str = 'brep2sdf' # 模型名称,用于文件命名 + checkpoint_format: str = '{model_name}_epoch_{epoch:03d}.pth' # 检查点文件名格式 + best_model_name: str = '{model_name}_best.pth' # 最佳模型文件名格式 + + @dataclass class TestConfig: vis_freq: int = 100 diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 5271529..c0dabc0 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -1,5 +1,7 @@ import torch import torch.optim as optim +import time +import os from brep2sdf.config.default_config import get_default_config from brep2sdf.data.data import load_brep_file,load_sdf_file @@ -103,30 +105,69 @@ class Trainer: def train(self): best_val_loss = float('inf') logger.info("Starting training...") + start_time = time.time() for epoch in range(1, self.config.train.num_epochs + 1): + # 训练一个epoch train_loss = self.train_epoch(epoch) + + # 验证 ''' - # 定期验证 if epoch % self.config.train.val_freq == 0: val_loss = self.validate(epoch) + logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}') # 保存最佳模型 if val_loss < best_val_loss: best_val_loss = val_loss self._save_model(epoch, val_loss) + logger.info(f'New best model saved at epoch {epoch} with val loss {val_loss:.6f}') + else: + logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}') ''' - # 定期保存检查点 + # 保存检查点 if epoch % self.config.train.save_freq == 0: self._save_checkpoint(epoch, train_loss) + logger.info(f'Checkpoint saved at epoch {epoch}') + + # 训练完成 + total_time = time.time() - start_time + logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s') + logger.info(f'Best validation loss: {best_val_loss:.6f}') + def _save_model(self, epoch: int, val_loss: float): - # 保存最佳模型的逻辑 - pass + """保存最佳模型""" + save_path = os.path.join( + self.config.train.checkpoint_dir, + self.config.train.best_model_name.format( + model_name=config.train.model_name + ) + ) + torch.save({ + 'epoch': epoch, + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'loss': val_loss, + 'config': self.config + }, save_path) def _save_checkpoint(self, epoch: int, train_loss: float): - # 保存检查点的逻辑 - pass + """保存训练检查点""" + checkpoint_path = os.path.join( + self.config.train.checkpoint_dir, + self.config.train.checkpoint_format.format( + model_name=self.config.train.model_name, + epoch=epoch + ) + ) + torch.save({ + 'epoch': epoch, + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'loss': train_loss, + 'config': self.config + }, checkpoint_path) def main(): # 这里需要初始化配置