4 changed files with 417 additions and 45 deletions
@ -0,0 +1,113 @@ |
|||
import torch |
|||
import torch.optim as optim |
|||
import numpy as np |
|||
from utils.logger import logger |
|||
|
|||
class LearningRateSchedule: |
|||
def get_learning_rate(self, epoch): |
|||
pass |
|||
|
|||
class StepLearningRateSchedule(LearningRateSchedule): |
|||
def __init__(self, initial, interval, factor): |
|||
""" |
|||
初始化步进学习率调度器 |
|||
:param initial_lr: 初始学习率 |
|||
:param interval: 衰减间隔 |
|||
:param factor: 衰减因子 |
|||
""" |
|||
self.initial = initial |
|||
self.interval = interval |
|||
self.factor = factor |
|||
|
|||
def get_learning_rate(self, epoch): |
|||
""" |
|||
获取当前学习率 |
|||
:param epoch: 当前训练周期 |
|||
:return: 当前学习率 |
|||
""" |
|||
return np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6) |
|||
|
|||
class LearningRateScheduler: |
|||
def __init__(self, lr_schedules, weight_decay, network_params): |
|||
try: |
|||
self.lr_schedules = self.get_learning_rate_schedules(lr_schedules) |
|||
self.weight_decay = weight_decay |
|||
|
|||
self.startepoch = 0 |
|||
self.optimizer = torch.optim.Adam([{ |
|||
"params": network_params, |
|||
"lr": self.lr_schedules[0].get_learning_rate(0), |
|||
"weight_decay": self.weight_decay |
|||
}]) |
|||
self.best_loss = float('inf') |
|||
self.patience = 10 |
|||
self.decay_factor = 0.5 |
|||
initial_lr = self.lr_schedules[0].get_learning_rate(0) |
|||
self.lr = initial_lr |
|||
self.epochs_since_improvement = 0 |
|||
|
|||
except Exception as e: |
|||
logger.error(f"Error setting up optimizer: {str(e)}") |
|||
raise |
|||
|
|||
def step(self, current_loss): |
|||
""" |
|||
更新学习率 |
|||
:param current_loss: 当前验证损失 |
|||
""" |
|||
if current_loss < self.best_loss: |
|||
self.best_loss = current_loss |
|||
self.epochs_since_improvement = 0 |
|||
else: |
|||
self.epochs_since_improvement += 1 |
|||
|
|||
if self.epochs_since_improvement >= self.patience: |
|||
self.lr *= self.decay_factor |
|||
for param_group in self.optimizer.param_groups: |
|||
param_group['lr'] = self.lr |
|||
print(f"学习率更新为: {self.lr:.6f}") |
|||
self.epochs_since_improvement = 0 |
|||
|
|||
def reset(self): |
|||
""" |
|||
重置学习率为初始值 |
|||
""" |
|||
self.lr = self.initial_lr |
|||
for param_group in self.optimizer.param_groups: |
|||
param_group['lr'] = self.lr |
|||
|
|||
@staticmethod |
|||
def get_learning_rate_schedules(schedule_specs): |
|||
""" |
|||
获取学习率调度策略 |
|||
:param schedule_specs: 学习率调度配置 |
|||
:return: 学习率调度列表 |
|||
""" |
|||
schedules = [] |
|||
|
|||
for spec in schedule_specs: |
|||
if spec["Type"] == "Step": |
|||
schedules.append( |
|||
StepLearningRateSchedule( |
|||
spec["Initial"], |
|||
spec["Interval"], |
|||
spec["Factor"], |
|||
) |
|||
) |
|||
else: |
|||
raise Exception( |
|||
'no known learning rate schedule of type "{}"'.format( |
|||
spec["Type"] |
|||
) |
|||
) |
|||
|
|||
return schedules |
|||
|
|||
def adjust_learning_rate(self, epoch): |
|||
""" |
|||
根据当前周期调整学习率 |
|||
:param epoch: 当前训练周期 |
|||
""" |
|||
for i, param_group in enumerate(self.optimizer.param_groups): |
|||
param_group["lr"] = self.lr_schedules[i].get_learning_rate(epoch) # 使用当前学习率更新优化器的学习率 |
|||
|
Loading…
Reference in new issue