1 changed files with 41 additions and 0 deletions
@ -0,0 +1,41 @@ |
|||||
|
import torch |
||||
|
import numpy as np |
||||
|
import os |
||||
|
from torch.utils.tensorboard import SummaryWriter |
||||
|
from utils.logger import logger |
||||
|
from utils.general import gradient |
||||
|
|
||||
|
|
||||
|
class NHREPNet_Training: |
||||
|
def run_nhrepnet_training(self): |
||||
|
print("running") # 输出训练开始的提示信息 |
||||
|
self.data = self.data.cuda() # 将数据移动到 GPU 上 |
||||
|
self.data.requires_grad_() # 设置数据以便计算梯度 |
||||
|
self.feature_mask = self.feature_mask.cuda() # 将特征掩码移动到 GPU 上 |
||||
|
n_branch = int(torch.max(self.feature_mask).item()) # 计算分支数量 |
||||
|
n_batchsize = self.points_batch # 设置批次大小 |
||||
|
n_patch_batch = n_batchsize // n_branch # 计算每个分支的补丁批次大小 |
||||
|
|
||||
|
# 初始化补丁 ID 列表 |
||||
|
patch_id = [np.where(self.feature_mask.cpu().numpy() == i + 1)[0] for i in range(n_branch)] |
||||
|
|
||||
|
for epoch in range(15000): # 开始训练循环 |
||||
|
indices = torch.cat([torch.tensor(patch_id[i][np.random.choice(len(patch_id[i]), n_patch_batch, replace=True)]).cuda() for i in range(n_branch)]).cuda() # 随机选择补丁的索引 |
||||
|
cur_data = self.data[indices] # 根据索引获取当前数据 |
||||
|
mnfld_pnts = cur_data[:, :self.d_in] # 提取流形点 |
||||
|
|
||||
|
# 前向传播 |
||||
|
mnfld_pred_all = self.network(mnfld_pnts) # 进行前向传播,计算流形点的预测值 |
||||
|
mnfld_pred = mnfld_pred_all[:, 0] # 提取流形预测结果 |
||||
|
|
||||
|
# 计算损失 |
||||
|
loss = (mnfld_pred.abs()).mean() # 计算流形损失 |
||||
|
|
||||
|
self.optimizer.zero_grad() # 清零优化器的梯度 |
||||
|
loss.backward() # 反向传播计算梯度 |
||||
|
self.optimizer.step() # 更新模型参数 |
||||
|
|
||||
|
if epoch % 100 == 0: # 每 100 轮记录损失 |
||||
|
print(f'Epoch [{epoch}/{self.nepochs}], Loss: {loss.item():.4f}') # 输出当前轮次的损失 |
||||
|
|
||||
|
self.tracing() # |
Loading…
Reference in new issue