Browse Source

改了优化器和学习率策略

final
mckay 3 months ago
parent
commit
0c40f0f5e3
  1. 118
      brep2sdf/networks/learning_rate.py
  2. 65
      brep2sdf/train.py

118
brep2sdf/networks/learning_rate.py

@ -0,0 +1,118 @@
import torch
import torch.optim as optim
import numpy as np
from brep2sdf.utils.logger import logger
class LearningRateSchedule:
def get_learning_rate(self, epoch):
pass
class StepLearningRateSchedule(LearningRateSchedule):
def __init__(self, initial, interval, factor):
"""
初始化步进学习率调度器
:param initial_lr: 初始学习率
:param interval: 衰减间隔
:param factor: 衰减因子
"""
self.initial = initial
self.interval = interval
self.factor = factor
def get_learning_rate(self, epoch):
"""
获取当前学习率
:param epoch: 当前训练周期
:return: 当前学习率
"""
return np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6)
class LearningRateScheduler:
def __init__(self, lr_schedules, weight_decay, network_params):
try:
self.lr_schedules = self.get_learning_rate_schedules(lr_schedules)
self.weight_decay = weight_decay
self.startepoch = 0
self.optimizer = torch.optim.Adam([{
"params": network_params,
"lr": self.lr_schedules[0].get_learning_rate(0),
"weight_decay": self.weight_decay
}])
self.best_loss = float('inf')
self.patience = 20
self.decay_factor = 0.5
initial_lr = self.lr_schedules[0].get_learning_rate(0)
self.lr = initial_lr
self.epochs_since_improvement = 0
except Exception as e:
logger.error(f"Error setting up optimizer: {str(e)}")
raise
def step(self, current_loss, current_epoch):
"""
更新学习率
:param current_loss: 当前验证损失
self.adjust_learning_rate 基于 epoch 进行一个整体 lr 更新
然后 基于 loss, 动态进行调整
"""
self.adjust_learning_rate(current_epoch)
'''
if current_loss < self.best_loss:
self.best_loss = current_loss
self.epochs_since_improvement = 0
else:
self.epochs_since_improvement += 1
if self.epochs_since_improvement >= self.patience:
self.lr *= self.decay_factor
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.lr
print(f"学习率更新为: {self.lr:.6f}")
self.epochs_since_improvement = 0
'''
def reset(self):
"""
重置学习率为初始值
"""
self.lr = self.initial_lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.lr
@staticmethod
def get_learning_rate_schedules(schedule_specs):
"""
获取学习率调度策略
:param schedule_specs: 学习率调度配置
:return: 学习率调度列表
"""
schedules = []
for spec in schedule_specs:
if spec["Type"] == "Step":
schedules.append(
StepLearningRateSchedule(
spec["Initial"],
spec["Interval"],
spec["Factor"],
)
)
else:
raise Exception(
'no known learning rate schedule of type "{}"'.format(
spec["Type"]
)
)
return schedules
def adjust_learning_rate(self, epoch):
"""
根据当前周期调整学习率
:param epoch: 当前训练周期
"""
for i, param_group in enumerate(self.optimizer.param_groups):
param_group["lr"] = self.lr_schedules[i].get_learning_rate(epoch) # 使用当前学习率更新优化器的学习率

65
brep2sdf/train.py

