diff --git a/brep2sdf/networks/learning_rate.py b/brep2sdf/networks/learning_rate.py new file mode 100644 index 0000000..e2af628 --- /dev/null +++ b/brep2sdf/networks/learning_rate.py @@ -0,0 +1,118 @@ +import torch +import torch.optim as optim +import numpy as np +from brep2sdf.utils.logger import logger + +class LearningRateSchedule: + def get_learning_rate(self, epoch): + pass + +class StepLearningRateSchedule(LearningRateSchedule): + def __init__(self, initial, interval, factor): + """ + 初始化步进学习率调度器 + :param initial_lr: 初始学习率 + :param interval: 衰减间隔 + :param factor: 衰减因子 + """ + self.initial = initial + self.interval = interval + self.factor = factor + + def get_learning_rate(self, epoch): + """ + 获取当前学习率 + :param epoch: 当前训练周期 + :return: 当前学习率 + """ + return np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6) + +class LearningRateScheduler: + def __init__(self, lr_schedules, weight_decay, network_params): + try: + self.lr_schedules = self.get_learning_rate_schedules(lr_schedules) + self.weight_decay = weight_decay + + self.startepoch = 0 + self.optimizer = torch.optim.Adam([{ + "params": network_params, + "lr": self.lr_schedules[0].get_learning_rate(0), + "weight_decay": self.weight_decay + }]) + self.best_loss = float('inf') + self.patience = 20 + self.decay_factor = 0.5 + initial_lr = self.lr_schedules[0].get_learning_rate(0) + self.lr = initial_lr + self.epochs_since_improvement = 0 + + except Exception as e: + logger.error(f"Error setting up optimizer: {str(e)}") + raise + + def step(self, current_loss, current_epoch): + """ + 更新学习率 + :param current_loss: 当前验证损失 + 先 self.adjust_learning_rate 基于 epoch 进行一个整体 lr 更新 + 然后 基于 loss, 动态进行调整 + """ + self.adjust_learning_rate(current_epoch) + ''' + if current_loss < self.best_loss: + self.best_loss = current_loss + self.epochs_since_improvement = 0 + else: + self.epochs_since_improvement += 1 + + if self.epochs_since_improvement >= self.patience: + self.lr *= self.decay_factor + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.lr + print(f"学习率更新为: {self.lr:.6f}") + self.epochs_since_improvement = 0 + ''' + + def reset(self): + """ + 重置学习率为初始值 + """ + self.lr = self.initial_lr + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.lr + + @staticmethod + def get_learning_rate_schedules(schedule_specs): + """ + 获取学习率调度策略 + :param schedule_specs: 学习率调度配置 + :return: 学习率调度列表 + """ + schedules = [] + + for spec in schedule_specs: + if spec["Type"] == "Step": + schedules.append( + StepLearningRateSchedule( + spec["Initial"], + spec["Interval"], + spec["Factor"], + ) + ) + else: + raise Exception( + 'no known learning rate schedule of type "{}"'.format( + spec["Type"] + ) + ) + + return schedules + + def adjust_learning_rate(self, epoch): + """ + 根据当前周期调整学习率 + :param epoch: 当前训练周期 + """ + for i, param_group in enumerate(self.optimizer.param_groups): + param_group["lr"] = self.lr_schedules[i].get_learning_rate(epoch) # 使用当前学习率更新优化器的学习率 + diff --git a/brep2sdf/train.py b/brep2sdf/train.py index eedccbf..daa0a62 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -134,14 +134,9 @@ class Trainer: ).to(self.device) logger.gpu_memory_stats("模型初始化后") - # 初始化优化器 - self.optimizer = optim.AdamW( - self.model.parameters(), - lr=config.train.learning_rate, - weight_decay=config.train.weight_decay - ) - #self.scheduler = LearningRateScheduler(self.conf.get_list('train.learning_rate_schedule'), self.conf.get_float('train.weight_decay'), self.model.parameters()) + + self.scheduler = LearningRateScheduler(config.train.learning_rate_schedule, config.train.weight_decay, self.model.parameters()) self.loss_manager = LossManager(ablation="none") logger.gpu_memory_stats("训练器初始化后") @@ -235,12 +230,11 @@ class Trainer: self.optimizer.zero_grad() mnfld_pred = self.model.forward_training_volumes(mnfld_points, step) nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step) - - logger.print_tensor_stats("mnfld_pred",mnfld_pred) - logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) - + if self.debug_mode: # --- 检查前向传播的输出 --- + logger.print_tensor_stats("mnfld_pred",mnfld_pred) + logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) logger.gpu_memory_stats("前向传播后") # --- 计算损失 --- @@ -268,11 +262,11 @@ class Trainer: # 新增:达到累积步数时执行反向传播 if (step + 1) % self.config.train.accumulation_steps == 0: - accumulated_loss.backward() - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) - self.optimizer.step() - self.optimizer.zero_grad() - accumulated_loss = 0.0 # 重置累积loss + # 反向传播 + self.scheduler.optimizer.zero_grad() # 清空梯度 + loss.backward() # 反向传播 + self.scheduler.optimizer.step() # 更新参数 + self.scheduler.step(accumulated_loss,epoch) # 记录日志保持不变 ... @@ -286,9 +280,11 @@ class Trainer: # 新增:处理最后未达到累积步数的剩余loss if accumulated_loss != 0: - accumulated_loss.backward() - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) - self.optimizer.step() + # 反向传播 + self.scheduler.optimizer.zero_grad() # 清空梯度 + loss.backward() # 反向传播 + self.scheduler.optimizer.step() # 更新参数 + self.scheduler.step(accumulated_loss,epoch) # 计算并记录epoch损失 logger.info(f'Train Epoch: {epoch:4d}]\t' @@ -346,15 +342,13 @@ class Trainer: nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 # --- 前向传播 --- - self.optimizer.zero_grad() mnfld_pred = self.model(mnfld_pnts) nonmnfld_pred = self.model(nonmnfld_pnts) - logger.print_tensor_stats("mnfld_pred",mnfld_pred) - logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) - if self.debug_mode: # --- 检查前向传播的输出 --- + logger.print_tensor_stats("mnfld_pred",mnfld_pred) + logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) logger.gpu_memory_stats("前向传播后") # --- 2. 检查模型输出 --- #if check_tensor(pred_sdf, "Predicted SDF (Model Output)", epoch, step): return float('inf') @@ -397,22 +391,11 @@ class Trainer: # --- 反向传播和优化 --- try: - loss.backward() - - # --- 5. (可选) 检查梯度 --- - # for name, param in self.model.named_parameters(): - # if param.grad is not None: - # if check_tensor(param.grad, f"Gradient/{name}", epoch, step): - # logger.warning(f"Epoch {epoch} Step {step}: Bad gradient for {name}. Consider clipping or zeroing.") - # # 例如:param.grad.data.clamp_(-1, 1) # 值裁剪 - # # 或在 optimizer.step() 前进行范数裁剪: - # # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) - - # --- (推荐) 添加梯度裁剪 --- - # 防止梯度爆炸,这可能是导致 inf/nan 的原因之一 - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 范数裁剪 - - self.optimizer.step() + # 反向传播 + self.scheduler.optimizer.zero_grad() # 清空梯度 + loss.backward() # 反向传播 + self.scheduler.optimizer.step() # 更新参数 + self.scheduler.step(loss,epoch) except Exception as backward_e: logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) # 如果你想看是哪个操作导致的,可以启用 anomaly detection @@ -552,7 +535,7 @@ class Trainer: torch.save({ 'epoch': epoch, 'model_state_dict': self.model.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict(), + 'optimizer_state_dict': self.scheduler.optimizer.state_dict(), 'loss': train_loss, }, checkpoint_path) @@ -561,7 +544,7 @@ class Trainer: try: checkpoint = torch.load(checkpoint_path) self.model.load_state_dict(checkpoint['model_state_dict']) - self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + self.scheduler.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) return checkpoint['epoch'] + 1 except Exception as e: logger.error(f"加载checkpoint失败: {str(e)}")