You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

119 lines
3.8 KiB

10 months ago
import torch
import torch.optim as optim
import numpy as np
from utils.logger import logger
class LearningRateSchedule:
def get_learning_rate(self, epoch):
pass
class StepLearningRateSchedule(LearningRateSchedule):
def __init__(self, initial, interval, factor):
"""
初始化步进学习率调度器
:param initial_lr: 初始学习率
:param interval: 衰减间隔
:param factor: 衰减因子
"""
self.initial = initial
self.interval = interval
self.factor = factor
def get_learning_rate(self, epoch):
"""
获取当前学习率
:param epoch: 当前训练周期
:return: 当前学习率
"""
return np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6)
class LearningRateScheduler:
def __init__(self, lr_schedules, weight_decay, network_params):
try:
self.lr_schedules = self.get_learning_rate_schedules(lr_schedules)
self.weight_decay = weight_decay
self.startepoch = 0
self.optimizer = torch.optim.Adam([{
"params": network_params,
"lr": self.lr_schedules[0].get_learning_rate(0),
"weight_decay": self.weight_decay
}])
self.best_loss = float('inf')
self.patience = 20
10 months ago
self.decay_factor = 0.5
initial_lr = self.lr_schedules[0].get_learning_rate(0)
self.lr = initial_lr
self.epochs_since_improvement = 0
except Exception as e:
logger.error(f"Error setting up optimizer: {str(e)}")
raise
def step(self, current_loss, current_epoch):
10 months ago
"""
更新学习率
:param current_loss: 当前验证损失
self.adjust_learning_rate 基于 epoch 进行一个整体 lr 更新
然后 基于 loss, 动态进行调整
10 months ago
"""
self.adjust_learning_rate(current_epoch)
'''
10 months ago
if current_loss < self.best_loss:
self.best_loss = current_loss
self.epochs_since_improvement = 0
else:
self.epochs_since_improvement += 1
if self.epochs_since_improvement >= self.patience:
self.lr *= self.decay_factor
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.lr
print(f"学习率更新为: {self.lr:.6f}")
self.epochs_since_improvement = 0
'''
10 months ago
def reset(self):
"""
重置学习率为初始值
"""
self.lr = self.initial_lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.lr
@staticmethod
def get_learning_rate_schedules(schedule_specs):
"""
获取学习率调度策略
:param schedule_specs: 学习率调度配置
:return: 学习率调度列表
"""
schedules = []
for spec in schedule_specs:
if spec["Type"] == "Step":
schedules.append(
StepLearningRateSchedule(
spec["Initial"],
spec["Interval"],
spec["Factor"],
)
)
else:
raise Exception(
'no known learning rate schedule of type "{}"'.format(
spec["Type"]
)
)
return schedules
def adjust_learning_rate(self, epoch):
"""
根据当前周期调整学习率
:param epoch: 当前训练周期
"""
for i, param_group in enumerate(self.optimizer.param_groups):
param_group["lr"] = self.lr_schedules[i].get_learning_rate(epoch) # 使用当前学习率更新优化器的学习率