From d42b3b46dd440d15c1373ae49617fe9b9d211e79 Mon Sep 17 00:00:00 2001 From: mckay Date: Sun, 23 Feb 2025 19:35:00 +0800 Subject: [PATCH] NOTE it's base --- code/conversion/learning_rate.py | 113 ++++++++++++++++++ code/conversion/loss.py | 152 +++++++++++++++++++++--- code/conversion/train.py | 195 +++++++++++++++++++++++++++---- code/utils/logger.py | 2 +- 4 files changed, 417 insertions(+), 45 deletions(-) create mode 100644 code/conversion/learning_rate.py diff --git a/code/conversion/learning_rate.py b/code/conversion/learning_rate.py new file mode 100644 index 0000000..ea8908c --- /dev/null +++ b/code/conversion/learning_rate.py @@ -0,0 +1,113 @@ +import torch +import torch.optim as optim +import numpy as np +from utils.logger import logger + +class LearningRateSchedule: + def get_learning_rate(self, epoch): + pass + +class StepLearningRateSchedule(LearningRateSchedule): + def __init__(self, initial, interval, factor): + """ + 初始化步进学习率调度器 + :param initial_lr: 初始学习率 + :param interval: 衰减间隔 + :param factor: 衰减因子 + """ + self.initial = initial + self.interval = interval + self.factor = factor + + def get_learning_rate(self, epoch): + """ + 获取当前学习率 + :param epoch: 当前训练周期 + :return: 当前学习率 + """ + return np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6) + +class LearningRateScheduler: + def __init__(self, lr_schedules, weight_decay, network_params): + try: + self.lr_schedules = self.get_learning_rate_schedules(lr_schedules) + self.weight_decay = weight_decay + + self.startepoch = 0 + self.optimizer = torch.optim.Adam([{ + "params": network_params, + "lr": self.lr_schedules[0].get_learning_rate(0), + "weight_decay": self.weight_decay + }]) + self.best_loss = float('inf') + self.patience = 10 + self.decay_factor = 0.5 + initial_lr = self.lr_schedules[0].get_learning_rate(0) + self.lr = initial_lr + self.epochs_since_improvement = 0 + + except Exception as e: + logger.error(f"Error setting up optimizer: {str(e)}") + raise + + def step(self, current_loss): + """ + 更新学习率 + :param current_loss: 当前验证损失 + """ + if current_loss < self.best_loss: + self.best_loss = current_loss + self.epochs_since_improvement = 0 + else: + self.epochs_since_improvement += 1 + + if self.epochs_since_improvement >= self.patience: + self.lr *= self.decay_factor + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.lr + print(f"学习率更新为: {self.lr:.6f}") + self.epochs_since_improvement = 0 + + def reset(self): + """ + 重置学习率为初始值 + """ + self.lr = self.initial_lr + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.lr + + @staticmethod + def get_learning_rate_schedules(schedule_specs): + """ + 获取学习率调度策略 + :param schedule_specs: 学习率调度配置 + :return: 学习率调度列表 + """ + schedules = [] + + for spec in schedule_specs: + if spec["Type"] == "Step": + schedules.append( + StepLearningRateSchedule( + spec["Initial"], + spec["Interval"], + spec["Factor"], + ) + ) + else: + raise Exception( + 'no known learning rate schedule of type "{}"'.format( + spec["Type"] + ) + ) + + return schedules + + def adjust_learning_rate(self, epoch): + """ + 根据当前周期调整学习率 + :param epoch: 当前训练周期 + """ + for i, param_group in enumerate(self.optimizer.param_groups): + param_group["lr"] = self.lr_schedules[i].get_learning_rate(epoch) # 使用当前学习率更新优化器的学习率 + diff --git a/code/conversion/loss.py b/code/conversion/loss.py index 8bf045b..b48a666 100644 --- a/code/conversion/loss.py +++ b/code/conversion/loss.py @@ -1,9 +1,57 @@ -import torch +import os +import sys +import time + +project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) +sys.path.append(project_dir) +os.chdir(project_dir) +import torch +from model.network import gradient class LossManager: - def __init__(self): - pass + def __init__(self, ablation, **condition_kwargs): + self.weights = { + "manifold": 1, + "feature_manifold": 1, # 原文里面和manifold的权重是一样的 + "normals": 1, + "eikonal": 1, + "offsurface": 1, + "consistency": 1, + "correction": 1, + } + self.condition_kwargs = condition_kwargs + self.ablation = ablation # 消融实验用 + + def _get_condition_kwargs(self, key): + """ + 获取条件参数, 期望 + ab: 损失类型 【overall, patch, off, cons, cc, cor,】 + siren: 是否使用SIREN + epoch: 当前epoch + baseline: 是否为baseline + """ + if key in self.condition_kwargs: + return self.condition_kwargs[key] + else: + raise ValueError(f"Key {key} not found in condition_kwargs") + + + def pre_process(self, mnfld_pnts, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last): + """ + 预处理 + """ + mnfld_pred_h = mnfld_pred_all[:,0] # 提取流形预测结果 + nonmnfld_pred_h = nonmnfld_pred_all[:,0] # 提取非流形预测结果 + mnfld_grad = gradient(mnfld_pnts, mnfld_pred_h) # 计算流形点的梯度 + + all_fi = torch.zeros([n_batchsize, 1], device = 'cuda') # 初始化所有流形预测值 + for i in range(n_branch - 1): + all_fi[i * n_patch_batch : (i + 1) * n_patch_batch, 0] = mnfld_pred_all[i * n_patch_batch : (i + 1) * n_patch_batch, i + 1] # 填充流形预测值 + # last patch + all_fi[(n_branch - 1) * n_patch_batch:, 0] = mnfld_pred_all[(n_branch - 1) * n_patch_batch:, n_branch] # 填充最后一个分支的流形预测值 + + return mnfld_pred_h, nonmnfld_pred_h, mnfld_grad, all_fi def position_loss(self, outputs): """ @@ -17,60 +65,126 @@ class LossManager: manifold_loss = (outputs.abs()).mean() # 计算流型损失 return manifold_loss - def normals_loss(self, cur_data: torch.Tensor, mnfld_pnts: torch.Tensor, all_fi: torch.Tensor, patch_sup: bool) -> torch.Tensor: + def normals_loss(self, normals: torch.Tensor, mnfld_pnts: torch.Tensor, all_fi: torch.Tensor, patch_sup: bool = True) -> torch.Tensor: """ 计算法线损失 - :param cur_data: 当前数据,包含法线信息 + :param normals: 法线 :param mnfld_pnts: 流型点 :param all_fi: 所有流型预测值 :param patch_sup: 是否支持补丁 :return: 计算得到的法线损失 """ - # 提取法线 - normals = cur_data[:, -self.d_in:] - + # NOTE 源代码 这里还有复杂逻辑 # 计算分支梯度 branch_grad = gradient(mnfld_pnts, all_fi[:, 0]) # 计算分支梯度 # 计算法线损失 normals_loss = (((branch_grad - normals).abs()).norm(2, dim=1)).mean() # 计算法线损失 - return self.normals_lambda * normals_loss # 返回加权后的法线损失 + return normals_loss # 返回加权后的法线损失 - def eikonal_loss(self, nonmnfld_pnts, nonmnfld_pred_all): + def eikonal_loss(self, nonmnfld_pnts, nonmnfld_pred): """ 计算Eikonal损失 """ grad_loss_h = torch.zeros(1).cuda() # 初始化 Eikonal 损失 - single_nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred_all[:,0]) # 计算非流形点的梯度 + single_nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred) # 计算非流形点的梯度 eikonal_loss = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean() # 计算 Eikonal 损失 return eikonal_loss - def offsurface_loss(self, nonmnfld_pnts, nonmnfld_pred_all): + def offsurface_loss(self, nonmnfld_pnts, nonmnfld_pred): """ Eo 惩罚远离表面但是预测值接近0的点 """ - offsurface_loss = torch.exp(-100.0 * torch.abs(nonmnfld_pred_all[:,0])).mean() # 计算离表面损失 + offsurface_loss = torch.zeros(1).cuda() + if not self.ablation == 'off': + offsurface_loss = torch.exp(-100.0 * torch.abs(nonmnfld_pred)).mean() # 计算离表面损失 return offsurface_loss - def consistency_loss(self, mnfld_pnts, mnfld_pred_all, all_fi): + def consistency_loss(self, mnfld_pnts, mnfld_pred, all_fi): """ 惩罚流形点预测值和非流形点预测值不一致的点 """ - mnfld_consistency_loss = (mnfld_pred - all_fi[:,0]).abs().mean() # 计算流形一致性损失 + mnfld_consistency_loss = torch.zeros(1).cuda() + if not (self.ablation == 'cons' or self.ablation == 'cc'): + mnfld_consistency_loss = (mnfld_pred - all_fi[:,0]).abs().mean() # 计算流形一致性损失 return mnfld_consistency_loss - def compute_loss(self, outputs): + def correction_loss(self, mnfld_pnts, mnfld_pred, all_fi, th_closeness = 1e-5, a_correction = 100): + """ + 修正损失 + """ + correction_loss = torch.zeros(1).cuda() # 初始化修正损失 + if not (self.ablation == 'cor' or self.ablation == 'cc'): + mismatch_id = torch.abs(mnfld_pred - all_fi[:,0]) > th_closeness # 计算不匹配的 ID + if mismatch_id.sum() != 0: # 如果存在不匹配 + correction_loss = (a_correction * torch.abs(mnfld_pred - all_fi[:,0])[mismatch_id]).mean() # 计算修正损失 + return correction_loss + + def compute_loss(self, mnfld_pnts, normals, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last): """ 计算流型损失的逻辑 :param outputs: 模型的输出 :return: 计算得到的流型损失值 """ - + mnfld_pred, nonmnfld_pred, mnfld_grad, all_fi = self.pre_process(mnfld_pnts, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last) + manifold_loss = torch.zeros(1).cuda() # 计算流型损失(这里使用均方误差作为示例) - manifold_loss = (outputs.abs()).mean() # 计算流型损失 - return manifold_loss + if not self.ablation == 'overall': + manifold_loss = (mnfld_pred.abs()).mean() # 计算流型损失 + ''' + if args.feature_sample: # 如果启用了特征采样 + feature_indices = torch.randperm(args.all_feature_sample)[:args.num_feature_sample].cuda() # 随机选择特征点 + feature_pnts = self.feature_data[feature_indices] # 获取特征点数据 + feature_mask_pair = self.feature_data_mask_pair[feature_indices] # 获取特征掩码对 + feature_pred_all = self.network(feature_pnts) # 进行前向传播,计算特征点的预测值 + feature_pred = feature_pred_all[:,0] # 提取特征预测结果 + feature_mnfld_loss = feature_pred.abs().mean() # 计算特征流形损失 + loss = loss + weight_mnfld_h * feature_mnfld_loss # 将特征流形损失加权到总损失中 + + # patch loss: + feature_id_left = [list(range(args.num_feature_sample)), feature_mask_pair[:,0].tolist()] # 获取左侧特征 ID + feature_id_right = [list(range(args.num_feature_sample)), feature_mask_pair[:,1].tolist()] # 获取右侧特征 ID + feature_fis_left = feature_pred_all[feature_id_left] # 获取左侧特征预测值 + feature_fis_right = feature_pred_all[feature_id_right] # 获取右侧特征预测值 + feature_loss_patch = feature_fis_left.abs().mean() + feature_fis_right.abs().mean() # 计算补丁损失 + loss += feature_loss_patch # 将补丁损失加权到总损失中 + + # consistency loss: + feature_loss_cons = (feature_fis_left - feature_pred).abs().mean() + (feature_fis_right - feature_pred).abs().mean() # 计算一致性损失 + ''' + manifold_loss_patch = torch.zeros(1).cuda() + if self.ablation == 'patch': + manifold_loss_patch = all_fi[:,0].abs().mean() + + # 计算法线损失 + normals_loss = self.normals_loss(normals, mnfld_pnts, all_fi, patch_sup=True) + + # 计算Eikonal损失 + eikonal_loss = self.eikonal_loss(nonmnfld_pnts, nonmnfld_pred_all) + + # 计算离表面损失 + offsurface_loss = self.offsurface_loss(nonmnfld_pnts, nonmnfld_pred_all) + + # 计算一致性损失 + consistency_loss = self.consistency_loss(mnfld_pnts, mnfld_pred, all_fi) + + # 计算修正损失 + correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi) + + # 计算总损失 + total_loss = (self.weights["manifold"] * manifold_loss + \ + #self.weights["feature_manifold"] * feature_manifold_loss + \ + manifold_loss_patch + \ + self.weights["normals"] * normals_loss + \ + self.weights["eikonal"] * eikonal_loss + \ + self.weights["offsurface"] * offsurface_loss + \ + self.weights["consistency"] * consistency_loss + \ + self.weights["correction"] * correction_loss) + return total_loss + + diff --git a/code/conversion/train.py b/code/conversion/train.py index a46a4fd..2a28498 100644 --- a/code/conversion/train.py +++ b/code/conversion/train.py @@ -9,68 +9,213 @@ os.chdir(project_dir) import torch import numpy as np from torch.utils.tensorboard import SummaryWriter -from torch.optim.lr_scheduler import StepLR from tqdm import tqdm +from pyhocon import ConfigFactory +from scipy.spatial import cKDTree + from utils.logger import logger +from utils.general import get_class from data_loader import NHREP_Dataset from loss import LossManager +from learning_rate import LearningRateScheduler from model.network import NHRepNet # 导入 NHRepNet +from model.sample import Sampler class NHREPNet_Training: def __init__(self, data_dir, name_prefix: str, if_baseline: bool = False, if_feature_sample: bool = False): + self.conf = ConfigFactory.parse_file('./conversion/setup.conf') + self.sampler = Sampler.get_sampler( + self.conf.get_string('network.sampler.sampler_type'))( + global_sigma=self.conf.get_float('network.sampler.properties.global_sigma'), + local_sigma=self.conf.get_float('network.sampler.properties.local_sigma') + ) self.dataset = NHREP_Dataset(data_dir, name_prefix, if_baseline, if_feature_sample) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - + # 初始化模型 - d_in = 6 # 输入维度,例如 3D 坐标 - dims_sdf = [256, 256, 256] # 隐藏层维度 - csg_tree, _ = self.dataset.get_csg_tree() - self.loss_manager = LossManager() - self.model = NHRepNet(d_in, dims_sdf, csg_tree).to(self.device) # 实例化模型并移动到设备 - - self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001) # Adam 优化器 - self.scheduler = StepLR(self.optimizer, step_size=1000, gamma=0.1) # 学习率调度器 + self.d_in = 3 # 输入维度,x, y, z. + self.dims_sdf = [256, 256, 256] # 隐藏层维度 + + + self.nepochs = 15000 # 训练轮数 self.writer = SummaryWriter() # TensorBoard 记录器 def run_nhrepnet_training(self): + # 数据准备 + logger.info("数据准备") + self.data = self.dataset.get_data().to(self.device).requires_grad_() # x, y, z, nx, ny, nz + feature_mask_cpu = self.dataset.get_feature_mask().numpy() # 特征掩码 + self.feature_mask = torch.from_numpy(feature_mask_cpu).to(self.device) # 特征掩码 # 特征掩码 + self.points_batch = 16384 # 批次大小 + + + n_branch = int(torch.max(self.feature_mask).item()) # 计算分支数量 + n_batchsize = self.points_batch # 设置批次大小 + n_patch_batch = n_batchsize // n_branch # 计算每个分支的补丁批次大小 + n_patch_last = n_batchsize - n_patch_batch * (n_branch - 1) # 计算最后一个分支的补丁大小 + # 1,准备训练数据 + # 1.1,计算每个分支的补丁数量 + patch_id, patch_id_n = self.compute_patch(n_branch, n_patch_batch, n_patch_last, feature_mask_cpu) + + # 1.2,获取分支掩码 + branch_mask, single_branch_mask_gt, single_branch_mask_id = self.get_branch_mask(n_branch, n_patch_batch, n_patch_last) + + + # 1.3,初始化模型 + csg_tree, flag_convex = self.dataset.get_csg_tree() + self.model = get_class(self.conf.get_string('train.network_class'))( + d_in=self.d_in, + n_branch=n_branch, + csg_tree=csg_tree, + flag_convex=flag_convex, + **self.conf.get_config('network.inputs') + ).to(self.device) + self.scheduler = LearningRateScheduler(self.conf.get_list('train.learning_rate_schedule'), self.conf.get_float('train.weight_decay'), self.model.parameters()) + self.loss_manager = LossManager(ablation="none") + + logger.info("开始训练") self.model.train() # 设置模型为训练模式 for epoch in range(self.nepochs): # 开始训练循环 try: - self.train_one_epoch(epoch) - self.scheduler.step() # 更新学习率 + self.train_one_epoch(epoch, patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch, n_batchsize) except Exception as e: logger.error(f"训练过程中发生错误: {str(e)}") break - def train_one_epoch(self, epoch): + def train_one_epoch(self, epoch, patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch, n_batchsize): logger.info(f"Epoch {epoch}/{self.nepochs} 开始") - total_loss = 0.0 + # 1.3,获取索引 + indices = self.get_indices(patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch) + + # 1.4,获取数据 + cur_data = self.data[indices] # x, y, z, nx, ny, nz + mnfld_pnts = cur_data[:, :self.d_in] # 提取流形点 + self.compute_local_sigma() + mnfld_sigma = self.local_sigma[indices] # 提取噪声点 - # 获取输入数据 - input_data = self.dataset.get_data().to(self.device) # 获取数据并移动到设备 - logger.info(f"输入数据: {input_data.shape}") + nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0), mnfld_sigma.unsqueeze(0)).squeeze() # 生成非流形点 + + #TODO 记录了log + + # 2,前向传播 + + self.scheduler.adjust_learning_rate(epoch) + + #logger.info(f"mnfld_pnts: {mnfld_pnts.shape}") + #logger.info(f"nonmnfld_pnts: {nonmnfld_pnts.shape}") # 前向传播 - outputs = self.model(input_data) # 使用模型进行前向传播 - logger.info(f"输出数据: {outputs.shape}") - + mnfld_pred_all = self.model(mnfld_pnts) # 使用模型进行前向传播 + nonmnfld_pred_all = self.model(nonmnfld_pnts) # 使用模型进行前向传播 + + #logger.info(f"mnfld_pred_all: {mnfld_pred_all.shape}") + #logger.info(f"nonmnfld_pred_all: {nonmnfld_pred_all.shape}") + + normals = cur_data[:, -self.d_in:] # 计算损失 - loss = self.loss_manager.compute_loss(outputs) # 计算损失 - total_loss += loss.item() + loss = self.loss_manager.compute_loss( + mnfld_pnts = mnfld_pnts, + normals = normals, + mnfld_pred_all = mnfld_pred_all, + nonmnfld_pnts = nonmnfld_pnts, + nonmnfld_pred_all = nonmnfld_pred_all, + n_batchsize = n_batchsize, + n_branch = n_branch, + n_patch_batch = n_patch_batch, + n_patch_last = n_patch_last, + ) # 计算损失 + + self.scheduler.step(loss) # 反向传播 - self.optimizer.zero_grad() # 清空梯度 + self.scheduler.optimizer.zero_grad() # 清空梯度 loss.backward() # 反向传播 - self.optimizer.step() # 更新参数 + self.scheduler.optimizer.step() # 更新参数 - avg_loss = total_loss + avg_loss = loss.item() logger.info(f'Epoch [{epoch}/{self.nepochs}], Average Loss: {avg_loss:.4f}') self.writer.add_scalar('Loss/train', avg_loss, epoch) # 记录损失到 TensorBoard +#============================ 前向传播 数据准备 ============================ + def compute_patch(self, n_branch, n_patch_batch, n_patch_last, feature_mask_cpu): + ''' + 计算每个分支的补丁数量 + ''' + patch_id = [] + patch_id_n = [] + for i in range(n_branch): + patch_id = patch_id + [np.where(feature_mask_cpu == i + 1)[0]] + patch_id_n = patch_id_n + [patch_id[i].shape[0]] + return patch_id, patch_id_n + + def get_branch_mask(self, n_branch, n_patch_batch, n_patch_last): + ''' + branch_mask: 分支掩码,用于表示每个分支在每个批次中的掩码。每一行对应一个分支,每一列对应一个样本。用于表示每个分支的补丁是否被选中。 + single_branch_mask_gt: 单分支掩码,用于表示每个补丁属于哪个分支。每一行对应一个样本,每一列对应一个分支。用于表示每个补丁属于哪个分支。 + single_branch_mask_id: 单分支 ID,用于表示每个补丁属于哪个分支。 + 作用: + ''' + branch_mask = torch.zeros(n_branch, n_patch_batch).cuda() + single_branch_mask_gt = torch.zeros(n_patch_batch, n_branch).cuda() + single_branch_mask_id = torch.zeros([n_patch_batch], dtype=torch.long).cuda() + for i in range(n_branch - 1): + branch_mask[i, i * n_patch_batch : (i + 1) * n_patch_batch] = 1.0 + single_branch_mask_gt[i * n_patch_batch : (i + 1) * n_patch_batch, i] = 1.0 + single_branch_mask_id[i * n_patch_batch : (i + 1) * n_patch_batch] = i + branch_mask[n_branch - 1, (n_branch - 1) * n_patch_batch:] = 1.0 + single_branch_mask_gt[(n_branch - 1) * n_patch_batch:, (n_branch - 1)] = 1.0 + single_branch_mask_id[(n_branch - 1) * n_patch_batch:] = (n_branch - 1) + return branch_mask, single_branch_mask_gt, single_branch_mask_id + + def get_indices(self, patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch): + indices = torch.empty(0, dtype=torch.int64).cuda() + for i in range(n_branch - 1): + indices_nonfeature = torch.tensor(patch_id[i][np.random.choice(patch_id_n[i], n_patch_batch, True)]).cuda() + indices = torch.cat((indices, indices_nonfeature), 0) + # last patch + indices_nonfeature = torch.tensor(patch_id[n_branch - 1][np.random.choice(patch_id_n[n_branch - 1], n_patch_last, True)]).cuda() + indices = torch.cat((indices, indices_nonfeature), 0) + return indices + + def compute_local_sigma(self): + """计算局部sigma值""" + try: + sigma_set = [] + data_cpu = self.data.detach().cpu().numpy() + ptree = cKDTree(data_cpu) + logger.debug("KD tree constructed") + + for p in np.array_split(data_cpu, 100, axis=0): + d = ptree.query(p, 50 + 1) + sigma_set.append(d[0][:, -1]) + + sigmas = np.concatenate(sigma_set) + self.local_sigma = torch.from_numpy(sigmas).float().cuda() + except Exception as e: + logger.error(f"Error computing local sigma: {str(e)}") + raise + + + + +#============================ 保存模型 ============================ + def save_checkpoints(self, epoch): + torch.save( + {"epoch": epoch, "model_state_dict": self.network.state_dict()}, + os.path.join(self.checkpoints_path, self.model_params_subdir, str(epoch) + ".pth")) + torch.save( + {"epoch": epoch, "model_state_dict": self.network.state_dict()}, + os.path.join(self.checkpoints_path, self.model_params_subdir, "latest.pth")) + torch.save( + {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()}, + os.path.join(self.checkpoints_path, self.optimizer_params_subdir, str(epoch) + ".pth")) + torch.save( + {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()}, + os.path.join(self.checkpoints_path, self.optimizer_params_subdir, "latest.pth")) if __name__ == "__main__": data_dir = '../data/input_data' # 数据目录 diff --git a/code/utils/logger.py b/code/utils/logger.py index 8c3ff91..37fc418 100644 --- a/code/utils/logger.py +++ b/code/utils/logger.py @@ -138,7 +138,7 @@ class Logger: """警告信息""" self._log(logging.WARNING, msg) - def error(self, msg, include_trace=False): + def error(self, msg, include_trace=True): """错误信息""" self._log(logging.ERROR, msg, exc_info=include_trace)