Browse Source

NOTE it's base

NH-Rep
mckay 3 weeks ago
parent
commit
d42b3b46dd
  1. 113
      code/conversion/learning_rate.py
  2. 152
      code/conversion/loss.py
  3. 193
      code/conversion/train.py
  4. 2
      code/utils/logger.py

113
code/conversion/learning_rate.py

@ -0,0 +1,113 @@
import torch
import torch.optim as optim
import numpy as np
from utils.logger import logger
class LearningRateSchedule:
def get_learning_rate(self, epoch):
pass
class StepLearningRateSchedule(LearningRateSchedule):
def __init__(self, initial, interval, factor):
"""
初始化步进学习率调度器
:param initial_lr: 初始学习率
:param interval: 衰减间隔
:param factor: 衰减因子
"""
self.initial = initial
self.interval = interval
self.factor = factor
def get_learning_rate(self, epoch):
"""
获取当前学习率
:param epoch: 当前训练周期
:return: 当前学习率
"""
return np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6)
class LearningRateScheduler:
def __init__(self, lr_schedules, weight_decay, network_params):
try:
self.lr_schedules = self.get_learning_rate_schedules(lr_schedules)
self.weight_decay = weight_decay
self.startepoch = 0
self.optimizer = torch.optim.Adam([{
"params": network_params,
"lr": self.lr_schedules[0].get_learning_rate(0),
"weight_decay": self.weight_decay
}])
self.best_loss = float('inf')
self.patience = 10
self.decay_factor = 0.5
initial_lr = self.lr_schedules[0].get_learning_rate(0)
self.lr = initial_lr
self.epochs_since_improvement = 0
except Exception as e:
logger.error(f"Error setting up optimizer: {str(e)}")
raise
def step(self, current_loss):
"""
更新学习率
:param current_loss: 当前验证损失
"""
if current_loss < self.best_loss:
self.best_loss = current_loss
self.epochs_since_improvement = 0
else:
self.epochs_since_improvement += 1
if self.epochs_since_improvement >= self.patience:
self.lr *= self.decay_factor
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.lr
print(f"学习率更新为: {self.lr:.6f}")
self.epochs_since_improvement = 0
def reset(self):
"""
重置学习率为初始值
"""
self.lr = self.initial_lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.lr
@staticmethod
def get_learning_rate_schedules(schedule_specs):
"""
获取学习率调度策略
:param schedule_specs: 学习率调度配置
:return: 学习率调度列表
"""
schedules = []
for spec in schedule_specs:
if spec["Type"] == "Step":
schedules.append(
StepLearningRateSchedule(
spec["Initial"],
spec["Interval"],
spec["Factor"],
)
)
else:
raise Exception(
'no known learning rate schedule of type "{}"'.format(
spec["Type"]
)
)
return schedules
def adjust_learning_rate(self, epoch):
"""
根据当前周期调整学习率
:param epoch: 当前训练周期
"""
for i, param_group in enumerate(self.optimizer.param_groups):
param_group["lr"] = self.lr_schedules[i].get_learning_rate(epoch) # 使用当前学习率更新优化器的学习率

152
code/conversion/loss.py

