Browse Source

加保存checkpoint

final
mckay 2 months ago
parent
commit
7f2049c4bd
  1. 35
      brep2sdf/config/default_config.py
  2. 53
      brep2sdf/train.py

35
brep2sdf/config/default_config.py

@ -38,20 +38,24 @@ class DataConfig:
# 数据路径 # 数据路径
pkl_path = "/home/wch/brep2sdf/test_data/pkl/train/bathtub_0004.pkl" pkl_path = "/home/wch/brep2sdf/test_data/pkl/train/bathtub_0004.pkl"
sdf_path = "/home/wch/brep2sdf/test_data/sdf/train/bathtub_0004.npz" sdf_path = "/home/wch/brep2sdf/test_data/sdf/train/bathtub_0004.npz"
brep_dir: str = '/home/wch/brep2sdf/test_data/pkl'
sdf_dir: str = '/home/wch/brep2sdf/test_data/sdf'
valid_data_dir: str = '/home/wch/brep2sdf/test_data/result/pkl'
# 基础路径配置
base_data_dir: str = '/home/wch/brep2sdf/test_data' # 基础数据目录
# 保存路径
save_dir: str = '/home/wch/brep2sdf/checkpoints' # 模型保存基础目录
model_save_dir: str = '/home/wch/brep2sdf/checkpoints/models' # 模型文件保存目录
result_save_dir: str = '/home/wch/brep2sdf/checkpoints/results' # 结果保存目录
# 文件命名 @property
model_name: str = 'brep2sdf' # 模型名称,用于文件命名 def brep_dir(self) -> str:
checkpoint_format: str = '{model_name}_epoch_{epoch:03d}.pth' # 检查点文件名格式 return os.path.join(self.base_data_dir, 'pkl')
best_model_name: str = '{model_name}_best.pth' # 最佳模型文件名格式
@property
def sdf_dir(self) -> str:
return os.path.join(self.base_data_dir, 'sdf')
@property
def valid_data_dir(self) -> str:
return os.path.join(self.base_data_dir, 'result/pkl')
@dataclass @dataclass
class TrainConfig: class TrainConfig:
@ -77,6 +81,15 @@ class TrainConfig:
save_freq: int = 10 # 每多少个epoch保存一次 save_freq: int = 10 # 每多少个epoch保存一次
val_freq: int = 1 # 每多少个epoch验证一次 val_freq: int = 1 # 每多少个epoch验证一次
# 保存路径
checkpoint_dir: str = '/home/wch/brep2sdf/checkpoints' # 结果保存目录
# 文件命名
model_name: str = 'brep2sdf' # 模型名称,用于文件命名
checkpoint_format: str = '{model_name}_epoch_{epoch:03d}.pth' # 检查点文件名格式
best_model_name: str = '{model_name}_best.pth' # 最佳模型文件名格式
@dataclass @dataclass
class TestConfig: class TestConfig:
vis_freq: int = 100 vis_freq: int = 100

53
brep2sdf/train.py

@ -1,5 +1,7 @@
import torch import torch
import torch.optim as optim import torch.optim as optim
import time
import os
from brep2sdf.config.default_config import get_default_config from brep2sdf.config.default_config import get_default_config
from brep2sdf.data.data import load_brep_file,load_sdf_file from brep2sdf.data.data import load_brep_file,load_sdf_file
@ -103,30 +105,69 @@ class Trainer:
def train(self): def train(self):
best_val_loss = float('inf') best_val_loss = float('inf')
logger.info("Starting training...") logger.info("Starting training...")
start_time = time.time()
for epoch in range(1, self.config.train.num_epochs + 1): for epoch in range(1, self.config.train.num_epochs + 1):
# 训练一个epoch
train_loss = self.train_epoch(epoch) train_loss = self.train_epoch(epoch)
# 验证
''' '''
# 定期验证
if epoch % self.config.train.val_freq == 0: if epoch % self.config.train.val_freq == 0:
val_loss = self.validate(epoch) val_loss = self.validate(epoch)
logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}')
# 保存最佳模型 # 保存最佳模型
if val_loss < best_val_loss: if val_loss < best_val_loss:
best_val_loss = val_loss best_val_loss = val_loss
self._save_model(epoch, val_loss) self._save_model(epoch, val_loss)
logger.info(f'New best model saved at epoch {epoch} with val loss {val_loss:.6f}')
else:
logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}')
''' '''
# 定期保存检查点 # 保存检查点
if epoch % self.config.train.save_freq == 0: if epoch % self.config.train.save_freq == 0:
self._save_checkpoint(epoch, train_loss) self._save_checkpoint(epoch, train_loss)
logger.info(f'Checkpoint saved at epoch {epoch}')
# 训练完成
total_time = time.time() - start_time
logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s')
logger.info(f'Best validation loss: {best_val_loss:.6f}')
def _save_model(self, epoch: int, val_loss: float): def _save_model(self, epoch: int, val_loss: float):
# 保存最佳模型的逻辑 """保存最佳模型"""
pass save_path = os.path.join(
self.config.train.checkpoint_dir,
self.config.train.best_model_name.format(
model_name=config.train.model_name
)
)
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'loss': val_loss,
'config': self.config
}, save_path)
def _save_checkpoint(self, epoch: int, train_loss: float): def _save_checkpoint(self, epoch: int, train_loss: float):
# 保存检查点的逻辑 """保存训练检查点"""
pass checkpoint_path = os.path.join(
self.config.train.checkpoint_dir,
self.config.train.checkpoint_format.format(
model_name=self.config.train.model_name,
epoch=epoch
)
)
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'loss': train_loss,
'config': self.config
}, checkpoint_path)
def main(): def main():
# 这里需要初始化配置 # 这里需要初始化配置

Loading…
Cancel
Save