@ -88,6 +88,7 @@ class NHREPNet_Training:
except Exception as e :
logger . error ( f " 训练过程中发生错误: { str ( e ) } " )
break
self . tracing ( )
def train_one_epoch ( self , epoch , patch_id , patch_id_n , n_patch_batch , n_patch_last , n_branch , n_batchsize ) :
#logger.info(f"Epoch {epoch}/{self.nepochs} 开始")
@ -211,10 +212,12 @@ class NHREPNet_Training:
self . checkpoints_path = os . path . join ( " ../exps/single_shape " , name_prefix , " checkpoints " )
self . ModelParameters_path = os . path . join ( self . checkpoints_path , " ModelParameters " )
self . OptimizerParameters_path = os . path . join ( self . checkpoints_path , " OptimizerParameters " )
self . TorchScript_path = os . path . join ( self . checkpoints_path , " TorchScript " )
# 创建目录
os . makedirs ( self . ModelParameters_path , exist_ok = True )
os . makedirs ( self . OptimizerParameters_path , exist_ok = True )
os . makedirs ( self . TorchScript_path , exist_ok = True )
def save_checkpoints ( self , epoch ) :
torch . save (
@ -230,6 +233,21 @@ class NHREPNet_Training:
{ " epoch " : epoch , " optimizer_state_dict " : self . scheduler . optimizer . state_dict ( ) } ,
os . path . join ( self . OptimizerParameters_path , " latest.pth " ) )
def tracing ( self ) :
csg_tree , flag_convex = self . dataset . get_csg_tree ( )
network = get_class ( self . conf . get_string ( ' train.network_class ' ) ) (
d_in = self . d_in ,
n_branch = int ( torch . max ( self . feature_mask ) . item ( ) ) ,
csg_tree = csg_tree ,
flag_convex = flag_convex ,
* * self . conf . get_config ( ' network.inputs ' )
) . to ( self . device )
#trace
example = torch . rand ( 224 , 3 ) . to ( self . device )
traced_script_module = torch . jit . trace ( network , example )
traced_script_module . save ( os . path . join ( self . TorchScript_path , " model_h.pt " ) )
if __name__ == " __main__ " :
name_prefix = ' broken_bullet_50k '
conf = ConfigFactory . parse_file ( ' ./conversion/setup.conf ' )