@ -1,9 +1,57 @@
import torch import os
import sys
import time
project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
sys.path.append(project_dir)
os.chdir(project_dir)
import torch
from model.network import gradient
class LossManager: class LossManager:
def __init__(self): def __init__(self, ablation, **condition_kwargs):
pass self.weights = {
"manifold": 1,
"feature_manifold": 1, # 原文里面和manifold的权重是一样的
"normals": 1,
"eikonal": 1,
"offsurface": 1,
"consistency": 1,
"correction": 1,
}
self.condition_kwargs = condition_kwargs
self.ablation = ablation # 消融实验用
def _get_condition_kwargs(self, key):
"""
获取条件参数, 期望
ab: 损失类型 overall, patch, off, cons, cc, cor,
siren: 是否使用SIREN
epoch: 当前epoch
baseline: 是否为baseline
"""
if key in self.condition_kwargs:
return self.condition_kwargs[key]
else:
raise ValueError(f"Key {key} not found in condition_kwargs")
def pre_process(self, mnfld_pnts, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last):
"""
预处理
"""
mnfld_pred_h = mnfld_pred_all[:,0] # 提取流形预测结果
nonmnfld_pred_h = nonmnfld_pred_all[:,0] # 提取非流形预测结果
mnfld_grad = gradient(mnfld_pnts, mnfld_pred_h) # 计算流形点的梯度
all_fi = torch.zeros([n_batchsize, 1], device = 'cuda') # 初始化所有流形预测值
for i in range(n_branch - 1):
all_fi[i * n_patch_batch : (i + 1) * n_patch_batch, 0] = mnfld_pred_all[i * n_patch_batch : (i + 1) * n_patch_batch, i + 1] # 填充流形预测值
# last patch
all_fi[(n_branch - 1) * n_patch_batch:, 0] = mnfld_pred_all[(n_branch - 1) * n_patch_batch:, n_branch] # 填充最后一个分支的流形预测值
return mnfld_pred_h, nonmnfld_pred_h, mnfld_grad, all_fi
def position_loss(self, outputs): def position_loss(self, outputs):
""" """
@ -17,60 +65,126 @@ class LossManager:
manifold_loss = (outputs.abs()).mean() # 计算流型损失 manifold_loss = (outputs.abs()).mean() # 计算流型损失
return manifold_loss return manifold_loss
def normals_loss(self, cur_data: torch.Tensor, mnfld_pnts: torch.Tensor, all_fi: torch.Tensor, patch_sup: bool) -> torch.Tensor: def normals_loss(self, normals: torch.Tensor, mnfld_pnts: torch.Tensor, all_fi: torch.Tensor, patch_sup: bool = True) -> torch.Tensor:
""" """
计算法线损失 计算法线损失
:param cur_data: 当前数据包含法线信息 :param normals: 法线
:param mnfld_pnts: 流型点 :param mnfld_pnts: 流型点
:param all_fi: 所有流型预测值 :param all_fi: 所有流型预测值
:param patch_sup: 是否支持补丁 :param patch_sup: 是否支持补丁
:return: 计算得到的法线损失 :return: 计算得到的法线损失
""" """
# 提取法线 # NOTE 源代码 这里还有复杂逻辑
normals = cur_data[:, -self.d_in:]
# 计算分支梯度 # 计算分支梯度
branch_grad = gradient(mnfld_pnts, all_fi[:, 0]) # 计算分支梯度 branch_grad = gradient(mnfld_pnts, all_fi[:, 0]) # 计算分支梯度
# 计算法线损失 # 计算法线损失
normals_loss = (((branch_grad - normals).abs()).norm(2, dim=1)).mean() # 计算法线损失 normals_loss = (((branch_grad - normals).abs()).norm(2, dim=1)).mean() # 计算法线损失
return self.normals_lambda * normals_loss # 返回加权后的法线损失 return normals_loss # 返回加权后的法线损失
def eikonal_loss(self, nonmnfld_pnts, nonmnfld_pred_all): def eikonal_loss(self, nonmnfld_pnts, nonmnfld_pred):
""" """
计算Eikonal损失 计算Eikonal损失
""" """
grad_loss_h = torch.zeros(1).cuda() # 初始化 Eikonal 损失 grad_loss_h = torch.zeros(1).cuda() # 初始化 Eikonal 损失
single_nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred_all[:,0]) # 计算非流形点的梯度 single_nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred) # 计算非流形点的梯度
eikonal_loss = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean() # 计算 Eikonal 损失 eikonal_loss = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean() # 计算 Eikonal 损失
return eikonal_loss return eikonal_loss
def offsurface_loss(self, nonmnfld_pnts, nonmnfld_pred_all): def offsurface_loss(self, nonmnfld_pnts, nonmnfld_pred):
""" """
Eo Eo
惩罚远离表面但是预测值接近0的点 惩罚远离表面但是预测值接近0的点
""" """
offsurface_loss = torch.exp(-100.0 * torch.abs(nonmnfld_pred_all[:,0])).mean() # 计算离表面损失 offsurface_loss = torch.zeros(1).cuda()
if not self.ablation == 'off':
offsurface_loss = torch.exp(-100.0 * torch.abs(nonmnfld_pred)).mean() # 计算离表面损失
return offsurface_loss return offsurface_loss
def consistency_loss(self, mnfld_pnts, mnfld_pred_all, all_fi): def consistency_loss(self, mnfld_pnts, mnfld_pred, all_fi):
""" """
惩罚流形点预测值和非流形点预测值不一致的点 惩罚流形点预测值和非流形点预测值不一致的点
""" """
mnfld_consistency_loss = (mnfld_pred - all_fi[:,0]).abs().mean() # 计算流形一致性损失 mnfld_consistency_loss = torch.zeros(1).cuda()
if not (self.ablation == 'cons' or self.ablation == 'cc'):
mnfld_consistency_loss = (mnfld_pred - all_fi[:,0]).abs().mean() # 计算流形一致性损失
return mnfld_consistency_loss return mnfld_consistency_loss
def compute_loss(self, outputs): def correction_loss(self, mnfld_pnts, mnfld_pred, all_fi, th_closeness = 1e-5, a_correction = 100):
"""
修正损失
"""
correction_loss = torch.zeros(1).cuda() # 初始化修正损失
if not (self.ablation == 'cor' or self.ablation == 'cc'):
mismatch_id = torch.abs(mnfld_pred - all_fi[:,0]) > th_closeness # 计算不匹配的 ID
if mismatch_id.sum() != 0: # 如果存在不匹配
correction_loss = (a_correction * torch.abs(mnfld_pred - all_fi[:,0])[mismatch_id]).mean() # 计算修正损失
return correction_loss
def compute_loss(self, mnfld_pnts, normals, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last):
""" """
计算流型损失的逻辑 计算流型损失的逻辑
:param outputs: 模型的输出 :param outputs: 模型的输出
:return: 计算得到的流型损失值 :return: 计算得到的流型损失值
""" """
mnfld_pred, nonmnfld_pred, mnfld_grad, all_fi = self.pre_process(mnfld_pnts, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last)
manifold_loss = torch.zeros(1).cuda()
# 计算流型损失(这里使用均方误差作为示例) # 计算流型损失(这里使用均方误差作为示例)
manifold_loss = (outputs.abs()).mean() # 计算流型损失 if not self.ablation == 'overall':
return manifold_loss manifold_loss = (mnfld_pred.abs()).mean() # 计算流型损失
'''
if args.feature_sample: # 如果启用了特征采样
feature_indices = torch.randperm(args.all_feature_sample)[:args.num_feature_sample].cuda() # 随机选择特征点
feature_pnts = self.feature_data[feature_indices] # 获取特征点数据
feature_mask_pair = self.feature_data_mask_pair[feature_indices] # 获取特征掩码对
feature_pred_all = self.network(feature_pnts) # 进行前向传播,计算特征点的预测值
feature_pred = feature_pred_all[:,0] # 提取特征预测结果
feature_mnfld_loss = feature_pred.abs().mean() # 计算特征流形损失
loss = loss + weight_mnfld_h * feature_mnfld_loss # 将特征流形损失加权到总损失中
# patch loss:
feature_id_left = [list(range(args.num_feature_sample)), feature_mask_pair[:,0].tolist()] # 获取左侧特征 ID
feature_id_right = [list(range(args.num_feature_sample)), feature_mask_pair[:,1].tolist()] # 获取右侧特征 ID
feature_fis_left = feature_pred_all[feature_id_left] # 获取左侧特征预测值
feature_fis_right = feature_pred_all[feature_id_right] # 获取右侧特征预测值
feature_loss_patch = feature_fis_left.abs().mean() + feature_fis_right.abs().mean() # 计算补丁损失
loss += feature_loss_patch # 将补丁损失加权到总损失中
# consistency loss:
feature_loss_cons = (feature_fis_left - feature_pred).abs().mean() + (feature_fis_right - feature_pred).abs().mean() # 计算一致性损失
'''
manifold_loss_patch = torch.zeros(1).cuda()
if self.ablation == 'patch':
manifold_loss_patch = all_fi[:,0].abs().mean()
# 计算法线损失
normals_loss = self.normals_loss(normals, mnfld_pnts, all_fi, patch_sup=True)
# 计算Eikonal损失
eikonal_loss = self.eikonal_loss(nonmnfld_pnts, nonmnfld_pred_all)
# 计算离表面损失
offsurface_loss = self.offsurface_loss(nonmnfld_pnts, nonmnfld_pred_all)
# 计算一致性损失
consistency_loss = self.consistency_loss(mnfld_pnts, mnfld_pred, all_fi)
# 计算修正损失
correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi)
# 计算总损失
total_loss = (self.weights["manifold"] * manifold_loss + \
#self.weights["feature_manifold"] * feature_manifold_loss + \
manifold_loss_patch + \
self.weights["normals"] * normals_loss + \
self.weights["eikonal"] * eikonal_loss + \
self.weights["offsurface"] * offsurface_loss + \
self.weights["consistency"] * consistency_loss + \
self.weights["correction"] * correction_loss)
return total_loss

