@ -22,13 +22,14 @@ from model.network import NHRepNet # 导入 NHRepNet
from model . sample import Sampler
from model . sample import Sampler
class NHREPNet_Training :
class NHREPNet_Training :
def __init__ ( self , data_dir , name_prefix : str , if_baseline : bool = False , if_feature_sample : bool = False ) :
def __init__ ( self , name_prefix : str , conf , if_baseline : bool = False , if_feature_sample : bool = False ) :
self . conf = ConfigFactory . parse_file ( ' ./conversion/setup. conf' )
self . conf = conf
self . sampler = Sampler . get_sampler (
self . sampler = Sampler . get_sampler (
self . conf . get_string ( ' network.sampler.sampler_type ' ) ) (
self . conf . get_string ( ' network.sampler.sampler_type ' ) ) (
global_sigma = self . conf . get_float ( ' network.sampler.properties.global_sigma ' ) ,
global_sigma = self . conf . get_float ( ' network.sampler.properties.global_sigma ' ) ,
local_sigma = self . conf . get_float ( ' network.sampler.properties.local_sigma ' )
local_sigma = self . conf . get_float ( ' network.sampler.properties.local_sigma ' )
)
)
data_dir = self . conf . get_string ( ' train.input_path ' )
self . dataset = NHREP_Dataset ( data_dir , name_prefix , if_baseline , if_feature_sample )
self . dataset = NHREP_Dataset ( data_dir , name_prefix , if_baseline , if_feature_sample )
self . device = torch . device ( " cuda " if torch . cuda . is_available ( ) else " cpu " )
self . device = torch . device ( " cuda " if torch . cuda . is_available ( ) else " cpu " )
@ -36,10 +37,13 @@ class NHREPNet_Training:
self . d_in = 3 # 输入维度,x, y, z.
self . d_in = 3 # 输入维度,x, y, z.
self . dims_sdf = [ 256 , 256 , 256 ] # 隐藏层维度
self . dims_sdf = [ 256 , 256 , 256 ] # 隐藏层维度
self . nepochs = 15000 # 训练轮数
folder = self . conf . get_string ( ' train.folderprefix ' )
self . writer = SummaryWriter ( os . path . join ( " summary " , folder , name_prefix ) ) # TensorBoard 记录器
# checkpoint
self . init_checkpoints ( )
self . nepochs = 15000 # 训练轮数
self . writer = SummaryWriter ( ) # TensorBoard 记录器
def run_nhrepnet_training ( self ) :
def run_nhrepnet_training ( self ) :
# 数据准备
# 数据准备
@ -48,7 +52,7 @@ class NHREPNet_Training:
feature_mask_cpu = self . dataset . get_feature_mask ( ) . numpy ( ) # 特征掩码
feature_mask_cpu = self . dataset . get_feature_mask ( ) . numpy ( ) # 特征掩码
self . feature_mask = torch . from_numpy ( feature_mask_cpu ) . to ( self . device ) # 特征掩码 # 特征掩码
self . feature_mask = torch . from_numpy ( feature_mask_cpu ) . to ( self . device ) # 特征掩码 # 特征掩码
self . points_batch = 16384 # 批次大小
self . points_batch = 16384 # 批次大小
self . compute_local_sigma ( )
n_branch = int ( torch . max ( self . feature_mask ) . item ( ) ) # 计算分支数量
n_branch = int ( torch . max ( self . feature_mask ) . item ( ) ) # 计算分支数量
n_batchsize = self . points_batch # 设置批次大小
n_batchsize = self . points_batch # 设置批次大小
@ -78,7 +82,7 @@ class NHREPNet_Training:
logger . info ( " 开始训练 " )
logger . info ( " 开始训练 " )
self . model . train ( ) # 设置模型为训练模式
self . model . train ( ) # 设置模型为训练模式
for epoch in range ( self . nepochs ) : # 开始训练循环
for epoch in tqdm ( range ( self . nepochs ) , desc = " 训练进度 " , unit = " epoch " ) : # 开始训练循环
try :
try :
self . train_one_epoch ( epoch , patch_id , patch_id_n , n_patch_batch , n_patch_last , n_branch , n_batchsize )
self . train_one_epoch ( epoch , patch_id , patch_id_n , n_patch_batch , n_patch_last , n_branch , n_batchsize )
except Exception as e :
except Exception as e :
@ -86,14 +90,13 @@ class NHREPNet_Training:
break
break
def train_one_epoch ( self , epoch , patch_id , patch_id_n , n_patch_batch , n_patch_last , n_branch , n_batchsize ) :
def train_one_epoch ( self , epoch , patch_id , patch_id_n , n_patch_batch , n_patch_last , n_branch , n_batchsize ) :
logger . info ( f " Epoch { epoch } / { self . nepochs } 开始 " )
#logger.info(f"Epoch {epoch}/{self.nepochs} 开始" )
# 1.3,获取索引
# 1.3,获取索引
indices = self . get_indices ( patch_id , patch_id_n , n_patch_batch , n_patch_last , n_branch )
indices = self . get_indices ( patch_id , patch_id_n , n_patch_batch , n_patch_last , n_branch )
# 1.4,获取数据
# 1.4,获取数据
cur_data = self . data [ indices ] # x, y, z, nx, ny, nz
cur_data = self . data [ indices ] # x, y, z, nx, ny, nz
mnfld_pnts = cur_data [ : , : self . d_in ] # 提取流形点
mnfld_pnts = cur_data [ : , : self . d_in ] # 提取流形点
self . compute_local_sigma ( )
mnfld_sigma = self . local_sigma [ indices ] # 提取噪声点
mnfld_sigma = self . local_sigma [ indices ] # 提取噪声点
nonmnfld_pnts = self . sampler . get_points ( mnfld_pnts . unsqueeze ( 0 ) , mnfld_sigma . unsqueeze ( 0 ) ) . squeeze ( ) # 生成非流形点
nonmnfld_pnts = self . sampler . get_points ( mnfld_pnts . unsqueeze ( 0 ) , mnfld_sigma . unsqueeze ( 0 ) ) . squeeze ( ) # 生成非流形点
@ -101,9 +104,6 @@ class NHREPNet_Training:
#TODO 记录了log
#TODO 记录了log
# 2,前向传播
# 2,前向传播
self . scheduler . adjust_learning_rate ( epoch )
#logger.info(f"mnfld_pnts: {mnfld_pnts.shape}")
#logger.info(f"mnfld_pnts: {mnfld_pnts.shape}")
#logger.info(f"nonmnfld_pnts: {nonmnfld_pnts.shape}")
#logger.info(f"nonmnfld_pnts: {nonmnfld_pnts.shape}")
@ -116,7 +116,7 @@ class NHREPNet_Training:
normals = cur_data [ : , - self . d_in : ]
normals = cur_data [ : , - self . d_in : ]
# 计算损失
# 计算损失
loss = self . loss_manager . compute_loss (
loss , loss_details = self . loss_manager . compute_loss (
mnfld_pnts = mnfld_pnts ,
mnfld_pnts = mnfld_pnts ,
normals = normals ,
normals = normals ,
mnfld_pred_all = mnfld_pred_all ,
mnfld_pred_all = mnfld_pred_all ,
@ -128,7 +128,7 @@ class NHREPNet_Training:
n_patch_last = n_patch_last ,
n_patch_last = n_patch_last ,
) # 计算损失
) # 计算损失
self . scheduler . step ( loss )
self . scheduler . step ( loss , epoch )
# 反向传播
# 反向传播
self . scheduler . optimizer . zero_grad ( ) # 清空梯度
self . scheduler . optimizer . zero_grad ( ) # 清空梯度
@ -136,8 +136,13 @@ class NHREPNet_Training:
self . scheduler . optimizer . step ( ) # 更新参数
self . scheduler . optimizer . step ( ) # 更新参数
avg_loss = loss . item ( )
avg_loss = loss . item ( )
logger . info ( f ' Epoch [ { epoch } / { self . nepochs } ], Average Loss: { avg_loss : .4f } ' )
if epoch % 100 == 0 :
self . writer . add_scalar ( ' Loss/train ' , avg_loss , epoch ) # 记录损失到 TensorBoard
#logger.info(f'Epoch [{epoch}/{self.nepochs}]')
self . writer . add_scalar ( ' Loss/train ' , avg_loss , epoch ) # 记录损失到 TensorBoard
for k , v in loss_details . items ( ) :
self . writer . add_scalar ( ' Loss/ ' + k , v . item ( ) , epoch )
if epoch % self . conf . get_int ( ' train.checkpoint_frequency ' ) == 0 : # 每隔一定轮次保存检查点
self . save_checkpoints ( epoch )
#============================ 前向传播 数据准备 ============================
#============================ 前向传播 数据准备 ============================
def compute_patch ( self , n_branch , n_patch_batch , n_patch_last , feature_mask_cpu ) :
def compute_patch ( self , n_branch , n_patch_batch , n_patch_last , feature_mask_cpu ) :
@ -202,23 +207,34 @@ class NHREPNet_Training:
#============================ 保存模型 ============================
#============================ 保存模型 ============================
def save_checkpoints ( self , epoch ) :
def init_checkpoints ( self ) :
self . checkpoints_path = os . path . join ( " ../exps/single_shape " , name_prefix , " checkpoints " )
self . ModelParameters_path = os . path . join ( self . checkpoints_path , " ModelParameters " )
self . OptimizerParameters_path = os . path . join ( self . checkpoints_path , " OptimizerParameters " )
# 创建目录
os . makedirs ( self . ModelParameters_path , exist_ok = True )
os . makedirs ( self . OptimizerParameters_path , exist_ok = True )
def save_checkpoints ( self , epoch ) :
torch . save (
torch . save (
{ " epoch " : epoch , " model_state_dict " : self . network . state_dict ( ) } ,
{ " epoch " : epoch , " model_state_dict " : self . model . state_dict ( ) } ,
os . path . join ( self . checkpoints_path , self . model_params_subdir , str ( epoch ) + " .pth " ) )
os . path . join ( self . ModelParameters_path , str ( epoch ) + " .pth " ) )
torch . save (
torch . save (
{ " epoch " : epoch , " model_state_dict " : self . network . state_dict ( ) } ,
{ " epoch " : epoch , " model_state_dict " : self . model . state_dict ( ) } ,
os . path . join ( self . checkpoints_path , self . model_params_subdir , " latest.pth " ) )
os . path . join ( self . ModelParameters_path , " latest.pth " ) )
torch . save (
torch . save (
{ " epoch " : epoch , " optimizer_state_dict " : self . optimizer . state_dict ( ) } ,
{ " epoch " : epoch , " optimizer_state_dict " : self . scheduler . optimizer . state_dict ( ) } ,
os . path . join ( self . checkpoints_path , self . optimizer_params_subdir , str ( epoch ) + " .pth " ) )
os . path . join ( self . OptimizerParameters_path , str ( epoch ) + " .pth " ) )
torch . save (
torch . save (
{ " epoch " : epoch , " optimizer_state_dict " : self . optimizer . state_dict ( ) } ,
{ " epoch " : epoch , " optimizer_state_dict " : self . scheduler . optimizer . state_dict ( ) } ,
os . path . join ( self . checkpoints_path , self . optimizer_params_subdir , " latest.pth " ) )
os . path . join ( self . OptimizerParameters_path , " latest.pth " ) )
if __name__ == " __main__ " :
if __name__ == " __main__ " :
data_dir = ' ../data/input_data ' # 数据目录
name_prefix = ' broken_bullet_50k '
name_prefix = ' broken_bullet_50k '
train = NHREPNet_Training ( data_dir , name_prefix , if_baseline = True , if_feature_sample = False )
conf = ConfigFactory . parse_file ( ' ./conversion/setup.conf ' )
train . run_nhrepnet_training ( )
try :
train = NHREPNet_Training ( name_prefix , conf , if_baseline = True , if_feature_sample = False )
train . run_nhrepnet_training ( )
except Exception as e :
logger . error ( str ( e ) )