Browse Source

Refactor ReconstructionRunner initialization in run.py for improved structure and error handling

- Split the initialization process into multiple private methods for better readability and maintainability.
- Added detailed logging for each step of the initialization process, including error handling for missing parameters and file loading issues.
- Enhanced configuration and directory setup with clearer error messages and structured logging.
- Improved data loading methods to handle both single and list-based data inputs more robustly.
- Introduced methods for setting up the CSG tree and computing local sigma values, with appropriate logging for each operation.
main
mckay 2 months ago
parent
commit
d425165c32
  1. 448
      code/conversion/run.py

448
code/conversion/run.py

@ -314,157 +314,327 @@ class ReconstructionRunner:
axis = 2)
def __init__(self, **kwargs):
self.home_dir = os.path.abspath(os.getcwd())
flag_list = False
if 'flag_list' in kwargs:
flag_list = True
# config setting
if type(kwargs['conf']) == str:
self.conf_filename = './conversion/' + kwargs['conf']
self.conf = ConfigFactory.parse_file(self.conf_filename)
else:
self.conf = kwargs['conf']
self.expname = kwargs['expname']
# GPU settings, currently we only support single-gpu training
self.GPU_INDEX = kwargs['gpu_index']
self.num_of_gpus = torch.cuda.device_count()
self.eval = kwargs['eval']
self.exps_folder_name = 'exps'
utils.mkdir_ifnotexists(utils.concat_home_dir(os.path.join(self.home_dir, self.exps_folder_name)))
self.expdir = utils.concat_home_dir(os.path.join(self.home_dir, self.exps_folder_name, self.expname))
utils.mkdir_ifnotexists(self.expdir)
if not flag_list:
try:
# 1. 基础设置初始化
self._initialize_basic_settings(kwargs)
# 2. 配置文件和实验目录设置
self._setup_config_and_directories(kwargs)
# 3. 数据加载
self._load_data(kwargs)
# 4. CSG树设置
self._setup_csg_tree()
# 5. 本地sigma计算
self._compute_local_sigma()
# 6. 网络和优化器设置
self._setup_network_and_optimizer(kwargs)
print("Initialization completed successfully")
except Exception as e:
logger.error(f"Error during initialization: {str(e)}")
raise
def _initialize_basic_settings(self, kwargs):
"""初始化基础设置"""
try:
self.home_dir = os.path.abspath(os.getcwd())
self.flag_list = kwargs.get('flag_list', False)
self.expname = kwargs['expname']
self.GPU_INDEX = kwargs['gpu_index']
self.num_of_gpus = torch.cuda.device_count()
self.eval = kwargs['eval']
logger.debug("Basic settings initialized successfully")
except KeyError as e:
logger.error(f"Missing required parameter: {str(e)}")
raise
except Exception as e:
logger.error(f"Error in basic settings initialization: {str(e)}")
raise
def _setup_config_and_directories(self, kwargs):
"""设置配置文件和创建必要的目录"""
try:
# 配置设置
if isinstance(kwargs['conf'], str):
self.conf_filename = './conversion/' + kwargs['conf']
self.conf = ConfigFactory.parse_file(self.conf_filename)
else:
self.conf = kwargs['conf']
# 创建实验目录
self.exps_folder_name = 'exps'
self.expdir = utils.concat_home_dir(os.path.join(
self.home_dir, self.exps_folder_name, self.expname))
utils.mkdir_ifnotexists(utils.concat_home_dir(
os.path.join(self.home_dir, self.exps_folder_name)))
utils.mkdir_ifnotexists(self.expdir)
logger.debug("Config and directories setup completed")
except Exception as e:
logger.error(f"Error in config and directory setup: {str(e)}")
raise
def _load_data(self, kwargs):
"""加载数据和特征掩码"""
try:
if not self.flag_list:
self._load_single_data()
else:
self._load_data_from_list(kwargs)
if args.baseline:
self.feature_mask = torch.ones(self.data.shape[0]).float()
logger.info(f"Data loading finished. Data shape: {self.data.shape}")
except FileNotFoundError as e:
logger.error(f"Data file not found: {str(e)}")
raise
except Exception as e:
logger.error(f"Error in data loading: {str(e)}")
raise
def _load_single_data(self):
"""加载单个数据文件"""
try:
self.input_file = self.conf.get_string('train.input_path')
self.data = utils.load_point_cloud_by_file_extension(self.input_file)
self.feature_mask_file = self.conf.get_string('train.feature_mask_path')
self.feature_mask = utils.load_feature_mask(self.feature_mask_file)
else:
self.input_file = os.path.join(self.conf.get_string('train.input_path'), kwargs['file_prefix']+'.xyz')
self.foldername = self.conf.get_string('train.foldername')
except Exception as e:
logger.error(f"Error loading single data file: {str(e)}")
raise
def _load_data_from_list(self, kwargs):
"""从列表加载数据"""
try:
self.input_file = os.path.join(
self.conf.get_string('train.input_path'),
kwargs['file_prefix']+'.xyz'
)
if not os.path.exists(self.input_file):
self.flag_data_load = False
return
raise Exception(f"Data file not found: {self.input_file}, absolute path: {os.path.abspath(self.input_file)}")
self.flag_data_load = True
self.data = utils.load_point_cloud_by_file_extension(self.input_file)
self.feature_mask_file = os.path.join(self.conf.get_string('train.input_path'), kwargs['file_prefix']+'_mask.txt')
self.feature_mask_file = os.path.join(
self.conf.get_string('train.input_path'),
kwargs['file_prefix']+'_mask.txt'
)
if not args.baseline:
self.feature_mask = utils.load_feature_mask(self.feature_mask_file)
if args.feature_sample:
input_fs_file = os.path.join(self.conf.get_string('train.input_path'), kwargs['file_prefix']+'_feature.xyz')
self.feature_data = np.loadtxt(input_fs_file)
self.feature_data = torch.tensor(self.feature_data, dtype = torch.float32, device = 'cuda')
fs_mask_file = os.path.join(self.conf.get_string('train.input_path'), kwargs['file_prefix']+'_feature_mask.txt')
self.feature_data_mask_pair = torch.tensor(np.loadtxt(fs_mask_file), dtype = torch.int64, device = 'cuda')
if args.baseline:
self.csg_tree = [0]
self.csg_flag_convex = True
else:
self.csg_tree = []
self.csg_tree = ConfigFactory.parse_file(self.input_file[:-4]+'_csg.conf').get_list('csg.list')
self.csg_flag_convex = ConfigFactory.parse_file(self.input_file[:-4]+'_csg.conf').get_int('csg.flag_convex')
print ("csg tree: ", self.csg_tree)
print ("csg convex flag: ", self.csg_flag_convex)
if not flag_list:
self.foldername = self.conf.get_string('train.foldername')
else:
self._load_feature_samples(kwargs)
self.foldername = kwargs['folder_prefix'] + kwargs['file_prefix']
except Exception as e:
logger.error(f"Error loading data from list: {str(e)}")
raise
def _load_feature_samples(self, kwargs):
"""加载特征样本"""
try:
input_fs_file = os.path.join(
self.conf.get_string('train.input_path'),
kwargs['file_prefix']+'_feature.xyz'
)
self.feature_data = np.loadtxt(input_fs_file)
self.feature_data = torch.tensor(
self.feature_data,
dtype=torch.float32,
device='cuda'
)
fs_mask_file = os.path.join(
self.conf.get_string('train.input_path'),
kwargs['file_prefix']+'_feature_mask.txt'
)
self.feature_data_mask_pair = torch.tensor(
np.loadtxt(fs_mask_file),
dtype=torch.int64,
device='cuda'
)
except Exception as e:
logger.error(f"Error loading feature samples: {str(e)}")
raise
if args.baseline:
self.feature_mask = torch.ones(self.data.shape[0]).float()
print ("loading finished")
print ("data shape: ", self.data.shape)
sigma_set = []
ptree = cKDTree(self.data)
print ("kd tree constructed")
for p in np.array_split(self.data, 100, axis=0):
d = ptree.query(p, 50 + 1)
sigma_set.append(d[0][:, -1])
sigmas = np.concatenate(sigma_set)
self.local_sigma = torch.from_numpy(sigmas).float().cuda()
self.cur_exp_dir = os.path.join(self.expdir, self.foldername)
utils.mkdir_ifnotexists(self.cur_exp_dir)
self.plots_dir = os.path.join(self.cur_exp_dir, 'plots')
utils.mkdir_ifnotexists(self.plots_dir)
self.checkpoints_path = os.path.join(self.cur_exp_dir, 'checkpoints')
utils.mkdir_ifnotexists(self.checkpoints_path)
self.model_params_subdir = "ModelParameters"
self.optimizer_params_subdir = "OptimizerParameters"
utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.model_params_subdir))
utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir))
model_params_path = os.path.join(self.checkpoints_path, self.model_params_subdir)
ckpts = os.listdir(model_params_path)
#if ckpts exists, then continue
is_continue = False
if (len(ckpts)) != 0:
is_continue = True
self.nepochs = kwargs['nepochs']
self.points_batch = kwargs['points_batch']
self.global_sigma = self.conf.get_float('network.sampler.properties.global_sigma')
self.sampler = Sampler.get_sampler(self.conf.get_string('network.sampler.sampler_type'))(self.global_sigma,
self.local_sigma)
self.grad_lambda = self.conf.get_float('network.loss.lambda')
self.normals_lambda = self.conf.get_float('network.loss.normals_lambda')
self.with_normals = self.normals_lambda > 0
self.d_in = self.conf.get_int('train.d_in')
self.network = utils.get_class(self.conf.get_string('train.network_class'))(d_in=self.d_in,
n_branch = int(torch.max(self.feature_mask).item()),
csg_tree = self.csg_tree,
flag_convex = self.csg_flag_convex,
**self.conf.get_config(
'network.inputs'))
print (self.network)
if torch.cuda.is_available():
self.network.cuda()
self.lr_schedules = self.get_learning_rate_schedules(self.conf.get_list('train.learning_rate_schedule'))
self.weight_decay = self.conf.get_float('train.weight_decay')
self.startepoch = 0
self.optimizer = torch.optim.Adam(
[
{
"params": self.network.parameters(),
"lr": self.lr_schedules[0].get_learning_rate(0),
"weight_decay": self.weight_decay
},
])
# if continue load checkpoints
if is_continue:
old_checkpnts_dir = os.path.join(self.expdir, self.foldername, 'checkpoints')
print('loading checkpoint from: ', old_checkpnts_dir)
saved_model_state = torch.load(
os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth"))
self.network.load_state_dict(saved_model_state["model_state_dict"])
data = torch.load(
os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth"))
self.optimizer.load_state_dict(data["optimizer_state_dict"])
self.startepoch = saved_model_state['epoch']
def _setup_csg_tree(self):
"""设置CSG树"""
try:
if args.baseline:
self.csg_tree = [0]
self.csg_flag_convex = True
else:
csg_conf_file = self.input_file[:-4]+'_csg.conf'
csg_config = ConfigFactory.parse_file(csg_conf_file)
self.csg_tree = csg_config.get_list('csg.list')
self.csg_flag_convex = csg_config.get_int('csg.flag_convex')
logger.info(f"CSG tree: {self.csg_tree}")
logger.info(f"CSG convex flag: {self.csg_flag_convex}")
except Exception as e:
logger.error(f"Error in CSG tree setup: {str(e)}")
raise
def _compute_local_sigma(self):
"""计算局部sigma值"""
try:
sigma_set = []
ptree = cKDTree(self.data)
logger.debug("KD tree constructed")
for p in np.array_split(self.data, 100, axis=0):
d = ptree.query(p, 50 + 1)
sigma_set.append(d[0][:, -1])
sigmas = np.concatenate(sigma_set)
self.local_sigma = torch.from_numpy(sigmas).float().cuda()
except Exception as e:
logger.error(f"Error computing local sigma: {str(e)}")
raise
def _setup_network_and_optimizer(self, kwargs):
"""设置网络和优化器"""
try:
# 设置目录
self._setup_checkpoints_directories()
# 网络参数设置
self._setup_network_parameters(kwargs)
# 创建网络
self._create_network()
# 设置优化器
self._setup_optimizer(kwargs)
logger.debug("Network and optimizer setup completed")
except Exception as e:
logger.error(f"Error in network and optimizer setup: {str(e)}")
raise
def _setup_checkpoints_directories(self):
"""设置检查点目录"""
try:
self.cur_exp_dir = os.path.join(self.expdir, self.foldername)
utils.mkdir_ifnotexists(self.cur_exp_dir)
self.plots_dir = os.path.join(self.cur_exp_dir, 'plots')
utils.mkdir_ifnotexists(self.plots_dir)
self.checkpoints_path = os.path.join(self.cur_exp_dir, 'checkpoints')
utils.mkdir_ifnotexists(self.checkpoints_path)
self.model_params_subdir = "ModelParameters"
self.optimizer_params_subdir = "OptimizerParameters"
utils.mkdir_ifnotexists(os.path.join(
self.checkpoints_path, self.model_params_subdir))
utils.mkdir_ifnotexists(os.path.join(
self.checkpoints_path, self.optimizer_params_subdir))
except Exception as e:
logger.error(f"Error setting up checkpoint directories: {str(e)}")
raise
def _setup_network_parameters(self, kwargs):
"""设置网络参数"""
try:
self.nepochs = kwargs['nepochs']
self.points_batch = kwargs['points_batch']
self.global_sigma = self.conf.get_float('network.sampler.properties.global_sigma')
self.sampler = Sampler.get_sampler(
self.conf.get_string('network.sampler.sampler_type'))(
self.global_sigma,
self.local_sigma
)
self.grad_lambda = self.conf.get_float('network.loss.lambda')
self.normals_lambda = self.conf.get_float('network.loss.normals_lambda')
self.with_normals = self.normals_lambda > 0
self.d_in = self.conf.get_int('train.d_in')
except Exception as e:
logger.error(f"Error setting up network parameters: {str(e)}")
raise
def _create_network(self):
"""创建网络"""
try:
self.network = utils.get_class(self.conf.get_string('train.network_class'))(
d_in=self.d_in,
n_branch=int(torch.max(self.feature_mask).item()),
csg_tree=self.csg_tree,
flag_convex=self.csg_flag_convex,
**self.conf.get_config('network.inputs')
)
logger.info(f"Network created: {self.network}")
if torch.cuda.is_available():
self.network.cuda()
except Exception as e:
logger.error(f"Error creating network: {str(e)}")
raise
def _setup_optimizer(self, kwargs):
"""设置优化器"""
try:
self.lr_schedules = self.get_learning_rate_schedules(
self.conf.get_list('train.learning_rate_schedule'))
self.weight_decay = self.conf.get_float('train.weight_decay')
self.startepoch = 0
self.optimizer = torch.optim.Adam([{
"params": self.network.parameters(),
"lr": self.lr_schedules[0].get_learning_rate(0),
"weight_decay": self.weight_decay
}])
# 如果继续训练,加载检查点
self._load_checkpoints_if_continue(kwargs)
except Exception as e:
logger.error(f"Error setting up optimizer: {str(e)}")
raise
def _load_checkpoints_if_continue(self, kwargs):
"""如果继续训练,加载检查点"""
try:
model_params_path = os.path.join(
self.checkpoints_path, self.model_params_subdir)
ckpts = os.listdir(model_params_path)
if len(ckpts) != 0:
old_checkpnts_dir = os.path.join(
self.expdir, self.foldername, 'checkpoints')
logger.info(f'Loading checkpoint from: {old_checkpnts_dir}')
# 加载模型状态
saved_model_state = torch.load(os.path.join(
old_checkpnts_dir,
'ModelParameters',
f"{kwargs['checkpoint']}.pth"
))
self.network.load_state_dict(saved_model_state["model_state_dict"])
# 加载优化器状态
data = torch.load(os.path.join(
old_checkpnts_dir,
'OptimizerParameters',
f"{kwargs['checkpoint']}.pth"
))
self.optimizer.load_state_dict(data["optimizer_state_dict"])
self.startepoch = saved_model_state['epoch']
except Exception as e:
logger.error(f"Error loading checkpoints: {str(e)}")
raise
def get_learning_rate_schedules(self, schedule_specs):
@ -509,6 +679,8 @@ class ReconstructionRunner:
{"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()},
os.path.join(self.checkpoints_path, self.optimizer_params_subdir, "latest.pth"))
if __name__ == '__main__':
if args.gpu == "auto":

Loading…
Cancel
Save