4 changed files with 417 additions and 45 deletions
			
			
		@ -0,0 +1,113 @@ | 
				
			|||
import torch | 
				
			|||
import torch.optim as optim | 
				
			|||
import numpy as np | 
				
			|||
from utils.logger import logger | 
				
			|||
 | 
				
			|||
class LearningRateSchedule: | 
				
			|||
    def get_learning_rate(self, epoch): | 
				
			|||
        pass | 
				
			|||
 | 
				
			|||
class StepLearningRateSchedule(LearningRateSchedule): | 
				
			|||
    def __init__(self, initial, interval, factor): | 
				
			|||
        """ | 
				
			|||
        初始化步进学习率调度器 | 
				
			|||
        :param initial_lr: 初始学习率 | 
				
			|||
        :param interval: 衰减间隔 | 
				
			|||
        :param factor: 衰减因子 | 
				
			|||
        """ | 
				
			|||
        self.initial = initial | 
				
			|||
        self.interval = interval | 
				
			|||
        self.factor = factor | 
				
			|||
 | 
				
			|||
    def get_learning_rate(self, epoch): | 
				
			|||
        """ | 
				
			|||
        获取当前学习率 | 
				
			|||
        :param epoch: 当前训练周期 | 
				
			|||
        :return: 当前学习率 | 
				
			|||
        """ | 
				
			|||
        return np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6) | 
				
			|||
 | 
				
			|||
class LearningRateScheduler: | 
				
			|||
    def __init__(self, lr_schedules, weight_decay, network_params): | 
				
			|||
        try: | 
				
			|||
            self.lr_schedules = self.get_learning_rate_schedules(lr_schedules) | 
				
			|||
            self.weight_decay = weight_decay | 
				
			|||
             | 
				
			|||
            self.startepoch = 0 | 
				
			|||
            self.optimizer = torch.optim.Adam([{ | 
				
			|||
                "params": network_params, | 
				
			|||
                "lr": self.lr_schedules[0].get_learning_rate(0), | 
				
			|||
                "weight_decay": self.weight_decay | 
				
			|||
            }]) | 
				
			|||
            self.best_loss = float('inf') | 
				
			|||
            self.patience = 10 | 
				
			|||
            self.decay_factor = 0.5 | 
				
			|||
            initial_lr = self.lr_schedules[0].get_learning_rate(0) | 
				
			|||
            self.lr = initial_lr | 
				
			|||
            self.epochs_since_improvement = 0 | 
				
			|||
             | 
				
			|||
        except Exception as e: | 
				
			|||
            logger.error(f"Error setting up optimizer: {str(e)}") | 
				
			|||
            raise | 
				
			|||
 | 
				
			|||
    def step(self, current_loss): | 
				
			|||
        """ | 
				
			|||
        更新学习率 | 
				
			|||
        :param current_loss: 当前验证损失 | 
				
			|||
        """ | 
				
			|||
        if current_loss < self.best_loss: | 
				
			|||
            self.best_loss = current_loss | 
				
			|||
            self.epochs_since_improvement = 0 | 
				
			|||
        else: | 
				
			|||
            self.epochs_since_improvement += 1 | 
				
			|||
 | 
				
			|||
        if self.epochs_since_improvement >= self.patience: | 
				
			|||
            self.lr *= self.decay_factor | 
				
			|||
            for param_group in self.optimizer.param_groups: | 
				
			|||
                param_group['lr'] = self.lr | 
				
			|||
            print(f"学习率更新为: {self.lr:.6f}") | 
				
			|||
            self.epochs_since_improvement = 0 | 
				
			|||
 | 
				
			|||
    def reset(self): | 
				
			|||
        """ | 
				
			|||
        重置学习率为初始值 | 
				
			|||
        """ | 
				
			|||
        self.lr = self.initial_lr | 
				
			|||
        for param_group in self.optimizer.param_groups: | 
				
			|||
            param_group['lr'] = self.lr | 
				
			|||
 | 
				
			|||
    @staticmethod | 
				
			|||
    def get_learning_rate_schedules(schedule_specs): | 
				
			|||
        """ | 
				
			|||
        获取学习率调度策略 | 
				
			|||
        :param schedule_specs: 学习率调度配置 | 
				
			|||
        :return: 学习率调度列表 | 
				
			|||
        """ | 
				
			|||
        schedules = [] | 
				
			|||
 | 
				
			|||
        for spec in schedule_specs: | 
				
			|||
            if spec["Type"] == "Step": | 
				
			|||
                schedules.append( | 
				
			|||
                    StepLearningRateSchedule( | 
				
			|||
                        spec["Initial"], | 
				
			|||
                        spec["Interval"], | 
				
			|||
                        spec["Factor"], | 
				
			|||
                    ) | 
				
			|||
                ) | 
				
			|||
            else: | 
				
			|||
                raise Exception( | 
				
			|||
                    'no known learning rate schedule of type "{}"'.format( | 
				
			|||
                        spec["Type"] | 
				
			|||
                    ) | 
				
			|||
                ) | 
				
			|||
 | 
				
			|||
        return schedules | 
				
			|||
 | 
				
			|||
    def adjust_learning_rate(self, epoch): | 
				
			|||
        """ | 
				
			|||
        根据当前周期调整学习率 | 
				
			|||
        :param epoch: 当前训练周期 | 
				
			|||
        """ | 
				
			|||
        for i, param_group in enumerate(self.optimizer.param_groups): | 
				
			|||
            param_group["lr"] = self.lr_schedules[i].get_learning_rate(epoch)  # 使用当前学习率更新优化器的学习率 | 
				
			|||
 | 
				
			|||
					Loading…
					
					
				
		Reference in new issue