@ -314,157 +314,327 @@ class ReconstructionRunner: 
				
			 
			
		
	
		
		
			
				
					 
					 
					                          axis  =  2 )  
					 
					 
					                          axis  =  2 )  
				
			 
			
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					 
					 
					    def  __init__ ( self ,  * * kwargs ) :  
					 
					 
					    def  __init__ ( self ,  * * kwargs ) :  
				
			 
			
		
	
		
		
			
				
					
					 
					 
					        self . home_dir  =  os . path . abspath ( os . getcwd ( ) )  
					 
					 
					        try :  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        flag_list  =  False  
					 
					 
					            # 1. 基础设置初始化  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        if  ' flag_list '  in  kwargs :  
					 
					 
					            self . _initialize_basic_settings ( kwargs )  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					            flag_list  =  True  
					 
					 
					             
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					
 
					 
					 
					            # 2. 配置文件和实验目录设置  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        # config setting  
					 
					 
					            self . _setup_config_and_directories ( kwargs )  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        if  type ( kwargs [ ' conf ' ] )  ==  str :  
					 
					 
					             
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					            self . conf_filename  =  ' ./conversion/ '  +  kwargs [ ' conf ' ]  
					 
					 
					            # 3. 数据加载  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					            self . conf  =  ConfigFactory . parse_file ( self . conf_filename )  
					 
					 
					            self . _load_data ( kwargs )  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        else :  
					 
					 
					             
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					            self . conf  =  kwargs [ ' conf ' ]  
					 
					 
					            # 4. CSG树设置  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					
 
					 
					 
					            self . _setup_csg_tree ( )  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        self . expname  =  kwargs [ ' expname ' ]  
					 
					 
					             
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					
 
					 
					 
					            # 5. 本地sigma计算  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        # GPU settings, currently we only support single-gpu training  
					 
					 
					            self . _compute_local_sigma ( )  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        self . GPU_INDEX  =  kwargs [ ' gpu_index ' ]  
					 
					 
					             
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        self . num_of_gpus  =  torch . cuda . device_count ( )  
					 
					 
					            # 6. 网络和优化器设置  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        self . eval  =  kwargs [ ' eval ' ]  
					 
					 
					            self . _setup_network_and_optimizer ( kwargs )  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					
 
					 
					 
					             
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        self . exps_folder_name  =  ' exps '  
					 
					 
					            print ( " Initialization completed successfully " )  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        utils . mkdir_ifnotexists ( utils . concat_home_dir ( os . path . join ( self . home_dir ,  self . exps_folder_name ) ) )  
					 
					 
					             
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        self . expdir  =  utils . concat_home_dir ( os . path . join ( self . home_dir ,  self . exps_folder_name ,  self . expname ) )  
					 
					 
					        except  Exception  as  e :  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        utils . mkdir_ifnotexists ( self . expdir )  
					 
					 
					            logger . error ( f " Error during initialization:  { str ( e ) } " )  
				
			 
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            raise  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					    def  _initialize_basic_settings ( self ,  kwargs ) :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        """ 初始化基础设置 """  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        try :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . home_dir  =  os . path . abspath ( os . getcwd ( ) )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . flag_list  =  kwargs . get ( ' flag_list ' ,  False )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . expname  =  kwargs [ ' expname ' ]  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . GPU_INDEX  =  kwargs [ ' gpu_index ' ]  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . num_of_gpus  =  torch . cuda . device_count ( )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . eval  =  kwargs [ ' eval ' ]  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            logger . debug ( " Basic settings initialized successfully " )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        except  KeyError  as  e :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            logger . error ( f " Missing required parameter:  { str ( e ) } " )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            raise  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        except  Exception  as  e :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            logger . error ( f " Error in basic settings initialization:  { str ( e ) } " )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            raise  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					    def  _setup_config_and_directories ( self ,  kwargs ) :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        """ 设置配置文件和创建必要的目录 """  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        try :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            # 配置设置  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            if  isinstance ( kwargs [ ' conf ' ] ,  str ) :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                self . conf_filename  =  ' ./conversion/ '  +  kwargs [ ' conf ' ]  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                self . conf  =  ConfigFactory . parse_file ( self . conf_filename )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            else :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                self . conf  =  kwargs [ ' conf ' ]  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            # 创建实验目录  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . exps_folder_name  =  ' exps '  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . expdir  =  utils . concat_home_dir ( os . path . join (  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                self . home_dir ,  self . exps_folder_name ,  self . expname ) )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            utils . mkdir_ifnotexists ( utils . concat_home_dir (  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                os . path . join ( self . home_dir ,  self . exps_folder_name ) ) )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            utils . mkdir_ifnotexists ( self . expdir )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            logger . debug ( " Config and directories setup completed " )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        except  Exception  as  e :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            logger . error ( f " Error in config and directory setup:  { str ( e ) } " )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            raise  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					    def  _load_data ( self ,  kwargs ) :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        """ 加载数据和特征掩码 """  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        try :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            if  not  self . flag_list :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                self . _load_single_data ( )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            else :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                self . _load_data_from_list ( kwargs )  
				
			 
			
		
	
		
		
			
				
					 
					 
					             
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					
					 
					 
					        if  not  flag_list :  
					 
					 
					            if  args . baseline :  
				
			 
			
				
				
			
		
	
		
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                self . feature_mask  =  torch . ones ( self . data . shape [ 0 ] ) . float ( )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            logger . info ( f " Data loading finished. Data shape:  { self . data . shape } " )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        except  FileNotFoundError  as  e :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            logger . error ( f " Data file not found:  { str ( e ) } " )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            raise  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        except  Exception  as  e :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            logger . error ( f " Error in data loading:  { str ( e ) } " )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            raise  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					    def  _load_single_data ( self ) :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        """ 加载单个数据文件 """  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        try :  
				
			 
			
		
	
		
		
			
				
					 
					 
					            self . input_file  =  self . conf . get_string ( ' train.input_path ' )  
					 
					 
					            self . input_file  =  self . conf . get_string ( ' train.input_path ' )  
				
			 
			
		
	
		
		
			
				
					 
					 
					            self . data  =  utils . load_point_cloud_by_file_extension ( self . input_file )  
					 
					 
					            self . data  =  utils . load_point_cloud_by_file_extension ( self . input_file )  
				
			 
			
		
	
		
		
			
				
					 
					 
					            self . feature_mask_file  =  self . conf . get_string ( ' train.feature_mask_path ' )  
					 
					 
					            self . feature_mask_file  =  self . conf . get_string ( ' train.feature_mask_path ' )  
				
			 
			
		
	
		
		
			
				
					 
					 
					            self . feature_mask  =  utils . load_feature_mask ( self . feature_mask_file )  
					 
					 
					            self . feature_mask  =  utils . load_feature_mask ( self . feature_mask_file )  
				
			 
			
		
	
		
		
			
				
					
					 
					 
					        else :  
					 
					 
					            self . foldername  =  self . conf . get_string ( ' train.foldername ' )  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					            self . input_file  =  os . path . join ( self . conf . get_string ( ' train.input_path ' ) ,  kwargs [ ' file_prefix ' ] + ' .xyz ' )  
					 
					 
					        except  Exception  as  e :  
				
			 
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            logger . error ( f " Error loading single data file:  { str ( e ) } " )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            raise  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					    def  _load_data_from_list ( self ,  kwargs ) :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        """ 从列表加载数据 """  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        try :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . input_file  =  os . path . join (  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                self . conf . get_string ( ' train.input_path ' ) ,   
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                kwargs [ ' file_prefix ' ] + ' .xyz '  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            )  
				
			 
			
		
	
		
		
			
				
					 
					 
					            if  not  os . path . exists ( self . input_file ) :  
					 
					 
					            if  not  os . path . exists ( self . input_file ) :  
				
			 
			
		
	
		
		
			
				
					 
					 
					                self . flag_data_load  =  False  
					 
					 
					                self . flag_data_load  =  False  
				
			 
			
		
	
		
		
			
				
					
					 
					 
					                return  
					 
					 
					                raise  Exception ( f " Data file not found:  { self . input_file } , absolute path:  { os . path . abspath ( self . input_file ) } " )  
				
			 
			
				
				
			
		
	
		
		
	
		
		
			
				
					 
					 
					 
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					            self . flag_data_load  =  True  
					 
					 
					            self . flag_data_load  =  True  
				
			 
			
		
	
		
		
			
				
					 
					 
					            self . data  =  utils . load_point_cloud_by_file_extension ( self . input_file )  
					 
					 
					            self . data  =  utils . load_point_cloud_by_file_extension ( self . input_file )  
				
			 
			
		
	
		
		
			
				
					
					 
					 
					            self . feature_mask_file  =  os . path . join ( self . conf . get_string ( ' train.input_path ' ) ,  kwargs [ ' file_prefix ' ] + ' _mask.txt ' )  
					 
					 
					            self . feature_mask_file  =  os . path . join (  
				
			 
			
				
				
			
		
	
		
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                self . conf . get_string ( ' train.input_path ' ) ,   
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                kwargs [ ' file_prefix ' ] + ' _mask.txt '  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					            if  not  args . baseline :  
					 
					 
					            if  not  args . baseline :  
				
			 
			
		
	
		
		
			
				
					 
					 
					                self . feature_mask  =  utils . load_feature_mask ( self . feature_mask_file )  
					 
					 
					                self . feature_mask  =  utils . load_feature_mask ( self . feature_mask_file )  
				
			 
			
		
	
		
		
			
				
					 
					 
					             
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					            if  args . feature_sample :  
					 
					 
					            if  args . feature_sample :  
				
			 
			
		
	
		
		
			
				
					
					 
					 
					                input_fs_file  =  os . path . join ( self . conf . get_string ( ' train.input_path ' ) ,  kwargs [ ' file_prefix ' ] + ' _feature.xyz ' )  
					 
					 
					                self . _load_feature_samples ( kwargs )  
				
			 
			
				
				
			
		
	
		
		
			
				
					 
					 
					                self . feature_data  =  np . loadtxt ( input_fs_file )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					                self . feature_data  =  torch . tensor ( self . feature_data ,  dtype  =  torch . float32 ,  device  =  ' cuda ' )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					                fs_mask_file  =  os . path . join ( self . conf . get_string ( ' train.input_path ' ) ,  kwargs [ ' file_prefix ' ] + ' _feature_mask.txt ' )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					                self . feature_data_mask_pair  =  torch . tensor ( np . loadtxt ( fs_mask_file ) ,  dtype  =  torch . int64 ,  device  =  ' cuda ' )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        if  args . baseline :  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					            self . csg_tree  =  [ 0 ]  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					            self . csg_flag_convex  =  True  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        else :  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					            self . csg_tree  =  [ ]  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					            self . csg_tree  =  ConfigFactory . parse_file ( self . input_file [ : - 4 ] + ' _csg.conf ' ) . get_list ( ' csg.list ' )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					            self . csg_flag_convex  =  ConfigFactory . parse_file ( self . input_file [ : - 4 ] + ' _csg.conf ' ) . get_int ( ' csg.flag_convex ' )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        print  ( " csg tree:  " ,  self . csg_tree )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        print  ( " csg convex flag:  " ,  self . csg_flag_convex )  
					 
					 
					 
				
			 
			
		
	
		
		
	
		
		
			
				
					 
					 
					             
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					        if  not  flag_list :  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					            self . foldername  =  self . conf . get_string ( ' train.foldername ' )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        else :  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					            self . foldername  =  kwargs [ ' folder_prefix ' ]  +  kwargs [ ' file_prefix ' ]  
					 
					 
					            self . foldername  =  kwargs [ ' folder_prefix ' ]  +  kwargs [ ' file_prefix ' ]  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        except  Exception  as  e :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            logger . error ( f " Error loading data from list:  { str ( e ) } " )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            raise  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					    def  _load_feature_samples ( self ,  kwargs ) :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        """ 加载特征样本 """  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        try :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            input_fs_file  =  os . path . join (  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                self . conf . get_string ( ' train.input_path ' ) ,   
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                kwargs [ ' file_prefix ' ] + ' _feature.xyz '  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . feature_data  =  np . loadtxt ( input_fs_file )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . feature_data  =  torch . tensor (  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                self . feature_data ,   
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                dtype = torch . float32 ,   
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                device = ' cuda '  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            )  
				
			 
			
		
	
		
		
			
				
					 
					 
					             
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					
					 
					 
					        if  args . baseline :  
					 
					 
					            fs_mask_file  =  os . path . join (  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					            self . feature_mask  =  torch . ones ( self . data . shape [ 0 ] ) . float ( )  
					 
					 
					                self . conf . get_string ( ' train.input_path ' ) ,   
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					
 
					 
					 
					                kwargs [ ' file_prefix ' ] + ' _feature_mask.txt '  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        print  ( " loading finished " )  
					 
					 
					            )  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        print  ( " data shape:  " ,  self . data . shape )  
					 
					 
					            self . feature_data_mask_pair  =  torch . tensor (  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					
 
					 
					 
					                np . loadtxt ( fs_mask_file ) ,   
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        sigma_set  =  [ ]  
					 
					 
					                dtype = torch . int64 ,   
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        ptree  =  cKDTree ( self . data )  
					 
					 
					                device = ' cuda '  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        print  ( " kd tree constructed " )  
					 
					 
					            )  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					
 
					 
					 
					        except  Exception  as  e :  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        for  p  in  np . array_split ( self . data ,  100 ,  axis = 0 ) :  
					 
					 
					            logger . error ( f " Error loading feature samples:  { str ( e ) } " )  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					            d  =  ptree . query ( p ,  50  +  1 )  
					 
					 
					            raise  
				
			 
			
				
				
			
		
	
		
		
			
				
					 
					 
					            sigma_set . append ( d [ 0 ] [ : ,  - 1 ] )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        sigmas  =  np . concatenate ( sigma_set )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        self . local_sigma  =  torch . from_numpy ( sigmas ) . float ( ) . cuda ( )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        self . cur_exp_dir  =  os . path . join ( self . expdir ,  self . foldername )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        utils . mkdir_ifnotexists ( self . cur_exp_dir )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        self . plots_dir  =  os . path . join ( self . cur_exp_dir ,  ' plots ' )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        utils . mkdir_ifnotexists ( self . plots_dir )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        self . checkpoints_path  =  os . path . join ( self . cur_exp_dir ,  ' checkpoints ' )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        utils . mkdir_ifnotexists ( self . checkpoints_path )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        self . model_params_subdir  =  " ModelParameters "  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        self . optimizer_params_subdir  =  " OptimizerParameters "  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        utils . mkdir_ifnotexists ( os . path . join ( self . checkpoints_path ,  self . model_params_subdir ) )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        utils . mkdir_ifnotexists ( os . path . join ( self . checkpoints_path ,  self . optimizer_params_subdir ) )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        model_params_path  =  os . path . join ( self . checkpoints_path ,  self . model_params_subdir )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        ckpts  =  os . listdir ( model_params_path )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        #if ckpts exists, then continue  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        is_continue  =  False  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        if  ( len ( ckpts ) )  !=  0 :  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					            is_continue  =  True  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        self . nepochs  =  kwargs [ ' nepochs ' ]  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        self . points_batch  =  kwargs [ ' points_batch ' ]  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        self . global_sigma  =  self . conf . get_float ( ' network.sampler.properties.global_sigma ' )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        self . sampler  =  Sampler . get_sampler ( self . conf . get_string ( ' network.sampler.sampler_type ' ) ) ( self . global_sigma ,  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					                                                                                                 self . local_sigma )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        self . grad_lambda  =  self . conf . get_float ( ' network.loss.lambda ' )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        self . normals_lambda  =  self . conf . get_float ( ' network.loss.normals_lambda ' )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        self . with_normals  =  self . normals_lambda  >  0  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        self . d_in  =  self . conf . get_int ( ' train.d_in ' )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        self . network  =  utils . get_class ( self . conf . get_string ( ' train.network_class ' ) ) ( d_in = self . d_in ,   
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					                                                                                n_branch  =  int ( torch . max ( self . feature_mask ) . item ( ) ) ,  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					                                                                                csg_tree  =  self . csg_tree ,  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					                                                                                flag_convex  =  self . csg_flag_convex ,  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					                                                                                * * self . conf . get_config (  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					                                                                                    ' network.inputs ' ) )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        print  ( self . network )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        if  torch . cuda . is_available ( ) :  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					            self . network . cuda ( )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        self . lr_schedules  =  self . get_learning_rate_schedules ( self . conf . get_list ( ' train.learning_rate_schedule ' ) )  
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					 
					 
					        self . weight_decay  =  self . conf . get_float ( ' train.weight_decay ' )  
					 
					 
					 
				
			 
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					        self . startepoch  =  0  
					 
					 
					    def  _setup_csg_tree ( self ) :  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        self . optimizer  =  torch . optim . Adam (  
					 
					 
					        """ 设置CSG树 """  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					            [  
					 
					 
					        try :  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					                {  
					 
					 
					            if  args . baseline :  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					                    " params " :  self . network . parameters ( ) ,  
					 
					 
					                self . csg_tree  =  [ 0 ]  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					                    " lr " :  self . lr_schedules [ 0 ] . get_learning_rate ( 0 ) ,  
					 
					 
					                self . csg_flag_convex  =  True  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					                    " weight_decay " :  self . weight_decay  
					 
					 
					            else :  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					                } ,  
					 
					 
					                csg_conf_file  =  self . input_file [ : - 4 ] + ' _csg.conf '  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					            ] )  
					 
					 
					                csg_config  =  ConfigFactory . parse_file ( csg_conf_file )  
				
			 
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                self . csg_tree  =  csg_config . get_list ( ' csg.list ' )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                self . csg_flag_convex  =  csg_config . get_int ( ' csg.flag_convex ' )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            logger . info ( f " CSG tree:  { self . csg_tree } " )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            logger . info ( f " CSG convex flag:  { self . csg_flag_convex } " )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        except  Exception  as  e :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            logger . error ( f " Error in CSG tree setup:  { str ( e ) } " )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            raise  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					    def  _compute_local_sigma ( self ) :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        """ 计算局部sigma值 """  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        try :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            sigma_set  =  [ ]  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            ptree  =  cKDTree ( self . data )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            logger . debug ( " KD tree constructed " )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            for  p  in  np . array_split ( self . data ,  100 ,  axis = 0 ) :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                d  =  ptree . query ( p ,  50  +  1 )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                sigma_set . append ( d [ 0 ] [ : ,  - 1 ] )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            sigmas  =  np . concatenate ( sigma_set )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . local_sigma  =  torch . from_numpy ( sigmas ) . float ( ) . cuda ( )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        except  Exception  as  e :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            logger . error ( f " Error computing local sigma:  { str ( e ) } " )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            raise  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					    def  _setup_network_and_optimizer ( self ,  kwargs ) :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        """ 设置网络和优化器 """  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        try :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            # 设置目录  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . _setup_checkpoints_directories ( )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            # 网络参数设置  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . _setup_network_parameters ( kwargs )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            # 创建网络  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . _create_network ( )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            # 设置优化器  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . _setup_optimizer ( kwargs )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            logger . debug ( " Network and optimizer setup completed " )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        except  Exception  as  e :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            logger . error ( f " Error in network and optimizer setup:  { str ( e ) } " )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            raise  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					    def  _setup_checkpoints_directories ( self ) :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        """ 设置检查点目录 """  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        try :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . cur_exp_dir  =  os . path . join ( self . expdir ,  self . foldername )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            utils . mkdir_ifnotexists ( self . cur_exp_dir )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . plots_dir  =  os . path . join ( self . cur_exp_dir ,  ' plots ' )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            utils . mkdir_ifnotexists ( self . plots_dir )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . checkpoints_path  =  os . path . join ( self . cur_exp_dir ,  ' checkpoints ' )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            utils . mkdir_ifnotexists ( self . checkpoints_path )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . model_params_subdir  =  " ModelParameters "  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . optimizer_params_subdir  =  " OptimizerParameters "  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            utils . mkdir_ifnotexists ( os . path . join (  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                self . checkpoints_path ,  self . model_params_subdir ) )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            utils . mkdir_ifnotexists ( os . path . join (  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                self . checkpoints_path ,  self . optimizer_params_subdir ) )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        except  Exception  as  e :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            logger . error ( f " Error setting up checkpoint directories:  { str ( e ) } " )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            raise  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					    def  _setup_network_parameters ( self ,  kwargs ) :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        """ 设置网络参数 """  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        try :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . nepochs  =  kwargs [ ' nepochs ' ]  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . points_batch  =  kwargs [ ' points_batch ' ]  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . global_sigma  =  self . conf . get_float ( ' network.sampler.properties.global_sigma ' )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . sampler  =  Sampler . get_sampler (  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                self . conf . get_string ( ' network.sampler.sampler_type ' ) ) (  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                    self . global_sigma ,  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                    self . local_sigma  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                )  
				
			 
			
		
	
		
		
			
				
					 
					 
					             
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					
					 
					 
					        # if continue load checkpoints  
					 
					 
					            self . grad_lambda  =  self . conf . get_float ( ' network.loss.lambda ' )  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					        if  is_continue :  
					 
					 
					            self . normals_lambda  =  self . conf . get_float ( ' network.loss.normals_lambda ' )  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					            old_checkpnts_dir  =  os . path . join ( self . expdir ,  self . foldername ,  ' checkpoints ' )  
					 
					 
					            self . with_normals  =  self . normals_lambda  >  0  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					            print ( ' loading checkpoint from:  ' ,  old_checkpnts_dir )  
					 
					 
					            self . d_in  =  self . conf . get_int ( ' train.d_in ' )  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					            saved_model_state  =  torch . load (  
					 
					 
					        except  Exception  as  e :  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					                os . path . join ( old_checkpnts_dir ,  ' ModelParameters ' ,  str ( kwargs [ ' checkpoint ' ] )  +  " .pth " ) )  
					 
					 
					            logger . error ( f " Error setting up network parameters:  { str ( e ) } " )  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					            self . network . load_state_dict ( saved_model_state [ " model_state_dict " ] )  
					 
					 
					            raise  
				
			 
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					    def  _create_network ( self ) :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        """ 创建网络 """  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        try :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . network  =  utils . get_class ( self . conf . get_string ( ' train.network_class ' ) ) (  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                d_in = self . d_in ,  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                n_branch = int ( torch . max ( self . feature_mask ) . item ( ) ) ,  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                csg_tree = self . csg_tree ,  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                flag_convex = self . csg_flag_convex ,  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                * * self . conf . get_config ( ' network.inputs ' )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            )  
				
			 
			
		
	
		
		
			
				
					 
					 
					             
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					
					 
					 
					            data  =  torch . load (  
					 
					 
					            logger . info ( f " Network created:  { self . network } " )  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					                os . path . join ( old_checkpnts_dir ,  ' OptimizerParameters ' ,  str ( kwargs [ ' checkpoint ' ] )  +  " .pth " ) )  
					 
					 
					             
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					            self . optimizer . load_state_dict ( data [ " optimizer_state_dict " ] )  
					 
					 
					            if  torch . cuda . is_available ( ) :  
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					            self . startepoch  =  saved_model_state [ ' epoch ' ]  
					 
					 
					                self . network . cuda ( )  
				
			 
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        except  Exception  as  e :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            logger . error ( f " Error creating network:  { str ( e ) } " )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            raise  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					    def  _setup_optimizer ( self ,  kwargs ) :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        """ 设置优化器 """  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        try :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . lr_schedules  =  self . get_learning_rate_schedules (  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                self . conf . get_list ( ' train.learning_rate_schedule ' ) )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . weight_decay  =  self . conf . get_float ( ' train.weight_decay ' )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . startepoch  =  0  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . optimizer  =  torch . optim . Adam ( [ {  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                " params " :  self . network . parameters ( ) ,  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                " lr " :  self . lr_schedules [ 0 ] . get_learning_rate ( 0 ) ,  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                " weight_decay " :  self . weight_decay  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            } ] )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            # 如果继续训练,加载检查点  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            self . _load_checkpoints_if_continue ( kwargs )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        except  Exception  as  e :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            logger . error ( f " Error setting up optimizer:  { str ( e ) } " )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            raise  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					    def  _load_checkpoints_if_continue ( self ,  kwargs ) :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        """ 如果继续训练,加载检查点 """  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        try :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            model_params_path  =  os . path . join (  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                self . checkpoints_path ,  self . model_params_subdir )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            ckpts  =  os . listdir ( model_params_path )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					             
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            if  len ( ckpts )  !=  0 :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                old_checkpnts_dir  =  os . path . join (  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                    self . expdir ,  self . foldername ,  ' checkpoints ' )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                logger . info ( f ' Loading checkpoint from:  { old_checkpnts_dir } ' )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                 
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                # 加载模型状态  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                saved_model_state  =  torch . load ( os . path . join (  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                    old_checkpnts_dir ,   
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                    ' ModelParameters ' ,   
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                    f " { kwargs [ ' checkpoint ' ] } .pth "  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                ) )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                self . network . load_state_dict ( saved_model_state [ " model_state_dict " ] )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                 
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                # 加载优化器状态  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                data  =  torch . load ( os . path . join (  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                    old_checkpnts_dir ,   
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                    ' OptimizerParameters ' ,   
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                    f " { kwargs [ ' checkpoint ' ] } .pth "  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                ) )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                self . optimizer . load_state_dict ( data [ " optimizer_state_dict " ] )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					                self . startepoch  =  saved_model_state [ ' epoch ' ]  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					        except  Exception  as  e :  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            logger . error ( f " Error loading checkpoints:  { str ( e ) } " )  
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					            raise  
				
			 
			
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					 
					 
					    def  get_learning_rate_schedules ( self ,  schedule_specs ) :  
					 
					 
					    def  get_learning_rate_schedules ( self ,  schedule_specs ) :  
				
			 
			
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					
 
				
			 
			
		
	
	
		
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
					@ -509,6 +679,8 @@ class ReconstructionRunner: 
				
			 
			
		
	
		
		
			
				
					 
					 
					            { " epoch " :  epoch ,  " optimizer_state_dict " :  self . optimizer . state_dict ( ) } ,  
					 
					 
					            { " epoch " :  epoch ,  " optimizer_state_dict " :  self . optimizer . state_dict ( ) } ,  
				
			 
			
		
	
		
		
			
				
					 
					 
					            os . path . join ( self . checkpoints_path ,  self . optimizer_params_subdir ,  " latest.pth " ) )  
					 
					 
					            os . path . join ( self . checkpoints_path ,  self . optimizer_params_subdir ,  " latest.pth " ) )  
				
			 
			
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					     
				
			 
			
		
	
		
		
			
				
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					 
					 
					if  __name__  ==  ' __main__ ' :  
					 
					 
					if  __name__  ==  ' __main__ ' :  
				
			 
			
		
	
		
		
			
				
					 
					 
					
 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					 
					 
					    if  args . gpu  ==  " auto " :  
					 
					 
					    if  args . gpu  ==  " auto " :