|
|
@ -6,50 +6,25 @@ from torch.utils.data import DataLoader |
|
|
|
from brep2sdf.data.data import BRepSDFDataset |
|
|
|
from brep2sdf.networks.encoder import BRepToSDF, sdf_loss |
|
|
|
from brep2sdf.utils.logger import logger |
|
|
|
from brep2sdf.config.default_config import get_default_config, load_config |
|
|
|
import wandb |
|
|
|
|
|
|
|
def main(): |
|
|
|
# 使用字典存储配置参数 |
|
|
|
config = { |
|
|
|
# 数据路径 |
|
|
|
'brep_dir': '/home/wch/brep2sdf/test_data/pkl', |
|
|
|
'sdf_dir': '/home/wch/brep2sdf/test_data/sdf', |
|
|
|
'valid_data_dir': '/home/wch/brep2sdf/test_data/result/pkl', |
|
|
|
'save_dir': 'checkpoints', |
|
|
|
|
|
|
|
# 训练参数 |
|
|
|
'batch_size': 32, |
|
|
|
'num_workers': 4, |
|
|
|
'num_epochs': 100, |
|
|
|
'learning_rate': 1e-4, |
|
|
|
'min_lr': 1e-6, |
|
|
|
'weight_decay': 0.01, |
|
|
|
'grad_weight': 0.1, |
|
|
|
'max_grad_norm': 1.0, |
|
|
|
|
|
|
|
# 模型参数 |
|
|
|
'brep_feature_dim': 48, |
|
|
|
'use_cf': True, |
|
|
|
'embed_dim': 768, |
|
|
|
'latent_dim': 256, |
|
|
|
|
|
|
|
# wandb参数 |
|
|
|
'use_wandb': True, |
|
|
|
'project_name': 'brep2sdf', |
|
|
|
'run_name': 'training_run', |
|
|
|
'log_interval': 10 |
|
|
|
} |
|
|
|
# 获取配置 |
|
|
|
config = get_default_config() |
|
|
|
|
|
|
|
# 创建保存目录 |
|
|
|
os.makedirs(config['save_dir'], exist_ok=True) |
|
|
|
# 创建所有保存目录 |
|
|
|
os.makedirs(config.data.model_save_dir, exist_ok=True) |
|
|
|
os.makedirs(config.data.log_save_dir, exist_ok=True) |
|
|
|
os.makedirs(config.data.result_save_dir, exist_ok=True) |
|
|
|
|
|
|
|
# 初始化wandb (添加超时设置和离线模式) |
|
|
|
if config['use_wandb']: |
|
|
|
# 初始化wandb |
|
|
|
if config.log.use_wandb: |
|
|
|
try: |
|
|
|
wandb.init( |
|
|
|
project=config['project_name'], |
|
|
|
name=config['run_name'], |
|
|
|
config=config, |
|
|
|
project=config.log.project_name, |
|
|
|
name=config.log.run_name, |
|
|
|
config=config.__dict__, |
|
|
|
settings=wandb.Settings( |
|
|
|
init_timeout=180, # 增加超时时间 |
|
|
|
_disable_stats=True, # 禁用统计 |
|
|
@ -60,7 +35,7 @@ def main(): |
|
|
|
logger.info("Wandb initialized in offline mode") |
|
|
|
except Exception as e: |
|
|
|
logger.warning(f"Failed to initialize wandb: {str(e)}") |
|
|
|
config['use_wandb'] = False # 禁用wandb |
|
|
|
config.log.use_wandb = False |
|
|
|
logger.warning("Continuing without wandb logging") |
|
|
|
|
|
|
|
# 初始化训练器并开始训练 |
|
|
@ -75,15 +50,15 @@ class Trainer: |
|
|
|
|
|
|
|
# 初始化数据集 |
|
|
|
self.train_dataset = BRepSDFDataset( |
|
|
|
brep_dir=config['brep_dir'], |
|
|
|
sdf_dir=config['sdf_dir'], |
|
|
|
valid_data_dir=config['valid_data_dir'], |
|
|
|
brep_dir=config.data.brep_dir, |
|
|
|
sdf_dir=config.data.sdf_dir, |
|
|
|
valid_data_dir=config.data.valid_data_dir, |
|
|
|
split='train' |
|
|
|
) |
|
|
|
self.val_dataset = BRepSDFDataset( |
|
|
|
brep_dir=config['brep_dir'], |
|
|
|
sdf_dir=config['sdf_dir'], |
|
|
|
valid_data_dir=config['valid_data_dir'], |
|
|
|
brep_dir=config.data.brep_dir, |
|
|
|
sdf_dir=config.data.sdf_dir, |
|
|
|
valid_data_dir=config.data.valid_data_dir, |
|
|
|
split='val' |
|
|
|
) |
|
|
|
|
|
|
@ -93,37 +68,37 @@ class Trainer: |
|
|
|
# 初始化数据加载器 |
|
|
|
self.train_loader = DataLoader( |
|
|
|
self.train_dataset, |
|
|
|
batch_size=config['batch_size'], |
|
|
|
batch_size=config.train.batch_size, |
|
|
|
shuffle=True, |
|
|
|
num_workers=config['num_workers'] |
|
|
|
num_workers=config.train.num_workers |
|
|
|
) |
|
|
|
self.val_loader = DataLoader( |
|
|
|
self.val_dataset, |
|
|
|
batch_size=config['batch_size'], |
|
|
|
batch_size=config.train.batch_size, |
|
|
|
shuffle=False, |
|
|
|
num_workers=config['num_workers'] |
|
|
|
num_workers=config.train.num_workers |
|
|
|
) |
|
|
|
|
|
|
|
# 初始化模型 |
|
|
|
self.model = BRepToSDF( |
|
|
|
brep_feature_dim=config['brep_feature_dim'], |
|
|
|
use_cf=config['use_cf'], |
|
|
|
embed_dim=config['embed_dim'], |
|
|
|
latent_dim=config['latent_dim'] |
|
|
|
brep_feature_dim=config.model.brep_feature_dim, |
|
|
|
use_cf=config.model.use_cf, |
|
|
|
embed_dim=config.model.embed_dim, |
|
|
|
latent_dim=config.model.latent_dim |
|
|
|
).to(self.device) |
|
|
|
|
|
|
|
# 初始化优化器 |
|
|
|
self.optimizer = optim.AdamW( |
|
|
|
self.model.parameters(), |
|
|
|
lr=config['learning_rate'], |
|
|
|
weight_decay=config['weight_decay'] |
|
|
|
lr=config.train.learning_rate, |
|
|
|
weight_decay=config.train.weight_decay |
|
|
|
) |
|
|
|
|
|
|
|
# 学习率调度器 |
|
|
|
self.scheduler = optim.lr_scheduler.CosineAnnealingLR( |
|
|
|
self.optimizer, |
|
|
|
T_max=config['num_epochs'], |
|
|
|
eta_min=config['min_lr'] |
|
|
|
T_max=config.train.num_epochs, |
|
|
|
eta_min=config.train.min_lr |
|
|
|
) |
|
|
|
|
|
|
|
def train_epoch(self, epoch): |
|
|
@ -148,7 +123,7 @@ class Trainer: |
|
|
|
pred_sdf, |
|
|
|
gt_sdf, |
|
|
|
query_points, |
|
|
|
grad_weight=self.config['grad_weight'] |
|
|
|
grad_weight=self.config.train.grad_weight |
|
|
|
) |
|
|
|
|
|
|
|
# 反向传播 |
|
|
@ -158,19 +133,19 @@ class Trainer: |
|
|
|
# 梯度裁剪 |
|
|
|
torch.nn.utils.clip_grad_norm_( |
|
|
|
self.model.parameters(), |
|
|
|
self.config['max_grad_norm'] |
|
|
|
self.config.train.max_grad_norm |
|
|
|
) |
|
|
|
|
|
|
|
self.optimizer.step() |
|
|
|
total_loss += loss.item() |
|
|
|
|
|
|
|
# 打印训练进度 |
|
|
|
if (batch_idx + 1) % self.config['log_interval'] == 0: |
|
|
|
if (batch_idx + 1) % self.config.log.log_interval == 0: |
|
|
|
logger.info(f'Train Epoch: {epoch} [{batch_idx+1}/{len(self.train_loader)}]\t' |
|
|
|
f'Loss: {loss.item():.6f}') |
|
|
|
|
|
|
|
# 记录到wandb |
|
|
|
if self.config['use_wandb'] and (batch_idx + 1) % self.config['log_interval'] == 0: |
|
|
|
if self.config.log.use_wandb and (batch_idx + 1) % self.config.log.log_interval == 0: |
|
|
|
wandb.log({ |
|
|
|
'batch_loss': loss.item(), |
|
|
|
'batch': batch_idx, |
|
|
@ -203,7 +178,7 @@ class Trainer: |
|
|
|
pred_sdf, |
|
|
|
gt_sdf, |
|
|
|
query_points, |
|
|
|
grad_weight=self.config['grad_weight'] |
|
|
|
grad_weight=self.config.train.grad_weight |
|
|
|
) |
|
|
|
|
|
|
|
total_loss += loss.item() |
|
|
@ -211,7 +186,7 @@ class Trainer: |
|
|
|
avg_loss = total_loss / len(self.val_loader) |
|
|
|
logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}') |
|
|
|
|
|
|
|
if self.config['use_wandb']: |
|
|
|
if self.config.log.use_wandb: |
|
|
|
wandb.log({ |
|
|
|
'val_loss': avg_loss, |
|
|
|
'epoch': epoch |
|
|
@ -222,35 +197,68 @@ class Trainer: |
|
|
|
best_val_loss = float('inf') |
|
|
|
logger.info("Starting training...") |
|
|
|
|
|
|
|
for epoch in range(1, self.config['num_epochs'] + 1): |
|
|
|
for epoch in range(1, self.config.train.num_epochs + 1): |
|
|
|
train_loss = self.train_epoch(epoch) |
|
|
|
val_loss = self.validate(epoch) |
|
|
|
self.scheduler.step() |
|
|
|
|
|
|
|
# 保存最佳模型 |
|
|
|
if val_loss < best_val_loss: |
|
|
|
best_val_loss = val_loss |
|
|
|
model_path = os.path.join(self.config['save_dir'], 'best_model.pth') |
|
|
|
# 定期验证 |
|
|
|
if epoch % self.config.train.val_freq == 0: |
|
|
|
val_loss = self.validate(epoch) |
|
|
|
|
|
|
|
# 保存最佳模型 |
|
|
|
if val_loss < best_val_loss: |
|
|
|
best_val_loss = val_loss |
|
|
|
best_model_path = os.path.join( |
|
|
|
self.config.data.model_save_dir, |
|
|
|
self.config.data.best_model_name.format( |
|
|
|
model_name=self.config.data.model_name |
|
|
|
) |
|
|
|
) |
|
|
|
torch.save({ |
|
|
|
'epoch': epoch, |
|
|
|
'model_state_dict': self.model.state_dict(), |
|
|
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
|
|
'val_loss': val_loss, |
|
|
|
}, best_model_path) |
|
|
|
logger.info(f'Saved best model with val_loss: {val_loss:.6f}') |
|
|
|
|
|
|
|
# 定期保存检查点 |
|
|
|
if epoch % self.config.train.save_freq == 0: |
|
|
|
checkpoint_path = os.path.join( |
|
|
|
self.config.data.model_save_dir, |
|
|
|
self.config.data.checkpoint_format.format( |
|
|
|
model_name=self.config.data.model_name, |
|
|
|
epoch=epoch |
|
|
|
) |
|
|
|
) |
|
|
|
torch.save({ |
|
|
|
'epoch': epoch, |
|
|
|
'model_state_dict': self.model.state_dict(), |
|
|
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
|
|
'val_loss': val_loss, |
|
|
|
}, model_path) |
|
|
|
logger.info(f'Saved best model with val_loss: {val_loss:.6f}') |
|
|
|
'scheduler_state_dict': self.scheduler.state_dict(), |
|
|
|
'train_loss': train_loss, |
|
|
|
'val_loss': val_loss if epoch % self.config.train.val_freq == 0 else None, |
|
|
|
}, checkpoint_path) |
|
|
|
logger.info(f'Saved checkpoint at epoch {epoch}') |
|
|
|
|
|
|
|
self.scheduler.step() |
|
|
|
|
|
|
|
# 记录训练信息 |
|
|
|
logger.info(f'Epoch: {epoch}\tTrain Loss: {train_loss:.6f}\t' |
|
|
|
f'Val Loss: {val_loss:.6f}\tLR: {self.scheduler.get_last_lr()[0]:.6f}') |
|
|
|
log_info = f'Epoch: {epoch}\tTrain Loss: {train_loss:.6f}\t' |
|
|
|
if epoch % self.config.train.val_freq == 0: |
|
|
|
log_info += f'Val Loss: {val_loss:.6f}\t' |
|
|
|
log_info += f'LR: {self.scheduler.get_last_lr()[0]:.6f}' |
|
|
|
logger.info(log_info) |
|
|
|
|
|
|
|
# 记录到wandb |
|
|
|
if self.config['use_wandb']: |
|
|
|
wandb.log({ |
|
|
|
if self.config.log.use_wandb: |
|
|
|
log_dict = { |
|
|
|
'train_loss': train_loss, |
|
|
|
'val_loss': val_loss, |
|
|
|
'learning_rate': self.scheduler.get_last_lr()[0], |
|
|
|
'epoch': epoch |
|
|
|
}) |
|
|
|
} |
|
|
|
if epoch % self.config.train.val_freq == 0: |
|
|
|
log_dict['val_loss'] = val_loss |
|
|
|
wandb.log(log_dict) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
main() |