You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
41 lines
1.9 KiB
41 lines
1.9 KiB
import torch
|
|
import numpy as np
|
|
import os
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from utils.logger import logger
|
|
from utils.general import gradient
|
|
|
|
|
|
class NHREPNet_Training:
|
|
def run_nhrepnet_training(self):
|
|
print("running") # 输出训练开始的提示信息
|
|
self.data = self.data.cuda() # 将数据移动到 GPU 上
|
|
self.data.requires_grad_() # 设置数据以便计算梯度
|
|
self.feature_mask = self.feature_mask.cuda() # 将特征掩码移动到 GPU 上
|
|
n_branch = int(torch.max(self.feature_mask).item()) # 计算分支数量
|
|
n_batchsize = self.points_batch # 设置批次大小
|
|
n_patch_batch = n_batchsize // n_branch # 计算每个分支的补丁批次大小
|
|
|
|
# 初始化补丁 ID 列表
|
|
patch_id = [np.where(self.feature_mask.cpu().numpy() == i + 1)[0] for i in range(n_branch)]
|
|
|
|
for epoch in range(15000): # 开始训练循环
|
|
indices = torch.cat([torch.tensor(patch_id[i][np.random.choice(len(patch_id[i]), n_patch_batch, replace=True)]).cuda() for i in range(n_branch)]).cuda() # 随机选择补丁的索引
|
|
cur_data = self.data[indices] # 根据索引获取当前数据
|
|
mnfld_pnts = cur_data[:, :self.d_in] # 提取流形点
|
|
|
|
# 前向传播
|
|
mnfld_pred_all = self.network(mnfld_pnts) # 进行前向传播,计算流形点的预测值
|
|
mnfld_pred = mnfld_pred_all[:, 0] # 提取流形预测结果
|
|
|
|
# 计算损失
|
|
loss = (mnfld_pred.abs()).mean() # 计算流形损失
|
|
|
|
self.optimizer.zero_grad() # 清零优化器的梯度
|
|
loss.backward() # 反向传播计算梯度
|
|
self.optimizer.step() # 更新模型参数
|
|
|
|
if epoch % 100 == 0: # 每 100 轮记录损失
|
|
print(f'Epoch [{epoch}/{self.nepochs}], Loss: {loss.item():.4f}') # 输出当前轮次的损失
|
|
|
|
self.tracing() #
|