1 changed files with 41 additions and 0 deletions
			
			
		@ -0,0 +1,41 @@ | 
				
			|||||
 | 
					import torch | 
				
			||||
 | 
					import numpy as np | 
				
			||||
 | 
					import os | 
				
			||||
 | 
					from torch.utils.tensorboard import SummaryWriter | 
				
			||||
 | 
					from utils.logger import logger | 
				
			||||
 | 
					from utils.general import gradient | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					class NHREPNet_Training: | 
				
			||||
 | 
					    def run_nhrepnet_training(self): | 
				
			||||
 | 
					        print("running")  # 输出训练开始的提示信息 | 
				
			||||
 | 
					        self.data = self.data.cuda()  # 将数据移动到 GPU 上 | 
				
			||||
 | 
					        self.data.requires_grad_()  # 设置数据以便计算梯度 | 
				
			||||
 | 
					        self.feature_mask = self.feature_mask.cuda()  # 将特征掩码移动到 GPU 上 | 
				
			||||
 | 
					        n_branch = int(torch.max(self.feature_mask).item())  # 计算分支数量 | 
				
			||||
 | 
					        n_batchsize = self.points_batch  # 设置批次大小 | 
				
			||||
 | 
					        n_patch_batch = n_batchsize // n_branch  # 计算每个分支的补丁批次大小 | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        # 初始化补丁 ID 列表 | 
				
			||||
 | 
					        patch_id = [np.where(self.feature_mask.cpu().numpy() == i + 1)[0] for i in range(n_branch)] | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        for epoch in range(15000):  # 开始训练循环 | 
				
			||||
 | 
					            indices = torch.cat([torch.tensor(patch_id[i][np.random.choice(len(patch_id[i]), n_patch_batch, replace=True)]).cuda() for i in range(n_branch)]).cuda()  # 随机选择补丁的索引 | 
				
			||||
 | 
					            cur_data = self.data[indices]  # 根据索引获取当前数据 | 
				
			||||
 | 
					            mnfld_pnts = cur_data[:, :self.d_in]  # 提取流形点 | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					            # 前向传播 | 
				
			||||
 | 
					            mnfld_pred_all = self.network(mnfld_pnts)  # 进行前向传播,计算流形点的预测值 | 
				
			||||
 | 
					            mnfld_pred = mnfld_pred_all[:, 0]  # 提取流形预测结果 | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					            # 计算损失 | 
				
			||||
 | 
					            loss = (mnfld_pred.abs()).mean()  # 计算流形损失 | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					            self.optimizer.zero_grad()  # 清零优化器的梯度 | 
				
			||||
 | 
					            loss.backward()  # 反向传播计算梯度 | 
				
			||||
 | 
					            self.optimizer.step()  # 更新模型参数 | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					            if epoch % 100 == 0:  # 每 100 轮记录损失 | 
				
			||||
 | 
					                print(f'Epoch [{epoch}/{self.nepochs}], Loss: {loss.item():.4f}')  # 输出当前轮次的损失 | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					        self.tracing()  # | 
				
			||||
					Loading…
					
					
				
		Reference in new issue