@ -134,14 +134,9 @@ class Trainer:
).to(self.device) ).to(self.device)
logger.gpu_memory_stats("模型初始化后") logger.gpu_memory_stats("模型初始化后")
# 初始化优化器
self.optimizer = optim.AdamW(
self.model.parameters(),
lr=config.train.learning_rate,
weight_decay=config.train.weight_decay
)
#self.scheduler = LearningRateScheduler(self.conf.get_list('train.learning_rate_schedule'), self.conf.get_float('train.weight_decay'), self.model.parameters())
self.scheduler = LearningRateScheduler(config.train.learning_rate_schedule, config.train.weight_decay, self.model.parameters())
self.loss_manager = LossManager(ablation="none") self.loss_manager = LossManager(ablation="none")
logger.gpu_memory_stats("训练器初始化后") logger.gpu_memory_stats("训练器初始化后")
@ -235,12 +230,11 @@ class Trainer:
self.optimizer.zero_grad() self.optimizer.zero_grad()
mnfld_pred = self.model.forward_training_volumes(mnfld_points, step) mnfld_pred = self.model.forward_training_volumes(mnfld_points, step)
nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step) nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step)
logger.print_tensor_stats("mnfld_pred",mnfld_pred)
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred)
if self.debug_mode: if self.debug_mode:
# --- 检查前向传播的输出 --- # --- 检查前向传播的输出 ---
logger.print_tensor_stats("mnfld_pred",mnfld_pred)
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred)
logger.gpu_memory_stats("前向传播后") logger.gpu_memory_stats("前向传播后")
# --- 计算损失 --- # --- 计算损失 ---
@ -268,11 +262,11 @@ class Trainer:
# 新增:达到累积步数时执行反向传播 # 新增:达到累积步数时执行反向传播
if (step + 1) % self.config.train.accumulation_steps == 0: if (step + 1) % self.config.train.accumulation_steps == 0:
accumulated_loss.backward() # 反向传播
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.scheduler.optimizer.zero_grad() # 清空梯度
self.optimizer.step() loss.backward() # 反向传播
self.optimizer.zero_grad() self.scheduler.optimizer.step() # 更新参数
accumulated_loss = 0.0 # 重置累积loss self.scheduler.step(accumulated_loss,epoch)
# 记录日志保持不变 ... # 记录日志保持不变 ...
@ -286,9 +280,11 @@ class Trainer:
# 新增:处理最后未达到累积步数的剩余loss # 新增:处理最后未达到累积步数的剩余loss
if accumulated_loss != 0: if accumulated_loss != 0:
accumulated_loss.backward() # 反向传播
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.scheduler.optimizer.zero_grad() # 清空梯度
self.optimizer.step() loss.backward() # 反向传播
self.scheduler.optimizer.step() # 更新参数
self.scheduler.step(accumulated_loss,epoch)
# 计算并记录epoch损失 # 计算并记录epoch损失
logger.info(f'Train Epoch: {epoch:4d}]\t' logger.info(f'Train Epoch: {epoch:4d}]\t'
@ -346,15 +342,13 @@ class Trainer:
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
# --- 前向传播 --- # --- 前向传播 ---
self.optimizer.zero_grad()
mnfld_pred = self.model(mnfld_pnts) mnfld_pred = self.model(mnfld_pnts)
nonmnfld_pred = self.model(nonmnfld_pnts) nonmnfld_pred = self.model(nonmnfld_pnts)
logger.print_tensor_stats("mnfld_pred",mnfld_pred)
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred)
if self.debug_mode: if self.debug_mode:
# --- 检查前向传播的输出 --- # --- 检查前向传播的输出 ---
logger.print_tensor_stats("mnfld_pred",mnfld_pred)
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred)
logger.gpu_memory_stats("前向传播后") logger.gpu_memory_stats("前向传播后")
# --- 2. 检查模型输出 --- # --- 2. 检查模型输出 ---
#if check_tensor(pred_sdf, "Predicted SDF (Model Output)", epoch, step): return float('inf') #if check_tensor(pred_sdf, "Predicted SDF (Model Output)", epoch, step): return float('inf')
@ -397,22 +391,11 @@ class Trainer:
# --- 反向传播和优化 --- # --- 反向传播和优化 ---
try: try:
loss.backward() # 反向传播
self.scheduler.optimizer.zero_grad() # 清空梯度
# --- 5. (可选) 检查梯度 --- loss.backward() # 反向传播
# for name, param in self.model.named_parameters(): self.scheduler.optimizer.step() # 更新参数
# if param.grad is not None: self.scheduler.step(loss,epoch)
# if check_tensor(param.grad, f"Gradient/{name}", epoch, step):
# logger.warning(f"Epoch {epoch} Step {step}: Bad gradient for {name}. Consider clipping or zeroing.")
# # 例如:param.grad.data.clamp_(-1, 1) # 值裁剪
# # 或在 optimizer.step() 前进行范数裁剪:
# # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# --- (推荐) 添加梯度裁剪 ---
# 防止梯度爆炸,这可能是导致 inf/nan 的原因之一
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 范数裁剪
self.optimizer.step()
except Exception as backward_e: except Exception as backward_e:
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True)
# 如果你想看是哪个操作导致的,可以启用 anomaly detection # 如果你想看是哪个操作导致的,可以启用 anomaly detection
@ -552,7 +535,7 @@ class Trainer:
torch.save({ torch.save({
'epoch': epoch, 'epoch': epoch,
'model_state_dict': self.model.state_dict(), 'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(), 'optimizer_state_dict': self.scheduler.optimizer.state_dict(),
'loss': train_loss, 'loss': train_loss,
}, checkpoint_path) }, checkpoint_path)
@ -561,7 +544,7 @@ class Trainer:
try: try:
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path)
self.model.load_state_dict(checkpoint['model_state_dict']) self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch'] + 1 return checkpoint['epoch'] + 1
except Exception as e: except Exception as e:
logger.error(f"加载checkpoint失败: {str(e)}") logger.error(f"加载checkpoint失败: {str(e)}")

Loading…
Cancel
Save