import torch import torch.optim as optim import numpy as np from utils.logger import logger class LearningRateSchedule: def get_learning_rate(self, epoch): pass class StepLearningRateSchedule(LearningRateSchedule): def __init__(self, initial, interval, factor): """ 初始化步进学习率调度器 :param initial_lr: 初始学习率 :param interval: 衰减间隔 :param factor: 衰减因子 """ self.initial = initial self.interval = interval self.factor = factor def get_learning_rate(self, epoch): """ 获取当前学习率 :param epoch: 当前训练周期 :return: 当前学习率 """ return np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6) class LearningRateScheduler: def __init__(self, lr_schedules, weight_decay, network_params): try: self.lr_schedules = self.get_learning_rate_schedules(lr_schedules) self.weight_decay = weight_decay self.startepoch = 0 self.optimizer = torch.optim.Adam([{ "params": network_params, "lr": self.lr_schedules[0].get_learning_rate(0), "weight_decay": self.weight_decay }]) self.best_loss = float('inf') self.patience = 10 self.decay_factor = 0.5 initial_lr = self.lr_schedules[0].get_learning_rate(0) self.lr = initial_lr self.epochs_since_improvement = 0 except Exception as e: logger.error(f"Error setting up optimizer: {str(e)}") raise def step(self, current_loss): """ 更新学习率 :param current_loss: 当前验证损失 """ if current_loss < self.best_loss: self.best_loss = current_loss self.epochs_since_improvement = 0 else: self.epochs_since_improvement += 1 if self.epochs_since_improvement >= self.patience: self.lr *= self.decay_factor for param_group in self.optimizer.param_groups: param_group['lr'] = self.lr print(f"学习率更新为: {self.lr:.6f}") self.epochs_since_improvement = 0 def reset(self): """ 重置学习率为初始值 """ self.lr = self.initial_lr for param_group in self.optimizer.param_groups: param_group['lr'] = self.lr @staticmethod def get_learning_rate_schedules(schedule_specs): """ 获取学习率调度策略 :param schedule_specs: 学习率调度配置 :return: 学习率调度列表 """ schedules = [] for spec in schedule_specs: if spec["Type"] == "Step": schedules.append( StepLearningRateSchedule( spec["Initial"], spec["Interval"], spec["Factor"], ) ) else: raise Exception( 'no known learning rate schedule of type "{}"'.format( spec["Type"] ) ) return schedules def adjust_learning_rate(self, epoch): """ 根据当前周期调整学习率 :param epoch: 当前训练周期 """ for i, param_group in enumerate(self.optimizer.param_groups): param_group["lr"] = self.lr_schedules[i].get_learning_rate(epoch) # 使用当前学习率更新优化器的学习率