You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

41 lines
1.9 KiB

import torch
import numpy as np
import os
from torch.utils.tensorboard import SummaryWriter
from utils.logger import logger
from utils.general import gradient
class NHREPNet_Training:
def run_nhrepnet_training(self):
print("running") # 输出训练开始的提示信息
self.data = self.data.cuda() # 将数据移动到 GPU 上
self.data.requires_grad_() # 设置数据以便计算梯度
self.feature_mask = self.feature_mask.cuda() # 将特征掩码移动到 GPU 上
n_branch = int(torch.max(self.feature_mask).item()) # 计算分支数量
n_batchsize = self.points_batch # 设置批次大小
n_patch_batch = n_batchsize // n_branch # 计算每个分支的补丁批次大小
# 初始化补丁 ID 列表
patch_id = [np.where(self.feature_mask.cpu().numpy() == i + 1)[0] for i in range(n_branch)]
for epoch in range(15000): # 开始训练循环
indices = torch.cat([torch.tensor(patch_id[i][np.random.choice(len(patch_id[i]), n_patch_batch, replace=True)]).cuda() for i in range(n_branch)]).cuda() # 随机选择补丁的索引
cur_data = self.data[indices] # 根据索引获取当前数据
mnfld_pnts = cur_data[:, :self.d_in] # 提取流形点
# 前向传播
mnfld_pred_all = self.network(mnfld_pnts) # 进行前向传播,计算流形点的预测值
mnfld_pred = mnfld_pred_all[:, 0] # 提取流形预测结果
# 计算损失
loss = (mnfld_pred.abs()).mean() # 计算流形损失
self.optimizer.zero_grad() # 清零优化器的梯度
loss.backward() # 反向传播计算梯度
self.optimizer.step() # 更新模型参数
if epoch % 100 == 0: # 每 100 轮记录损失
print(f'Epoch [{epoch}/{self.nepochs}], Loss: {loss.item():.4f}') # 输出当前轮次的损失
self.tracing() #