From 4180ad0eaea23f8e6bb9b5278f3240e8755e190a Mon Sep 17 00:00:00 2001 From: mckay Date: Mon, 18 Nov 2024 23:45:39 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20train=20=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/train.py | 256 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 256 insertions(+) create mode 100644 brep2sdf/train.py diff --git a/brep2sdf/train.py b/brep2sdf/train.py new file mode 100644 index 0000000..8a1ca00 --- /dev/null +++ b/brep2sdf/train.py @@ -0,0 +1,256 @@ +import os +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from brep2sdf.data.data import BRepSDFDataset +from brep2sdf.networks.encoder import BRepToSDF, sdf_loss +from brep2sdf.utils.logger import logger +import wandb + +def main(): + # 使用字典存储配置参数 + config = { + # 数据路径 + 'brep_dir': '/home/wch/brep2sdf/test_data/pkl', + 'sdf_dir': '/home/wch/brep2sdf/test_data/sdf', + 'valid_data_dir': '/home/wch/brep2sdf/test_data/result/pkl', + 'save_dir': 'checkpoints', + + # 训练参数 + 'batch_size': 32, + 'num_workers': 4, + 'num_epochs': 100, + 'learning_rate': 1e-4, + 'min_lr': 1e-6, + 'weight_decay': 0.01, + 'grad_weight': 0.1, + 'max_grad_norm': 1.0, + + # 模型参数 + 'brep_feature_dim': 48, + 'use_cf': True, + 'embed_dim': 768, + 'latent_dim': 256, + + # wandb参数 + 'use_wandb': True, + 'project_name': 'brep2sdf', + 'run_name': 'training_run', + 'log_interval': 10 + } + + # 创建保存目录 + os.makedirs(config['save_dir'], exist_ok=True) + + # 初始化wandb (添加超时设置和离线模式) + if config['use_wandb']: + try: + wandb.init( + project=config['project_name'], + name=config['run_name'], + config=config, + settings=wandb.Settings( + init_timeout=180, # 增加超时时间 + _disable_stats=True, # 禁用统计 + _disable_meta=True, # 禁用元数据 + ), + mode="offline" # 使用离线模式 + ) + logger.info("Wandb initialized in offline mode") + except Exception as e: + logger.warning(f"Failed to initialize wandb: {str(e)}") + config['use_wandb'] = False # 禁用wandb + logger.warning("Continuing without wandb logging") + + # 初始化训练器并开始训练 + trainer = Trainer(config) + trainer.train() + +class Trainer: + def __init__(self, config): + self.config = config + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + logger.info(f"Using device: {self.device}") + + # 初始化数据集 + self.train_dataset = BRepSDFDataset( + brep_dir=config['brep_dir'], + sdf_dir=config['sdf_dir'], + valid_data_dir=config['valid_data_dir'], + split='train' + ) + self.val_dataset = BRepSDFDataset( + brep_dir=config['brep_dir'], + sdf_dir=config['sdf_dir'], + valid_data_dir=config['valid_data_dir'], + split='val' + ) + + logger.info(f"Train dataset size: {len(self.train_dataset)}") + logger.info(f"Val dataset size: {len(self.val_dataset)}") + + # 初始化数据加载器 + self.train_loader = DataLoader( + self.train_dataset, + batch_size=config['batch_size'], + shuffle=True, + num_workers=config['num_workers'] + ) + self.val_loader = DataLoader( + self.val_dataset, + batch_size=config['batch_size'], + shuffle=False, + num_workers=config['num_workers'] + ) + + # 初始化模型 + self.model = BRepToSDF( + brep_feature_dim=config['brep_feature_dim'], + use_cf=config['use_cf'], + embed_dim=config['embed_dim'], + latent_dim=config['latent_dim'] + ).to(self.device) + + # 初始化优化器 + self.optimizer = optim.AdamW( + self.model.parameters(), + lr=config['learning_rate'], + weight_decay=config['weight_decay'] + ) + + # 学习率调度器 + self.scheduler = optim.lr_scheduler.CosineAnnealingLR( + self.optimizer, + T_max=config['num_epochs'], + eta_min=config['min_lr'] + ) + + def train_epoch(self, epoch): + self.model.train() + total_loss = 0 + + for batch_idx, batch in enumerate(self.train_loader): + # 获取数据 + surf_z = batch['surf_z'].to(self.device) + edge_z = batch['edge_z'].to(self.device) + surf_p = batch['surf_p'].to(self.device) + edge_p = batch['edge_p'].to(self.device) + vert_p = batch['vert_p'].to(self.device) + query_points = batch['points'].to(self.device) + gt_sdf = batch['sdf'].to(self.device) + + # 前向传播 + pred_sdf = self.model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) + + # 计算损失 + loss = sdf_loss( + pred_sdf, + gt_sdf, + query_points, + grad_weight=self.config['grad_weight'] + ) + + # 反向传播 + self.optimizer.zero_grad() + loss.backward() + + # 梯度裁剪 + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.config['max_grad_norm'] + ) + + self.optimizer.step() + total_loss += loss.item() + + # 打印训练进度 + if (batch_idx + 1) % self.config['log_interval'] == 0: + logger.info(f'Train Epoch: {epoch} [{batch_idx+1}/{len(self.train_loader)}]\t' + f'Loss: {loss.item():.6f}') + + # 记录到wandb + if self.config['use_wandb'] and (batch_idx + 1) % self.config['log_interval'] == 0: + wandb.log({ + 'batch_loss': loss.item(), + 'batch': batch_idx, + 'epoch': epoch + }) + + avg_loss = total_loss / len(self.train_loader) + return avg_loss + + def validate(self, epoch): + self.model.eval() + total_loss = 0 + + with torch.no_grad(): + for batch in self.val_loader: + # 获取数据 + surf_z = batch['surf_z'].to(self.device) + edge_z = batch['edge_z'].to(self.device) + surf_p = batch['surf_p'].to(self.device) + edge_p = batch['edge_p'].to(self.device) + vert_p = batch['vert_p'].to(self.device) + query_points = batch['points'].to(self.device) + gt_sdf = batch['sdf'].to(self.device) + + # 前向传播 + pred_sdf = self.model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) + + # 计算损失 + loss = sdf_loss( + pred_sdf, + gt_sdf, + query_points, + grad_weight=self.config['grad_weight'] + ) + + total_loss += loss.item() + + avg_loss = total_loss / len(self.val_loader) + logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}') + + if self.config['use_wandb']: + wandb.log({ + 'val_loss': avg_loss, + 'epoch': epoch + }) + return avg_loss + + def train(self): + best_val_loss = float('inf') + logger.info("Starting training...") + + for epoch in range(1, self.config['num_epochs'] + 1): + train_loss = self.train_epoch(epoch) + val_loss = self.validate(epoch) + self.scheduler.step() + + # 保存最佳模型 + if val_loss < best_val_loss: + best_val_loss = val_loss + model_path = os.path.join(self.config['save_dir'], 'best_model.pth') + torch.save({ + 'epoch': epoch, + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'val_loss': val_loss, + }, model_path) + logger.info(f'Saved best model with val_loss: {val_loss:.6f}') + + # 记录训练信息 + logger.info(f'Epoch: {epoch}\tTrain Loss: {train_loss:.6f}\t' + f'Val Loss: {val_loss:.6f}\tLR: {self.scheduler.get_last_lr()[0]:.6f}') + + # 记录到wandb + if self.config['use_wandb']: + wandb.log({ + 'train_loss': train_loss, + 'val_loss': val_loss, + 'learning_rate': self.scheduler.get_last_lr()[0], + 'epoch': epoch + }) + +if __name__ == '__main__': + main() \ No newline at end of file