import os import sys import time project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) sys.path.append(project_dir) os.chdir(project_dir) import torch import numpy as np from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from pyhocon import ConfigFactory from scipy.spatial import cKDTree from utils.logger import logger from utils.general import get_class from data_loader import NHREP_Dataset from loss import LossManager from learning_rate import LearningRateScheduler from model.network import NHRepNet # 导入 NHRepNet from model.sample import Sampler class NHREPNet_Training: def __init__(self, data_dir, name_prefix: str, if_baseline: bool = False, if_feature_sample: bool = False): self.conf = ConfigFactory.parse_file('./conversion/setup.conf') self.sampler = Sampler.get_sampler( self.conf.get_string('network.sampler.sampler_type'))( global_sigma=self.conf.get_float('network.sampler.properties.global_sigma'), local_sigma=self.conf.get_float('network.sampler.properties.local_sigma') ) self.dataset = NHREP_Dataset(data_dir, name_prefix, if_baseline, if_feature_sample) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 初始化模型 self.d_in = 3 # 输入维度,x, y, z. self.dims_sdf = [256, 256, 256] # 隐藏层维度 self.nepochs = 15000 # 训练轮数 self.writer = SummaryWriter() # TensorBoard 记录器 def run_nhrepnet_training(self): # 数据准备 logger.info("数据准备") self.data = self.dataset.get_data().to(self.device).requires_grad_() # x, y, z, nx, ny, nz feature_mask_cpu = self.dataset.get_feature_mask().numpy() # 特征掩码 self.feature_mask = torch.from_numpy(feature_mask_cpu).to(self.device) # 特征掩码 # 特征掩码 self.points_batch = 16384 # 批次大小 n_branch = int(torch.max(self.feature_mask).item()) # 计算分支数量 n_batchsize = self.points_batch # 设置批次大小 n_patch_batch = n_batchsize // n_branch # 计算每个分支的补丁批次大小 n_patch_last = n_batchsize - n_patch_batch * (n_branch - 1) # 计算最后一个分支的补丁大小 # 1,准备训练数据 # 1.1,计算每个分支的补丁数量 patch_id, patch_id_n = self.compute_patch(n_branch, n_patch_batch, n_patch_last, feature_mask_cpu) # 1.2,获取分支掩码 branch_mask, single_branch_mask_gt, single_branch_mask_id = self.get_branch_mask(n_branch, n_patch_batch, n_patch_last) # 1.3,初始化模型 csg_tree, flag_convex = self.dataset.get_csg_tree() self.model = get_class(self.conf.get_string('train.network_class'))( d_in=self.d_in, n_branch=n_branch, csg_tree=csg_tree, flag_convex=flag_convex, **self.conf.get_config('network.inputs') ).to(self.device) self.scheduler = LearningRateScheduler(self.conf.get_list('train.learning_rate_schedule'), self.conf.get_float('train.weight_decay'), self.model.parameters()) self.loss_manager = LossManager(ablation="none") logger.info("开始训练") self.model.train() # 设置模型为训练模式 for epoch in range(self.nepochs): # 开始训练循环 try: self.train_one_epoch(epoch, patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch, n_batchsize) except Exception as e: logger.error(f"训练过程中发生错误: {str(e)}") break def train_one_epoch(self, epoch, patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch, n_batchsize): logger.info(f"Epoch {epoch}/{self.nepochs} 开始") # 1.3,获取索引 indices = self.get_indices(patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch) # 1.4,获取数据 cur_data = self.data[indices] # x, y, z, nx, ny, nz mnfld_pnts = cur_data[:, :self.d_in] # 提取流形点 self.compute_local_sigma() mnfld_sigma = self.local_sigma[indices] # 提取噪声点 nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0), mnfld_sigma.unsqueeze(0)).squeeze() # 生成非流形点 #TODO 记录了log # 2,前向传播 self.scheduler.adjust_learning_rate(epoch) #logger.info(f"mnfld_pnts: {mnfld_pnts.shape}") #logger.info(f"nonmnfld_pnts: {nonmnfld_pnts.shape}") # 前向传播 mnfld_pred_all = self.model(mnfld_pnts) # 使用模型进行前向传播 nonmnfld_pred_all = self.model(nonmnfld_pnts) # 使用模型进行前向传播 #logger.info(f"mnfld_pred_all: {mnfld_pred_all.shape}") #logger.info(f"nonmnfld_pred_all: {nonmnfld_pred_all.shape}") normals = cur_data[:, -self.d_in:] # 计算损失 loss = self.loss_manager.compute_loss( mnfld_pnts = mnfld_pnts, normals = normals, mnfld_pred_all = mnfld_pred_all, nonmnfld_pnts = nonmnfld_pnts, nonmnfld_pred_all = nonmnfld_pred_all, n_batchsize = n_batchsize, n_branch = n_branch, n_patch_batch = n_patch_batch, n_patch_last = n_patch_last, ) # 计算损失 self.scheduler.step(loss) # 反向传播 self.scheduler.optimizer.zero_grad() # 清空梯度 loss.backward() # 反向传播 self.scheduler.optimizer.step() # 更新参数 avg_loss = loss.item() logger.info(f'Epoch [{epoch}/{self.nepochs}], Average Loss: {avg_loss:.4f}') self.writer.add_scalar('Loss/train', avg_loss, epoch) # 记录损失到 TensorBoard #============================ 前向传播 数据准备 ============================ def compute_patch(self, n_branch, n_patch_batch, n_patch_last, feature_mask_cpu): ''' 计算每个分支的补丁数量 ''' patch_id = [] patch_id_n = [] for i in range(n_branch): patch_id = patch_id + [np.where(feature_mask_cpu == i + 1)[0]] patch_id_n = patch_id_n + [patch_id[i].shape[0]] return patch_id, patch_id_n def get_branch_mask(self, n_branch, n_patch_batch, n_patch_last): ''' branch_mask: 分支掩码,用于表示每个分支在每个批次中的掩码。每一行对应一个分支,每一列对应一个样本。用于表示每个分支的补丁是否被选中。 single_branch_mask_gt: 单分支掩码,用于表示每个补丁属于哪个分支。每一行对应一个样本,每一列对应一个分支。用于表示每个补丁属于哪个分支。 single_branch_mask_id: 单分支 ID,用于表示每个补丁属于哪个分支。 作用: ''' branch_mask = torch.zeros(n_branch, n_patch_batch).cuda() single_branch_mask_gt = torch.zeros(n_patch_batch, n_branch).cuda() single_branch_mask_id = torch.zeros([n_patch_batch], dtype=torch.long).cuda() for i in range(n_branch - 1): branch_mask[i, i * n_patch_batch : (i + 1) * n_patch_batch] = 1.0 single_branch_mask_gt[i * n_patch_batch : (i + 1) * n_patch_batch, i] = 1.0 single_branch_mask_id[i * n_patch_batch : (i + 1) * n_patch_batch] = i branch_mask[n_branch - 1, (n_branch - 1) * n_patch_batch:] = 1.0 single_branch_mask_gt[(n_branch - 1) * n_patch_batch:, (n_branch - 1)] = 1.0 single_branch_mask_id[(n_branch - 1) * n_patch_batch:] = (n_branch - 1) return branch_mask, single_branch_mask_gt, single_branch_mask_id def get_indices(self, patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch): indices = torch.empty(0, dtype=torch.int64).cuda() for i in range(n_branch - 1): indices_nonfeature = torch.tensor(patch_id[i][np.random.choice(patch_id_n[i], n_patch_batch, True)]).cuda() indices = torch.cat((indices, indices_nonfeature), 0) # last patch indices_nonfeature = torch.tensor(patch_id[n_branch - 1][np.random.choice(patch_id_n[n_branch - 1], n_patch_last, True)]).cuda() indices = torch.cat((indices, indices_nonfeature), 0) return indices def compute_local_sigma(self): """计算局部sigma值""" try: sigma_set = [] data_cpu = self.data.detach().cpu().numpy() ptree = cKDTree(data_cpu) logger.debug("KD tree constructed") for p in np.array_split(data_cpu, 100, axis=0): d = ptree.query(p, 50 + 1) sigma_set.append(d[0][:, -1]) sigmas = np.concatenate(sigma_set) self.local_sigma = torch.from_numpy(sigmas).float().cuda() except Exception as e: logger.error(f"Error computing local sigma: {str(e)}") raise #============================ 保存模型 ============================ def save_checkpoints(self, epoch): torch.save( {"epoch": epoch, "model_state_dict": self.network.state_dict()}, os.path.join(self.checkpoints_path, self.model_params_subdir, str(epoch) + ".pth")) torch.save( {"epoch": epoch, "model_state_dict": self.network.state_dict()}, os.path.join(self.checkpoints_path, self.model_params_subdir, "latest.pth")) torch.save( {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()}, os.path.join(self.checkpoints_path, self.optimizer_params_subdir, str(epoch) + ".pth")) torch.save( {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()}, os.path.join(self.checkpoints_path, self.optimizer_params_subdir, "latest.pth")) if __name__ == "__main__": data_dir = '../data/input_data' # 数据目录 name_prefix = 'broken_bullet_50k' train = NHREPNet_Training(data_dir, name_prefix, if_baseline=True, if_feature_sample=False) train.run_nhrepnet_training()