You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

263 lines
9.5 KiB

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from brep2sdf.data.data import BRepSDFDataset
from brep2sdf.networks.encoder import BRepToSDF, sdf_loss
from brep2sdf.utils.logger import logger
from brep2sdf.config.default_config import get_default_config, load_config
import wandb
def main():
# 获取配置
config = get_default_config()
# 创建所有保存目录
os.makedirs(config.data.model_save_dir, exist_ok=True)
os.makedirs(config.data.result_save_dir, exist_ok=True)
# 初始化wandb
if config.log.use_wandb:
try:
wandb.init(
project=config.log.project_name,
name=config.log.run_name,
config=config.__dict__,
settings=wandb.Settings(
init_timeout=180, # 增加超时时间
_disable_stats=True, # 禁用统计
_disable_meta=True, # 禁用元数据
),
mode="offline" # 使用离线模式
)
logger.info("Wandb initialized in offline mode")
except Exception as e:
logger.warning(f"Failed to initialize wandb: {str(e)}")
config.log.use_wandb = False
logger.warning("Continuing without wandb logging")
# 初始化训练器并开始训练
trainer = Trainer(config)
trainer.train()
class Trainer:
def __init__(self, config):
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {self.device}")
# 初始化数据集
self.train_dataset = BRepSDFDataset(
brep_dir=config.data.brep_dir,
sdf_dir=config.data.sdf_dir,
valid_data_dir=config.data.valid_data_dir,
split='train'
)
self.val_dataset = BRepSDFDataset(
brep_dir=config.data.brep_dir,
sdf_dir=config.data.sdf_dir,
valid_data_dir=config.data.valid_data_dir,
split='val'
)
logger.info(f"Train dataset size: {len(self.train_dataset)}")
logger.info(f"Val dataset size: {len(self.val_dataset)}")
# 初始化数据加载器
self.train_loader = DataLoader(
self.train_dataset,
batch_size=config.train.batch_size,
shuffle=True,
num_workers=config.train.num_workers
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size=config.train.batch_size,
shuffle=False,
num_workers=config.train.num_workers
)
# 初始化模型
self.model = BRepToSDF(
brep_feature_dim=config.model.brep_feature_dim,
use_cf=config.model.use_cf,
embed_dim=config.model.embed_dim,
latent_dim=config.model.latent_dim
).to(self.device)
# 初始化优化器
self.optimizer = optim.AdamW(
self.model.parameters(),
lr=config.train.learning_rate,
weight_decay=config.train.weight_decay
)
# 学习率调度器
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
self.optimizer,
T_max=config.train.num_epochs,
eta_min=config.train.min_lr
)
def train_epoch(self, epoch):
self.model.train()
total_loss = 0
for batch_idx, batch in enumerate(self.train_loader):
# 获取数据
surf_z = batch['surf_z'].to(self.device)
edge_z = batch['edge_z'].to(self.device)
surf_p = batch['surf_p'].to(self.device)
edge_p = batch['edge_p'].to(self.device)
vert_p = batch['vert_p'].to(self.device)
query_points = batch['points'].to(self.device)
gt_sdf = batch['sdf'].to(self.device)
# 前向传播
pred_sdf = self.model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points)
# 计算损失
loss = sdf_loss(
pred_sdf,
gt_sdf,
query_points,
grad_weight=self.config.train.grad_weight
)
# 反向传播
self.optimizer.zero_grad()
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config.train.max_grad_norm
)
self.optimizer.step()
total_loss += loss.item()
# 打印训练进度
if (batch_idx + 1) % self.config.log.log_interval == 0:
logger.info(f'Train Epoch: {epoch} [{batch_idx+1}/{len(self.train_loader)}]\t'
f'Loss: {loss.item():.6f}')
# 记录到wandb
if self.config.log.use_wandb and (batch_idx + 1) % self.config.log.log_interval == 0:
wandb.log({
'batch_loss': loss.item(),
'batch': batch_idx,
'epoch': epoch
})
avg_loss = total_loss / len(self.train_loader)
return avg_loss
def validate(self, epoch):
self.model.eval()
total_loss = 0
with torch.no_grad():
for batch in self.val_loader:
# 获取数据
surf_z = batch['surf_z'].to(self.device)
edge_z = batch['edge_z'].to(self.device)
surf_p = batch['surf_p'].to(self.device)
edge_p = batch['edge_p'].to(self.device)
vert_p = batch['vert_p'].to(self.device)
query_points = batch['points'].to(self.device)
gt_sdf = batch['sdf'].to(self.device)
# 前向传播
pred_sdf = self.model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points)
# 计算损失
loss = sdf_loss(
pred_sdf,
gt_sdf,
query_points,
grad_weight=self.config.train.grad_weight
)
total_loss += loss.item()
avg_loss = total_loss / len(self.val_loader)
logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}')
if self.config.log.use_wandb:
wandb.log({
'val_loss': avg_loss,
'epoch': epoch
})
return avg_loss
def train(self):
best_val_loss = float('inf')
logger.info("Starting training...")
for epoch in range(1, self.config.train.num_epochs + 1):
train_loss = self.train_epoch(epoch)
# 定期验证
if epoch % self.config.train.val_freq == 0:
val_loss = self.validate(epoch)
# 保存最佳模型
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_path = os.path.join(
self.config.data.model_save_dir,
self.config.data.best_model_name.format(
model_name=self.config.data.model_name
)
)
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'val_loss': val_loss,
}, best_model_path)
logger.info(f'Saved best model with val_loss: {val_loss:.6f}')
# 定期保存检查点
if epoch % self.config.train.save_freq == 0:
checkpoint_path = os.path.join(
self.config.data.model_save_dir,
self.config.data.checkpoint_format.format(
model_name=self.config.data.model_name,
epoch=epoch
)
)
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'train_loss': train_loss,
'val_loss': val_loss if epoch % self.config.train.val_freq == 0 else None,
}, checkpoint_path)
logger.info(f'Saved checkpoint at epoch {epoch}')
self.scheduler.step()
# 记录训练信息
log_info = f'Epoch: {epoch}\tTrain Loss: {train_loss:.6f}\t'
if epoch % self.config.train.val_freq == 0:
log_info += f'Val Loss: {val_loss:.6f}\t'
log_info += f'LR: {self.scheduler.get_last_lr()[0]:.6f}'
logger.info(log_info)
# 记录到wandb
if self.config.log.use_wandb:
log_dict = {
'train_loss': train_loss,
'learning_rate': self.scheduler.get_last_lr()[0],
'epoch': epoch
}
if epoch % self.config.train.val_freq == 0:
log_dict['val_loss'] = val_loss
wandb.log(log_dict)
if __name__ == '__main__':
main()