You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
294 lines
11 KiB
294 lines
11 KiB
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader
|
|
from brep2sdf.data.data import BRepSDFDataset
|
|
from brep2sdf.networks.network import BRepToSDF
|
|
from brep2sdf.networks.loss import Brep2SDFLoss
|
|
from brep2sdf.utils.logger import logger
|
|
from brep2sdf.config.default_config import get_default_config, load_config
|
|
import wandb
|
|
|
|
def main():
|
|
# 获取配置
|
|
config = get_default_config()
|
|
|
|
# 创建所有保存目录
|
|
os.makedirs(config.data.model_save_dir, exist_ok=True)
|
|
os.makedirs(config.data.result_save_dir, exist_ok=True)
|
|
|
|
# 初始化wandb
|
|
if config.log.use_wandb:
|
|
try:
|
|
wandb.init(
|
|
project=config.log.project_name,
|
|
name=config.log.run_name,
|
|
config=config.__dict__,
|
|
settings=wandb.Settings(
|
|
init_timeout=180, # 增加超时时间
|
|
_disable_stats=True, # 禁用统计
|
|
_disable_meta=True, # 禁用元数据
|
|
),
|
|
mode="offline" # 使用离线模式
|
|
)
|
|
logger.info("Wandb initialized in offline mode")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize wandb: {str(e)}")
|
|
config.log.use_wandb = False
|
|
logger.warning("Continuing without wandb logging")
|
|
|
|
# 初始化训练器并开始训练
|
|
trainer = Trainer(config)
|
|
trainer.train()
|
|
|
|
class Trainer:
|
|
def __init__(self, config):
|
|
self.config = config
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
clamping_distance = self.config.train.clamping_distance
|
|
self.criterion = Brep2SDFLoss(
|
|
batch_size = config.train.batch_size,
|
|
enforce_minmax= (clamping_distance > 0),
|
|
clamping_distance= clamping_distance
|
|
)
|
|
use_pin_memory = self.device.type == 'cuda' # 根据设备类型决定是否使用pin_memory
|
|
|
|
logger.info(f"Using device: {self.device}")
|
|
|
|
# 初始化数据集
|
|
self.train_dataset = BRepSDFDataset(
|
|
brep_dir=config.data.brep_dir,
|
|
sdf_dir=config.data.sdf_dir,
|
|
valid_data_dir=config.data.valid_data_dir,
|
|
split='train'
|
|
)
|
|
self.val_dataset = BRepSDFDataset(
|
|
brep_dir=config.data.brep_dir,
|
|
sdf_dir=config.data.sdf_dir,
|
|
valid_data_dir=config.data.valid_data_dir,
|
|
split='val'
|
|
)
|
|
|
|
logger.info(f"Train dataset size: {len(self.train_dataset)}")
|
|
logger.info(f"Val dataset size: {len(self.val_dataset)}")
|
|
|
|
# 初始化数据加载器
|
|
self.train_loader = DataLoader(
|
|
self.train_dataset,
|
|
batch_size=config.train.batch_size,
|
|
shuffle=True,
|
|
num_workers=config.train.num_workers,
|
|
pin_memory=False #use_pin_memory # 根据设备类型设置,是否将数据固定在内存中
|
|
)
|
|
self.val_loader = DataLoader(
|
|
self.val_dataset,
|
|
batch_size=config.train.batch_size,
|
|
shuffle=False,
|
|
num_workers=config.train.num_workers,
|
|
pin_memory=False #use_pin_memory # 根据设备类型设置,是否将数据固定在内存中
|
|
)
|
|
|
|
# 初始化模型
|
|
self.model = BRepToSDF(config).to(self.device)
|
|
|
|
# 初始化优化器
|
|
self.optimizer = optim.AdamW(
|
|
self.model.parameters(),
|
|
lr=config.train.learning_rate,
|
|
weight_decay=config.train.weight_decay
|
|
)
|
|
|
|
# 学习率调度器
|
|
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
|
self.optimizer,
|
|
T_max=config.train.num_epochs,
|
|
eta_min=config.train.min_lr
|
|
)
|
|
|
|
def train_epoch(self, epoch):
|
|
self.model.train()
|
|
total_loss = 0
|
|
|
|
for batch_idx, batch in enumerate(self.train_loader):
|
|
# 清空梯度
|
|
self.optimizer.zero_grad()
|
|
|
|
# 获取数据并移动到设备,同时设置梯度
|
|
# 获取数据并移动到设备,同时保留计算图
|
|
surf_ncs = batch['surf_ncs'].to(self.device).requires_grad_(True)
|
|
edge_ncs = batch['edge_ncs'].to(self.device).requires_grad_(True)
|
|
surf_pos = batch['surf_pos'].to(self.device).requires_grad_(True)
|
|
edge_pos = batch['edge_pos'].to(self.device).requires_grad_(True)
|
|
vertex_pos = batch['vertex_pos'].to(self.device).requires_grad_(True)
|
|
points = batch['points'].to(self.device).requires_grad_(True)
|
|
|
|
|
|
#logger.print_tensor_stats("batch surf_ncs",surf_ncs)
|
|
|
|
# 这些不需要梯度
|
|
edge_mask = batch['edge_mask'].to(self.device)
|
|
gt_sdf = batch['sdf'].to(self.device)
|
|
|
|
# 前向传播
|
|
pred_sdf = self.model(
|
|
surf_ncs=surf_ncs,
|
|
edge_ncs=edge_ncs,
|
|
surf_pos=surf_pos,
|
|
edge_pos=edge_pos,
|
|
vertex_pos=vertex_pos,
|
|
edge_mask=edge_mask,
|
|
query_points=points
|
|
)
|
|
|
|
#logger.info("\n=== SDF值分布 ===")
|
|
#logger.info(f"Pred SDF - min: {pred_sdf.min().item():.6f}, max: {pred_sdf.max().item():.6f}, mean: {pred_sdf.mean().item():.6f}")
|
|
#logger.info(f"GT SDF - min: {gt_sdf.min().item():.6f}, max: {gt_sdf.max().item():.6f}, mean: {gt_sdf.mean().item():.6f}")
|
|
|
|
# 计算损失
|
|
loss = self.criterion(
|
|
pred_sdf=pred_sdf,
|
|
gt_sdf=gt_sdf,
|
|
)
|
|
|
|
#logger.print_tensor_stats("after loss batch surf_ncs",surf_ncs)
|
|
|
|
# 反向传播和优化
|
|
loss.backward()
|
|
|
|
# 梯度裁剪
|
|
torch.nn.utils.clip_grad_norm_(
|
|
self.model.parameters(),
|
|
self.config.train.max_grad_norm
|
|
)
|
|
self.optimizer.step()
|
|
|
|
total_loss += loss.item()
|
|
|
|
# 记录训练进度
|
|
if (batch_idx + 1) % self.config.log.log_interval == 0:
|
|
logger.info(f'Train Epoch: {epoch} [{batch_idx+1}/{len(self.train_loader)}]\t'
|
|
f'Loss: {loss.item():.6f}')
|
|
|
|
if self.config.log.use_wandb:
|
|
wandb.log({
|
|
'batch_loss': loss.item(),
|
|
'batch': batch_idx,
|
|
'epoch': epoch
|
|
})
|
|
|
|
avg_loss = total_loss / len(self.train_loader)
|
|
return avg_loss
|
|
|
|
def validate(self, epoch):
|
|
self.model.eval()
|
|
total_loss = 0
|
|
|
|
with torch.no_grad():
|
|
for batch in self.val_loader:
|
|
# 获取数据并移动到设备
|
|
surf_ncs = batch['surf_ncs'].to(self.device)
|
|
edge_ncs = batch['edge_ncs'].to(self.device)
|
|
surf_pos = batch['surf_pos'].to(self.device)
|
|
edge_pos = batch['edge_pos'].to(self.device)
|
|
vertex_pos = batch['vertex_pos'].to(self.device)
|
|
edge_mask = batch['edge_mask'].to(self.device)
|
|
points = batch['points'].to(self.device)
|
|
gt_sdf = batch['sdf'].to(self.device)
|
|
|
|
# 前向传播
|
|
pred_sdf = self.model(
|
|
surf_ncs=surf_ncs, edge_ncs=edge_ncs,
|
|
surf_pos=surf_pos, edge_pos=edge_pos,
|
|
vertex_pos=vertex_pos, edge_mask=edge_mask,
|
|
query_points=points
|
|
)
|
|
|
|
# 计算损失
|
|
loss = self.criterion(
|
|
pred_sdf=pred_sdf,
|
|
gt_sdf=gt_sdf,
|
|
)
|
|
|
|
total_loss += loss.item()
|
|
|
|
avg_loss = total_loss / len(self.val_loader)
|
|
logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}')
|
|
|
|
if self.config.log.use_wandb:
|
|
wandb.log({
|
|
'val_loss': avg_loss,
|
|
'epoch': epoch
|
|
})
|
|
return avg_loss
|
|
|
|
def train(self):
|
|
best_val_loss = float('inf')
|
|
logger.info("Starting training...")
|
|
val_loss = float('inf') # 初始化val_loss
|
|
|
|
for epoch in range(1, self.config.train.num_epochs + 1):
|
|
train_loss = self.train_epoch(epoch)
|
|
|
|
# 定期验证和保存
|
|
if epoch % self.config.train.val_freq == 0:
|
|
val_loss = self.validate(epoch)
|
|
|
|
# 保存最佳模型
|
|
if val_loss < best_val_loss:
|
|
best_val_loss = val_loss
|
|
best_model_path = os.path.join(
|
|
self.config.data.model_save_dir,
|
|
self.config.data.best_model_name.format(
|
|
model_name=self.config.data.model_name
|
|
)
|
|
)
|
|
torch.save({
|
|
'epoch': epoch,
|
|
'model_state_dict': self.model.state_dict(),
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
'val_loss': val_loss,
|
|
}, best_model_path)
|
|
logger.info(f'Saved best model with val_loss: {val_loss:.6f}')
|
|
|
|
# 定期保存检查点
|
|
if epoch % self.config.train.save_freq == 0:
|
|
checkpoint_path = os.path.join(
|
|
self.config.data.model_save_dir,
|
|
self.config.data.checkpoint_format.format(
|
|
model_name=self.config.data.model_name,
|
|
epoch=epoch
|
|
)
|
|
)
|
|
torch.save({
|
|
'epoch': epoch,
|
|
'model_state_dict': self.model.state_dict(),
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
'scheduler_state_dict': self.scheduler.state_dict(),
|
|
'train_loss': train_loss,
|
|
'val_loss': val_loss if epoch % self.config.train.val_freq == 0 else None,
|
|
}, checkpoint_path)
|
|
logger.info(f'Saved checkpoint at epoch {epoch}')
|
|
|
|
self.scheduler.step()
|
|
|
|
# 记录训练信息
|
|
log_info = f'Epoch: {epoch}\tTrain Loss: {train_loss:.6f}\t'
|
|
if epoch % self.config.train.val_freq == 0:
|
|
log_info += f'Val Loss: {val_loss:.6f}\t'
|
|
log_info += f'LR: {self.scheduler.get_last_lr()[0]:.6f}'
|
|
logger.info(log_info)
|
|
|
|
# 记录到wandb
|
|
if self.config.log.use_wandb:
|
|
log_dict = {
|
|
'train_loss': train_loss,
|
|
'learning_rate': self.scheduler.get_last_lr()[0],
|
|
'epoch': epoch
|
|
}
|
|
if epoch % self.config.train.val_freq == 0:
|
|
log_dict['val_loss'] = val_loss
|
|
wandb.log(log_dict)
|
|
|
|
if __name__ == '__main__':
|
|
main()
|