Browse Source

refactor & style: 用config配置train,同时增加保存checkpoint的信息

main
王琛涵 4 months ago
parent
commit
3555cf5359
  1. 12
      brep2sdf/config/default_config.py
  2. 152
      brep2sdf/train.py

12
brep2sdf/config/default_config.py

@ -24,7 +24,17 @@ class DataConfig:
brep_dir: str = '/home/wch/brep2sdf/test_data/pkl'
sdf_dir: str = '/home/wch/brep2sdf/test_data/sdf'
valid_data_dir: str = '/home/wch/brep2sdf/test_data/result/pkl'
save_dir: str = 'checkpoints'
# 保存路径
save_dir: str = 'checkpoints' # 模型保存基础目录
model_save_dir: str = 'checkpoints/models' # 模型文件保存目录
log_save_dir: str = 'checkpoints/logs' # 日志文件保存目录
result_save_dir: str = 'checkpoints/results' # 结果保存目录
# 文件命名
model_name: str = 'brep2sdf' # 模型名称,用于文件命名
checkpoint_format: str = '{model_name}_epoch_{epoch:03d}.pth' # 检查点文件名格式
best_model_name: str = '{model_name}_best.pth' # 最佳模型文件名格式
@dataclass
class TrainConfig:

152
brep2sdf/train.py

