4 changed files with 417 additions and 45 deletions
			
			
		@ -0,0 +1,113 @@ | 
				
			|||||
 | 
					import torch | 
				
			||||
 | 
					import torch.optim as optim | 
				
			||||
 | 
					import numpy as np | 
				
			||||
 | 
					from utils.logger import logger | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					class LearningRateSchedule: | 
				
			||||
 | 
					    def get_learning_rate(self, epoch): | 
				
			||||
 | 
					        pass | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					class StepLearningRateSchedule(LearningRateSchedule): | 
				
			||||
 | 
					    def __init__(self, initial, interval, factor): | 
				
			||||
 | 
					        """ | 
				
			||||
 | 
					        初始化步进学习率调度器 | 
				
			||||
 | 
					        :param initial_lr: 初始学习率 | 
				
			||||
 | 
					        :param interval: 衰减间隔 | 
				
			||||
 | 
					        :param factor: 衰减因子 | 
				
			||||
 | 
					        """ | 
				
			||||
 | 
					        self.initial = initial | 
				
			||||
 | 
					        self.interval = interval | 
				
			||||
 | 
					        self.factor = factor | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    def get_learning_rate(self, epoch): | 
				
			||||
 | 
					        """ | 
				
			||||
 | 
					        获取当前学习率 | 
				
			||||
 | 
					        :param epoch: 当前训练周期 | 
				
			||||
 | 
					        :return: 当前学习率 | 
				
			||||
 | 
					        """ | 
				
			||||
 | 
					        return np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					class LearningRateScheduler: | 
				
			||||
 | 
					    def __init__(self, lr_schedules, weight_decay, network_params): | 
				
			||||
 | 
					        try: | 
				
			||||
 | 
					            self.lr_schedules = self.get_learning_rate_schedules(lr_schedules) | 
				
			||||
 | 
					            self.weight_decay = weight_decay | 
				
			||||
 | 
					             | 
				
			||||
 | 
					            self.startepoch = 0 | 
				
			||||
 | 
					            self.optimizer = torch.optim.Adam([{ | 
				
			||||
 | 
					                "params": network_params, | 
				
			||||
 | 
					                "lr": self.lr_schedules[0].get_learning_rate(0), | 
				
			||||
 | 
					                "weight_decay": self.weight_decay | 
				
			||||
 | 
					            }]) | 
				
			||||
 | 
					            self.best_loss = float('inf') | 
				
			||||
 | 
					            self.patience = 10 | 
				
			||||
 | 
					            self.decay_factor = 0.5 | 
				
			||||
 | 
					            initial_lr = self.lr_schedules[0].get_learning_rate(0) | 
				
			||||
 | 
					            self.lr = initial_lr | 
				
			||||
 | 
					            self.epochs_since_improvement = 0 | 
				
			||||
 | 
					             | 
				
			||||
 | 
					        except Exception as e: | 
				
			||||
 | 
					            logger.error(f"Error setting up optimizer: {str(e)}") | 
				
			||||
 | 
					            raise | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    def step(self, current_loss): | 
				
			||||
 | 
					        """ | 
				
			||||
 | 
					        更新学习率 | 
				
			||||
 | 
					        :param current_loss: 当前验证损失 | 
				
			||||
 | 
					        """ | 
				
			||||
 | 
					        if current_loss < self.best_loss: | 
				
			||||
 | 
					            self.best_loss = current_loss | 
				
			||||
 | 
					            self.epochs_since_improvement = 0 | 
				
			||||
 | 
					        else: | 
				
			||||
 | 
					            self.epochs_since_improvement += 1 | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        if self.epochs_since_improvement >= self.patience: | 
				
			||||
 | 
					            self.lr *= self.decay_factor | 
				
			||||
 | 
					            for param_group in self.optimizer.param_groups: | 
				
			||||
 | 
					                param_group['lr'] = self.lr | 
				
			||||
 | 
					            print(f"学习率更新为: {self.lr:.6f}") | 
				
			||||
 | 
					            self.epochs_since_improvement = 0 | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    def reset(self): | 
				
			||||
 | 
					        """ | 
				
			||||
 | 
					        重置学习率为初始值 | 
				
			||||
 | 
					        """ | 
				
			||||
 | 
					        self.lr = self.initial_lr | 
				
			||||
 | 
					        for param_group in self.optimizer.param_groups: | 
				
			||||
 | 
					            param_group['lr'] = self.lr | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    @staticmethod | 
				
			||||
 | 
					    def get_learning_rate_schedules(schedule_specs): | 
				
			||||
 | 
					        """ | 
				
			||||
 | 
					        获取学习率调度策略 | 
				
			||||
 | 
					        :param schedule_specs: 学习率调度配置 | 
				
			||||
 | 
					        :return: 学习率调度列表 | 
				
			||||
 | 
					        """ | 
				
			||||
 | 
					        schedules = [] | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        for spec in schedule_specs: | 
				
			||||
 | 
					            if spec["Type"] == "Step": | 
				
			||||
 | 
					                schedules.append( | 
				
			||||
 | 
					                    StepLearningRateSchedule( | 
				
			||||
 | 
					                        spec["Initial"], | 
				
			||||
 | 
					                        spec["Interval"], | 
				
			||||
 | 
					                        spec["Factor"], | 
				
			||||
 | 
					                    ) | 
				
			||||
 | 
					                ) | 
				
			||||
 | 
					            else: | 
				
			||||
 | 
					                raise Exception( | 
				
			||||
 | 
					                    'no known learning rate schedule of type "{}"'.format( | 
				
			||||
 | 
					                        spec["Type"] | 
				
			||||
 | 
					                    ) | 
				
			||||
 | 
					                ) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        return schedules | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    def adjust_learning_rate(self, epoch): | 
				
			||||
 | 
					        """ | 
				
			||||
 | 
					        根据当前周期调整学习率 | 
				
			||||
 | 
					        :param epoch: 当前训练周期 | 
				
			||||
 | 
					        """ | 
				
			||||
 | 
					        for i, param_group in enumerate(self.optimizer.param_groups): | 
				
			||||
 | 
					            param_group["lr"] = self.lr_schedules[i].get_learning_rate(epoch)  # 使用当前学习率更新优化器的学习率 | 
				
			||||
 | 
					
 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue