diff --git a/code/conversion/train.py b/code/conversion/train.py new file mode 100644 index 0000000..fddd09e --- /dev/null +++ b/code/conversion/train.py @@ -0,0 +1,41 @@ +import torch +import numpy as np +import os +from torch.utils.tensorboard import SummaryWriter +from utils.logger import logger +from utils.general import gradient + + +class NHREPNet_Training: + def run_nhrepnet_training(self): + print("running") # 输出训练开始的提示信息 + self.data = self.data.cuda() # 将数据移动到 GPU 上 + self.data.requires_grad_() # 设置数据以便计算梯度 + self.feature_mask = self.feature_mask.cuda() # 将特征掩码移动到 GPU 上 + n_branch = int(torch.max(self.feature_mask).item()) # 计算分支数量 + n_batchsize = self.points_batch # 设置批次大小 + n_patch_batch = n_batchsize // n_branch # 计算每个分支的补丁批次大小 + + # 初始化补丁 ID 列表 + patch_id = [np.where(self.feature_mask.cpu().numpy() == i + 1)[0] for i in range(n_branch)] + + for epoch in range(15000): # 开始训练循环 + indices = torch.cat([torch.tensor(patch_id[i][np.random.choice(len(patch_id[i]), n_patch_batch, replace=True)]).cuda() for i in range(n_branch)]).cuda() # 随机选择补丁的索引 + cur_data = self.data[indices] # 根据索引获取当前数据 + mnfld_pnts = cur_data[:, :self.d_in] # 提取流形点 + + # 前向传播 + mnfld_pred_all = self.network(mnfld_pnts) # 进行前向传播,计算流形点的预测值 + mnfld_pred = mnfld_pred_all[:, 0] # 提取流形预测结果 + + # 计算损失 + loss = (mnfld_pred.abs()).mean() # 计算流形损失 + + self.optimizer.zero_grad() # 清零优化器的梯度 + loss.backward() # 反向传播计算梯度 + self.optimizer.step() # 更新模型参数 + + if epoch % 100 == 0: # 每 100 轮记录损失 + print(f'Epoch [{epoch}/{self.nepochs}], Loss: {loss.item():.4f}') # 输出当前轮次的损失 + + self.tracing() # \ No newline at end of file