You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

294 lines
11 KiB

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from brep2sdf.data.data import BRepSDFDataset
from brep2sdf.networks.network import BRepToSDF
from brep2sdf.networks.loss import Brep2SDFLoss
from brep2sdf.utils.logger import logger
from brep2sdf.config.default_config import get_default_config, load_config
import wandb
def main():
# 获取配置
config = get_default_config()
# 创建所有保存目录
os.makedirs(config.data.model_save_dir, exist_ok=True)
os.makedirs(config.data.result_save_dir, exist_ok=True)
# 初始化wandb
if config.log.use_wandb:
try:
wandb.init(
project=config.log.project_name,
name=config.log.run_name,
config=config.__dict__,
settings=wandb.Settings(
init_timeout=180, # 增加超时时间
_disable_stats=True, # 禁用统计
_disable_meta=True, # 禁用元数据
),
mode="offline" # 使用离线模式
)
logger.info("Wandb initialized in offline mode")
except Exception as e:
logger.warning(f"Failed to initialize wandb: {str(e)}")
config.log.use_wandb = False
logger.warning("Continuing without wandb logging")
# 初始化训练器并开始训练
trainer = Trainer(config)
trainer.train()
class Trainer:
def __init__(self, config):
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
clamping_distance = self.config.train.clamping_distance
self.criterion = Brep2SDFLoss(
batch_size = config.train.batch_size,
enforce_minmax= (clamping_distance > 0),
clamping_distance= clamping_distance
)
7 months ago
use_pin_memory = self.device.type == 'cuda' # 根据设备类型决定是否使用pin_memory
logger.info(f"Using device: {self.device}")
# 初始化数据集
self.train_dataset = BRepSDFDataset(
brep_dir=config.data.brep_dir,
sdf_dir=config.data.sdf_dir,
valid_data_dir=config.data.valid_data_dir,
split='train'
)
self.val_dataset = BRepSDFDataset(
brep_dir=config.data.brep_dir,
sdf_dir=config.data.sdf_dir,
valid_data_dir=config.data.valid_data_dir,
split='val'
)
logger.info(f"Train dataset size: {len(self.train_dataset)}")
logger.info(f"Val dataset size: {len(self.val_dataset)}")
# 初始化数据加载器
self.train_loader = DataLoader(
self.train_dataset,
batch_size=config.train.batch_size,
shuffle=True,
7 months ago
num_workers=config.train.num_workers,
pin_memory=False #use_pin_memory # 根据设备类型设置,是否将数据固定在内存中
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size=config.train.batch_size,
shuffle=False,
7 months ago
num_workers=config.train.num_workers,
pin_memory=False #use_pin_memory # 根据设备类型设置,是否将数据固定在内存中
)
# 初始化模型
7 months ago
self.model = BRepToSDF(config).to(self.device)
# 初始化优化器
self.optimizer = optim.AdamW(
self.model.parameters(),
lr=config.train.learning_rate,
weight_decay=config.train.weight_decay
)
# 学习率调度器
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
self.optimizer,
T_max=config.train.num_epochs,
eta_min=config.train.min_lr
)
def train_epoch(self, epoch):
self.model.train()
total_loss = 0
for batch_idx, batch in enumerate(self.train_loader):
7 months ago
# 清空梯度
self.optimizer.zero_grad()
# 获取数据并移动到设备,同时设置梯度
# 获取数据并移动到设备,同时保留计算图
surf_ncs = batch['surf_ncs'].to(self.device).requires_grad_(True)
edge_ncs = batch['edge_ncs'].to(self.device).requires_grad_(True)
surf_pos = batch['surf_pos'].to(self.device).requires_grad_(True)
edge_pos = batch['edge_pos'].to(self.device).requires_grad_(True)
vertex_pos = batch['vertex_pos'].to(self.device).requires_grad_(True)
points = batch['points'].to(self.device).requires_grad_(True)
#logger.print_tensor_stats("batch surf_ncs",surf_ncs)
# 这些不需要梯度
edge_mask = batch['edge_mask'].to(self.device)
gt_sdf = batch['sdf'].to(self.device)
# 前向传播
7 months ago
pred_sdf = self.model(
surf_ncs=surf_ncs,
edge_ncs=edge_ncs,
surf_pos=surf_pos,
edge_pos=edge_pos,
vertex_pos=vertex_pos,
edge_mask=edge_mask,
query_points=points
7 months ago
)
6 months ago
#logger.info("\n=== SDF值分布 ===")
#logger.info(f"Pred SDF - min: {pred_sdf.min().item():.6f}, max: {pred_sdf.max().item():.6f}, mean: {pred_sdf.mean().item():.6f}")
#logger.info(f"GT SDF - min: {gt_sdf.min().item():.6f}, max: {gt_sdf.max().item():.6f}, mean: {gt_sdf.mean().item():.6f}")
# 计算损失
loss = self.criterion(
pred_sdf=pred_sdf,
gt_sdf=gt_sdf,
)
#logger.print_tensor_stats("after loss batch surf_ncs",surf_ncs)
7 months ago
# 反向传播和优化
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config.train.max_grad_norm
)
self.optimizer.step()
7 months ago
total_loss += loss.item()
7 months ago
# 记录训练进度
if (batch_idx + 1) % self.config.log.log_interval == 0:
logger.info(f'Train Epoch: {epoch} [{batch_idx+1}/{len(self.train_loader)}]\t'
f'Loss: {loss.item():.6f}')
7 months ago
if self.config.log.use_wandb:
wandb.log({
'batch_loss': loss.item(),
'batch': batch_idx,
'epoch': epoch
})
avg_loss = total_loss / len(self.train_loader)
return avg_loss
def validate(self, epoch):
self.model.eval()
total_loss = 0
with torch.no_grad():
for batch in self.val_loader:
7 months ago
# 获取数据并移动到设备
surf_ncs = batch['surf_ncs'].to(self.device)
edge_ncs = batch['edge_ncs'].to(self.device)
surf_pos = batch['surf_pos'].to(self.device)
edge_pos = batch['edge_pos'].to(self.device)
vertex_pos = batch['vertex_pos'].to(self.device)
edge_mask = batch['edge_mask'].to(self.device)
points = batch['points'].to(self.device)
gt_sdf = batch['sdf'].to(self.device)
# 前向传播
7 months ago
pred_sdf = self.model(
surf_ncs=surf_ncs, edge_ncs=edge_ncs,
surf_pos=surf_pos, edge_pos=edge_pos,
vertex_pos=vertex_pos, edge_mask=edge_mask,
query_points=points
7 months ago
)
# 计算损失
loss = self.criterion(
pred_sdf=pred_sdf,
gt_sdf=gt_sdf,
)
total_loss += loss.item()
avg_loss = total_loss / len(self.val_loader)
logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}')
if self.config.log.use_wandb:
wandb.log({
'val_loss': avg_loss,
'epoch': epoch
})
return avg_loss
def train(self):
best_val_loss = float('inf')
logger.info("Starting training...")
7 months ago
val_loss = float('inf') # 初始化val_loss
for epoch in range(1, self.config.train.num_epochs + 1):
train_loss = self.train_epoch(epoch)
7 months ago
# 定期验证和保存
if epoch % self.config.train.val_freq == 0:
val_loss = self.validate(epoch)
# 保存最佳模型
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_path = os.path.join(
self.config.data.model_save_dir,
self.config.data.best_model_name.format(
model_name=self.config.data.model_name
)
)
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'val_loss': val_loss,
}, best_model_path)
logger.info(f'Saved best model with val_loss: {val_loss:.6f}')
# 定期保存检查点
if epoch % self.config.train.save_freq == 0:
checkpoint_path = os.path.join(
self.config.data.model_save_dir,
self.config.data.checkpoint_format.format(
model_name=self.config.data.model_name,
epoch=epoch
)
)
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'train_loss': train_loss,
'val_loss': val_loss if epoch % self.config.train.val_freq == 0 else None,
}, checkpoint_path)
logger.info(f'Saved checkpoint at epoch {epoch}')
self.scheduler.step()
# 记录训练信息
log_info = f'Epoch: {epoch}\tTrain Loss: {train_loss:.6f}\t'
if epoch % self.config.train.val_freq == 0:
log_info += f'Val Loss: {val_loss:.6f}\t'
log_info += f'LR: {self.scheduler.get_last_lr()[0]:.6f}'
logger.info(log_info)
# 记录到wandb
if self.config.log.use_wandb:
log_dict = {
'train_loss': train_loss,
'learning_rate': self.scheduler.get_last_lr()[0],
'epoch': epoch
}
if epoch % self.config.train.val_freq == 0:
log_dict['val_loss'] = val_loss
wandb.log(log_dict)
if __name__ == '__main__':
main()