You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

114 lines
3.6 KiB

import torch
import torch.optim as optim
import numpy as np
from utils.logger import logger
class LearningRateSchedule:
def get_learning_rate(self, epoch):
pass
class StepLearningRateSchedule(LearningRateSchedule):
def __init__(self, initial, interval, factor):
"""
初始化步进学习率调度器
:param initial_lr: 初始学习率
:param interval: 衰减间隔
:param factor: 衰减因子
"""
self.initial = initial
self.interval = interval
self.factor = factor
def get_learning_rate(self, epoch):
"""
获取当前学习率
:param epoch: 当前训练周期
:return: 当前学习率
"""
return np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6)
class LearningRateScheduler:
def __init__(self, lr_schedules, weight_decay, network_params):
try:
self.lr_schedules = self.get_learning_rate_schedules(lr_schedules)
self.weight_decay = weight_decay
self.startepoch = 0
self.optimizer = torch.optim.Adam([{
"params": network_params,
"lr": self.lr_schedules[0].get_learning_rate(0),
"weight_decay": self.weight_decay
}])
self.best_loss = float('inf')
self.patience = 10
self.decay_factor = 0.5
initial_lr = self.lr_schedules[0].get_learning_rate(0)
self.lr = initial_lr
self.epochs_since_improvement = 0
except Exception as e:
logger.error(f"Error setting up optimizer: {str(e)}")
raise
def step(self, current_loss):
"""
更新学习率
:param current_loss: 当前验证损失
"""
if current_loss < self.best_loss:
self.best_loss = current_loss
self.epochs_since_improvement = 0
else:
self.epochs_since_improvement += 1
if self.epochs_since_improvement >= self.patience:
self.lr *= self.decay_factor
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.lr
print(f"学习率更新为: {self.lr:.6f}")
self.epochs_since_improvement = 0
def reset(self):
"""
重置学习率为初始值
"""
self.lr = self.initial_lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.lr
@staticmethod
def get_learning_rate_schedules(schedule_specs):
"""
获取学习率调度策略
:param schedule_specs: 学习率调度配置
:return: 学习率调度列表
"""
schedules = []
for spec in schedule_specs:
if spec["Type"] == "Step":
schedules.append(
StepLearningRateSchedule(
spec["Initial"],
spec["Interval"],
spec["Factor"],
)
)
else:
raise Exception(
'no known learning rate schedule of type "{}"'.format(
spec["Type"]
)
)
return schedules
def adjust_learning_rate(self, epoch):
"""
根据当前周期调整学习率
:param epoch: 当前训练周期
"""
for i, param_group in enumerate(self.optimizer.param_groups):
param_group["lr"] = self.lr_schedules[i].get_learning_rate(epoch) # 使用当前学习率更新优化器的学习率