You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

724 lines
32 KiB

import os
import sys
import time
project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
sys.path.append(project_dir)
os.chdir(project_dir)
import argparse
# parse args first and set gpu id
parser = argparse.ArgumentParser()
parser.add_argument('--points_batch', type=int, default=16384, help='point batch size')
parser.add_argument('--nepoch', type=int, default=15001, help='number of epochs to train for')
parser.add_argument('--conf', type=str, default='setup.conf')
parser.add_argument('--expname', type=str, default='single_shape')
parser.add_argument('--gpu', type=str, default='0', help='GPU id to use')
parser.add_argument('--is_continue', default=False, action="store_true", help='continue')
parser.add_argument('--checkpoint', default='latest', type=str)
parser.add_argument('--eval', default=False, action="store_true")
parser.add_argument('--summary', default = False, action="store_true", help = 'write tensorboard summary')
parser.add_argument('--baseline', default = False, action="store_true", help = 'run baseline')
parser.add_argument('--th_closeness',type=float, default = 1e-5, help = 'threshold deciding whether two points are the same')
parser.add_argument('--cpu', default = False, action="store_true", help = 'save for cpu device')
parser.add_argument('--ab', default='none', type=str, help = 'ablation')
parser.add_argument('--siren', default = False, action="store_true", help = 'siren normal loss')
parser.add_argument('--pt', default='ptfile path', type=str)
parser.add_argument('--feature_sample', action="store_true", help = 'use feature curve samples')
parser.add_argument('--num_feature_sample', type=int, default=2048, help ='number of bs feature samples')
parser.add_argument('--all_feature_sample', type=int, default=10000, help ='number of all feature samples')
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(args.gpu)
from pyhocon import ConfigFactory
import numpy as np
import GPUtil
import torch
import utils.general as utils
from model.sample import Sampler
from model.network import gradient
from scipy.spatial import cKDTree
from utils.plots import plot_surface, plot_cuts_axis
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from utils.logger import logger
logger.info(f"project_dir: {project_dir}")
class ReconstructionRunner:
def run_nhrepnet_training(self):
print("running")
self.data = self.data.cuda()
self.data.requires_grad_()
feature_mask_cpu = self.feature_mask.numpy()
self.feature_mask = self.feature_mask.cuda()
n_branch = int(torch.max(self.feature_mask).item())
n_batchsize = self.points_batch
n_patch_batch = n_batchsize // n_branch
n_patch_last = n_batchsize - n_patch_batch * (n_branch - 1)
patch_sup = True
weight_mnfld_h = 1
weight_mnfld_cs = 1
weight_correction = 1
a_correction = 100
patch_id = []
patch_id_n = []
for i in range(n_branch):
patch_id = patch_id + [np.where(feature_mask_cpu == i + 1)[0]]
patch_id_n = patch_id_n + [patch_id[i].shape[0]]
if self.eval:
print("evaluating epoch: {0}".format(self.startepoch))
my_path = os.path.join(self.cur_exp_dir, 'evaluation', str(self.startepoch))
utils.mkdir_ifnotexists(os.path.join(self.cur_exp_dir, 'evaluation'))
utils.mkdir_ifnotexists(my_path)
for i in range(1):
self.network.flag_output = i + 1
self.plot_shapes(epoch=self.startepoch, path=my_path, file_suffix = "_" + str(i), with_cuts = True)
self.network.flag_output = 0
return
print("training begin")
if args.summary == True:
writer = SummaryWriter(os.path.join("summary", self.foldername))
# branch mask is predefined
branch_mask = torch.zeros(n_branch, n_batchsize).cuda()
single_branch_mask_gt = torch.zeros(n_batchsize, n_branch).cuda()
single_branch_mask_id = torch.zeros([n_batchsize], dtype = torch.long).cuda()
for i in range(n_branch - 1):
branch_mask[i, i * n_patch_batch : (i + 1) * n_patch_batch] = 1.0
single_branch_mask_gt[i * n_patch_batch : (i + 1) * n_patch_batch, i] = 1.0
single_branch_mask_id[i * n_patch_batch : (i + 1) * n_patch_batch] = i
#last patch
branch_mask[n_branch - 1, (n_branch - 1) * n_patch_batch:] = 1.0
single_branch_mask_gt[(n_branch - 1) * n_patch_batch:, (n_branch - 1)] = 1.0
single_branch_mask_id[(n_branch - 1) * n_patch_batch:] = (n_branch - 1)
for epoch in range(self.startepoch, self.nepochs + 1):
indices = torch.empty(0,dtype=torch.int64).cuda()
for i in range(n_branch - 1):
indices_nonfeature = torch.tensor(patch_id[i][np.random.choice(patch_id_n[i], n_patch_batch, True)]).cuda()
indices = torch.cat((indices, indices_nonfeature), 0)
#last patch
indices_nonfeature = torch.tensor(patch_id[n_branch - 1][np.random.choice(patch_id_n[n_branch - 1], n_patch_last, True)]).cuda()
indices = torch.cat((indices, indices_nonfeature), 0)
cur_data = self.data[indices]
mnfld_pnts = cur_data[:, :self.d_in] #n_indices x 3
mnfld_sigma = self.local_sigma[indices] #noise points
if epoch % self.conf.get_int('train.checkpoint_frequency') == 0:
self.save_checkpoints(epoch)
if epoch % self.conf.get_int('train.plot_frequency') == 0:
print('plot validation epoch: ', epoch)
for i in range(n_branch + 1):
self.network.flag_output = i + 1
self.plot_shapes(epoch, file_suffix = "_" + str(i), with_cuts = False)
self.network.flag_output = 0
self.network.train()
self.adjust_learning_rate(epoch)
nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0), mnfld_sigma.unsqueeze(0)).squeeze()
# forward pass
mnfld_pred_all = self.network(mnfld_pnts)
nonmnfld_pred_all = self.network(nonmnfld_pnts)
mnfld_pred = mnfld_pred_all[:,0]
nonmnfld_pred = nonmnfld_pred_all[:,0]
loss = 0.0
mnfld_grad = gradient(mnfld_pnts, mnfld_pred)
# manifold loss
mnfld_loss = torch.zeros(1).cuda()
if not args.ab == 'overall':
mnfld_loss = (mnfld_pred.abs()).mean()
loss = loss + weight_mnfld_h * mnfld_loss
#feature sample
if args.feature_sample:
feature_indices = torch.randperm(args.all_feature_sample)[:args.num_feature_sample].cuda()
feature_pnts = self.feature_data[feature_indices]
feature_mask_pair = self.feature_data_mask_pair[feature_indices]
feature_pred_all = self.network(feature_pnts)
feature_pred = feature_pred_all[:,0]
feature_mnfld_loss = feature_pred.abs().mean()
loss = loss + weight_mnfld_h * feature_mnfld_loss #|h|
#patch loss:
feature_id_left = [list(range(args.num_feature_sample)), feature_mask_pair[:,0].tolist()]
feature_id_right = [list(range(args.num_feature_sample)), feature_mask_pair[:,1].tolist()]
feature_fis_left = feature_pred_all[feature_id_left]
feature_fis_right = feature_pred_all[feature_id_right]
feature_loss_patch = feature_fis_left.abs().mean() + feature_fis_right.abs().mean()
loss += feature_loss_patch
#consistency loss:
feature_loss_cons = (feature_fis_left - feature_pred).abs().mean() + (feature_fis_right - feature_pred).abs().mean()
loss += weight_mnfld_cs * feature_loss_cons
all_fi = torch.zeros([n_batchsize, 1], device = 'cuda')
for i in range(n_branch - 1):
all_fi[i * n_patch_batch : (i + 1) * n_patch_batch, 0] = mnfld_pred_all[i * n_patch_batch : (i + 1) * n_patch_batch, i + 1]
#last patch
all_fi[(n_branch - 1) * n_patch_batch:, 0] = mnfld_pred_all[(n_branch - 1) * n_patch_batch:, n_branch]
# manifold loss for patches
mnfld_loss_patch = torch.zeros(1).cuda()
if not args.ab == 'patch':
if patch_sup:
mnfld_loss_patch = all_fi[:,0].abs().mean()
loss = loss + mnfld_loss_patch
#correction loss
correction_loss = torch.zeros(1).cuda()
if not (args.ab == 'cor' or args.ab == 'cc') and epoch > 10000 and not args.baseline:
mismatch_id = torch.abs(mnfld_pred - all_fi[:,0]) > args.th_closeness
if mismatch_id.sum() != 0:
correction_loss = (a_correction * torch.abs(mnfld_pred - all_fi[:,0])[mismatch_id]).mean()
loss = loss + weight_correction * correction_loss
#off surface_loss
offsurface_loss = torch.zeros(1).cuda()
if not args.ab == 'off':
offsurface_loss = torch.exp(-100.0 * torch.abs(nonmnfld_pred[n_batchsize:])).mean()
loss = loss + offsurface_loss
#manifold consistency loss
mnfld_consistency_loss = torch.zeros(1).cuda()
if not (args.ab == 'cons' or args.ab == 'cc'):
mnfld_consistency_loss = (mnfld_pred - all_fi[:,0]).abs().mean()
loss = loss + weight_mnfld_cs * mnfld_consistency_loss
#eikonal loss for h
grad_loss_h = torch.zeros(1).cuda()
single_nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred_all[:,0])
grad_loss_h = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean()
loss = loss + self.grad_lambda * grad_loss_h
# normals loss
normals_loss_h = torch.zeros(1).cuda()
normals_loss = torch.zeros(1).cuda()
normal_consistency_loss = torch.zeros(1).cuda()
if not args.siren:
if not args.ab == 'normal' and self.with_normals:
#all normals
normals = cur_data[:, -self.d_in:]
if patch_sup:
branch_grad = gradient(mnfld_pnts, all_fi[:,0])
normals_loss = (((branch_grad - normals).abs()).norm(2, dim=1)).mean()
loss = loss + self.normals_lambda * normals_loss
#only supervised, not used for loss computation
mnfld_grad = gradient(mnfld_pnts, mnfld_pred_all[:, 0])
normal_consistency_loss = (mnfld_grad - branch_grad).abs().norm(2, dim=1).mean()
else:
single_nonmnfld_grad = gradient(mnfld_pnts, all_fi[:,0])
normals_loss_h = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean()
loss = loss + self.normals_lambda * normals_loss_h
else:
#compute consine normal
normals = cur_data[:, -self.d_in:]
normals_loss_h = (1 - F.cosine_similarity(mnfld_grad, normals, dim=-1)).mean()
loss = loss + self.normals_lambda * normals_loss_h
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
#tensorboard
if args.summary == True and epoch % 100 == 0:
writer.add_scalar('Loss/Total loss', loss.item(), epoch)
writer.add_scalar('Loss/Manifold loss h', mnfld_loss.item(), epoch)
writer.add_scalar('Loss/Manifold patch loss', mnfld_loss_patch.item(), epoch)
writer.add_scalar('Loss/Manifold cons loss', mnfld_consistency_loss.item(), epoch)
writer.add_scalar('Loss/Grad loss h',self.grad_lambda * grad_loss_h.item(), epoch)
writer.add_scalar('Loss/Normal loss all', self.normals_lambda * normals_loss.item(), epoch)
writer.add_scalar('Loss/Normal cs loss', self.normals_lambda * normal_consistency_loss.item(), epoch)
writer.add_scalar('Loss/Assignment loss', correction_loss.item(), epoch)
writer.add_scalar('Loss/Offsurface loss', offsurface_loss.item(), epoch)
if epoch % self.conf.get_int('train.status_frequency') == 0:
print('Train Epoch: [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}\t Manifold loss: {:.6f}'
'\tManifold patch loss: {:.6f}\t grad loss h: {:.6f}\t normals loss all: {:.6f}\t normals loss h: {:.6f}\t Manifold consistency loss: {:.6f}\tCorrection loss: {:.6f}\t Offsurface loss: {:.6f}'.format(
epoch, self.nepochs, 100. * epoch / self.nepochs,
loss.item(), mnfld_loss.item(), mnfld_loss_patch.item(), grad_loss_h.item(), normals_loss.item(), normals_loss_h.item(), mnfld_consistency_loss.item(), correction_loss.item(), offsurface_loss.item()))
if args.feature_sample:
print('feature mnfld loss: {} patch loss: {} cons loss: {}'.format(feature_mnfld_loss.item(), feature_loss_patch.item(), feature_loss_cons.item()))
self.tracing()
def tracing(self):
#network definition
device = torch.device('cuda')
if args.cpu:
device = torch.device('cpu')
network = utils.get_class(self.conf.get_string('train.network_class'))(d_in=3, flag_output = 1,
n_branch = int(torch.max(self.feature_mask).item()),
csg_tree = self.csg_tree,
flag_convex = self.csg_flag_convex,
**self.conf.get_config(
'network.inputs'))
network.to(device)
ckpt_prefix = 'exps/single_shape/'
save_prefix = '{}/'.format(args.pt)
if not os.path.exists(save_prefix):
os.mkdir(save_prefix)
if args.cpu:
saved_model_state = torch.load(ckpt_prefix + self.foldername + '/checkpoints/ModelParameters/latest.pth', map_location=device)
network.load_state_dict(saved_model_state["model_state_dict"])
else:
saved_model_state = torch.load(ckpt_prefix + self.foldername + '/checkpoints/ModelParameters/latest.pth')
network.load_state_dict(saved_model_state["model_state_dict"])
#trace
example = torch.rand(224,3).to(device)
traced_script_module = torch.jit.trace(network, example)
traced_script_module.save(save_prefix + self.foldername + "_model_h.pt")
print('converting to pt finished')
def plot_shapes(self, epoch, path=None, with_cuts=False, file_suffix="all"):
# plot network validation shapes
with torch.no_grad():
self.network.eval()
if not path:
path = self.plots_dir
indices = torch.tensor(np.random.choice(self.data.shape[0], self.points_batch, True)) #modified 0107, with replace
pnts = self.data[indices, :3]
#draw nonmnfld pts
mnfld_sigma = self.local_sigma[indices]
nonmnfld_pnts = self.sampler.get_points(pnts.unsqueeze(0), mnfld_sigma.unsqueeze(0)).squeeze()
pnts = nonmnfld_pnts
plot_surface(with_points=True,
points=pnts,
decoder=self.network,
path=path,
epoch=epoch,
shapename=self.expname,
suffix = file_suffix,
**self.conf.get_config('plot'))
if with_cuts:
plot_cuts_axis(points=pnts,
decoder=self.network,
latent = None,
path=path,
epoch=epoch,
near_zero=False,
axis = 2)
def __init__(self, **kwargs):
try:
# 1. 基础设置初始化
self._initialize_basic_settings(kwargs)
# 2. 配置文件和实验目录设置
self._setup_config_and_directories(kwargs)
# 3. 数据加载
self._load_data(kwargs)
# 4. CSG树设置
self._setup_csg_tree()
# 5. 本地sigma计算
self._compute_local_sigma()
# 6. 网络和优化器设置
self._setup_network_and_optimizer(kwargs)
print("Initialization completed successfully")
except Exception as e:
logger.error(f"Error during initialization: {str(e)}")
raise
def _initialize_basic_settings(self, kwargs):
"""初始化基础设置"""
try:
self.home_dir = os.path.abspath(os.getcwd())
self.flag_list = kwargs.get('flag_list', False)
self.expname = kwargs['expname']
self.GPU_INDEX = kwargs['gpu_index']
self.num_of_gpus = torch.cuda.device_count()
self.eval = kwargs['eval']
logger.debug("Basic settings initialized successfully")
except KeyError as e:
logger.error(f"Missing required parameter: {str(e)}")
raise
except Exception as e:
logger.error(f"Error in basic settings initialization: {str(e)}")
raise
def _setup_config_and_directories(self, kwargs):
"""设置配置文件和创建必要的目录"""
try:
# 配置设置
if isinstance(kwargs['conf'], str):
self.conf_filename = './conversion/' + kwargs['conf']
self.conf = ConfigFactory.parse_file(self.conf_filename)
else:
self.conf = kwargs['conf']
# 创建实验目录
self.exps_folder_name = 'exps'
self.expdir = utils.concat_home_dir(os.path.join(
self.home_dir, self.exps_folder_name, self.expname))
utils.mkdir_ifnotexists(utils.concat_home_dir(
os.path.join(self.home_dir, self.exps_folder_name)))
utils.mkdir_ifnotexists(self.expdir)
logger.debug("Config and directories setup completed")
except Exception as e:
logger.error(f"Error in config and directory setup: {str(e)}")
raise
def _load_data(self, kwargs):
"""加载数据和特征掩码"""
try:
if not self.flag_list:
self._load_single_data()
else:
self._load_data_from_list(kwargs)
if args.baseline:
self.feature_mask = torch.ones(self.data.shape[0]).float()
logger.info(f"Data loading finished. Data shape: {self.data.shape}")
except FileNotFoundError as e:
logger.error(f"Data file not found: {str(e)}")
raise
except Exception as e:
logger.error(f"Error in data loading: {str(e)}")
raise
def _load_single_data(self):
"""加载单个数据文件"""
try:
self.input_file = self.conf.get_string('train.input_path')
self.data = utils.load_point_cloud_by_file_extension(self.input_file)
self.feature_mask_file = self.conf.get_string('train.feature_mask_path')
self.feature_mask = utils.load_feature_mask(self.feature_mask_file)
self.foldername = self.conf.get_string('train.foldername')
except Exception as e:
logger.error(f"Error loading single data file: {str(e)}")
raise
def _load_data_from_list(self, kwargs):
"""从列表加载数据"""
try:
self.input_file = os.path.join(
self.conf.get_string('train.input_path'),
kwargs['file_prefix']+'.xyz'
)
if not os.path.exists(self.input_file):
self.flag_data_load = False
raise Exception(f"Data file not found: {self.input_file}, absolute path: {os.path.abspath(self.input_file)}")
self.flag_data_load = True
self.data = utils.load_point_cloud_by_file_extension(self.input_file)
self.feature_mask_file = os.path.join(
self.conf.get_string('train.input_path'),
kwargs['file_prefix']+'_mask.txt'
)
if not args.baseline:
self.feature_mask = utils.load_feature_mask(self.feature_mask_file)
if args.feature_sample:
self._load_feature_samples(kwargs)
self.foldername = kwargs['folder_prefix'] + kwargs['file_prefix']
except Exception as e:
logger.error(f"Error loading data from list: {str(e)}")
raise
def _load_feature_samples(self, kwargs):
"""加载特征样本"""
try:
input_fs_file = os.path.join(
self.conf.get_string('train.input_path'),
kwargs['file_prefix']+'_feature.xyz'
)
self.feature_data = np.loadtxt(input_fs_file)
self.feature_data = torch.tensor(
self.feature_data,
dtype=torch.float32,
device='cuda'
)
fs_mask_file = os.path.join(
self.conf.get_string('train.input_path'),
kwargs['file_prefix']+'_feature_mask.txt'
)
self.feature_data_mask_pair = torch.tensor(
np.loadtxt(fs_mask_file),
dtype=torch.int64,
device='cuda'
)
except Exception as e:
logger.error(f"Error loading feature samples: {str(e)}")
raise
def _setup_csg_tree(self):
"""设置CSG树"""
try:
if args.baseline:
self.csg_tree = [0]
self.csg_flag_convex = True
else:
csg_conf_file = self.input_file[:-4]+'_csg.conf'
csg_config = ConfigFactory.parse_file(csg_conf_file)
self.csg_tree = csg_config.get_list('csg.list')
self.csg_flag_convex = csg_config.get_int('csg.flag_convex')
logger.info(f"CSG tree: {self.csg_tree}")
logger.info(f"CSG convex flag: {self.csg_flag_convex}")
except Exception as e:
logger.error(f"Error in CSG tree setup: {str(e)}")
raise
def _compute_local_sigma(self):
"""计算局部sigma值"""
try:
sigma_set = []
ptree = cKDTree(self.data)
logger.debug("KD tree constructed")
for p in np.array_split(self.data, 100, axis=0):
d = ptree.query(p, 50 + 1)
sigma_set.append(d[0][:, -1])
sigmas = np.concatenate(sigma_set)
self.local_sigma = torch.from_numpy(sigmas).float().cuda()
except Exception as e:
logger.error(f"Error computing local sigma: {str(e)}")
raise
def _setup_network_and_optimizer(self, kwargs):
"""设置网络和优化器"""
try:
# 设置目录
self._setup_checkpoints_directories()
# 网络参数设置
self._setup_network_parameters(kwargs)
# 创建网络
self._create_network()
# 设置优化器
self._setup_optimizer(kwargs)
logger.debug("Network and optimizer setup completed")
except Exception as e:
logger.error(f"Error in network and optimizer setup: {str(e)}")
raise
def _setup_checkpoints_directories(self):
"""设置检查点目录"""
try:
self.cur_exp_dir = os.path.join(self.expdir, self.foldername)
utils.mkdir_ifnotexists(self.cur_exp_dir)
self.plots_dir = os.path.join(self.cur_exp_dir, 'plots')
utils.mkdir_ifnotexists(self.plots_dir)
self.checkpoints_path = os.path.join(self.cur_exp_dir, 'checkpoints')
utils.mkdir_ifnotexists(self.checkpoints_path)
self.model_params_subdir = "ModelParameters"
self.optimizer_params_subdir = "OptimizerParameters"
utils.mkdir_ifnotexists(os.path.join(
self.checkpoints_path, self.model_params_subdir))
utils.mkdir_ifnotexists(os.path.join(
self.checkpoints_path, self.optimizer_params_subdir))
except Exception as e:
logger.error(f"Error setting up checkpoint directories: {str(e)}")
raise
def _setup_network_parameters(self, kwargs):
"""设置网络参数"""
try:
self.nepochs = kwargs['nepochs']
self.points_batch = kwargs['points_batch']
self.global_sigma = self.conf.get_float('network.sampler.properties.global_sigma')
self.sampler = Sampler.get_sampler(
self.conf.get_string('network.sampler.sampler_type'))(
self.global_sigma,
self.local_sigma
)
self.grad_lambda = self.conf.get_float('network.loss.lambda')
self.normals_lambda = self.conf.get_float('network.loss.normals_lambda')
self.with_normals = self.normals_lambda > 0
self.d_in = self.conf.get_int('train.d_in')
except Exception as e:
logger.error(f"Error setting up network parameters: {str(e)}")
raise
def _create_network(self):
"""创建网络"""
try:
self.network = utils.get_class(self.conf.get_string('train.network_class'))(
d_in=self.d_in,
n_branch=int(torch.max(self.feature_mask).item()),
csg_tree=self.csg_tree,
flag_convex=self.csg_flag_convex,
**self.conf.get_config('network.inputs')
)
logger.info(f"Network created: {self.network}")
if torch.cuda.is_available():
self.network.cuda()
except Exception as e:
logger.error(f"Error creating network: {str(e)}")
raise
def _setup_optimizer(self, kwargs):
"""设置优化器"""
try:
self.lr_schedules = self.get_learning_rate_schedules(
self.conf.get_list('train.learning_rate_schedule'))
self.weight_decay = self.conf.get_float('train.weight_decay')
self.startepoch = 0
self.optimizer = torch.optim.Adam([{
"params": self.network.parameters(),
"lr": self.lr_schedules[0].get_learning_rate(0),
"weight_decay": self.weight_decay
}])
# 如果继续训练,加载检查点
self._load_checkpoints_if_continue(kwargs)
except Exception as e:
logger.error(f"Error setting up optimizer: {str(e)}")
raise
def _load_checkpoints_if_continue(self, kwargs):
"""如果继续训练,加载检查点"""
try:
model_params_path = os.path.join(
self.checkpoints_path, self.model_params_subdir)
ckpts = os.listdir(model_params_path)
if len(ckpts) != 0:
old_checkpnts_dir = os.path.join(
self.expdir, self.foldername, 'checkpoints')
logger.info(f'Loading checkpoint from: {old_checkpnts_dir}')
# 加载模型状态
saved_model_state = torch.load(os.path.join(
old_checkpnts_dir,
'ModelParameters',
f"{kwargs['checkpoint']}.pth"
))
self.network.load_state_dict(saved_model_state["model_state_dict"])
# 加载优化器状态
data = torch.load(os.path.join(
old_checkpnts_dir,
'OptimizerParameters',
f"{kwargs['checkpoint']}.pth"
))
self.optimizer.load_state_dict(data["optimizer_state_dict"])
self.startepoch = saved_model_state['epoch']
except Exception as e:
logger.error(f"Error loading checkpoints: {str(e)}")
raise
def get_learning_rate_schedules(self, schedule_specs):
schedules = []
for schedule_specs in schedule_specs:
if schedule_specs["Type"] == "Step":
schedules.append(
utils.StepLearningRateSchedule(
schedule_specs["Initial"],
schedule_specs["Interval"],
schedule_specs["Factor"],
)
)
else:
raise Exception(
'no known learning rate schedule of type "{}"'.format(
schedule_specs["Type"]
)
)
return schedules
def adjust_learning_rate(self, epoch):
for i, param_group in enumerate(self.optimizer.param_groups):
param_group["lr"] = self.lr_schedules[i].get_learning_rate(epoch)
def save_checkpoints(self, epoch):
torch.save(
{"epoch": epoch, "model_state_dict": self.network.state_dict()},
os.path.join(self.checkpoints_path, self.model_params_subdir, str(epoch) + ".pth"))
torch.save(
{"epoch": epoch, "model_state_dict": self.network.state_dict()},
os.path.join(self.checkpoints_path, self.model_params_subdir, "latest.pth"))
torch.save(
{"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()},
os.path.join(self.checkpoints_path, self.optimizer_params_subdir, str(epoch) + ".pth"))
torch.save(
{"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()},
os.path.join(self.checkpoints_path, self.optimizer_params_subdir, "latest.pth"))
if __name__ == '__main__':
if args.gpu == "auto":
deviceIDs = GPUtil.getAvailable(order='memory', limit=1, maxLoad=0.5, maxMemory=0.5, includeNan=False, excludeID=[],
excludeUUID=[])
gpu = deviceIDs[0]
else:
gpu = args.gpu
nepoch = args.nepoch
conf = ConfigFactory.parse_file('./conversion/' + args.conf)
folderprefix = conf.get_string('train.folderprefix')
fileprefix_list = conf.get_list('train.fileprefix_list')
trainrunners = []
for i in range(len(fileprefix_list)):
fp = fileprefix_list[i]
print ('cur model: ', fp)
begin = time.time()
trainrunners.append(ReconstructionRunner(
conf=args.conf,
folder_prefix = folderprefix,
file_prefix = fp,
points_batch=args.points_batch,
nepochs=nepoch,
expname=args.expname,
gpu_index=gpu,
is_continue=args.is_continue,
checkpoint=args.checkpoint,
eval=args.eval,
flag_list = True
)
)
if trainrunners[i].flag_data_load:
trainrunners[i].run_nhrepnet_training()
end = time.time()
dur = end - begin
if args.baseline:
fp = fp+"_bl"