193
code/conversion/train.py

@ -9,68 +9,213 @@ os.chdir(project_dir)
import torch import torch
import numpy as np import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm from tqdm import tqdm
from pyhocon import ConfigFactory
from scipy.spatial import cKDTree
from utils.logger import logger from utils.logger import logger
from utils.general import get_class
from data_loader import NHREP_Dataset from data_loader import NHREP_Dataset
from loss import LossManager from loss import LossManager
from learning_rate import LearningRateScheduler
from model.network import NHRepNet # 导入 NHRepNet from model.network import NHRepNet # 导入 NHRepNet
from model.sample import Sampler
class NHREPNet_Training: class NHREPNet_Training:
def __init__(self, data_dir, name_prefix: str, if_baseline: bool = False, if_feature_sample: bool = False): def __init__(self, data_dir, name_prefix: str, if_baseline: bool = False, if_feature_sample: bool = False):
self.conf = ConfigFactory.parse_file('./conversion/setup.conf')
self.sampler = Sampler.get_sampler(
self.conf.get_string('network.sampler.sampler_type'))(
global_sigma=self.conf.get_float('network.sampler.properties.global_sigma'),
local_sigma=self.conf.get_float('network.sampler.properties.local_sigma')
)
self.dataset = NHREP_Dataset(data_dir, name_prefix, if_baseline, if_feature_sample) self.dataset = NHREP_Dataset(data_dir, name_prefix, if_baseline, if_feature_sample)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 初始化模型 # 初始化模型
d_in = 6 # 输入维度,例如 3D 坐标 self.d_in = 3 # 输入维度,x, y, z.
dims_sdf = [256, 256, 256] # 隐藏层维度 self.dims_sdf = [256, 256, 256] # 隐藏层维度
csg_tree, _ = self.dataset.get_csg_tree()
self.loss_manager = LossManager()
self.model = NHRepNet(d_in, dims_sdf, csg_tree).to(self.device) # 实例化模型并移动到设备
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001) # Adam 优化器
self.scheduler = StepLR(self.optimizer, step_size=1000, gamma=0.1) # 学习率调度器
self.nepochs = 15000 # 训练轮数 self.nepochs = 15000 # 训练轮数
self.writer = SummaryWriter() # TensorBoard 记录器 self.writer = SummaryWriter() # TensorBoard 记录器
def run_nhrepnet_training(self): def run_nhrepnet_training(self):
# 数据准备
logger.info("数据准备")
self.data = self.dataset.get_data().to(self.device).requires_grad_() # x, y, z, nx, ny, nz
feature_mask_cpu = self.dataset.get_feature_mask().numpy() # 特征掩码
self.feature_mask = torch.from_numpy(feature_mask_cpu).to(self.device) # 特征掩码 # 特征掩码
self.points_batch = 16384 # 批次大小
n_branch = int(torch.max(self.feature_mask).item()) # 计算分支数量
n_batchsize = self.points_batch # 设置批次大小
n_patch_batch = n_batchsize // n_branch # 计算每个分支的补丁批次大小
n_patch_last = n_batchsize - n_patch_batch * (n_branch - 1) # 计算最后一个分支的补丁大小
# 1,准备训练数据
# 1.1,计算每个分支的补丁数量
patch_id, patch_id_n = self.compute_patch(n_branch, n_patch_batch, n_patch_last, feature_mask_cpu)
# 1.2,获取分支掩码
branch_mask, single_branch_mask_gt, single_branch_mask_id = self.get_branch_mask(n_branch, n_patch_batch, n_patch_last)
# 1.3,初始化模型
csg_tree, flag_convex = self.dataset.get_csg_tree()
self.model = get_class(self.conf.get_string('train.network_class'))(
d_in=self.d_in,
n_branch=n_branch,
csg_tree=csg_tree,
flag_convex=flag_convex,
**self.conf.get_config('network.inputs')
).to(self.device)
self.scheduler = LearningRateScheduler(self.conf.get_list('train.learning_rate_schedule'), self.conf.get_float('train.weight_decay'), self.model.parameters())
self.loss_manager = LossManager(ablation="none")
logger.info("开始训练") logger.info("开始训练")
self.model.train() # 设置模型为训练模式 self.model.train() # 设置模型为训练模式
for epoch in range(self.nepochs): # 开始训练循环 for epoch in range(self.nepochs): # 开始训练循环
try: try:
self.train_one_epoch(epoch) self.train_one_epoch(epoch, patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch, n_batchsize)
self.scheduler.step() # 更新学习率
except Exception as e: except Exception as e:
logger.error(f"训练过程中发生错误: {str(e)}") logger.error(f"训练过程中发生错误: {str(e)}")
break break
def train_one_epoch(self, epoch): def train_one_epoch(self, epoch, patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch, n_batchsize):
logger.info(f"Epoch {epoch}/{self.nepochs} 开始") logger.info(f"Epoch {epoch}/{self.nepochs} 开始")
total_loss = 0.0 # 1.3,获取索引
indices = self.get_indices(patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch)
# 1.4,获取数据
cur_data = self.data[indices] # x, y, z, nx, ny, nz
mnfld_pnts = cur_data[:, :self.d_in] # 提取流形点
self.compute_local_sigma()
mnfld_sigma = self.local_sigma[indices] # 提取噪声点
# 获取输入数据 nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0), mnfld_sigma.unsqueeze(0)).squeeze() # 生成非流形点
input_data = self.dataset.get_data().to(self.device) # 获取数据并移动到设备
logger.info(f"输入数据: {input_data.shape}") #TODO 记录了log
# 2,前向传播
self.scheduler.adjust_learning_rate(epoch)
#logger.info(f"mnfld_pnts: {mnfld_pnts.shape}")
#logger.info(f"nonmnfld_pnts: {nonmnfld_pnts.shape}")
# 前向传播 # 前向传播
outputs = self.model(input_data) # 使用模型进行前向传播 mnfld_pred_all = self.model(mnfld_pnts) # 使用模型进行前向传播
logger.info(f"输出数据: {outputs.shape}") nonmnfld_pred_all = self.model(nonmnfld_pnts) # 使用模型进行前向传播
#logger.info(f"mnfld_pred_all: {mnfld_pred_all.shape}")
#logger.info(f"nonmnfld_pred_all: {nonmnfld_pred_all.shape}")
normals = cur_data[:, -self.d_in:]
# 计算损失 # 计算损失
loss = self.loss_manager.compute_loss(outputs) # 计算损失 loss = self.loss_manager.compute_loss(
total_loss += loss.item() mnfld_pnts = mnfld_pnts,
normals = normals,
mnfld_pred_all = mnfld_pred_all,
nonmnfld_pnts = nonmnfld_pnts,
nonmnfld_pred_all = nonmnfld_pred_all,
n_batchsize = n_batchsize,
n_branch = n_branch,
n_patch_batch = n_patch_batch,
n_patch_last = n_patch_last,
) # 计算损失
self.scheduler.step(loss)
# 反向传播 # 反向传播
self.optimizer.zero_grad() # 清空梯度 self.scheduler.optimizer.zero_grad() # 清空梯度
loss.backward() # 反向传播 loss.backward() # 反向传播
self.optimizer.step() # 更新参数 self.scheduler.optimizer.step() # 更新参数
avg_loss = total_loss avg_loss = loss.item()
logger.info(f'Epoch [{epoch}/{self.nepochs}], Average Loss: {avg_loss:.4f}') logger.info(f'Epoch [{epoch}/{self.nepochs}], Average Loss: {avg_loss:.4f}')
self.writer.add_scalar('Loss/train', avg_loss, epoch) # 记录损失到 TensorBoard self.writer.add_scalar('Loss/train', avg_loss, epoch) # 记录损失到 TensorBoard
#============================ 前向传播 数据准备 ============================
def compute_patch(self, n_branch, n_patch_batch, n_patch_last, feature_mask_cpu):
'''
计算每个分支的补丁数量
'''
patch_id = []
patch_id_n = []
for i in range(n_branch):
patch_id = patch_id + [np.where(feature_mask_cpu == i + 1)[0]]
patch_id_n = patch_id_n + [patch_id[i].shape[0]]
return patch_id, patch_id_n
def get_branch_mask(self, n_branch, n_patch_batch, n_patch_last):
'''
branch_mask: 分支掩码用于表示每个分支在每个批次中的掩码每一行对应一个分支每一列对应一个样本用于表示每个分支的补丁是否被选中
single_branch_mask_gt: 单分支掩码用于表示每个补丁属于哪个分支每一行对应一个样本每一列对应一个分支用于表示每个补丁属于哪个分支
single_branch_mask_id: 单分支 ID用于表示每个补丁属于哪个分支
作用
'''
branch_mask = torch.zeros(n_branch, n_patch_batch).cuda()
single_branch_mask_gt = torch.zeros(n_patch_batch, n_branch).cuda()
single_branch_mask_id = torch.zeros([n_patch_batch], dtype=torch.long).cuda()
for i in range(n_branch - 1):
branch_mask[i, i * n_patch_batch : (i + 1) * n_patch_batch] = 1.0
single_branch_mask_gt[i * n_patch_batch : (i + 1) * n_patch_batch, i] = 1.0
single_branch_mask_id[i * n_patch_batch : (i + 1) * n_patch_batch] = i
branch_mask[n_branch - 1, (n_branch - 1) * n_patch_batch:] = 1.0
single_branch_mask_gt[(n_branch - 1) * n_patch_batch:, (n_branch - 1)] = 1.0
single_branch_mask_id[(n_branch - 1) * n_patch_batch:] = (n_branch - 1)
return branch_mask, single_branch_mask_gt, single_branch_mask_id
def get_indices(self, patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch):
indices = torch.empty(0, dtype=torch.int64).cuda()
for i in range(n_branch - 1):
indices_nonfeature = torch.tensor(patch_id[i][np.random.choice(patch_id_n[i], n_patch_batch, True)]).cuda()
indices = torch.cat((indices, indices_nonfeature), 0)
# last patch
indices_nonfeature = torch.tensor(patch_id[n_branch - 1][np.random.choice(patch_id_n[n_branch - 1], n_patch_last, True)]).cuda()
indices = torch.cat((indices, indices_nonfeature), 0)
return indices
def compute_local_sigma(self):
"""计算局部sigma值"""
try:
sigma_set = []
data_cpu = self.data.detach().cpu().numpy()
ptree = cKDTree(data_cpu)
logger.debug("KD tree constructed")
for p in np.array_split(data_cpu, 100, axis=0):
d = ptree.query(p, 50 + 1)
sigma_set.append(d[0][:, -1])
sigmas = np.concatenate(sigma_set)
self.local_sigma = torch.from_numpy(sigmas).float().cuda()
except Exception as e:
logger.error(f"Error computing local sigma: {str(e)}")
raise
#============================ 保存模型 ============================
def save_checkpoints(self, epoch):
torch.save(
{"epoch": epoch, "model_state_dict": self.network.state_dict()},
os.path.join(self.checkpoints_path, self.model_params_subdir, str(epoch) + ".pth"))
torch.save(
{"epoch": epoch, "model_state_dict": self.network.state_dict()},
os.path.join(self.checkpoints_path, self.model_params_subdir, "latest.pth"))
torch.save(
{"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()},
os.path.join(self.checkpoints_path, self.optimizer_params_subdir, str(epoch) + ".pth"))
torch.save(
{"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()},
os.path.join(self.checkpoints_path, self.optimizer_params_subdir, "latest.pth"))
if __name__ == "__main__": if __name__ == "__main__":
data_dir = '../data/input_data' # 数据目录 data_dir = '../data/input_data' # 数据目录

2
code/utils/logger.py

@ -138,7 +138,7 @@ class Logger:
"""警告信息""" """警告信息"""
self._log(logging.WARNING, msg) self._log(logging.WARNING, msg)
def error(self, msg, include_trace=False): def error(self, msg, include_trace=True):
"""错误信息""" """错误信息"""
self._log(logging.ERROR, msg, exc_info=include_trace) self._log(logging.ERROR, msg, exc_info=include_trace)

Loading…
Cancel
Save