|
|
|
import os
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.optim as optim
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from brep2sdf.data.data import BRepSDFDataset
|
|
|
|
from brep2sdf.networks.encoder import BRepToSDF, sdf_loss
|
|
|
|
from brep2sdf.utils.logger import logger
|
|
|
|
from brep2sdf.config.default_config import get_default_config, load_config
|
|
|
|
import wandb
|
|
|
|
|
|
|
|
def main():
|
|
|
|
# 获取配置
|
|
|
|
config = get_default_config()
|
|
|
|
|
|
|
|
# 创建所有保存目录
|
|
|
|
os.makedirs(config.data.model_save_dir, exist_ok=True)
|
|
|
|
os.makedirs(config.data.result_save_dir, exist_ok=True)
|
|
|
|
|
|
|
|
# 初始化wandb
|
|
|
|
if config.log.use_wandb:
|
|
|
|
try:
|
|
|
|
wandb.init(
|
|
|
|
project=config.log.project_name,
|
|
|
|
name=config.log.run_name,
|
|
|
|
config=config.__dict__,
|
|
|
|
settings=wandb.Settings(
|
|
|
|
init_timeout=180, # 增加超时时间
|
|
|
|
_disable_stats=True, # 禁用统计
|
|
|
|
_disable_meta=True, # 禁用元数据
|
|
|
|
),
|
|
|
|
mode="offline" # 使用离线模式
|
|
|
|
)
|
|
|
|
logger.info("Wandb initialized in offline mode")
|
|
|
|
except Exception as e:
|
|
|
|
logger.warning(f"Failed to initialize wandb: {str(e)}")
|
|
|
|
config.log.use_wandb = False
|
|
|
|
logger.warning("Continuing without wandb logging")
|
|
|
|
|
|
|
|
# 初始化训练器并开始训练
|
|
|
|
trainer = Trainer(config)
|
|
|
|
trainer.train()
|
|
|
|
|
|
|
|
class Trainer:
|
|
|
|
def __init__(self, config):
|
|
|
|
self.config = config
|
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
logger.info(f"Using device: {self.device}")
|
|
|
|
|
|
|
|
# 初始化数据集
|
|
|
|
self.train_dataset = BRepSDFDataset(
|
|
|
|
brep_dir=config.data.brep_dir,
|
|
|
|
sdf_dir=config.data.sdf_dir,
|
|
|
|
valid_data_dir=config.data.valid_data_dir,
|
|
|
|
split='train'
|
|
|
|
)
|
|
|
|
self.val_dataset = BRepSDFDataset(
|
|
|
|
brep_dir=config.data.brep_dir,
|
|
|
|
sdf_dir=config.data.sdf_dir,
|
|
|
|
valid_data_dir=config.data.valid_data_dir,
|
|
|
|
split='val'
|
|
|
|
)
|
|
|
|
|
|
|
|
logger.info(f"Train dataset size: {len(self.train_dataset)}")
|
|
|
|
logger.info(f"Val dataset size: {len(self.val_dataset)}")
|
|
|
|
|
|
|
|
# 初始化数据加载器
|
|
|
|
self.train_loader = DataLoader(
|
|
|
|
self.train_dataset,
|
|
|
|
batch_size=config.train.batch_size,
|
|
|
|
shuffle=True,
|
|
|
|
num_workers=config.train.num_workers
|
|
|
|
)
|
|
|
|
self.val_loader = DataLoader(
|
|
|
|
self.val_dataset,
|
|
|
|
batch_size=config.train.batch_size,
|
|
|
|
shuffle=False,
|
|
|
|
num_workers=config.train.num_workers
|
|
|
|
)
|
|
|
|
|
|
|
|
# 初始化模型
|
|
|
|
self.model = BRepToSDF(
|
|
|
|
brep_feature_dim=config.model.brep_feature_dim,
|
|
|
|
use_cf=config.model.use_cf,
|
|
|
|
embed_dim=config.model.embed_dim,
|
|
|
|
latent_dim=config.model.latent_dim
|
|
|
|
).to(self.device)
|
|
|
|
|
|
|
|
# 初始化优化器
|
|
|
|
self.optimizer = optim.AdamW(
|
|
|
|
self.model.parameters(),
|
|
|
|
lr=config.train.learning_rate,
|
|
|
|
weight_decay=config.train.weight_decay
|
|
|
|
)
|
|
|
|
|
|
|
|
# 学习率调度器
|
|
|
|
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
|
|
|
self.optimizer,
|
|
|
|
T_max=config.train.num_epochs,
|
|
|
|
eta_min=config.train.min_lr
|
|
|
|
)
|
|
|
|
|
|
|
|
def train_epoch(self, epoch):
|
|
|
|
self.model.train()
|
|
|
|
total_loss = 0
|
|
|
|
|
|
|
|
for batch_idx, batch in enumerate(self.train_loader):
|
|
|
|
# 获取数据
|
|
|
|
surf_z = batch['surf_z'].to(self.device)
|
|
|
|
edge_z = batch['edge_z'].to(self.device)
|
|
|
|
surf_p = batch['surf_p'].to(self.device)
|
|
|
|
edge_p = batch['edge_p'].to(self.device)
|
|
|
|
vert_p = batch['vert_p'].to(self.device)
|
|
|
|
query_points = batch['points'].to(self.device)
|
|
|
|
gt_sdf = batch['sdf'].to(self.device)
|
|
|
|
|
|
|
|
# 前向传播
|
|
|
|
pred_sdf = self.model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points)
|
|
|
|
|
|
|
|
# 计算损失
|
|
|
|
loss = sdf_loss(
|
|
|
|
pred_sdf,
|
|
|
|
gt_sdf,
|
|
|
|
query_points,
|
|
|
|
grad_weight=self.config.train.grad_weight
|
|
|
|
)
|
|
|
|
|
|
|
|
# 反向传播
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
# 梯度裁剪
|
|
|
|
torch.nn.utils.clip_grad_norm_(
|
|
|
|
self.model.parameters(),
|
|
|
|
self.config.train.max_grad_norm
|
|
|
|
)
|
|
|
|
|
|
|
|
self.optimizer.step()
|
|
|
|
total_loss += loss.item()
|
|
|
|
|
|
|
|
# 打印训练进度
|
|
|
|
if (batch_idx + 1) % self.config.log.log_interval == 0:
|
|
|
|
logger.info(f'Train Epoch: {epoch} [{batch_idx+1}/{len(self.train_loader)}]\t'
|
|
|
|
f'Loss: {loss.item():.6f}')
|
|
|
|
|
|
|
|
# 记录到wandb
|
|
|
|
if self.config.log.use_wandb and (batch_idx + 1) % self.config.log.log_interval == 0:
|
|
|
|
wandb.log({
|
|
|
|
'batch_loss': loss.item(),
|
|
|
|
'batch': batch_idx,
|
|
|
|
'epoch': epoch
|
|
|
|
})
|
|
|
|
|
|
|
|
avg_loss = total_loss / len(self.train_loader)
|
|
|
|
return avg_loss
|
|
|
|
|
|
|
|
def validate(self, epoch):
|
|
|
|
self.model.eval()
|
|
|
|
total_loss = 0
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
for batch in self.val_loader:
|
|
|
|
# 获取数据
|
|
|
|
surf_z = batch['surf_z'].to(self.device)
|
|
|
|
edge_z = batch['edge_z'].to(self.device)
|
|
|
|
surf_p = batch['surf_p'].to(self.device)
|
|
|
|
edge_p = batch['edge_p'].to(self.device)
|
|
|
|
vert_p = batch['vert_p'].to(self.device)
|
|
|
|
query_points = batch['points'].to(self.device)
|
|
|
|
gt_sdf = batch['sdf'].to(self.device)
|
|
|
|
|
|
|
|
# 前向传播
|
|
|
|
pred_sdf = self.model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points)
|
|
|
|
|
|
|
|
# 计算损失
|
|
|
|
loss = sdf_loss(
|
|
|
|
pred_sdf,
|
|
|
|
gt_sdf,
|
|
|
|
query_points,
|
|
|
|
grad_weight=self.config.train.grad_weight
|
|
|
|
)
|
|
|
|
|
|
|
|
total_loss += loss.item()
|
|
|
|
|
|
|
|
avg_loss = total_loss / len(self.val_loader)
|
|
|
|
logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}')
|
|
|
|
|
|
|
|
if self.config.log.use_wandb:
|
|
|
|
wandb.log({
|
|
|
|
'val_loss': avg_loss,
|
|
|
|
'epoch': epoch
|
|
|
|
})
|
|
|
|
return avg_loss
|
|
|
|
|
|
|
|
def train(self):
|
|
|
|
best_val_loss = float('inf')
|
|
|
|
logger.info("Starting training...")
|
|
|
|
|
|
|
|
for epoch in range(1, self.config.train.num_epochs + 1):
|
|
|
|
train_loss = self.train_epoch(epoch)
|
|
|
|
|
|
|
|
# 定期验证
|
|
|
|
if epoch % self.config.train.val_freq == 0:
|
|
|
|
val_loss = self.validate(epoch)
|
|
|
|
|
|
|
|
# 保存最佳模型
|
|
|
|
if val_loss < best_val_loss:
|
|
|
|
best_val_loss = val_loss
|
|
|
|
best_model_path = os.path.join(
|
|
|
|
self.config.data.model_save_dir,
|
|
|
|
self.config.data.best_model_name.format(
|
|
|
|
model_name=self.config.data.model_name
|
|
|
|
)
|
|
|
|
)
|
|
|
|
torch.save({
|
|
|
|
'epoch': epoch,
|
|
|
|
'model_state_dict': self.model.state_dict(),
|
|
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
|
|
'val_loss': val_loss,
|
|
|
|
}, best_model_path)
|
|
|
|
logger.info(f'Saved best model with val_loss: {val_loss:.6f}')
|
|
|
|
|
|
|
|
# 定期保存检查点
|
|
|
|
if epoch % self.config.train.save_freq == 0:
|
|
|
|
checkpoint_path = os.path.join(
|
|
|
|
self.config.data.model_save_dir,
|
|
|
|
self.config.data.checkpoint_format.format(
|
|
|
|
model_name=self.config.data.model_name,
|
|
|
|
epoch=epoch
|
|
|
|
)
|
|
|
|
)
|
|
|
|
torch.save({
|
|
|
|
'epoch': epoch,
|
|
|
|
'model_state_dict': self.model.state_dict(),
|
|
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
|
|
'scheduler_state_dict': self.scheduler.state_dict(),
|
|
|
|
'train_loss': train_loss,
|
|
|
|
'val_loss': val_loss if epoch % self.config.train.val_freq == 0 else None,
|
|
|
|
}, checkpoint_path)
|
|
|
|
logger.info(f'Saved checkpoint at epoch {epoch}')
|
|
|
|
|
|
|
|
self.scheduler.step()
|
|
|
|
|
|
|
|
# 记录训练信息
|
|
|
|
log_info = f'Epoch: {epoch}\tTrain Loss: {train_loss:.6f}\t'
|
|
|
|
if epoch % self.config.train.val_freq == 0:
|
|
|
|
log_info += f'Val Loss: {val_loss:.6f}\t'
|
|
|
|
log_info += f'LR: {self.scheduler.get_last_lr()[0]:.6f}'
|
|
|
|
logger.info(log_info)
|
|
|
|
|
|
|
|
# 记录到wandb
|
|
|
|
if self.config.log.use_wandb:
|
|
|
|
log_dict = {
|
|
|
|
'train_loss': train_loss,
|
|
|
|
'learning_rate': self.scheduler.get_last_lr()[0],
|
|
|
|
'epoch': epoch
|
|
|
|
}
|
|
|
|
if epoch % self.config.train.val_freq == 0:
|
|
|
|
log_dict['val_loss'] = val_loss
|
|
|
|
wandb.log(log_dict)
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|