4 changed files with 417 additions and 45 deletions
@ -0,0 +1,113 @@ |
|||||
|
import torch |
||||
|
import torch.optim as optim |
||||
|
import numpy as np |
||||
|
from utils.logger import logger |
||||
|
|
||||
|
class LearningRateSchedule: |
||||
|
def get_learning_rate(self, epoch): |
||||
|
pass |
||||
|
|
||||
|
class StepLearningRateSchedule(LearningRateSchedule): |
||||
|
def __init__(self, initial, interval, factor): |
||||
|
""" |
||||
|
初始化步进学习率调度器 |
||||
|
:param initial_lr: 初始学习率 |
||||
|
:param interval: 衰减间隔 |
||||
|
:param factor: 衰减因子 |
||||
|
""" |
||||
|
self.initial = initial |
||||
|
self.interval = interval |
||||
|
self.factor = factor |
||||
|
|
||||
|
def get_learning_rate(self, epoch): |
||||
|
""" |
||||
|
获取当前学习率 |
||||
|
:param epoch: 当前训练周期 |
||||
|
:return: 当前学习率 |
||||
|
""" |
||||
|
return np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6) |
||||
|
|
||||
|
class LearningRateScheduler: |
||||
|
def __init__(self, lr_schedules, weight_decay, network_params): |
||||
|
try: |
||||
|
self.lr_schedules = self.get_learning_rate_schedules(lr_schedules) |
||||
|
self.weight_decay = weight_decay |
||||
|
|
||||
|
self.startepoch = 0 |
||||
|
self.optimizer = torch.optim.Adam([{ |
||||
|
"params": network_params, |
||||
|
"lr": self.lr_schedules[0].get_learning_rate(0), |
||||
|
"weight_decay": self.weight_decay |
||||
|
}]) |
||||
|
self.best_loss = float('inf') |
||||
|
self.patience = 10 |
||||
|
self.decay_factor = 0.5 |
||||
|
initial_lr = self.lr_schedules[0].get_learning_rate(0) |
||||
|
self.lr = initial_lr |
||||
|
self.epochs_since_improvement = 0 |
||||
|
|
||||
|
except Exception as e: |
||||
|
logger.error(f"Error setting up optimizer: {str(e)}") |
||||
|
raise |
||||
|
|
||||
|
def step(self, current_loss): |
||||
|
""" |
||||
|
更新学习率 |
||||
|
:param current_loss: 当前验证损失 |
||||
|
""" |
||||
|
if current_loss < self.best_loss: |
||||
|
self.best_loss = current_loss |
||||
|
self.epochs_since_improvement = 0 |
||||
|
else: |
||||
|
self.epochs_since_improvement += 1 |
||||
|
|
||||
|
if self.epochs_since_improvement >= self.patience: |
||||
|
self.lr *= self.decay_factor |
||||
|
for param_group in self.optimizer.param_groups: |
||||
|
param_group['lr'] = self.lr |
||||
|
print(f"学习率更新为: {self.lr:.6f}") |
||||
|
self.epochs_since_improvement = 0 |
||||
|
|
||||
|
def reset(self): |
||||
|
""" |
||||
|
重置学习率为初始值 |
||||
|
""" |
||||
|
self.lr = self.initial_lr |
||||
|
for param_group in self.optimizer.param_groups: |
||||
|
param_group['lr'] = self.lr |
||||
|
|
||||
|
@staticmethod |
||||
|
def get_learning_rate_schedules(schedule_specs): |
||||
|
""" |
||||
|
获取学习率调度策略 |
||||
|
:param schedule_specs: 学习率调度配置 |
||||
|
:return: 学习率调度列表 |
||||
|
""" |
||||
|
schedules = [] |
||||
|
|
||||
|
for spec in schedule_specs: |
||||
|
if spec["Type"] == "Step": |
||||
|
schedules.append( |
||||
|
StepLearningRateSchedule( |
||||
|
spec["Initial"], |
||||
|
spec["Interval"], |
||||
|
spec["Factor"], |
||||
|
) |
||||
|
) |
||||
|
else: |
||||
|
raise Exception( |
||||
|
'no known learning rate schedule of type "{}"'.format( |
||||
|
spec["Type"] |
||||
|
) |
||||
|
) |
||||
|
|
||||
|
return schedules |
||||
|
|
||||
|
def adjust_learning_rate(self, epoch): |
||||
|
""" |
||||
|
根据当前周期调整学习率 |
||||
|
:param epoch: 当前训练周期 |
||||
|
""" |
||||
|
for i, param_group in enumerate(self.optimizer.param_groups): |
||||
|
param_group["lr"] = self.lr_schedules[i].get_learning_rate(epoch) # 使用当前学习率更新优化器的学习率 |
||||
|
|
Loading…
Reference in new issue