1 changed files with 256 additions and 0 deletions
@ -0,0 +1,256 @@ |
|||
import os |
|||
import torch |
|||
import torch.nn as nn |
|||
import torch.optim as optim |
|||
from torch.utils.data import DataLoader |
|||
from brep2sdf.data.data import BRepSDFDataset |
|||
from brep2sdf.networks.encoder import BRepToSDF, sdf_loss |
|||
from brep2sdf.utils.logger import logger |
|||
import wandb |
|||
|
|||
def main(): |
|||
# 使用字典存储配置参数 |
|||
config = { |
|||
# 数据路径 |
|||
'brep_dir': '/home/wch/brep2sdf/test_data/pkl', |
|||
'sdf_dir': '/home/wch/brep2sdf/test_data/sdf', |
|||
'valid_data_dir': '/home/wch/brep2sdf/test_data/result/pkl', |
|||
'save_dir': 'checkpoints', |
|||
|
|||
# 训练参数 |
|||
'batch_size': 32, |
|||
'num_workers': 4, |
|||
'num_epochs': 100, |
|||
'learning_rate': 1e-4, |
|||
'min_lr': 1e-6, |
|||
'weight_decay': 0.01, |
|||
'grad_weight': 0.1, |
|||
'max_grad_norm': 1.0, |
|||
|
|||
# 模型参数 |
|||
'brep_feature_dim': 48, |
|||
'use_cf': True, |
|||
'embed_dim': 768, |
|||
'latent_dim': 256, |
|||
|
|||
# wandb参数 |
|||
'use_wandb': True, |
|||
'project_name': 'brep2sdf', |
|||
'run_name': 'training_run', |
|||
'log_interval': 10 |
|||
} |
|||
|
|||
# 创建保存目录 |
|||
os.makedirs(config['save_dir'], exist_ok=True) |
|||
|
|||
# 初始化wandb (添加超时设置和离线模式) |
|||
if config['use_wandb']: |
|||
try: |
|||
wandb.init( |
|||
project=config['project_name'], |
|||
name=config['run_name'], |
|||
config=config, |
|||
settings=wandb.Settings( |
|||
init_timeout=180, # 增加超时时间 |
|||
_disable_stats=True, # 禁用统计 |
|||
_disable_meta=True, # 禁用元数据 |
|||
), |
|||
mode="offline" # 使用离线模式 |
|||
) |
|||
logger.info("Wandb initialized in offline mode") |
|||
except Exception as e: |
|||
logger.warning(f"Failed to initialize wandb: {str(e)}") |
|||
config['use_wandb'] = False # 禁用wandb |
|||
logger.warning("Continuing without wandb logging") |
|||
|
|||
# 初始化训练器并开始训练 |
|||
trainer = Trainer(config) |
|||
trainer.train() |
|||
|
|||
class Trainer: |
|||
def __init__(self, config): |
|||
self.config = config |
|||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|||
logger.info(f"Using device: {self.device}") |
|||
|
|||
# 初始化数据集 |
|||
self.train_dataset = BRepSDFDataset( |
|||
brep_dir=config['brep_dir'], |
|||
sdf_dir=config['sdf_dir'], |
|||
valid_data_dir=config['valid_data_dir'], |
|||
split='train' |
|||
) |
|||
self.val_dataset = BRepSDFDataset( |
|||
brep_dir=config['brep_dir'], |
|||
sdf_dir=config['sdf_dir'], |
|||
valid_data_dir=config['valid_data_dir'], |
|||
split='val' |
|||
) |
|||
|
|||
logger.info(f"Train dataset size: {len(self.train_dataset)}") |
|||
logger.info(f"Val dataset size: {len(self.val_dataset)}") |
|||
|
|||
# 初始化数据加载器 |
|||
self.train_loader = DataLoader( |
|||
self.train_dataset, |
|||
batch_size=config['batch_size'], |
|||
shuffle=True, |
|||
num_workers=config['num_workers'] |
|||
) |
|||
self.val_loader = DataLoader( |
|||
self.val_dataset, |
|||
batch_size=config['batch_size'], |
|||
shuffle=False, |
|||
num_workers=config['num_workers'] |
|||
) |
|||
|
|||
# 初始化模型 |
|||
self.model = BRepToSDF( |
|||
brep_feature_dim=config['brep_feature_dim'], |
|||
use_cf=config['use_cf'], |
|||
embed_dim=config['embed_dim'], |
|||
latent_dim=config['latent_dim'] |
|||
).to(self.device) |
|||
|
|||
# 初始化优化器 |
|||
self.optimizer = optim.AdamW( |
|||
self.model.parameters(), |
|||
lr=config['learning_rate'], |
|||
weight_decay=config['weight_decay'] |
|||
) |
|||
|
|||
# 学习率调度器 |
|||
self.scheduler = optim.lr_scheduler.CosineAnnealingLR( |
|||
self.optimizer, |
|||
T_max=config['num_epochs'], |
|||
eta_min=config['min_lr'] |
|||
) |
|||
|
|||
def train_epoch(self, epoch): |
|||
self.model.train() |
|||
total_loss = 0 |
|||
|
|||
for batch_idx, batch in enumerate(self.train_loader): |
|||
# 获取数据 |
|||
surf_z = batch['surf_z'].to(self.device) |
|||
edge_z = batch['edge_z'].to(self.device) |
|||
surf_p = batch['surf_p'].to(self.device) |
|||
edge_p = batch['edge_p'].to(self.device) |
|||
vert_p = batch['vert_p'].to(self.device) |
|||
query_points = batch['points'].to(self.device) |
|||
gt_sdf = batch['sdf'].to(self.device) |
|||
|
|||
# 前向传播 |
|||
pred_sdf = self.model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) |
|||
|
|||
# 计算损失 |
|||
loss = sdf_loss( |
|||
pred_sdf, |
|||
gt_sdf, |
|||
query_points, |
|||
grad_weight=self.config['grad_weight'] |
|||
) |
|||
|
|||
# 反向传播 |
|||
self.optimizer.zero_grad() |
|||
loss.backward() |
|||
|
|||
# 梯度裁剪 |
|||
torch.nn.utils.clip_grad_norm_( |
|||
self.model.parameters(), |
|||
self.config['max_grad_norm'] |
|||
) |
|||
|
|||
self.optimizer.step() |
|||
total_loss += loss.item() |
|||
|
|||
# 打印训练进度 |
|||
if (batch_idx + 1) % self.config['log_interval'] == 0: |
|||
logger.info(f'Train Epoch: {epoch} [{batch_idx+1}/{len(self.train_loader)}]\t' |
|||
f'Loss: {loss.item():.6f}') |
|||
|
|||
# 记录到wandb |
|||
if self.config['use_wandb'] and (batch_idx + 1) % self.config['log_interval'] == 0: |
|||
wandb.log({ |
|||
'batch_loss': loss.item(), |
|||
'batch': batch_idx, |
|||
'epoch': epoch |
|||
}) |
|||
|
|||
avg_loss = total_loss / len(self.train_loader) |
|||
return avg_loss |
|||
|
|||
def validate(self, epoch): |
|||
self.model.eval() |
|||
total_loss = 0 |
|||
|
|||
with torch.no_grad(): |
|||
for batch in self.val_loader: |
|||
# 获取数据 |
|||
surf_z = batch['surf_z'].to(self.device) |
|||
edge_z = batch['edge_z'].to(self.device) |
|||
surf_p = batch['surf_p'].to(self.device) |
|||
edge_p = batch['edge_p'].to(self.device) |
|||
vert_p = batch['vert_p'].to(self.device) |
|||
query_points = batch['points'].to(self.device) |
|||
gt_sdf = batch['sdf'].to(self.device) |
|||
|
|||
# 前向传播 |
|||
pred_sdf = self.model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) |
|||
|
|||
# 计算损失 |
|||
loss = sdf_loss( |
|||
pred_sdf, |
|||
gt_sdf, |
|||
query_points, |
|||
grad_weight=self.config['grad_weight'] |
|||
) |
|||
|
|||
total_loss += loss.item() |
|||
|
|||
avg_loss = total_loss / len(self.val_loader) |
|||
logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}') |
|||
|
|||
if self.config['use_wandb']: |
|||
wandb.log({ |
|||
'val_loss': avg_loss, |
|||
'epoch': epoch |
|||
}) |
|||
return avg_loss |
|||
|
|||
def train(self): |
|||
best_val_loss = float('inf') |
|||
logger.info("Starting training...") |
|||
|
|||
for epoch in range(1, self.config['num_epochs'] + 1): |
|||
train_loss = self.train_epoch(epoch) |
|||
val_loss = self.validate(epoch) |
|||
self.scheduler.step() |
|||
|
|||
# 保存最佳模型 |
|||
if val_loss < best_val_loss: |
|||
best_val_loss = val_loss |
|||
model_path = os.path.join(self.config['save_dir'], 'best_model.pth') |
|||
torch.save({ |
|||
'epoch': epoch, |
|||
'model_state_dict': self.model.state_dict(), |
|||
'optimizer_state_dict': self.optimizer.state_dict(), |
|||
'val_loss': val_loss, |
|||
}, model_path) |
|||
logger.info(f'Saved best model with val_loss: {val_loss:.6f}') |
|||
|
|||
# 记录训练信息 |
|||
logger.info(f'Epoch: {epoch}\tTrain Loss: {train_loss:.6f}\t' |
|||
f'Val Loss: {val_loss:.6f}\tLR: {self.scheduler.get_last_lr()[0]:.6f}') |
|||
|
|||
# 记录到wandb |
|||
if self.config['use_wandb']: |
|||
wandb.log({ |
|||
'train_loss': train_loss, |
|||
'val_loss': val_loss, |
|||
'learning_rate': self.scheduler.get_last_lr()[0], |
|||
'epoch': epoch |
|||
}) |
|||
|
|||
if __name__ == '__main__': |
|||
main() |
Loading…
Reference in new issue