|
@ -1,5 +1,7 @@ |
|
|
import torch |
|
|
import torch |
|
|
import torch.optim as optim |
|
|
import torch.optim as optim |
|
|
|
|
|
import time |
|
|
|
|
|
import os |
|
|
|
|
|
|
|
|
from brep2sdf.config.default_config import get_default_config |
|
|
from brep2sdf.config.default_config import get_default_config |
|
|
from brep2sdf.data.data import load_brep_file,load_sdf_file |
|
|
from brep2sdf.data.data import load_brep_file,load_sdf_file |
|
@ -103,30 +105,69 @@ class Trainer: |
|
|
def train(self): |
|
|
def train(self): |
|
|
best_val_loss = float('inf') |
|
|
best_val_loss = float('inf') |
|
|
logger.info("Starting training...") |
|
|
logger.info("Starting training...") |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
for epoch in range(1, self.config.train.num_epochs + 1): |
|
|
for epoch in range(1, self.config.train.num_epochs + 1): |
|
|
|
|
|
# 训练一个epoch |
|
|
train_loss = self.train_epoch(epoch) |
|
|
train_loss = self.train_epoch(epoch) |
|
|
|
|
|
|
|
|
|
|
|
# 验证 |
|
|
''' |
|
|
''' |
|
|
# 定期验证 |
|
|
|
|
|
if epoch % self.config.train.val_freq == 0: |
|
|
if epoch % self.config.train.val_freq == 0: |
|
|
val_loss = self.validate(epoch) |
|
|
val_loss = self.validate(epoch) |
|
|
|
|
|
logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}') |
|
|
|
|
|
|
|
|
# 保存最佳模型 |
|
|
# 保存最佳模型 |
|
|
if val_loss < best_val_loss: |
|
|
if val_loss < best_val_loss: |
|
|
best_val_loss = val_loss |
|
|
best_val_loss = val_loss |
|
|
self._save_model(epoch, val_loss) |
|
|
self._save_model(epoch, val_loss) |
|
|
|
|
|
logger.info(f'New best model saved at epoch {epoch} with val loss {val_loss:.6f}') |
|
|
|
|
|
else: |
|
|
|
|
|
logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}') |
|
|
''' |
|
|
''' |
|
|
# 定期保存检查点 |
|
|
# 保存检查点 |
|
|
if epoch % self.config.train.save_freq == 0: |
|
|
if epoch % self.config.train.save_freq == 0: |
|
|
self._save_checkpoint(epoch, train_loss) |
|
|
self._save_checkpoint(epoch, train_loss) |
|
|
|
|
|
logger.info(f'Checkpoint saved at epoch {epoch}') |
|
|
|
|
|
|
|
|
|
|
|
# 训练完成 |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
|
|
|
logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s') |
|
|
|
|
|
logger.info(f'Best validation loss: {best_val_loss:.6f}') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _save_model(self, epoch: int, val_loss: float): |
|
|
def _save_model(self, epoch: int, val_loss: float): |
|
|
# 保存最佳模型的逻辑 |
|
|
"""保存最佳模型""" |
|
|
pass |
|
|
save_path = os.path.join( |
|
|
|
|
|
self.config.train.checkpoint_dir, |
|
|
|
|
|
self.config.train.best_model_name.format( |
|
|
|
|
|
model_name=config.train.model_name |
|
|
|
|
|
) |
|
|
|
|
|
) |
|
|
|
|
|
torch.save({ |
|
|
|
|
|
'epoch': epoch, |
|
|
|
|
|
'model_state_dict': self.model.state_dict(), |
|
|
|
|
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
|
|
|
|
'loss': val_loss, |
|
|
|
|
|
'config': self.config |
|
|
|
|
|
}, save_path) |
|
|
|
|
|
|
|
|
def _save_checkpoint(self, epoch: int, train_loss: float): |
|
|
def _save_checkpoint(self, epoch: int, train_loss: float): |
|
|
# 保存检查点的逻辑 |
|
|
"""保存训练检查点""" |
|
|
pass |
|
|
checkpoint_path = os.path.join( |
|
|
|
|
|
self.config.train.checkpoint_dir, |
|
|
|
|
|
self.config.train.checkpoint_format.format( |
|
|
|
|
|
model_name=self.config.train.model_name, |
|
|
|
|
|
epoch=epoch |
|
|
|
|
|
) |
|
|
|
|
|
) |
|
|
|
|
|
torch.save({ |
|
|
|
|
|
'epoch': epoch, |
|
|
|
|
|
'model_state_dict': self.model.state_dict(), |
|
|
|
|
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
|
|
|
|
'loss': train_loss, |
|
|
|
|
|
'config': self.config |
|
|
|
|
|
}, checkpoint_path) |
|
|
|
|
|
|
|
|
def main(): |
|
|
def main(): |
|
|
# 这里需要初始化配置 |
|
|
# 这里需要初始化配置 |
|
|