import os import sys import time project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) sys.path.append(project_dir) os.chdir(project_dir) import argparse # parse args first and set gpu id parser = argparse.ArgumentParser() parser.add_argument('--points_batch', type=int, default=16384, help='point batch size') parser.add_argument('--nepoch', type=int, default=15001, help='number of epochs to train for') parser.add_argument('--conf', type=str, default='setup.conf') parser.add_argument('--expname', type=str, default='single_shape') parser.add_argument('--gpu', type=str, default='0', help='GPU id to use') parser.add_argument('--is_continue', default=False, action="store_true", help='continue') parser.add_argument('--checkpoint', default='latest', type=str) parser.add_argument('--eval', default=False, action="store_true") parser.add_argument('--summary', default = False, action="store_true", help = 'write tensorboard summary') parser.add_argument('--baseline', default = False, action="store_true", help = 'run baseline') parser.add_argument('--th_closeness',type=float, default = 1e-5, help = 'threshold deciding whether two points are the same') parser.add_argument('--cpu', default = False, action="store_true", help = 'save for cpu device') parser.add_argument('--ab', default='none', type=str, help = 'ablation') parser.add_argument('--siren', default = False, action="store_true", help = 'siren normal loss') parser.add_argument('--pt', default='ptfile path', type=str) parser.add_argument('--feature_sample', action="store_true", help = 'use feature curve samples') parser.add_argument('--num_feature_sample', type=int, default=2048, help ='number of bs feature samples') parser.add_argument('--all_feature_sample', type=int, default=10000, help ='number of all feature samples') args = parser.parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(args.gpu) from pyhocon import ConfigFactory import numpy as np import GPUtil import torch import utils.general as utils from model.sample import Sampler from model.network import gradient from scipy.spatial import cKDTree from utils.plots import plot_surface, plot_cuts_axis from torch.utils.tensorboard import SummaryWriter import torch.nn.functional as F from utils.logger import logger logger.info(f"project_dir: {project_dir}") class ReconstructionRunner: def run_nhrepnet_training(self): print("running") # 输出训练开始的提示信息 self.data = self.data.cuda() # 将数据移动到 GPU 上 self.data.requires_grad_() # 设置数据以便计算梯度 feature_mask_cpu = self.feature_mask.numpy() # 将特征掩码转换为 NumPy 数组 self.feature_mask = self.feature_mask.cuda() # 将特征掩码移动到 GPU 上 n_branch = int(torch.max(self.feature_mask).item()) # 计算分支数量 n_batchsize = self.points_batch # 设置批次大小 n_patch_batch = n_batchsize // n_branch # 计算每个分支的补丁批次大小 n_patch_last = n_batchsize - n_patch_batch * (n_branch - 1) # 计算最后一个分支的补丁大小 patch_sup = True # 设置补丁支持标志 weight_mnfld_h = 1 # 初始化流形损失权重 weight_mnfld_cs = 1 # 初始化流形一致性损失权重 weight_correction = 1 # 初始化修正损失权重 a_correction = 100 # 初始化修正损失的系数 patch_id = [] # 初始化补丁 ID 列表 patch_id_n = [] # 初始化补丁数量列表 for i in range(n_branch): patch_id = patch_id + [np.where(feature_mask_cpu == i + 1)[0]] # 找到每个分支的补丁 ID patch_id_n = patch_id_n + [patch_id[i].shape[0]] # 记录每个补丁的数量 if self.eval: # 检查是否处于评估模式 print("evaluating epoch: {0}".format(self.startepoch)) # 输出当前评估的轮次 my_path = os.path.join(self.cur_exp_dir, 'evaluation', str(self.startepoch)) # 创建评估结果保存路径 utils.mkdir_ifnotexists(os.path.join(self.cur_exp_dir, 'evaluation')) # 确保评估目录存在 utils.mkdir_ifnotexists(my_path) # 确保当前评估路径存在 for i in range(1): self.network.flag_output = i + 1 # 设置网络输出标志 self.plot_shapes(epoch=self.startepoch, path=my_path, file_suffix = "_" + str(i), with_cuts = True) # 绘制评估结果 self.network.flag_output = 0 # 将输出标志重置为 0 return # 结束方法 print("training begin") # 输出训练开始的提示信息 if args.summary == True: # 如果启用了摘要记录 writer = SummaryWriter(os.path.join("summary", self.foldername)) # 创建一个 SummaryWriter 实例 # branch mask is predefined branch_mask = torch.zeros(n_branch, n_batchsize).cuda() # 初始化分支掩码 single_branch_mask_gt = torch.zeros(n_batchsize, n_branch).cuda() # 初始化单分支掩码 single_branch_mask_id = torch.zeros([n_batchsize], dtype = torch.long).cuda() # 初始化单分支 ID for i in range(n_branch - 1): branch_mask[i, i * n_patch_batch : (i + 1) * n_patch_batch] = 1.0 # 设置当前分支的补丁掩码 single_branch_mask_gt[i * n_patch_batch : (i + 1) * n_patch_batch, i] = 1.0 # 设置单分支的真实掩码 single_branch_mask_id[i * n_patch_batch : (i + 1) * n_patch_batch] = i # 设置单分支 ID # last patch branch_mask[n_branch - 1, (n_branch - 1) * n_patch_batch:] = 1.0 # 设置最后一个分支的补丁掩码 single_branch_mask_gt[(n_branch - 1) * n_patch_batch:, (n_branch - 1)] = 1.0 # 设置最后一个分支的真实掩码 single_branch_mask_id[(n_branch - 1) * n_patch_batch:] = (n_branch - 1) # 设置最后一个分支 ID for epoch in range(self.startepoch, self.nepochs + 1): # 开始训练循环 indices = torch.empty(0,dtype=torch.int64).cuda() # 初始化索引张量 for i in range(n_branch - 1): indices_nonfeature = torch.tensor(patch_id[i][np.random.choice(patch_id_n[i], n_patch_batch, True)]).cuda() # 随机选择非特征补丁的索引 indices = torch.cat((indices, indices_nonfeature), 0) # 将索引添加到总索引中 # last patch indices_nonfeature = torch.tensor(patch_id[n_branch - 1][np.random.choice(patch_id_n[n_branch - 1], n_patch_last, True)]).cuda() # 处理最后一个补丁的索引 indices = torch.cat((indices, indices_nonfeature), 0) # 将最后一个补丁的索引添加到总索引中 cur_data = self.data[indices] # 根据索引获取当前数据 mnfld_pnts = cur_data[:, :self.d_in] # 提取流形点 mnfld_sigma = self.local_sigma[indices] # 提取噪声点 if epoch % self.conf.get_int('train.checkpoint_frequency') == 0: # 每隔一定轮次保存检查点 self.save_checkpoints(epoch) if epoch % self.conf.get_int('train.plot_frequency') == 0: # 每隔一定轮次绘制验证结果 print('plot validation epoch: ', epoch) # 输出当前绘制的轮次 for i in range(n_branch + 1): self.network.flag_output = i + 1 # 设置网络输出标志 self.plot_shapes(epoch, file_suffix = "_" + str(i), with_cuts = False) # 绘制形状 self.network.flag_output = 0 # 将输出标志重置为 0 self.network.train() # 设置网络为训练模式 self.adjust_learning_rate(epoch) # 调整学习率 nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0), mnfld_sigma.unsqueeze(0)).squeeze() # 生成非流形点 # forward pass mnfld_pred_all = self.network(mnfld_pnts) # 进行前向传播,计算流形点的预测值 nonmnfld_pred_all = self.network(nonmnfld_pnts) # 进行前向传播,计算非流形点的预测值 mnfld_pred = mnfld_pred_all[:,0] # 提取流形预测结果 nonmnfld_pred = nonmnfld_pred_all[:,0] # 提取非流形预测结果 loss = 0.0 # 初始化损失为 0 mnfld_grad = gradient(mnfld_pnts, mnfld_pred) # 计算流形点的梯度 # manifold loss mnfld_loss = torch.zeros(1).cuda() # 初始化流形损失 if not args.ab == 'overall': # 检查是否为整体损失 mnfld_loss = (mnfld_pred.abs()).mean() # 计算流形损失 loss = loss + weight_mnfld_h * mnfld_loss # 将流形损失加权到总损失中 # feature sample if args.feature_sample: # 如果启用了特征采样 feature_indices = torch.randperm(args.all_feature_sample)[:args.num_feature_sample].cuda() # 随机选择特征点 feature_pnts = self.feature_data[feature_indices] # 获取特征点数据 feature_mask_pair = self.feature_data_mask_pair[feature_indices] # 获取特征掩码对 feature_pred_all = self.network(feature_pnts) # 进行前向传播,计算特征点的预测值 feature_pred = feature_pred_all[:,0] # 提取特征预测结果 feature_mnfld_loss = feature_pred.abs().mean() # 计算特征流形损失 loss = loss + weight_mnfld_h * feature_mnfld_loss # 将特征流形损失加权到总损失中 # patch loss: feature_id_left = [list(range(args.num_feature_sample)), feature_mask_pair[:,0].tolist()] # 获取左侧特征 ID feature_id_right = [list(range(args.num_feature_sample)), feature_mask_pair[:,1].tolist()] # 获取右侧特征 ID feature_fis_left = feature_pred_all[feature_id_left] # 获取左侧特征预测值 feature_fis_right = feature_pred_all[feature_id_right] # 获取右侧特征预测值 feature_loss_patch = feature_fis_left.abs().mean() + feature_fis_right.abs().mean() # 计算补丁损失 loss += feature_loss_patch # 将补丁损失加权到总损失中 # consistency loss: feature_loss_cons = (feature_fis_left - feature_pred).abs().mean() + (feature_fis_right - feature_pred).abs().mean() # 计算一致性损失 loss += weight_mnfld_cs * feature_loss_cons # 将一致性损失加权到总损失中 all_fi = torch.zeros([n_batchsize, 1], device = 'cuda') # 初始化所有流形预测值 for i in range(n_branch - 1): all_fi[i * n_patch_batch : (i + 1) * n_patch_batch, 0] = mnfld_pred_all[i * n_patch_batch : (i + 1) * n_patch_batch, i + 1] # 填充流形预测值 # last patch all_fi[(n_branch - 1) * n_patch_batch:, 0] = mnfld_pred_all[(n_branch - 1) * n_patch_batch:, n_branch] # 填充最后一个分支的流形预测值 # manifold loss for patches mnfld_loss_patch = torch.zeros(1).cuda() # 初始化补丁流形损失 if not args.ab == 'patch': # 检查是否为补丁损失 if patch_sup: # 如果支持补丁 mnfld_loss_patch = all_fi[:,0].abs().mean() # 计算补丁流形损失 loss = loss + mnfld_loss_patch # 将补丁流形损失加权到总损失中 # correction loss correction_loss = torch.zeros(1).cuda() # 初始化修正损失 if not (args.ab == 'cor' or args.ab == 'cc') and epoch > 10000 and not args.baseline: # 检查修正损失的条件 mismatch_id = torch.abs(mnfld_pred - all_fi[:,0]) > args.th_closeness # 计算不匹配的 ID if mismatch_id.sum() != 0: # 如果存在不匹配 correction_loss = (a_correction * torch.abs(mnfld_pred - all_fi[:,0])[mismatch_id]).mean() # 计算修正损失 loss = loss + weight_correction * correction_loss # 将修正损失加权到总损失中 # off surface loss offsurface_loss = torch.zeros(1).cuda() # 初始化离表面损失 if not args.ab == 'off': # 检查是否为离表面损失 offsurface_loss = torch.exp(-100.0 * torch.abs(nonmnfld_pred[n_batchsize:])).mean() # 计算离表面损失 loss = loss + offsurface_loss # 将离表面损失加权到总损失中 # manifold consistency loss mnfld_consistency_loss = torch.zeros(1).cuda() # 初始化流形一致性损失 if not (args.ab == 'cons' or args.ab == 'cc'): # 检查是否为一致性损失 mnfld_consistency_loss = (mnfld_pred - all_fi[:,0]).abs().mean() # 计算流形一致性损失 loss = loss + weight_mnfld_cs * mnfld_consistency_loss # 将一致性损失加权到总损失中 # eikonal loss for h grad_loss_h = torch.zeros(1).cuda() # 初始化 Eikonal 损失 single_nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred_all[:,0]) # 计算非流形点的梯度 grad_loss_h = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean() # 计算 Eikonal 损失 loss = loss + self.grad_lambda * grad_loss_h # 将 Eikonal 损失加权到总损失中 # normals loss normals_loss_h = torch.zeros(1).cuda() # 初始化法线损失 normals_loss = torch.zeros(1).cuda() # 初始化法线损失 normal_consistency_loss = torch.zeros(1).cuda() # 初始化法线一致性损失 if not args.siren: # 检查是否使用 SIREN if not args.ab == 'normal' and self.with_normals: # 检查法线损失的条件 # all normals normals = cur_data[:, -self.d_in:] # 提取法线 if patch_sup: # 如果支持补丁 branch_grad = gradient(mnfld_pnts, all_fi[:,0]) # 计算分支梯度 normals_loss = (((branch_grad - normals).abs()).norm(2, dim=1)).mean() # 计算法线损失 loss = loss + self.normals_lambda * normals_loss # 将法线损失加权到总损失中 # only supervised, not used for loss computation mnfld_grad = gradient(mnfld_pnts, mnfld_pred_all[:, 0]) # 计算流形梯度 normal_consistency_loss = (mnfld_grad - branch_grad).abs().norm(2, dim=1).mean() # 计算法线一致性损失 else: single_nonmnfld_grad = gradient(mnfld_pnts, all_fi[:,0]) # 计算非流形点的梯度 normals_loss_h = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean() # 计算法线损失 loss = loss + self.normals_lambda * normals_loss_h # 将法线损失加权到总损失中 else: # compute cosine normal normals = cur_data[:, -self.d_in:] # 提取法线 normals_loss_h = (1 - F.cosine_similarity(mnfld_grad, normals, dim=-1)).mean() # 计算法线的余弦相似度损失 loss = loss + self.normals_lambda * normals_loss_h # 将法线损失加权到总损失中 self.optimizer.zero_grad() # 清零优化器的梯度 loss.backward() # 反向传播计算梯度 self.optimizer.step() # 更新模型参数 # tensorboard if args.summary == True and epoch % 100 == 0: # 每 100 轮记录损失到 TensorBoard writer.add_scalar('Loss/Total loss', loss.item(), epoch) # 记录总损失 writer.add_scalar('Loss/Manifold loss h', mnfld_loss.item(), epoch) # 记录流形损失 writer.add_scalar('Loss/Manifold patch loss', mnfld_loss_patch.item(), epoch) # 记录补丁流形损失 writer.add_scalar('Loss/Manifold cons loss', mnfld_consistency_loss.item(), epoch) # 记录流形一致性损失 writer.add_scalar('Loss/Grad loss h',self.grad_lambda * grad_loss_h.item(), epoch) # 记录 Eikonal 损失 writer.add_scalar('Loss/Normal loss all', self.normals_lambda * normals_loss.item(), epoch) # 记录法线损失 writer.add_scalar('Loss/Normal cs loss', self.normals_lambda * normal_consistency_loss.item(), epoch) # 记录法线一致性损失 writer.add_scalar('Loss/Assignment loss', correction_loss.item(), epoch) # 记录修正损失 writer.add_scalar('Loss/Offsurface loss', offsurface_loss.item(), epoch) # 记录离表面损失 if epoch % self.conf.get_int('train.status_frequency') == 0: # 每隔一定轮次记录训练状态 logger.info('Train Epoch: [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}\t Manifold loss: {:.6f}' '\tManifold patch loss: {:.6f}\t grad loss h: {:.6f}\t normals loss all: {:.6f}\t normals loss h: {:.6f}\t Manifold consistency loss: {:.6f}\tCorrection loss: {:.6f}\t Offsurface loss: {:.6f}'.format( epoch, self.nepochs, 100. * epoch / self.nepochs, loss.item(), mnfld_loss.item(), mnfld_loss_patch.item(), grad_loss_h.item(), normals_loss.item(), normals_loss_h.item(), mnfld_consistency_loss.item(), correction_loss.item(), offsurface_loss.item())) if args.feature_sample: # 如果启用了特征采样 logger.info('feature mnfld loss: {} patch loss: {} cons loss: {}'.format(feature_mnfld_loss.item(), feature_loss_patch.item(), feature_loss_cons.item())) # 记录特征损失信息 self.tracing() # 调用 tracing 方法,可能用于记录或保存训练过程中的某些信息 def tracing(self): #network definition device = torch.device('cuda') if args.cpu: device = torch.device('cpu') network = utils.get_class(self.conf.get_string('train.network_class'))(d_in=3, flag_output = 1, n_branch = int(torch.max(self.feature_mask).item()), csg_tree = self.csg_tree, flag_convex = self.csg_flag_convex, **self.conf.get_config( 'network.inputs')) network.to(device) ckpt_prefix = 'exps/single_shape/' save_prefix = '{}/'.format(args.pt) if not os.path.exists(save_prefix): os.mkdir(save_prefix) if args.cpu: saved_model_state = torch.load(ckpt_prefix + self.foldername + '/checkpoints/ModelParameters/latest.pth', map_location=device) network.load_state_dict(saved_model_state["model_state_dict"]) else: saved_model_state = torch.load(ckpt_prefix + self.foldername + '/checkpoints/ModelParameters/latest.pth') network.load_state_dict(saved_model_state["model_state_dict"]) #trace example = torch.rand(224,3).to(device) traced_script_module = torch.jit.trace(network, example) traced_script_module.save(save_prefix + self.foldername + "_model_h.pt") logger.info('converting to pt finished') def plot_shapes(self, epoch, path=None, with_cuts=False, file_suffix="all"): # plot network validation shapes with torch.no_grad(): self.network.eval() if not path: path = self.plots_dir indices = torch.tensor(np.random.choice(self.data.shape[0], self.points_batch, True)) #modified 0107, with replace pnts = self.data[indices, :3] #draw nonmnfld pts mnfld_sigma = self.local_sigma[indices] nonmnfld_pnts = self.sampler.get_points(pnts.unsqueeze(0), mnfld_sigma.unsqueeze(0)).squeeze() pnts = nonmnfld_pnts plot_surface(with_points=True, points=pnts, decoder=self.network, path=path, epoch=epoch, shapename=self.expname, suffix = file_suffix, **self.conf.get_config('plot')) if with_cuts: plot_cuts_axis(points=pnts, decoder=self.network, latent = None, path=path, epoch=epoch, near_zero=False, axis = 2) def __init__(self, **kwargs): try: # 1. 基础设置初始化 self._initialize_basic_settings(kwargs) # 2. 配置文件和实验目录设置 self._setup_config_and_directories(kwargs) # 3. 数据加载 self._load_data(kwargs) # 4. CSG树设置 self._setup_csg_tree() # 5. 本地sigma计算 self._compute_local_sigma() # 6. 网络和优化器设置 self._setup_network_and_optimizer(kwargs) print("Initialization completed successfully") except Exception as e: logger.error(f"Error during initialization: {str(e)}") raise def _initialize_basic_settings(self, kwargs): """初始化基础设置""" try: self.home_dir = os.path.abspath(os.getcwd()) self.flag_list = kwargs.get('flag_list', False) self.expname = kwargs['expname'] self.GPU_INDEX = kwargs['gpu_index'] self.num_of_gpus = torch.cuda.device_count() self.eval = kwargs['eval'] logger.debug("Basic settings initialized successfully") except KeyError as e: logger.error(f"Missing required parameter: {str(e)}") raise except Exception as e: logger.error(f"Error in basic settings initialization: {str(e)}") raise def _setup_config_and_directories(self, kwargs): """设置配置文件和创建必要的目录""" try: # 配置设置 if isinstance(kwargs['conf'], str): self.conf_filename = './conversion/' + kwargs['conf'] self.conf = ConfigFactory.parse_file(self.conf_filename) else: self.conf = kwargs['conf'] # 创建实验目录 self.exps_folder_name = 'exps' self.expdir = utils.concat_home_dir(os.path.join( self.home_dir, self.exps_folder_name, self.expname)) utils.mkdir_ifnotexists(utils.concat_home_dir( os.path.join(self.home_dir, self.exps_folder_name))) utils.mkdir_ifnotexists(self.expdir) logger.debug("Config and directories setup completed") except Exception as e: logger.error(f"Error in config and directory setup: {str(e)}") raise def _load_data(self, kwargs): """加载数据和特征掩码""" try: if not self.flag_list: self._load_single_data() else: self._load_data_from_list(kwargs) if args.baseline: self.feature_mask = torch.ones(self.data.shape[0]).float() logger.info(f"Data loading finished. Data shape: {self.data.shape}") except FileNotFoundError as e: logger.error(f"Data file not found: {str(e)}") raise except Exception as e: logger.error(f"Error in data loading: {str(e)}") raise def _load_single_data(self): """加载单个数据文件""" try: self.input_file = self.conf.get_string('train.input_path') self.data = utils.load_point_cloud_by_file_extension(self.input_file) self.feature_mask_file = self.conf.get_string('train.feature_mask_path') self.feature_mask = utils.load_feature_mask(self.feature_mask_file) self.foldername = self.conf.get_string('train.foldername') except Exception as e: logger.error(f"Error loading single data file: {str(e)}") raise def _load_data_from_list(self, kwargs): """从列表加载数据""" try: self.input_file = os.path.join( self.conf.get_string('train.input_path'), kwargs['file_prefix']+'.xyz' ) if not os.path.exists(self.input_file): self.flag_data_load = False raise Exception(f"Data file not found: {self.input_file}, absolute path: {os.path.abspath(self.input_file)}") self.flag_data_load = True self.data = utils.load_point_cloud_by_file_extension(self.input_file) self.feature_mask_file = os.path.join( self.conf.get_string('train.input_path'), kwargs['file_prefix']+'_mask.txt' ) if not args.baseline: self.feature_mask = utils.load_feature_mask(self.feature_mask_file) if args.feature_sample: self._load_feature_samples(kwargs) self.foldername = kwargs['folder_prefix'] + kwargs['file_prefix'] except Exception as e: logger.error(f"Error loading data from list: {str(e)}") raise def _load_feature_samples(self, kwargs): """加载特征样本""" try: input_fs_file = os.path.join( self.conf.get_string('train.input_path'), kwargs['file_prefix']+'_feature.xyz' ) self.feature_data = np.loadtxt(input_fs_file) self.feature_data = torch.tensor( self.feature_data, dtype=torch.float32, device='cuda' ) fs_mask_file = os.path.join( self.conf.get_string('train.input_path'), kwargs['file_prefix']+'_feature_mask.txt' ) self.feature_data_mask_pair = torch.tensor( np.loadtxt(fs_mask_file), dtype=torch.int64, device='cuda' ) except Exception as e: logger.error(f"Error loading feature samples: {str(e)}") raise def _setup_csg_tree(self): """设置CSG树""" try: if args.baseline: self.csg_tree = [0] self.csg_flag_convex = True else: csg_conf_file = self.input_file[:-4]+'_csg.conf' csg_config = ConfigFactory.parse_file(csg_conf_file) self.csg_tree = csg_config.get_list('csg.list') self.csg_flag_convex = csg_config.get_int('csg.flag_convex') logger.info(f"CSG tree: {self.csg_tree}") logger.info(f"CSG convex flag: {self.csg_flag_convex}") except Exception as e: logger.error(f"Error in CSG tree setup: {str(e)}") raise def _compute_local_sigma(self): """计算局部sigma值""" try: sigma_set = [] ptree = cKDTree(self.data) logger.debug("KD tree constructed") for p in np.array_split(self.data, 100, axis=0): d = ptree.query(p, 50 + 1) sigma_set.append(d[0][:, -1]) sigmas = np.concatenate(sigma_set) self.local_sigma = torch.from_numpy(sigmas).float().cuda() except Exception as e: logger.error(f"Error computing local sigma: {str(e)}") raise def _setup_network_and_optimizer(self, kwargs): """设置网络和优化器""" try: # 设置目录 self._setup_checkpoints_directories() # 网络参数设置 self._setup_network_parameters(kwargs) # 创建网络 self._create_network() # 设置优化器 self._setup_optimizer(kwargs) logger.debug("Network and optimizer setup completed") except Exception as e: logger.error(f"Error in network and optimizer setup: {str(e)}") raise def _setup_checkpoints_directories(self): """设置检查点目录""" try: self.cur_exp_dir = os.path.join(self.expdir, self.foldername) utils.mkdir_ifnotexists(self.cur_exp_dir) self.plots_dir = os.path.join(self.cur_exp_dir, 'plots') utils.mkdir_ifnotexists(self.plots_dir) self.checkpoints_path = os.path.join(self.cur_exp_dir, 'checkpoints') utils.mkdir_ifnotexists(self.checkpoints_path) self.model_params_subdir = "ModelParameters" self.optimizer_params_subdir = "OptimizerParameters" utils.mkdir_ifnotexists(os.path.join( self.checkpoints_path, self.model_params_subdir)) utils.mkdir_ifnotexists(os.path.join( self.checkpoints_path, self.optimizer_params_subdir)) except Exception as e: logger.error(f"Error setting up checkpoint directories: {str(e)}") raise def _setup_network_parameters(self, kwargs): """设置网络参数""" try: self.nepochs = kwargs['nepochs'] self.points_batch = kwargs['points_batch'] self.global_sigma = self.conf.get_float('network.sampler.properties.global_sigma') self.sampler = Sampler.get_sampler( self.conf.get_string('network.sampler.sampler_type'))( self.global_sigma, self.local_sigma ) self.grad_lambda = self.conf.get_float('network.loss.lambda') self.normals_lambda = self.conf.get_float('network.loss.normals_lambda') self.with_normals = self.normals_lambda > 0 self.d_in = self.conf.get_int('train.d_in') except Exception as e: logger.error(f"Error setting up network parameters: {str(e)}") raise def _create_network(self): """创建网络""" try: self.network = utils.get_class(self.conf.get_string('train.network_class'))( d_in=self.d_in, n_branch=int(torch.max(self.feature_mask).item()), csg_tree=self.csg_tree, flag_convex=self.csg_flag_convex, **self.conf.get_config('network.inputs') ) logger.info(f"Network created: {self.network}") if torch.cuda.is_available(): self.network.cuda() except Exception as e: logger.error(f"Error creating network: {str(e)}") raise def _setup_optimizer(self, kwargs): """设置优化器""" try: self.lr_schedules = self.get_learning_rate_schedules( self.conf.get_list('train.learning_rate_schedule')) self.weight_decay = self.conf.get_float('train.weight_decay') self.startepoch = 0 self.optimizer = torch.optim.Adam([{ "params": self.network.parameters(), "lr": self.lr_schedules[0].get_learning_rate(0), "weight_decay": self.weight_decay }]) # 如果继续训练,加载检查点 self._load_checkpoints_if_continue(kwargs) except Exception as e: logger.error(f"Error setting up optimizer: {str(e)}") raise def _load_checkpoints_if_continue(self, kwargs): """如果继续训练,加载检查点""" try: model_params_path = os.path.join( self.checkpoints_path, self.model_params_subdir) ckpts = os.listdir(model_params_path) if len(ckpts) != 0: old_checkpnts_dir = os.path.join( self.expdir, self.foldername, 'checkpoints') logger.info(f'Loading checkpoint from: {old_checkpnts_dir}') # 加载模型状态 saved_model_state = torch.load(os.path.join( old_checkpnts_dir, 'ModelParameters', f"{kwargs['checkpoint']}.pth" )) self.network.load_state_dict(saved_model_state["model_state_dict"]) # 加载优化器状态 data = torch.load(os.path.join( old_checkpnts_dir, 'OptimizerParameters', f"{kwargs['checkpoint']}.pth" )) self.optimizer.load_state_dict(data["optimizer_state_dict"]) self.startepoch = saved_model_state['epoch'] except Exception as e: logger.error(f"Error loading checkpoints: {str(e)}") raise def get_learning_rate_schedules(self, schedule_specs): schedules = [] for schedule_specs in schedule_specs: if schedule_specs["Type"] == "Step": schedules.append( utils.StepLearningRateSchedule( schedule_specs["Initial"], schedule_specs["Interval"], schedule_specs["Factor"], ) ) else: raise Exception( 'no known learning rate schedule of type "{}"'.format( schedule_specs["Type"] ) ) return schedules def adjust_learning_rate(self, epoch): for i, param_group in enumerate(self.optimizer.param_groups): param_group["lr"] = self.lr_schedules[i].get_learning_rate(epoch) def save_checkpoints(self, epoch): torch.save( {"epoch": epoch, "model_state_dict": self.network.state_dict()}, os.path.join(self.checkpoints_path, self.model_params_subdir, str(epoch) + ".pth")) torch.save( {"epoch": epoch, "model_state_dict": self.network.state_dict()}, os.path.join(self.checkpoints_path, self.model_params_subdir, "latest.pth")) torch.save( {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()}, os.path.join(self.checkpoints_path, self.optimizer_params_subdir, str(epoch) + ".pth")) torch.save( {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()}, os.path.join(self.checkpoints_path, self.optimizer_params_subdir, "latest.pth")) if __name__ == '__main__': if args.gpu == "auto": deviceIDs = GPUtil.getAvailable(order='memory', limit=1, maxLoad=0.5, maxMemory=0.5, includeNan=False, excludeID=[], excludeUUID=[]) gpu = deviceIDs[0] else: gpu = args.gpu nepoch = args.nepoch conf = ConfigFactory.parse_file('./conversion/' + args.conf) folderprefix = conf.get_string('train.folderprefix') fileprefix_list = conf.get_list('train.fileprefix_list') trainrunners = [] for i in range(len(fileprefix_list)): fp = fileprefix_list[i] print ('cur model: ', fp) begin = time.time() trainrunners.append(ReconstructionRunner( conf=args.conf, folder_prefix = folderprefix, file_prefix = fp, points_batch=args.points_batch, nepochs=nepoch, expname=args.expname, gpu_index=gpu, is_continue=args.is_continue, checkpoint=args.checkpoint, eval=args.eval, flag_list = True ) ) if trainrunners[i].flag_data_load: trainrunners[i].run_nhrepnet_training() end = time.time() dur = end - begin if args.baseline: fp = fp+"_bl"