|
|
|
import torch
|
|
|
|
import torch.optim as optim
|
|
|
|
import numpy as np
|
|
|
|
from utils.logger import logger
|
|
|
|
|
|
|
|
class LearningRateSchedule:
|
|
|
|
def get_learning_rate(self, epoch):
|
|
|
|
pass
|
|
|
|
|
|
|
|
class StepLearningRateSchedule(LearningRateSchedule):
|
|
|
|
def __init__(self, initial, interval, factor):
|
|
|
|
"""
|
|
|
|
初始化步进学习率调度器
|
|
|
|
:param initial_lr: 初始学习率
|
|
|
|
:param interval: 衰减间隔
|
|
|
|
:param factor: 衰减因子
|
|
|
|
"""
|
|
|
|
self.initial = initial
|
|
|
|
self.interval = interval
|
|
|
|
self.factor = factor
|
|
|
|
|
|
|
|
def get_learning_rate(self, epoch):
|
|
|
|
"""
|
|
|
|
获取当前学习率
|
|
|
|
:param epoch: 当前训练周期
|
|
|
|
:return: 当前学习率
|
|
|
|
"""
|
|
|
|
return np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6)
|
|
|
|
|
|
|
|
class LearningRateScheduler:
|
|
|
|
def __init__(self, lr_schedules, weight_decay, network_params):
|
|
|
|
try:
|
|
|
|
self.lr_schedules = self.get_learning_rate_schedules(lr_schedules)
|
|
|
|
self.weight_decay = weight_decay
|
|
|
|
|
|
|
|
self.startepoch = 0
|
|
|
|
self.optimizer = torch.optim.Adam([{
|
|
|
|
"params": network_params,
|
|
|
|
"lr": self.lr_schedules[0].get_learning_rate(0),
|
|
|
|
"weight_decay": self.weight_decay
|
|
|
|
}])
|
|
|
|
self.best_loss = float('inf')
|
|
|
|
self.patience = 20
|
|
|
|
self.decay_factor = 0.5
|
|
|
|
initial_lr = self.lr_schedules[0].get_learning_rate(0)
|
|
|
|
self.lr = initial_lr
|
|
|
|
self.epochs_since_improvement = 0
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error setting up optimizer: {str(e)}")
|
|
|
|
raise
|
|
|
|
|
|
|
|
def step(self, current_loss, current_epoch):
|
|
|
|
"""
|
|
|
|
更新学习率
|
|
|
|
:param current_loss: 当前验证损失
|
|
|
|
先 self.adjust_learning_rate 基于 epoch 进行一个整体 lr 更新
|
|
|
|
然后 基于 loss, 动态进行调整
|
|
|
|
"""
|
|
|
|
self.adjust_learning_rate(current_epoch)
|
|
|
|
'''
|
|
|
|
if current_loss < self.best_loss:
|
|
|
|
self.best_loss = current_loss
|
|
|
|
self.epochs_since_improvement = 0
|
|
|
|
else:
|
|
|
|
self.epochs_since_improvement += 1
|
|
|
|
|
|
|
|
if self.epochs_since_improvement >= self.patience:
|
|
|
|
self.lr *= self.decay_factor
|
|
|
|
for param_group in self.optimizer.param_groups:
|
|
|
|
param_group['lr'] = self.lr
|
|
|
|
print(f"学习率更新为: {self.lr:.6f}")
|
|
|
|
self.epochs_since_improvement = 0
|
|
|
|
'''
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
"""
|
|
|
|
重置学习率为初始值
|
|
|
|
"""
|
|
|
|
self.lr = self.initial_lr
|
|
|
|
for param_group in self.optimizer.param_groups:
|
|
|
|
param_group['lr'] = self.lr
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def get_learning_rate_schedules(schedule_specs):
|
|
|
|
"""
|
|
|
|
获取学习率调度策略
|
|
|
|
:param schedule_specs: 学习率调度配置
|
|
|
|
:return: 学习率调度列表
|
|
|
|
"""
|
|
|
|
schedules = []
|
|
|
|
|
|
|
|
for spec in schedule_specs:
|
|
|
|
if spec["Type"] == "Step":
|
|
|
|
schedules.append(
|
|
|
|
StepLearningRateSchedule(
|
|
|
|
spec["Initial"],
|
|
|
|
spec["Interval"],
|
|
|
|
spec["Factor"],
|
|
|
|
)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise Exception(
|
|
|
|
'no known learning rate schedule of type "{}"'.format(
|
|
|
|
spec["Type"]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
return schedules
|
|
|
|
|
|
|
|
def adjust_learning_rate(self, epoch):
|
|
|
|
"""
|
|
|
|
根据当前周期调整学习率
|
|
|
|
:param epoch: 当前训练周期
|
|
|
|
"""
|
|
|
|
for i, param_group in enumerate(self.optimizer.param_groups):
|
|
|
|
param_group["lr"] = self.lr_schedules[i].get_learning_rate(epoch) # 使用当前学习率更新优化器的学习率
|
|
|
|
|