@ -314,157 +314,327 @@ class ReconstructionRunner:
axis = 2 )
def __init__ ( self , * * kwargs ) :
self . home_dir = os . path . abspath ( os . getcwd ( ) )
flag_list = False
if ' flag_list ' in kwargs :
flag_list = True
# config setting
if type ( kwargs [ ' conf ' ] ) == str :
self . conf_filename = ' ./conversion/ ' + kwargs [ ' conf ' ]
self . conf = ConfigFactory . parse_file ( self . conf_filename )
else :
self . conf = kwargs [ ' conf ' ]
self . expname = kwargs [ ' expname ' ]
# GPU settings, currently we only support single-gpu training
self . GPU_INDEX = kwargs [ ' gpu_index ' ]
self . num_of_gpus = torch . cuda . device_count ( )
self . eval = kwargs [ ' eval ' ]
self . exps_folder_name = ' exps '
utils . mkdir_ifnotexists ( utils . concat_home_dir ( os . path . join ( self . home_dir , self . exps_folder_name ) ) )
self . expdir = utils . concat_home_dir ( os . path . join ( self . home_dir , self . exps_folder_name , self . expname ) )
utils . mkdir_ifnotexists ( self . expdir )
if not flag_list :
try :
# 1. 基础设置初始化
self . _initialize_basic_settings ( kwargs )
# 2. 配置文件和实验目录设置
self . _setup_config_and_directories ( kwargs )
# 3. 数据加载
self . _load_data ( kwargs )
# 4. CSG树设置
self . _setup_csg_tree ( )
# 5. 本地sigma计算
self . _compute_local_sigma ( )
# 6. 网络和优化器设置
self . _setup_network_and_optimizer ( kwargs )
print ( " Initialization completed successfully " )
except Exception as e :
logger . error ( f " Error during initialization: { str ( e ) } " )
raise
def _initialize_basic_settings ( self , kwargs ) :
""" 初始化基础设置 """
try :
self . home_dir = os . path . abspath ( os . getcwd ( ) )
self . flag_list = kwargs . get ( ' flag_list ' , False )
self . expname = kwargs [ ' expname ' ]
self . GPU_INDEX = kwargs [ ' gpu_index ' ]
self . num_of_gpus = torch . cuda . device_count ( )
self . eval = kwargs [ ' eval ' ]
logger . debug ( " Basic settings initialized successfully " )
except KeyError as e :
logger . error ( f " Missing required parameter: { str ( e ) } " )
raise
except Exception as e :
logger . error ( f " Error in basic settings initialization: { str ( e ) } " )
raise
def _setup_config_and_directories ( self , kwargs ) :
""" 设置配置文件和创建必要的目录 """
try :
# 配置设置
if isinstance ( kwargs [ ' conf ' ] , str ) :
self . conf_filename = ' ./conversion/ ' + kwargs [ ' conf ' ]
self . conf = ConfigFactory . parse_file ( self . conf_filename )
else :
self . conf = kwargs [ ' conf ' ]
# 创建实验目录
self . exps_folder_name = ' exps '
self . expdir = utils . concat_home_dir ( os . path . join (
self . home_dir , self . exps_folder_name , self . expname ) )
utils . mkdir_ifnotexists ( utils . concat_home_dir (
os . path . join ( self . home_dir , self . exps_folder_name ) ) )
utils . mkdir_ifnotexists ( self . expdir )
logger . debug ( " Config and directories setup completed " )
except Exception as e :
logger . error ( f " Error in config and directory setup: { str ( e ) } " )
raise
def _load_data ( self , kwargs ) :
""" 加载数据和特征掩码 """
try :
if not self . flag_list :
self . _load_single_data ( )
else :
self . _load_data_from_list ( kwargs )
if args . baseline :
self . feature_mask = torch . ones ( self . data . shape [ 0 ] ) . float ( )
logger . info ( f " Data loading finished. Data shape: { self . data . shape } " )
except FileNotFoundError as e :
logger . error ( f " Data file not found: { str ( e ) } " )
raise
except Exception as e :
logger . error ( f " Error in data loading: { str ( e ) } " )
raise
def _load_single_data ( self ) :
""" 加载单个数据文件 """
try :
self . input_file = self . conf . get_string ( ' train.input_path ' )
self . data = utils . load_point_cloud_by_file_extension ( self . input_file )
self . feature_mask_file = self . conf . get_string ( ' train.feature_mask_path ' )
self . feature_mask = utils . load_feature_mask ( self . feature_mask_file )
else :
self . input_file = os . path . join ( self . conf . get_string ( ' train.input_path ' ) , kwargs [ ' file_prefix ' ] + ' .xyz ' )
self . foldername = self . conf . get_string ( ' train.foldername ' )
except Exception as e :
logger . error ( f " Error loading single data file: { str ( e ) } " )
raise
def _load_data_from_list ( self , kwargs ) :
""" 从列表加载数据 """
try :
self . input_file = os . path . join (
self . conf . get_string ( ' train.input_path ' ) ,
kwargs [ ' file_prefix ' ] + ' .xyz '
)
if not os . path . exists ( self . input_file ) :
self . flag_data_load = False
return
raise Exception ( f " Data file not found: { self . input_file } , absolute path: { os . path . abspath ( self . input_file ) } " )
self . flag_data_load = True
self . data = utils . load_point_cloud_by_file_extension ( self . input_file )
self . feature_mask_file = os . path . join ( self . conf . get_string ( ' train.input_path ' ) , kwargs [ ' file_prefix ' ] + ' _mask.txt ' )
self . feature_mask_file = os . path . join (
self . conf . get_string ( ' train.input_path ' ) ,
kwargs [ ' file_prefix ' ] + ' _mask.txt '
)
if not args . baseline :
self . feature_mask = utils . load_feature_mask ( self . feature_mask_file )
if args . feature_sample :
input_fs_file = os . path . join ( self . conf . get_string ( ' train.input_path ' ) , kwargs [ ' file_prefix ' ] + ' _feature.xyz ' )
self . feature_data = np . loadtxt ( input_fs_file )
self . feature_data = torch . tensor ( self . feature_data , dtype = torch . float32 , device = ' cuda ' )
fs_mask_file = os . path . join ( self . conf . get_string ( ' train.input_path ' ) , kwargs [ ' file_prefix ' ] + ' _feature_mask.txt ' )
self . feature_data_mask_pair = torch . tensor ( np . loadtxt ( fs_mask_file ) , dtype = torch . int64 , device = ' cuda ' )
if args . baseline :
self . csg_tree = [ 0 ]
self . csg_flag_convex = True
else :
self . csg_tree = [ ]
self . csg_tree = ConfigFactory . parse_file ( self . input_file [ : - 4 ] + ' _csg.conf ' ) . get_list ( ' csg.list ' )
self . csg_flag_convex = ConfigFactory . parse_file ( self . input_file [ : - 4 ] + ' _csg.conf ' ) . get_int ( ' csg.flag_convex ' )
print ( " csg tree: " , self . csg_tree )
print ( " csg convex flag: " , self . csg_flag_convex )
if not flag_list :
self . foldername = self . conf . get_string ( ' train.foldername ' )
else :
self . _load_feature_samples ( kwargs )
self . foldername = kwargs [ ' folder_prefix ' ] + kwargs [ ' file_prefix ' ]
except Exception as e :
logger . error ( f " Error loading data from list: { str ( e ) } " )
raise
def _load_feature_samples ( self , kwargs ) :
""" 加载特征样本 """
try :
input_fs_file = os . path . join (
self . conf . get_string ( ' train.input_path ' ) ,
kwargs [ ' file_prefix ' ] + ' _feature.xyz '
)
self . feature_data = np . loadtxt ( input_fs_file )
self . feature_data = torch . tensor (
self . feature_data ,
dtype = torch . float32 ,
device = ' cuda '
)
fs_mask_file = os . path . join (
self . conf . get_string ( ' train.input_path ' ) ,
kwargs [ ' file_prefix ' ] + ' _feature_mask.txt '
)
self . feature_data_mask_pair = torch . tensor (
np . loadtxt ( fs_mask_file ) ,
dtype = torch . int64 ,
device = ' cuda '
)
except Exception as e :
logger . error ( f " Error loading feature samples: { str ( e ) } " )
raise
if args . baseline :
self . feature_mask = torch . ones ( self . data . shape [ 0 ] ) . float ( )
print ( " loading finished " )
print ( " data shape: " , self . data . shape )
sigma_set = [ ]
ptree = cKDTree ( self . data )
print ( " kd tree constructed " )
for p in np . array_split ( self . data , 100 , axis = 0 ) :
d = ptree . query ( p , 50 + 1 )
sigma_set . append ( d [ 0 ] [ : , - 1 ] )
sigmas = np . concatenate ( sigma_set )
self . local_sigma = torch . from_numpy ( sigmas ) . float ( ) . cuda ( )
self . cur_exp_dir = os . path . join ( self . expdir , self . foldername )
utils . mkdir_ifnotexists ( self . cur_exp_dir )
self . plots_dir = os . path . join ( self . cur_exp_dir , ' plots ' )
utils . mkdir_ifnotexists ( self . plots_dir )
self . checkpoints_path = os . path . join ( self . cur_exp_dir , ' checkpoints ' )
utils . mkdir_ifnotexists ( self . checkpoints_path )
self . model_params_subdir = " ModelParameters "
self . optimizer_params_subdir = " OptimizerParameters "
utils . mkdir_ifnotexists ( os . path . join ( self . checkpoints_path , self . model_params_subdir ) )
utils . mkdir_ifnotexists ( os . path . join ( self . checkpoints_path , self . optimizer_params_subdir ) )
model_params_path = os . path . join ( self . checkpoints_path , self . model_params_subdir )
ckpts = os . listdir ( model_params_path )
#if ckpts exists, then continue
is_continue = False
if ( len ( ckpts ) ) != 0 :
is_continue = True
self . nepochs = kwargs [ ' nepochs ' ]
self . points_batch = kwargs [ ' points_batch ' ]
self . global_sigma = self . conf . get_float ( ' network.sampler.properties.global_sigma ' )
self . sampler = Sampler . get_sampler ( self . conf . get_string ( ' network.sampler.sampler_type ' ) ) ( self . global_sigma ,
self . local_sigma )
self . grad_lambda = self . conf . get_float ( ' network.loss.lambda ' )
self . normals_lambda = self . conf . get_float ( ' network.loss.normals_lambda ' )
self . with_normals = self . normals_lambda > 0
self . d_in = self . conf . get_int ( ' train.d_in ' )
self . network = utils . get_class ( self . conf . get_string ( ' train.network_class ' ) ) ( d_in = self . d_in ,
n_branch = int ( torch . max ( self . feature_mask ) . item ( ) ) ,
csg_tree = self . csg_tree ,
flag_convex = self . csg_flag_convex ,
* * self . conf . get_config (
' network.inputs ' ) )
print ( self . network )
if torch . cuda . is_available ( ) :
self . network . cuda ( )
self . lr_schedules = self . get_learning_rate_schedules ( self . conf . get_list ( ' train.learning_rate_schedule ' ) )
self . weight_decay = self . conf . get_float ( ' train.weight_decay ' )
self . startepoch = 0
self . optimizer = torch . optim . Adam (
[
{
" params " : self . network . parameters ( ) ,
" lr " : self . lr_schedules [ 0 ] . get_learning_rate ( 0 ) ,
" weight_decay " : self . weight_decay
} ,
] )
# if continue load checkpoints
if is_continue :
old_checkpnts_dir = os . path . join ( self . expdir , self . foldername , ' checkpoints ' )
print ( ' loading checkpoint from: ' , old_checkpnts_dir )
saved_model_state = torch . load (
os . path . join ( old_checkpnts_dir , ' ModelParameters ' , str ( kwargs [ ' checkpoint ' ] ) + " .pth " ) )
self . network . load_state_dict ( saved_model_state [ " model_state_dict " ] )
data = torch . load (
os . path . join ( old_checkpnts_dir , ' OptimizerParameters ' , str ( kwargs [ ' checkpoint ' ] ) + " .pth " ) )
self . optimizer . load_state_dict ( data [ " optimizer_state_dict " ] )
self . startepoch = saved_model_state [ ' epoch ' ]
def _setup_csg_tree ( self ) :
""" 设置CSG树 """
try :
if args . baseline :
self . csg_tree = [ 0 ]
self . csg_flag_convex = True
else :
csg_conf_file = self . input_file [ : - 4 ] + ' _csg.conf '
csg_config = ConfigFactory . parse_file ( csg_conf_file )
self . csg_tree = csg_config . get_list ( ' csg.list ' )
self . csg_flag_convex = csg_config . get_int ( ' csg.flag_convex ' )
logger . info ( f " CSG tree: { self . csg_tree } " )
logger . info ( f " CSG convex flag: { self . csg_flag_convex } " )
except Exception as e :
logger . error ( f " Error in CSG tree setup: { str ( e ) } " )
raise
def _compute_local_sigma ( self ) :
""" 计算局部sigma值 """
try :
sigma_set = [ ]
ptree = cKDTree ( self . data )
logger . debug ( " KD tree constructed " )
for p in np . array_split ( self . data , 100 , axis = 0 ) :
d = ptree . query ( p , 50 + 1 )
sigma_set . append ( d [ 0 ] [ : , - 1 ] )
sigmas = np . concatenate ( sigma_set )
self . local_sigma = torch . from_numpy ( sigmas ) . float ( ) . cuda ( )
except Exception as e :
logger . error ( f " Error computing local sigma: { str ( e ) } " )
raise
def _setup_network_and_optimizer ( self , kwargs ) :
""" 设置网络和优化器 """
try :
# 设置目录
self . _setup_checkpoints_directories ( )
# 网络参数设置
self . _setup_network_parameters ( kwargs )
# 创建网络
self . _create_network ( )
# 设置优化器
self . _setup_optimizer ( kwargs )
logger . debug ( " Network and optimizer setup completed " )
except Exception as e :
logger . error ( f " Error in network and optimizer setup: { str ( e ) } " )
raise
def _setup_checkpoints_directories ( self ) :
""" 设置检查点目录 """
try :
self . cur_exp_dir = os . path . join ( self . expdir , self . foldername )
utils . mkdir_ifnotexists ( self . cur_exp_dir )
self . plots_dir = os . path . join ( self . cur_exp_dir , ' plots ' )
utils . mkdir_ifnotexists ( self . plots_dir )
self . checkpoints_path = os . path . join ( self . cur_exp_dir , ' checkpoints ' )
utils . mkdir_ifnotexists ( self . checkpoints_path )
self . model_params_subdir = " ModelParameters "
self . optimizer_params_subdir = " OptimizerParameters "
utils . mkdir_ifnotexists ( os . path . join (
self . checkpoints_path , self . model_params_subdir ) )
utils . mkdir_ifnotexists ( os . path . join (
self . checkpoints_path , self . optimizer_params_subdir ) )
except Exception as e :
logger . error ( f " Error setting up checkpoint directories: { str ( e ) } " )
raise
def _setup_network_parameters ( self , kwargs ) :
""" 设置网络参数 """
try :
self . nepochs = kwargs [ ' nepochs ' ]
self . points_batch = kwargs [ ' points_batch ' ]
self . global_sigma = self . conf . get_float ( ' network.sampler.properties.global_sigma ' )
self . sampler = Sampler . get_sampler (
self . conf . get_string ( ' network.sampler.sampler_type ' ) ) (
self . global_sigma ,
self . local_sigma
)
self . grad_lambda = self . conf . get_float ( ' network.loss.lambda ' )
self . normals_lambda = self . conf . get_float ( ' network.loss.normals_lambda ' )
self . with_normals = self . normals_lambda > 0
self . d_in = self . conf . get_int ( ' train.d_in ' )
except Exception as e :
logger . error ( f " Error setting up network parameters: { str ( e ) } " )
raise
def _create_network ( self ) :
""" 创建网络 """
try :
self . network = utils . get_class ( self . conf . get_string ( ' train.network_class ' ) ) (
d_in = self . d_in ,
n_branch = int ( torch . max ( self . feature_mask ) . item ( ) ) ,
csg_tree = self . csg_tree ,
flag_convex = self . csg_flag_convex ,
* * self . conf . get_config ( ' network.inputs ' )
)
logger . info ( f " Network created: { self . network } " )
if torch . cuda . is_available ( ) :
self . network . cuda ( )
except Exception as e :
logger . error ( f " Error creating network: { str ( e ) } " )
raise
def _setup_optimizer ( self , kwargs ) :
""" 设置优化器 """
try :
self . lr_schedules = self . get_learning_rate_schedules (
self . conf . get_list ( ' train.learning_rate_schedule ' ) )
self . weight_decay = self . conf . get_float ( ' train.weight_decay ' )
self . startepoch = 0
self . optimizer = torch . optim . Adam ( [ {
" params " : self . network . parameters ( ) ,
" lr " : self . lr_schedules [ 0 ] . get_learning_rate ( 0 ) ,
" weight_decay " : self . weight_decay
} ] )
# 如果继续训练,加载检查点
self . _load_checkpoints_if_continue ( kwargs )
except Exception as e :
logger . error ( f " Error setting up optimizer: { str ( e ) } " )
raise
def _load_checkpoints_if_continue ( self , kwargs ) :
""" 如果继续训练,加载检查点 """
try :
model_params_path = os . path . join (
self . checkpoints_path , self . model_params_subdir )
ckpts = os . listdir ( model_params_path )
if len ( ckpts ) != 0 :
old_checkpnts_dir = os . path . join (
self . expdir , self . foldername , ' checkpoints ' )
logger . info ( f ' Loading checkpoint from: { old_checkpnts_dir } ' )
# 加载模型状态
saved_model_state = torch . load ( os . path . join (
old_checkpnts_dir ,
' ModelParameters ' ,
f " { kwargs [ ' checkpoint ' ] } .pth "
) )
self . network . load_state_dict ( saved_model_state [ " model_state_dict " ] )
# 加载优化器状态
data = torch . load ( os . path . join (
old_checkpnts_dir ,
' OptimizerParameters ' ,
f " { kwargs [ ' checkpoint ' ] } .pth "
) )
self . optimizer . load_state_dict ( data [ " optimizer_state_dict " ] )
self . startepoch = saved_model_state [ ' epoch ' ]
except Exception as e :
logger . error ( f " Error loading checkpoints: { str ( e ) } " )
raise
def get_learning_rate_schedules ( self , schedule_specs ) :
@ -509,6 +679,8 @@ class ReconstructionRunner:
{ " epoch " : epoch , " optimizer_state_dict " : self . optimizer . state_dict ( ) } ,
os . path . join ( self . checkpoints_path , self . optimizer_params_subdir , " latest.pth " ) )
if __name__ == ' __main__ ' :
if args . gpu == " auto " :