From d425165c328302050c8dd24db52ae5b86dcfb771 Mon Sep 17 00:00:00 2001 From: mckay Date: Sun, 5 Jan 2025 17:41:30 +0800 Subject: [PATCH] Refactor ReconstructionRunner initialization in run.py for improved structure and error handling - Split the initialization process into multiple private methods for better readability and maintainability. - Added detailed logging for each step of the initialization process, including error handling for missing parameters and file loading issues. - Enhanced configuration and directory setup with clearer error messages and structured logging. - Improved data loading methods to handle both single and list-based data inputs more robustly. - Introduced methods for setting up the CSG tree and computing local sigma values, with appropriate logging for each operation. --- code/conversion/run.py | 448 ++++++++++++++++++++++++++++------------- 1 file changed, 310 insertions(+), 138 deletions(-) diff --git a/code/conversion/run.py b/code/conversion/run.py index 6dbad5b..1cb145e 100644 --- a/code/conversion/run.py +++ b/code/conversion/run.py @@ -314,157 +314,327 @@ class ReconstructionRunner: axis = 2) def __init__(self, **kwargs): - self.home_dir = os.path.abspath(os.getcwd()) - flag_list = False - if 'flag_list' in kwargs: - flag_list = True - - # config setting - if type(kwargs['conf']) == str: - self.conf_filename = './conversion/' + kwargs['conf'] - self.conf = ConfigFactory.parse_file(self.conf_filename) - else: - self.conf = kwargs['conf'] - - self.expname = kwargs['expname'] - - # GPU settings, currently we only support single-gpu training - self.GPU_INDEX = kwargs['gpu_index'] - self.num_of_gpus = torch.cuda.device_count() - self.eval = kwargs['eval'] - - self.exps_folder_name = 'exps' - utils.mkdir_ifnotexists(utils.concat_home_dir(os.path.join(self.home_dir, self.exps_folder_name))) - self.expdir = utils.concat_home_dir(os.path.join(self.home_dir, self.exps_folder_name, self.expname)) - utils.mkdir_ifnotexists(self.expdir) - - if not flag_list: + try: + # 1. 基础设置初始化 + self._initialize_basic_settings(kwargs) + + # 2. 配置文件和实验目录设置 + self._setup_config_and_directories(kwargs) + + # 3. 数据加载 + self._load_data(kwargs) + + # 4. CSG树设置 + self._setup_csg_tree() + + # 5. 本地sigma计算 + self._compute_local_sigma() + + # 6. 网络和优化器设置 + self._setup_network_and_optimizer(kwargs) + + print("Initialization completed successfully") + + except Exception as e: + logger.error(f"Error during initialization: {str(e)}") + raise + + def _initialize_basic_settings(self, kwargs): + """初始化基础设置""" + try: + self.home_dir = os.path.abspath(os.getcwd()) + self.flag_list = kwargs.get('flag_list', False) + self.expname = kwargs['expname'] + self.GPU_INDEX = kwargs['gpu_index'] + self.num_of_gpus = torch.cuda.device_count() + self.eval = kwargs['eval'] + logger.debug("Basic settings initialized successfully") + except KeyError as e: + logger.error(f"Missing required parameter: {str(e)}") + raise + except Exception as e: + logger.error(f"Error in basic settings initialization: {str(e)}") + raise + + def _setup_config_and_directories(self, kwargs): + """设置配置文件和创建必要的目录""" + try: + # 配置设置 + if isinstance(kwargs['conf'], str): + self.conf_filename = './conversion/' + kwargs['conf'] + self.conf = ConfigFactory.parse_file(self.conf_filename) + else: + self.conf = kwargs['conf'] + + # 创建实验目录 + self.exps_folder_name = 'exps' + self.expdir = utils.concat_home_dir(os.path.join( + self.home_dir, self.exps_folder_name, self.expname)) + utils.mkdir_ifnotexists(utils.concat_home_dir( + os.path.join(self.home_dir, self.exps_folder_name))) + utils.mkdir_ifnotexists(self.expdir) + + logger.debug("Config and directories setup completed") + except Exception as e: + logger.error(f"Error in config and directory setup: {str(e)}") + raise + + def _load_data(self, kwargs): + """加载数据和特征掩码""" + try: + if not self.flag_list: + self._load_single_data() + else: + self._load_data_from_list(kwargs) + + if args.baseline: + self.feature_mask = torch.ones(self.data.shape[0]).float() + + logger.info(f"Data loading finished. Data shape: {self.data.shape}") + except FileNotFoundError as e: + logger.error(f"Data file not found: {str(e)}") + raise + except Exception as e: + logger.error(f"Error in data loading: {str(e)}") + raise + + def _load_single_data(self): + """加载单个数据文件""" + try: self.input_file = self.conf.get_string('train.input_path') self.data = utils.load_point_cloud_by_file_extension(self.input_file) self.feature_mask_file = self.conf.get_string('train.feature_mask_path') self.feature_mask = utils.load_feature_mask(self.feature_mask_file) - else: - self.input_file = os.path.join(self.conf.get_string('train.input_path'), kwargs['file_prefix']+'.xyz') + self.foldername = self.conf.get_string('train.foldername') + except Exception as e: + logger.error(f"Error loading single data file: {str(e)}") + raise + + def _load_data_from_list(self, kwargs): + """从列表加载数据""" + try: + self.input_file = os.path.join( + self.conf.get_string('train.input_path'), + kwargs['file_prefix']+'.xyz' + ) if not os.path.exists(self.input_file): self.flag_data_load = False - return + raise Exception(f"Data file not found: {self.input_file}, absolute path: {os.path.abspath(self.input_file)}") + self.flag_data_load = True self.data = utils.load_point_cloud_by_file_extension(self.input_file) - self.feature_mask_file = os.path.join(self.conf.get_string('train.input_path'), kwargs['file_prefix']+'_mask.txt') + self.feature_mask_file = os.path.join( + self.conf.get_string('train.input_path'), + kwargs['file_prefix']+'_mask.txt' + ) + if not args.baseline: self.feature_mask = utils.load_feature_mask(self.feature_mask_file) - + if args.feature_sample: - input_fs_file = os.path.join(self.conf.get_string('train.input_path'), kwargs['file_prefix']+'_feature.xyz') - self.feature_data = np.loadtxt(input_fs_file) - self.feature_data = torch.tensor(self.feature_data, dtype = torch.float32, device = 'cuda') - fs_mask_file = os.path.join(self.conf.get_string('train.input_path'), kwargs['file_prefix']+'_feature_mask.txt') - self.feature_data_mask_pair = torch.tensor(np.loadtxt(fs_mask_file), dtype = torch.int64, device = 'cuda') - if args.baseline: - self.csg_tree = [0] - self.csg_flag_convex = True - else: - self.csg_tree = [] - self.csg_tree = ConfigFactory.parse_file(self.input_file[:-4]+'_csg.conf').get_list('csg.list') - self.csg_flag_convex = ConfigFactory.parse_file(self.input_file[:-4]+'_csg.conf').get_int('csg.flag_convex') - print ("csg tree: ", self.csg_tree) - print ("csg convex flag: ", self.csg_flag_convex) - - if not flag_list: - self.foldername = self.conf.get_string('train.foldername') - else: + self._load_feature_samples(kwargs) + self.foldername = kwargs['folder_prefix'] + kwargs['file_prefix'] + except Exception as e: + logger.error(f"Error loading data from list: {str(e)}") + raise + + def _load_feature_samples(self, kwargs): + """加载特征样本""" + try: + input_fs_file = os.path.join( + self.conf.get_string('train.input_path'), + kwargs['file_prefix']+'_feature.xyz' + ) + self.feature_data = np.loadtxt(input_fs_file) + self.feature_data = torch.tensor( + self.feature_data, + dtype=torch.float32, + device='cuda' + ) + + fs_mask_file = os.path.join( + self.conf.get_string('train.input_path'), + kwargs['file_prefix']+'_feature_mask.txt' + ) + self.feature_data_mask_pair = torch.tensor( + np.loadtxt(fs_mask_file), + dtype=torch.int64, + device='cuda' + ) + except Exception as e: + logger.error(f"Error loading feature samples: {str(e)}") + raise - if args.baseline: - self.feature_mask = torch.ones(self.data.shape[0]).float() - - print ("loading finished") - print ("data shape: ", self.data.shape) - - sigma_set = [] - ptree = cKDTree(self.data) - print ("kd tree constructed") - - for p in np.array_split(self.data, 100, axis=0): - d = ptree.query(p, 50 + 1) - sigma_set.append(d[0][:, -1]) - - sigmas = np.concatenate(sigma_set) - self.local_sigma = torch.from_numpy(sigmas).float().cuda() - - self.cur_exp_dir = os.path.join(self.expdir, self.foldername) - utils.mkdir_ifnotexists(self.cur_exp_dir) - - self.plots_dir = os.path.join(self.cur_exp_dir, 'plots') - utils.mkdir_ifnotexists(self.plots_dir) - - self.checkpoints_path = os.path.join(self.cur_exp_dir, 'checkpoints') - utils.mkdir_ifnotexists(self.checkpoints_path) - - self.model_params_subdir = "ModelParameters" - self.optimizer_params_subdir = "OptimizerParameters" - - utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.model_params_subdir)) - utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir)) - model_params_path = os.path.join(self.checkpoints_path, self.model_params_subdir) - ckpts = os.listdir(model_params_path) - #if ckpts exists, then continue - is_continue = False - if (len(ckpts)) != 0: - is_continue = True - - self.nepochs = kwargs['nepochs'] - - self.points_batch = kwargs['points_batch'] - - self.global_sigma = self.conf.get_float('network.sampler.properties.global_sigma') - self.sampler = Sampler.get_sampler(self.conf.get_string('network.sampler.sampler_type'))(self.global_sigma, - self.local_sigma) - self.grad_lambda = self.conf.get_float('network.loss.lambda') - self.normals_lambda = self.conf.get_float('network.loss.normals_lambda') - - self.with_normals = self.normals_lambda > 0 - - self.d_in = self.conf.get_int('train.d_in') - - self.network = utils.get_class(self.conf.get_string('train.network_class'))(d_in=self.d_in, - n_branch = int(torch.max(self.feature_mask).item()), - csg_tree = self.csg_tree, - flag_convex = self.csg_flag_convex, - **self.conf.get_config( - 'network.inputs')) - - - print (self.network) - - if torch.cuda.is_available(): - self.network.cuda() - - self.lr_schedules = self.get_learning_rate_schedules(self.conf.get_list('train.learning_rate_schedule')) - self.weight_decay = self.conf.get_float('train.weight_decay') - - self.startepoch = 0 - self.optimizer = torch.optim.Adam( - [ - { - "params": self.network.parameters(), - "lr": self.lr_schedules[0].get_learning_rate(0), - "weight_decay": self.weight_decay - }, - ]) - - # if continue load checkpoints - if is_continue: - old_checkpnts_dir = os.path.join(self.expdir, self.foldername, 'checkpoints') - print('loading checkpoint from: ', old_checkpnts_dir) - saved_model_state = torch.load( - os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth")) - self.network.load_state_dict(saved_model_state["model_state_dict"]) - - data = torch.load( - os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth")) - self.optimizer.load_state_dict(data["optimizer_state_dict"]) - self.startepoch = saved_model_state['epoch'] + def _setup_csg_tree(self): + """设置CSG树""" + try: + if args.baseline: + self.csg_tree = [0] + self.csg_flag_convex = True + else: + csg_conf_file = self.input_file[:-4]+'_csg.conf' + csg_config = ConfigFactory.parse_file(csg_conf_file) + self.csg_tree = csg_config.get_list('csg.list') + self.csg_flag_convex = csg_config.get_int('csg.flag_convex') + + logger.info(f"CSG tree: {self.csg_tree}") + logger.info(f"CSG convex flag: {self.csg_flag_convex}") + except Exception as e: + logger.error(f"Error in CSG tree setup: {str(e)}") + raise + + def _compute_local_sigma(self): + """计算局部sigma值""" + try: + sigma_set = [] + ptree = cKDTree(self.data) + logger.debug("KD tree constructed") + + for p in np.array_split(self.data, 100, axis=0): + d = ptree.query(p, 50 + 1) + sigma_set.append(d[0][:, -1]) + + sigmas = np.concatenate(sigma_set) + self.local_sigma = torch.from_numpy(sigmas).float().cuda() + except Exception as e: + logger.error(f"Error computing local sigma: {str(e)}") + raise + + def _setup_network_and_optimizer(self, kwargs): + """设置网络和优化器""" + try: + # 设置目录 + self._setup_checkpoints_directories() + + # 网络参数设置 + self._setup_network_parameters(kwargs) + + # 创建网络 + self._create_network() + + # 设置优化器 + self._setup_optimizer(kwargs) + + logger.debug("Network and optimizer setup completed") + except Exception as e: + logger.error(f"Error in network and optimizer setup: {str(e)}") + raise + + def _setup_checkpoints_directories(self): + """设置检查点目录""" + try: + self.cur_exp_dir = os.path.join(self.expdir, self.foldername) + utils.mkdir_ifnotexists(self.cur_exp_dir) + + self.plots_dir = os.path.join(self.cur_exp_dir, 'plots') + utils.mkdir_ifnotexists(self.plots_dir) + + self.checkpoints_path = os.path.join(self.cur_exp_dir, 'checkpoints') + utils.mkdir_ifnotexists(self.checkpoints_path) + + self.model_params_subdir = "ModelParameters" + self.optimizer_params_subdir = "OptimizerParameters" + + utils.mkdir_ifnotexists(os.path.join( + self.checkpoints_path, self.model_params_subdir)) + utils.mkdir_ifnotexists(os.path.join( + self.checkpoints_path, self.optimizer_params_subdir)) + except Exception as e: + logger.error(f"Error setting up checkpoint directories: {str(e)}") + raise + + def _setup_network_parameters(self, kwargs): + """设置网络参数""" + try: + self.nepochs = kwargs['nepochs'] + self.points_batch = kwargs['points_batch'] + self.global_sigma = self.conf.get_float('network.sampler.properties.global_sigma') + + self.sampler = Sampler.get_sampler( + self.conf.get_string('network.sampler.sampler_type'))( + self.global_sigma, + self.local_sigma + ) + + self.grad_lambda = self.conf.get_float('network.loss.lambda') + self.normals_lambda = self.conf.get_float('network.loss.normals_lambda') + self.with_normals = self.normals_lambda > 0 + self.d_in = self.conf.get_int('train.d_in') + except Exception as e: + logger.error(f"Error setting up network parameters: {str(e)}") + raise + + def _create_network(self): + """创建网络""" + try: + self.network = utils.get_class(self.conf.get_string('train.network_class'))( + d_in=self.d_in, + n_branch=int(torch.max(self.feature_mask).item()), + csg_tree=self.csg_tree, + flag_convex=self.csg_flag_convex, + **self.conf.get_config('network.inputs') + ) + + logger.info(f"Network created: {self.network}") + + if torch.cuda.is_available(): + self.network.cuda() + except Exception as e: + logger.error(f"Error creating network: {str(e)}") + raise + + def _setup_optimizer(self, kwargs): + """设置优化器""" + try: + self.lr_schedules = self.get_learning_rate_schedules( + self.conf.get_list('train.learning_rate_schedule')) + self.weight_decay = self.conf.get_float('train.weight_decay') + + self.startepoch = 0 + self.optimizer = torch.optim.Adam([{ + "params": self.network.parameters(), + "lr": self.lr_schedules[0].get_learning_rate(0), + "weight_decay": self.weight_decay + }]) + + # 如果继续训练,加载检查点 + self._load_checkpoints_if_continue(kwargs) + except Exception as e: + logger.error(f"Error setting up optimizer: {str(e)}") + raise + + def _load_checkpoints_if_continue(self, kwargs): + """如果继续训练,加载检查点""" + try: + model_params_path = os.path.join( + self.checkpoints_path, self.model_params_subdir) + ckpts = os.listdir(model_params_path) + + if len(ckpts) != 0: + old_checkpnts_dir = os.path.join( + self.expdir, self.foldername, 'checkpoints') + logger.info(f'Loading checkpoint from: {old_checkpnts_dir}') + + # 加载模型状态 + saved_model_state = torch.load(os.path.join( + old_checkpnts_dir, + 'ModelParameters', + f"{kwargs['checkpoint']}.pth" + )) + self.network.load_state_dict(saved_model_state["model_state_dict"]) + + # 加载优化器状态 + data = torch.load(os.path.join( + old_checkpnts_dir, + 'OptimizerParameters', + f"{kwargs['checkpoint']}.pth" + )) + self.optimizer.load_state_dict(data["optimizer_state_dict"]) + self.startepoch = saved_model_state['epoch'] + except Exception as e: + logger.error(f"Error loading checkpoints: {str(e)}") + raise def get_learning_rate_schedules(self, schedule_specs): @@ -509,6 +679,8 @@ class ReconstructionRunner: {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()}, os.path.join(self.checkpoints_path, self.optimizer_params_subdir, "latest.pth")) + + if __name__ == '__main__': if args.gpu == "auto":