import os import sys import time project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) sys.path.append(project_dir) os.chdir(project_dir) import torch import numpy as np from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from pyhocon import ConfigFactory from scipy.spatial import cKDTree from utils.logger import logger from utils.general import get_class from data_loader import NHREP_Dataset from loss import LossManager from learning_rate import LearningRateScheduler from model.network import NHRepNet # 导入 NHRepNet from model.sample import Sampler class NHREPNet_Training: def __init__(self, name_prefix: str, conf, if_baseline: bool = False, if_feature_sample: bool = False): self.conf = conf self.sampler = Sampler.get_sampler( self.conf.get_string('network.sampler.sampler_type'))( global_sigma=self.conf.get_float('network.sampler.properties.global_sigma'), local_sigma=self.conf.get_float('network.sampler.properties.local_sigma') ) data_dir = self.conf.get_string('train.input_path') self.dataset = NHREP_Dataset(data_dir, name_prefix, if_baseline, if_feature_sample) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 初始化模型 self.d_in = 3 # 输入维度,x, y, z. self.dims_sdf = [256, 256, 256] # 隐藏层维度 self.nepochs = 15000 # 训练轮数 folder = self.conf.get_string('train.folderprefix') self.writer = SummaryWriter(os.path.join("summary",folder, name_prefix)) # TensorBoard 记录器 # checkpoint self.init_checkpoints() def run_nhrepnet_training(self): # 数据准备 logger.info("数据准备") self.data = self.dataset.get_data().to(self.device).requires_grad_() # x, y, z, nx, ny, nz feature_mask_cpu = self.dataset.get_feature_mask().numpy() # 特征掩码 self.feature_mask = torch.from_numpy(feature_mask_cpu).to(self.device) # 特征掩码 # 特征掩码 self.points_batch = 16384 # 批次大小 self.compute_local_sigma() n_branch = int(torch.max(self.feature_mask).item()) # 计算分支数量 n_batchsize = self.points_batch # 设置批次大小 n_patch_batch = n_batchsize // n_branch # 计算每个分支的补丁批次大小 n_patch_last = n_batchsize - n_patch_batch * (n_branch - 1) # 计算最后一个分支的补丁大小 # 1,准备训练数据 # 1.1,计算每个分支的补丁数量 patch_id, patch_id_n = self.compute_patch(n_branch, n_patch_batch, n_patch_last, feature_mask_cpu) # 1.2,获取分支掩码 branch_mask, single_branch_mask_gt, single_branch_mask_id = self.get_branch_mask(n_branch, n_patch_batch, n_patch_last) # 1.3,初始化模型 csg_tree, flag_convex = self.dataset.get_csg_tree() self.model = get_class(self.conf.get_string('train.network_class'))( d_in=self.d_in, n_branch=n_branch, csg_tree=csg_tree, flag_convex=flag_convex, **self.conf.get_config('network.inputs') ).to(self.device) self.scheduler = LearningRateScheduler(self.conf.get_list('train.learning_rate_schedule'), self.conf.get_float('train.weight_decay'), self.model.parameters()) self.loss_manager = LossManager(ablation="none") logger.info("开始训练") self.model.train() # 设置模型为训练模式 for epoch in tqdm(range(self.nepochs), desc="训练进度", unit="epoch"): # 开始训练循环 try: self.train_one_epoch(epoch, patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch, n_batchsize) except Exception as e: logger.error(f"训练过程中发生错误: {str(e)}") break self.tracing() def train_one_epoch(self, epoch, patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch, n_batchsize): #logger.info(f"Epoch {epoch}/{self.nepochs} 开始") # 1.3,获取索引 indices = self.get_indices(patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch) # 1.4,获取数据 cur_data = self.data[indices] # x, y, z, nx, ny, nz mnfld_pnts = cur_data[:, :self.d_in] # 提取流形点 mnfld_sigma = self.local_sigma[indices] # 提取噪声点 nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0), mnfld_sigma.unsqueeze(0)).squeeze() # 生成非流形点 #TODO 记录了log # 2,前向传播 #logger.info(f"mnfld_pnts: {mnfld_pnts.shape}") #logger.info(f"nonmnfld_pnts: {nonmnfld_pnts.shape}") # 前向传播 mnfld_pred_all = self.model(mnfld_pnts) # 使用模型进行前向传播 nonmnfld_pred_all = self.model(nonmnfld_pnts) # 使用模型进行前向传播 #logger.info(f"mnfld_pred_all: {mnfld_pred_all.shape}") #logger.info(f"nonmnfld_pred_all: {nonmnfld_pred_all.shape}") normals = cur_data[:, -self.d_in:] # 计算损失 loss,loss_details = self.loss_manager.compute_loss( mnfld_pnts = mnfld_pnts, normals = normals, mnfld_pred_all = mnfld_pred_all, nonmnfld_pnts = nonmnfld_pnts, nonmnfld_pred_all = nonmnfld_pred_all, n_batchsize = n_batchsize, n_branch = n_branch, n_patch_batch = n_patch_batch, n_patch_last = n_patch_last, ) # 计算损失 self.scheduler.step(loss,epoch) # 反向传播 self.scheduler.optimizer.zero_grad() # 清空梯度 loss.backward() # 反向传播 self.scheduler.optimizer.step() # 更新参数 avg_loss = loss.item() if epoch % 100 == 0: #logger.info(f'Epoch [{epoch}/{self.nepochs}]') self.writer.add_scalar('Loss/train', avg_loss, epoch) # 记录损失到 TensorBoard for k,v in loss_details.items(): self.writer.add_scalar('Loss/'+k, v.item(), epoch) if epoch % self.conf.get_int('train.checkpoint_frequency') == 0: # 每隔一定轮次保存检查点 self.save_checkpoints(epoch) #============================ 前向传播 数据准备 ============================ def compute_patch(self, n_branch, n_patch_batch, n_patch_last, feature_mask_cpu): ''' 计算每个分支的补丁数量 ''' patch_id = [] patch_id_n = [] for i in range(n_branch): patch_id = patch_id + [np.where(feature_mask_cpu == i + 1)[0]] patch_id_n = patch_id_n + [patch_id[i].shape[0]] return patch_id, patch_id_n def get_branch_mask(self, n_branch, n_patch_batch, n_patch_last): ''' branch_mask: 分支掩码,用于表示每个分支在每个批次中的掩码。每一行对应一个分支,每一列对应一个样本。用于表示每个分支的补丁是否被选中。 single_branch_mask_gt: 单分支掩码,用于表示每个补丁属于哪个分支。每一行对应一个样本,每一列对应一个分支。用于表示每个补丁属于哪个分支。 single_branch_mask_id: 单分支 ID,用于表示每个补丁属于哪个分支。 作用: ''' branch_mask = torch.zeros(n_branch, n_patch_batch).cuda() single_branch_mask_gt = torch.zeros(n_patch_batch, n_branch).cuda() single_branch_mask_id = torch.zeros([n_patch_batch], dtype=torch.long).cuda() for i in range(n_branch - 1): branch_mask[i, i * n_patch_batch : (i + 1) * n_patch_batch] = 1.0 single_branch_mask_gt[i * n_patch_batch : (i + 1) * n_patch_batch, i] = 1.0 single_branch_mask_id[i * n_patch_batch : (i + 1) * n_patch_batch] = i branch_mask[n_branch - 1, (n_branch - 1) * n_patch_batch:] = 1.0 single_branch_mask_gt[(n_branch - 1) * n_patch_batch:, (n_branch - 1)] = 1.0 single_branch_mask_id[(n_branch - 1) * n_patch_batch:] = (n_branch - 1) return branch_mask, single_branch_mask_gt, single_branch_mask_id def get_indices(self, patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch): indices = torch.empty(0, dtype=torch.int64).cuda() for i in range(n_branch - 1): indices_nonfeature = torch.tensor(patch_id[i][np.random.choice(patch_id_n[i], n_patch_batch, True)]).cuda() indices = torch.cat((indices, indices_nonfeature), 0) # last patch indices_nonfeature = torch.tensor(patch_id[n_branch - 1][np.random.choice(patch_id_n[n_branch - 1], n_patch_last, True)]).cuda() indices = torch.cat((indices, indices_nonfeature), 0) return indices def compute_local_sigma(self): """计算局部sigma值""" try: sigma_set = [] data_cpu = self.data.detach().cpu().numpy() ptree = cKDTree(data_cpu) logger.debug("KD tree constructed") for p in np.array_split(data_cpu, 100, axis=0): d = ptree.query(p, 50 + 1) sigma_set.append(d[0][:, -1]) sigmas = np.concatenate(sigma_set) self.local_sigma = torch.from_numpy(sigmas).float().cuda() except Exception as e: logger.error(f"Error computing local sigma: {str(e)}") raise #============================ 保存模型 ============================ def init_checkpoints(self): self.checkpoints_path = os.path.join("../exps/single_shape",name_prefix, "checkpoints") self.ModelParameters_path = os.path.join(self.checkpoints_path, "ModelParameters") self.OptimizerParameters_path = os.path.join(self.checkpoints_path, "OptimizerParameters") self.TorchScript_path = os.path.join(self.checkpoints_path, "TorchScript") # 创建目录 os.makedirs(self.ModelParameters_path, exist_ok=True) os.makedirs(self.OptimizerParameters_path, exist_ok=True) os.makedirs(self.TorchScript_path, exist_ok=True) def save_checkpoints(self, epoch): torch.save( {"epoch": epoch, "model_state_dict": self.model.state_dict()}, os.path.join(self.ModelParameters_path, str(epoch) + ".pth")) torch.save( {"epoch": epoch, "model_state_dict": self.model.state_dict()}, os.path.join(self.ModelParameters_path, "latest.pth")) torch.save( {"epoch": epoch, "optimizer_state_dict": self.scheduler.optimizer.state_dict()}, os.path.join(self.OptimizerParameters_path, str(epoch) + ".pth")) torch.save( {"epoch": epoch, "optimizer_state_dict": self.scheduler.optimizer.state_dict()}, os.path.join(self.OptimizerParameters_path, "latest.pth")) def tracing(self): csg_tree, flag_convex = self.dataset.get_csg_tree() network = get_class(self.conf.get_string('train.network_class'))( d_in=self.d_in, n_branch=int(torch.max(self.feature_mask).item()), csg_tree=csg_tree, flag_convex=flag_convex, **self.conf.get_config('network.inputs') ).to(self.device) #trace example = torch.rand(224,3).to(self.device) traced_script_module = torch.jit.trace(network, example) traced_script_module.save(os.path.join(self.TorchScript_path, "model_h.pt")) if __name__ == "__main__": name_prefix = 'broken_bullet_50k' conf = ConfigFactory.parse_file('./conversion/setup.conf') try: train = NHREPNet_Training(name_prefix, conf, if_baseline=True, if_feature_sample=False) train.run_nhrepnet_training() except Exception as e: logger.error(str(e))