You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

294 lines
11 KiB

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from brep2sdf.data.data import BRepSDFDataset
from brep2sdf.networks.network import BRepToSDF
from brep2sdf.networks.loss import Brep2SDFLoss
from brep2sdf.utils.logger import logger
from brep2sdf.config.default_config import get_default_config, load_config
import wandb
def main():
# 获取配置
config = get_default_config()
# 创建所有保存目录
os.makedirs(config.data.model_save_dir, exist_ok=True)
os.makedirs(config.data.result_save_dir, exist_ok=True)
# 初始化wandb
if config.log.use_wandb:
try:
wandb.init(
project=config.log.project_name,
name=config.log.run_name,
config=config.__dict__,
settings=wandb.Settings(
init_timeout=180, # 增加超时时间
_disable_stats=True, # 禁用统计
_disable_meta=True, # 禁用元数据
),
mode="offline" # 使用离线模式
)
logger.info("Wandb initialized in offline mode")
except Exception as e:
logger.warning(f"Failed to initialize wandb: {str(e)}")
config.log.use_wandb = False
logger.warning("Continuing without wandb logging")
# 初始化训练器并开始训练
trainer = Trainer(config)
trainer.train()
class Trainer:
def __init__(self, config):
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
clamping_distance = self.config.train.clamping_distance
self.criterion = Brep2SDFLoss(
batch_size = config.train.batch_size,
enforce_minmax= (clamping_distance > 0),
clamping_distance= clamping_distance
)
use_pin_memory = self.device.type == 'cuda' # 根据设备类型决定是否使用pin_memory
logger.info(f"Using device: {self.device}")
# 初始化数据集
self.train_dataset = BRepSDFDataset(
brep_dir=config.data.brep_dir,
sdf_dir=config.data.sdf_dir,
valid_data_dir=config.data.valid_data_dir,
split='train'
)
self.val_dataset = BRepSDFDataset(
brep_dir=config.data.brep_dir,
sdf_dir=config.data.sdf_dir,
valid_data_dir=config.data.valid_data_dir,
split='val'
)
logger.info(f"Train dataset size: {len(self.train_dataset)}")
logger.info(f"Val dataset size: {len(self.val_dataset)}")
# 初始化数据加载器
self.train_loader = DataLoader(
self.train_dataset,
batch_size=config.train.batch_size,
shuffle=True,
num_workers=config.train.num_workers,
pin_memory=False #use_pin_memory # 根据设备类型设置,是否将数据固定在内存中
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size=config.train.batch_size,
shuffle=False,
num_workers=config.train.num_workers,
pin_memory=False #use_pin_memory # 根据设备类型设置,是否将数据固定在内存中
)
# 初始化模型
self.model = BRepToSDF(config).to(self.device)
# 初始化优化器
self.optimizer = optim.AdamW(
self.model.parameters(),
lr=config.train.learning_rate,
weight_decay=config.train.weight_decay
)
# 学习率调度器
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
self.optimizer,
T_max=config.train.num_epochs,
eta_min=config.train.min_lr
)
def train_epoch(self, epoch):
self.model.train()
total_loss = 0
for batch_idx, batch in enumerate(self.train_loader):
# 清空梯度
self.optimizer.zero_grad()
# 获取数据并移动到设备,同时设置梯度
# 获取数据并移动到设备,同时保留计算图
surf_ncs = batch['surf_ncs'].to(self.device).requires_grad_(True)
edge_ncs = batch['edge_ncs'].to(self.device).requires_grad_(True)
surf_pos = batch['surf_pos'].to(self.device).requires_grad_(True)
edge_pos = batch['edge_pos'].to(self.device).requires_grad_(True)
vertex_pos = batch['vertex_pos'].to(self.device).requires_grad_(True)
points = batch['points'].to(self.device).requires_grad_(True)
#logger.print_tensor_stats("batch surf_ncs",surf_ncs)
# 这些不需要梯度
edge_mask = batch['edge_mask'].to(self.device)
gt_sdf = batch['sdf'].to(self.device)
# 前向传播
pred_sdf = self.model(
surf_ncs=surf_ncs,
edge_ncs=edge_ncs,
surf_pos=surf_pos,
edge_pos=edge_pos,
vertex_pos=vertex_pos,
edge_mask=edge_mask,
query_points=points
)
#logger.info("\n=== SDF值分布 ===")
#logger.info(f"Pred SDF - min: {pred_sdf.min().item():.6f}, max: {pred_sdf.max().item():.6f}, mean: {pred_sdf.mean().item():.6f}")
#logger.info(f"GT SDF - min: {gt_sdf.min().item():.6f}, max: {gt_sdf.max().item():.6f}, mean: {gt_sdf.mean().item():.6f}")
# 计算损失
loss = self.criterion(
pred_sdf=pred_sdf,
gt_sdf=gt_sdf,
)
#logger.print_tensor_stats("after loss batch surf_ncs",surf_ncs)
# 反向传播和优化
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config.train.max_grad_norm
)
self.optimizer.step()
total_loss += loss.item()
# 记录训练进度
if (batch_idx + 1) % self.config.log.log_interval == 0:
logger.info(f'Train Epoch: {epoch} [{batch_idx+1}/{len(self.train_loader)}]\t'
f'Loss: {loss.item():.6f}')
if self.config.log.use_wandb:
wandb.log({
'batch_loss': loss.item(),
'batch': batch_idx,
'epoch': epoch
})
avg_loss = total_loss / len(self.train_loader)
return avg_loss
def validate(self, epoch):
self.model.eval()
total_loss = 0
with torch.no_grad():
for batch in self.val_loader:
# 获取数据并移动到设备
surf_ncs = batch['surf_ncs'].to(self.device)
edge_ncs = batch['edge_ncs'].to(self.device)
surf_pos = batch['surf_pos'].to(self.device)
edge_pos = batch['edge_pos'].to(self.device)
vertex_pos = batch['vertex_pos'].to(self.device)
edge_mask = batch['edge_mask'].to(self.device)
points = batch['points'].to(self.device)
gt_sdf = batch['sdf'].to(self.device)
# 前向传播
pred_sdf = self.model(
surf_ncs=surf_ncs, edge_ncs=edge_ncs,
surf_pos=surf_pos, edge_pos=edge_pos,
vertex_pos=vertex_pos, edge_mask=edge_mask,
query_points=points
)
# 计算损失
loss = self.criterion(
pred_sdf=pred_sdf,
gt_sdf=gt_sdf,
)
total_loss += loss.item()
avg_loss = total_loss / len(self.val_loader)
logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}')
if self.config.log.use_wandb:
wandb.log({
'val_loss': avg_loss,
'epoch': epoch
})
return avg_loss
def train(self):
best_val_loss = float('inf')
logger.info("Starting training...")
val_loss = float('inf') # 初始化val_loss
for epoch in range(1, self.config.train.num_epochs + 1):
train_loss = self.train_epoch(epoch)
# 定期验证和保存
if epoch % self.config.train.val_freq == 0:
val_loss = self.validate(epoch)
# 保存最佳模型
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_path = os.path.join(
self.config.data.model_save_dir,
self.config.data.best_model_name.format(
model_name=self.config.data.model_name
)
)
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'val_loss': val_loss,
}, best_model_path)
logger.info(f'Saved best model with val_loss: {val_loss:.6f}')
# 定期保存检查点
if epoch % self.config.train.save_freq == 0:
checkpoint_path = os.path.join(
self.config.data.model_save_dir,
self.config.data.checkpoint_format.format(
model_name=self.config.data.model_name,
epoch=epoch
)
)
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'train_loss': train_loss,
'val_loss': val_loss if epoch % self.config.train.val_freq == 0 else None,
}, checkpoint_path)
logger.info(f'Saved checkpoint at epoch {epoch}')
self.scheduler.step()
# 记录训练信息
log_info = f'Epoch: {epoch}\tTrain Loss: {train_loss:.6f}\t'
if epoch % self.config.train.val_freq == 0:
log_info += f'Val Loss: {val_loss:.6f}\t'
log_info += f'LR: {self.scheduler.get_last_lr()[0]:.6f}'
logger.info(log_info)
# 记录到wandb
if self.config.log.use_wandb:
log_dict = {
'train_loss': train_loss,
'learning_rate': self.scheduler.get_last_lr()[0],
'epoch': epoch
}
if epoch % self.config.train.val_freq == 0:
log_dict['val_loss'] = val_loss
wandb.log(log_dict)
if __name__ == '__main__':
main()