@ -6,50 +6,25 @@ from torch.utils.data import DataLoader
from brep2sdf.data.data import BRepSDFDataset
from brep2sdf.networks.encoder import BRepToSDF, sdf_loss
from brep2sdf.utils.logger import logger
from brep2sdf.config.default_config import get_default_config, load_config
import wandb
def main():
# 使用字典存储配置参数
config = {
# 数据路径
'brep_dir': '/home/wch/brep2sdf/test_data/pkl',
'sdf_dir': '/home/wch/brep2sdf/test_data/sdf',
'valid_data_dir': '/home/wch/brep2sdf/test_data/result/pkl',
'save_dir': 'checkpoints',
# 训练参数
'batch_size': 32,
'num_workers': 4,
'num_epochs': 100,
'learning_rate': 1e-4,
'min_lr': 1e-6,
'weight_decay': 0.01,
'grad_weight': 0.1,
'max_grad_norm': 1.0,
# 模型参数
'brep_feature_dim': 48,
'use_cf': True,
'embed_dim': 768,
'latent_dim': 256,
# wandb参数
'use_wandb': True,
'project_name': 'brep2sdf',
'run_name': 'training_run',
'log_interval': 10
}
# 获取配置
config = get_default_config()
# 创建保存目录
os.makedirs(config['save_dir'], exist_ok=True)
# 创建所有保存目录
os.makedirs(config.data.model_save_dir, exist_ok=True)
os.makedirs(config.data.log_save_dir, exist_ok=True)
os.makedirs(config.data.result_save_dir, exist_ok=True)
# 初始化wandb (添加超时设置和离线模式)
if config['use_wandb']:
# 初始化wandb
if config.log.use_wandb:
try:
wandb.init(
project=config['project_name'],
name=config['run_name'],
config=config,
project=config.log.project_name,
name=config.log.run_name,
config=config.__dict__,
settings=wandb.Settings(
init_timeout=180, # 增加超时时间
_disable_stats=True, # 禁用统计
@ -60,7 +35,7 @@ def main():
logger.info("Wandb initialized in offline mode")
except Exception as e:
logger.warning(f"Failed to initialize wandb: {str(e)}")
config['use_wandb'] = False # 禁用wandb
config.log.use_wandb = False
logger.warning("Continuing without wandb logging")
# 初始化训练器并开始训练
@ -75,15 +50,15 @@ class Trainer:
# 初始化数据集
self.train_dataset = BRepSDFDataset(
brep_dir=config['brep_dir'],
sdf_dir=config['sdf_dir'],
valid_data_dir=config['valid_data_dir'],
brep_dir=config.data.brep_dir,
sdf_dir=config.data.sdf_dir,
valid_data_dir=config.data.valid_data_dir,
split='train'
)
self.val_dataset = BRepSDFDataset(
brep_dir=config['brep_dir'],
sdf_dir=config['sdf_dir'],
valid_data_dir=config['valid_data_dir'],
brep_dir=config.data.brep_dir,
sdf_dir=config.data.sdf_dir,
valid_data_dir=config.data.valid_data_dir,
split='val'
)
@ -93,37 +68,37 @@ class Trainer:
# 初始化数据加载器
self.train_loader = DataLoader(
self.train_dataset,
batch_size=config['batch_size'],
batch_size=config.train.batch_size,
shuffle=True,
num_workers=config['num_workers']
num_workers=config.train.num_workers
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size=config['batch_size'],
batch_size=config.train.batch_size,
shuffle=False,
num_workers=config['num_workers']
num_workers=config.train.num_workers
)
# 初始化模型
self.model = BRepToSDF(
brep_feature_dim=config['brep_feature_dim'],
use_cf=config['use_cf'],
embed_dim=config['embed_dim'],
latent_dim=config['latent_dim']
brep_feature_dim=config.model.brep_feature_dim,
use_cf=config.model.use_cf,
embed_dim=config.model.embed_dim,
latent_dim=config.model.latent_dim
).to(self.device)
# 初始化优化器
self.optimizer = optim.AdamW(
self.model.parameters(),
lr=config['learning_rate'],
weight_decay=config['weight_decay']
lr=config.train.learning_rate,
weight_decay=config.train.weight_decay
)
# 学习率调度器
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
self.optimizer,
T_max=config['num_epochs'],
eta_min=config['min_lr']
T_max=config.train.num_epochs,
eta_min=config.train.min_lr
)
def train_epoch(self, epoch):
@ -148,7 +123,7 @@ class Trainer:
pred_sdf,
gt_sdf,
query_points,
grad_weight=self.config['grad_weight']
grad_weight=self.config.train.grad_weight
)
# 反向传播
@ -158,19 +133,19 @@ class Trainer:
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config['max_grad_norm']
self.config.train.max_grad_norm
)
self.optimizer.step()
total_loss += loss.item()
# 打印训练进度
if (batch_idx + 1) % self.config['log_interval'] == 0:
if (batch_idx + 1) % self.config.log.log_interval == 0:
logger.info(f'Train Epoch: {epoch} [{batch_idx+1}/{len(self.train_loader)}]\t'
f'Loss: {loss.item():.6f}')
# 记录到wandb
if self.config['use_wandb'] and (batch_idx + 1) % self.config['log_interval'] == 0:
if self.config.log.use_wandb and (batch_idx + 1) % self.config.log.log_interval == 0:
wandb.log({
'batch_loss': loss.item(),
'batch': batch_idx,
@ -203,7 +178,7 @@ class Trainer:
pred_sdf,
gt_sdf,
query_points,
grad_weight=self.config['grad_weight']
grad_weight=self.config.train.grad_weight
)
total_loss += loss.item()
@ -211,7 +186,7 @@ class Trainer:
avg_loss = total_loss / len(self.val_loader)
logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}')
if self.config['use_wandb']:
if self.config.log.use_wandb:
wandb.log({
'val_loss': avg_loss,
'epoch': epoch
@ -222,35 +197,68 @@ class Trainer:
best_val_loss = float('inf')
logger.info("Starting training...")
for epoch in range(1, self.config['num_epochs'] + 1):
for epoch in range(1, self.config.train.num_epochs + 1):
train_loss = self.train_epoch(epoch)
# 定期验证
if epoch % self.config.train.val_freq == 0:
val_loss = self.validate(epoch)
self.scheduler.step()
# 保存最佳模型
if val_loss < best_val_loss:
best_val_loss = val_loss
model_path = os.path.join(self.config['save_dir'], 'best_model.pth')
best_model_path = os.path.join(
self.config.data.model_save_dir,
self.config.data.best_model_name.format(
model_name=self.config.data.model_name
)
)
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'val_loss': val_loss,
}, model_path)
}, best_model_path)
logger.info(f'Saved best model with val_loss: {val_loss:.6f}')
# 定期保存检查点
if epoch % self.config.train.save_freq == 0:
checkpoint_path = os.path.join(
self.config.data.model_save_dir,
self.config.data.checkpoint_format.format(
model_name=self.config.data.model_name,
epoch=epoch
)
)
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'train_loss': train_loss,
'val_loss': val_loss if epoch % self.config.train.val_freq == 0 else None,
}, checkpoint_path)
logger.info(f'Saved checkpoint at epoch {epoch}')
self.scheduler.step()
# 记录训练信息
logger.info(f'Epoch: {epoch}\tTrain Loss: {train_loss:.6f}\t'
f'Val Loss: {val_loss:.6f}\tLR: {self.scheduler.get_last_lr()[0]:.6f}')
log_info = f'Epoch: {epoch}\tTrain Loss: {train_loss:.6f}\t'
if epoch % self.config.train.val_freq == 0:
log_info += f'Val Loss: {val_loss:.6f}\t'
log_info += f'LR: {self.scheduler.get_last_lr()[0]:.6f}'
logger.info(log_info)
# 记录到wandb
if self.config['use_wandb']:
wandb.log({
if self.config.log.use_wandb:
log_dict = {
'train_loss': train_loss,
'val_loss': val_loss,
'learning_rate': self.scheduler.get_last_lr()[0],
'epoch': epoch
})
}
if epoch % self.config.train.val_freq == 0:
log_dict['val_loss'] = val_loss
wandb.log(log_dict)
if __name__ == '__main__':
main()
Loading…
Cancel
Save