Browse Source

Refactor ReconstructionRunner initialization in run.py for improved structure and error handling

- Split the initialization process into multiple private methods for better readability and maintainability.
- Added detailed logging for each step of the initialization process, including error handling for missing parameters and file loading issues.
- Enhanced configuration and directory setup with clearer error messages and structured logging.
- Improved data loading methods to handle both single and list-based data inputs more robustly.
- Introduced methods for setting up the CSG tree and computing local sigma values, with appropriate logging for each operation.
NH-Rep
mckay 2 months ago
parent
commit
d425165c32
  1. 448
      code/conversion/run.py

448
code/conversion/run.py

@ -314,157 +314,327 @@ class ReconstructionRunner:
axis = 2) axis = 2)
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.home_dir = os.path.abspath(os.getcwd()) try:
flag_list = False # 1. 基础设置初始化
if 'flag_list' in kwargs: self._initialize_basic_settings(kwargs)
flag_list = True
# 2. 配置文件和实验目录设置
# config setting self._setup_config_and_directories(kwargs)
if type(kwargs['conf']) == str:
self.conf_filename = './conversion/' + kwargs['conf'] # 3. 数据加载
self.conf = ConfigFactory.parse_file(self.conf_filename) self._load_data(kwargs)
else:
self.conf = kwargs['conf'] # 4. CSG树设置
self._setup_csg_tree()
self.expname = kwargs['expname']
# 5. 本地sigma计算
# GPU settings, currently we only support single-gpu training self._compute_local_sigma()
self.GPU_INDEX = kwargs['gpu_index']
self.num_of_gpus = torch.cuda.device_count() # 6. 网络和优化器设置
self.eval = kwargs['eval'] self._setup_network_and_optimizer(kwargs)
self.exps_folder_name = 'exps' print("Initialization completed successfully")
utils.mkdir_ifnotexists(utils.concat_home_dir(os.path.join(self.home_dir, self.exps_folder_name)))
self.expdir = utils.concat_home_dir(os.path.join(self.home_dir, self.exps_folder_name, self.expname)) except Exception as e:
utils.mkdir_ifnotexists(self.expdir) logger.error(f"Error during initialization: {str(e)}")
raise
if not flag_list:
def _initialize_basic_settings(self, kwargs):
"""初始化基础设置"""
try:
self.home_dir = os.path.abspath(os.getcwd())
self.flag_list = kwargs.get('flag_list', False)
self.expname = kwargs['expname']
self.GPU_INDEX = kwargs['gpu_index']
self.num_of_gpus = torch.cuda.device_count()
self.eval = kwargs['eval']
logger.debug("Basic settings initialized successfully")
except KeyError as e:
logger.error(f"Missing required parameter: {str(e)}")
raise
except Exception as e:
logger.error(f"Error in basic settings initialization: {str(e)}")
raise
def _setup_config_and_directories(self, kwargs):
"""设置配置文件和创建必要的目录"""
try:
# 配置设置
if isinstance(kwargs['conf'], str):
self.conf_filename = './conversion/' + kwargs['conf']
self.conf = ConfigFactory.parse_file(self.conf_filename)
else:
self.conf = kwargs['conf']
# 创建实验目录
self.exps_folder_name = 'exps'
self.expdir = utils.concat_home_dir(os.path.join(
self.home_dir, self.exps_folder_name, self.expname))
utils.mkdir_ifnotexists(utils.concat_home_dir(
os.path.join(self.home_dir, self.exps_folder_name)))
utils.mkdir_ifnotexists(self.expdir)
logger.debug("Config and directories setup completed")
except Exception as e:
logger.error(f"Error in config and directory setup: {str(e)}")
raise
def _load_data(self, kwargs):
"""加载数据和特征掩码"""
try:
if not self.flag_list:
self._load_single_data()
else:
self._load_data_from_list(kwargs)
if args.baseline:
self.feature_mask = torch.ones(self.data.shape[0]).float()
logger.info(f"Data loading finished. Data shape: {self.data.shape}")
except FileNotFoundError as e:
logger.error(f"Data file not found: {str(e)}")
raise
except Exception as e:
logger.error(f"Error in data loading: {str(e)}")
raise
def _load_single_data(self):
"""加载单个数据文件"""
try:
self.input_file = self.conf.get_string('train.input_path') self.input_file = self.conf.get_string('train.input_path')
self.data = utils.load_point_cloud_by_file_extension(self.input_file) self.data = utils.load_point_cloud_by_file_extension(self.input_file)
self.feature_mask_file = self.conf.get_string('train.feature_mask_path') self.feature_mask_file = self.conf.get_string('train.feature_mask_path')
self.feature_mask = utils.load_feature_mask(self.feature_mask_file) self.feature_mask = utils.load_feature_mask(self.feature_mask_file)
else: self.foldername = self.conf.get_string('train.foldername')
self.input_file = os.path.join(self.conf.get_string('train.input_path'), kwargs['file_prefix']+'.xyz') except Exception as e:
logger.error(f"Error loading single data file: {str(e)}")
raise
def _load_data_from_list(self, kwargs):
"""从列表加载数据"""
try:
self.input_file = os.path.join(
self.conf.get_string('train.input_path'),
kwargs['file_prefix']+'.xyz'
)
if not os.path.exists(self.input_file): if not os.path.exists(self.input_file):
self.flag_data_load = False self.flag_data_load = False
return raise Exception(f"Data file not found: {self.input_file}, absolute path: {os.path.abspath(self.input_file)}")
self.flag_data_load = True self.flag_data_load = True
self.data = utils.load_point_cloud_by_file_extension(self.input_file) self.data = utils.load_point_cloud_by_file_extension(self.input_file)
self.feature_mask_file = os.path.join(self.conf.get_string('train.input_path'), kwargs['file_prefix']+'_mask.txt') self.feature_mask_file = os.path.join(
self.conf.get_string('train.input_path'),
kwargs['file_prefix']+'_mask.txt'
)
if not args.baseline: if not args.baseline:
self.feature_mask = utils.load_feature_mask(self.feature_mask_file) self.feature_mask = utils.load_feature_mask(self.feature_mask_file)
if args.feature_sample: if args.feature_sample:
input_fs_file = os.path.join(self.conf.get_string('train.input_path'), kwargs['file_prefix']+'_feature.xyz') self._load_feature_samples(kwargs)
self.feature_data = np.loadtxt(input_fs_file)
self.feature_data = torch.tensor(self.feature_data, dtype = torch.float32, device = 'cuda')
fs_mask_file = os.path.join(self.conf.get_string('train.input_path'), kwargs['file_prefix']+'_feature_mask.txt')
self.feature_data_mask_pair = torch.tensor(np.loadtxt(fs_mask_file), dtype = torch.int64, device = 'cuda')
if args.baseline:
self.csg_tree = [0]
self.csg_flag_convex = True
else:
self.csg_tree = []
self.csg_tree = ConfigFactory.parse_file(self.input_file[:-4]+'_csg.conf').get_list('csg.list')
self.csg_flag_convex = ConfigFactory.parse_file(self.input_file[:-4]+'_csg.conf').get_int('csg.flag_convex')
print ("csg tree: ", self.csg_tree)
print ("csg convex flag: ", self.csg_flag_convex)
if not flag_list:
self.foldername = self.conf.get_string('train.foldername')
else:
self.foldername = kwargs['folder_prefix'] + kwargs['file_prefix'] self.foldername = kwargs['folder_prefix'] + kwargs['file_prefix']
except Exception as e:
logger.error(f"Error loading data from list: {str(e)}")
raise
def _load_feature_samples(self, kwargs):
"""加载特征样本"""
try:
input_fs_file = os.path.join(
self.conf.get_string('train.input_path'),
kwargs['file_prefix']+'_feature.xyz'
)
self.feature_data = np.loadtxt(input_fs_file)
self.feature_data = torch.tensor(
self.feature_data,
dtype=torch.float32,
device='cuda'
)
fs_mask_file = os.path.join(
self.conf.get_string('train.input_path'),
kwargs['file_prefix']+'_feature_mask.txt'
)
self.feature_data_mask_pair = torch.tensor(
np.loadtxt(fs_mask_file),
dtype=torch.int64,
device='cuda'
)
except Exception as e:
logger.error(f"Error loading feature samples: {str(e)}")
raise
if args.baseline: def _setup_csg_tree(self):
self.feature_mask = torch.ones(self.data.shape[0]).float() """设置CSG树"""
try:
print ("loading finished") if args.baseline:
print ("data shape: ", self.data.shape) self.csg_tree = [0]
self.csg_flag_convex = True
sigma_set = [] else:
ptree = cKDTree(self.data) csg_conf_file = self.input_file[:-4]+'_csg.conf'
print ("kd tree constructed") csg_config = ConfigFactory.parse_file(csg_conf_file)
self.csg_tree = csg_config.get_list('csg.list')
for p in np.array_split(self.data, 100, axis=0): self.csg_flag_convex = csg_config.get_int('csg.flag_convex')
d = ptree.query(p, 50 + 1)
sigma_set.append(d[0][:, -1]) logger.info(f"CSG tree: {self.csg_tree}")
logger.info(f"CSG convex flag: {self.csg_flag_convex}")
sigmas = np.concatenate(sigma_set) except Exception as e:
self.local_sigma = torch.from_numpy(sigmas).float().cuda() logger.error(f"Error in CSG tree setup: {str(e)}")
raise
self.cur_exp_dir = os.path.join(self.expdir, self.foldername)
utils.mkdir_ifnotexists(self.cur_exp_dir) def _compute_local_sigma(self):
"""计算局部sigma值"""
self.plots_dir = os.path.join(self.cur_exp_dir, 'plots') try:
utils.mkdir_ifnotexists(self.plots_dir) sigma_set = []
ptree = cKDTree(self.data)
self.checkpoints_path = os.path.join(self.cur_exp_dir, 'checkpoints') logger.debug("KD tree constructed")
utils.mkdir_ifnotexists(self.checkpoints_path)
for p in np.array_split(self.data, 100, axis=0):
self.model_params_subdir = "ModelParameters" d = ptree.query(p, 50 + 1)
self.optimizer_params_subdir = "OptimizerParameters" sigma_set.append(d[0][:, -1])
utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.model_params_subdir)) sigmas = np.concatenate(sigma_set)
utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir)) self.local_sigma = torch.from_numpy(sigmas).float().cuda()
model_params_path = os.path.join(self.checkpoints_path, self.model_params_subdir) except Exception as e:
ckpts = os.listdir(model_params_path) logger.error(f"Error computing local sigma: {str(e)}")
#if ckpts exists, then continue raise
is_continue = False
if (len(ckpts)) != 0: def _setup_network_and_optimizer(self, kwargs):
is_continue = True """设置网络和优化器"""
try:
self.nepochs = kwargs['nepochs'] # 设置目录
self._setup_checkpoints_directories()
self.points_batch = kwargs['points_batch']
# 网络参数设置
self.global_sigma = self.conf.get_float('network.sampler.properties.global_sigma') self._setup_network_parameters(kwargs)
self.sampler = Sampler.get_sampler(self.conf.get_string('network.sampler.sampler_type'))(self.global_sigma,
self.local_sigma) # 创建网络
self.grad_lambda = self.conf.get_float('network.loss.lambda') self._create_network()
self.normals_lambda = self.conf.get_float('network.loss.normals_lambda')
# 设置优化器
self.with_normals = self.normals_lambda > 0 self._setup_optimizer(kwargs)
self.d_in = self.conf.get_int('train.d_in') logger.debug("Network and optimizer setup completed")
except Exception as e:
self.network = utils.get_class(self.conf.get_string('train.network_class'))(d_in=self.d_in, logger.error(f"Error in network and optimizer setup: {str(e)}")
n_branch = int(torch.max(self.feature_mask).item()), raise
csg_tree = self.csg_tree,
flag_convex = self.csg_flag_convex, def _setup_checkpoints_directories(self):
**self.conf.get_config( """设置检查点目录"""
'network.inputs')) try:
self.cur_exp_dir = os.path.join(self.expdir, self.foldername)
utils.mkdir_ifnotexists(self.cur_exp_dir)
print (self.network)
self.plots_dir = os.path.join(self.cur_exp_dir, 'plots')
if torch.cuda.is_available(): utils.mkdir_ifnotexists(self.plots_dir)
self.network.cuda()
self.checkpoints_path = os.path.join(self.cur_exp_dir, 'checkpoints')
self.lr_schedules = self.get_learning_rate_schedules(self.conf.get_list('train.learning_rate_schedule')) utils.mkdir_ifnotexists(self.checkpoints_path)
self.weight_decay = self.conf.get_float('train.weight_decay')
self.model_params_subdir = "ModelParameters"
self.startepoch = 0 self.optimizer_params_subdir = "OptimizerParameters"
self.optimizer = torch.optim.Adam(
[ utils.mkdir_ifnotexists(os.path.join(
{ self.checkpoints_path, self.model_params_subdir))
"params": self.network.parameters(), utils.mkdir_ifnotexists(os.path.join(
"lr": self.lr_schedules[0].get_learning_rate(0), self.checkpoints_path, self.optimizer_params_subdir))
"weight_decay": self.weight_decay except Exception as e:
}, logger.error(f"Error setting up checkpoint directories: {str(e)}")
]) raise
# if continue load checkpoints def _setup_network_parameters(self, kwargs):
if is_continue: """设置网络参数"""
old_checkpnts_dir = os.path.join(self.expdir, self.foldername, 'checkpoints') try:
print('loading checkpoint from: ', old_checkpnts_dir) self.nepochs = kwargs['nepochs']
saved_model_state = torch.load( self.points_batch = kwargs['points_batch']
os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth")) self.global_sigma = self.conf.get_float('network.sampler.properties.global_sigma')
self.network.load_state_dict(saved_model_state["model_state_dict"])
self.sampler = Sampler.get_sampler(
data = torch.load( self.conf.get_string('network.sampler.sampler_type'))(
os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth")) self.global_sigma,
self.optimizer.load_state_dict(data["optimizer_state_dict"]) self.local_sigma
self.startepoch = saved_model_state['epoch'] )
self.grad_lambda = self.conf.get_float('network.loss.lambda')
self.normals_lambda = self.conf.get_float('network.loss.normals_lambda')
self.with_normals = self.normals_lambda > 0
self.d_in = self.conf.get_int('train.d_in')
except Exception as e:
logger.error(f"Error setting up network parameters: {str(e)}")
raise
def _create_network(self):
"""创建网络"""
try:
self.network = utils.get_class(self.conf.get_string('train.network_class'))(
d_in=self.d_in,
n_branch=int(torch.max(self.feature_mask).item()),
csg_tree=self.csg_tree,
flag_convex=self.csg_flag_convex,
**self.conf.get_config('network.inputs')
)
logger.info(f"Network created: {self.network}")
if torch.cuda.is_available():
self.network.cuda()
except Exception as e:
logger.error(f"Error creating network: {str(e)}")
raise
def _setup_optimizer(self, kwargs):
"""设置优化器"""
try:
self.lr_schedules = self.get_learning_rate_schedules(
self.conf.get_list('train.learning_rate_schedule'))
self.weight_decay = self.conf.get_float('train.weight_decay')
self.startepoch = 0
self.optimizer = torch.optim.Adam([{
"params": self.network.parameters(),
"lr": self.lr_schedules[0].get_learning_rate(0),
"weight_decay": self.weight_decay
}])
# 如果继续训练,加载检查点
self._load_checkpoints_if_continue(kwargs)
except Exception as e:
logger.error(f"Error setting up optimizer: {str(e)}")
raise
def _load_checkpoints_if_continue(self, kwargs):
"""如果继续训练,加载检查点"""
try:
model_params_path = os.path.join(
self.checkpoints_path, self.model_params_subdir)
ckpts = os.listdir(model_params_path)
if len(ckpts) != 0:
old_checkpnts_dir = os.path.join(
self.expdir, self.foldername, 'checkpoints')
logger.info(f'Loading checkpoint from: {old_checkpnts_dir}')
# 加载模型状态
saved_model_state = torch.load(os.path.join(
old_checkpnts_dir,
'ModelParameters',
f"{kwargs['checkpoint']}.pth"
))
self.network.load_state_dict(saved_model_state["model_state_dict"])
# 加载优化器状态
data = torch.load(os.path.join(
old_checkpnts_dir,
'OptimizerParameters',
f"{kwargs['checkpoint']}.pth"
))
self.optimizer.load_state_dict(data["optimizer_state_dict"])
self.startepoch = saved_model_state['epoch']
except Exception as e:
logger.error(f"Error loading checkpoints: {str(e)}")
raise
def get_learning_rate_schedules(self, schedule_specs): def get_learning_rate_schedules(self, schedule_specs):
@ -509,6 +679,8 @@ class ReconstructionRunner:
{"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()}, {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()},
os.path.join(self.checkpoints_path, self.optimizer_params_subdir, "latest.pth")) os.path.join(self.checkpoints_path, self.optimizer_params_subdir, "latest.pth"))
if __name__ == '__main__': if __name__ == '__main__':
if args.gpu == "auto": if args.gpu == "auto":

Loading…
Cancel
Save