You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
724 lines
32 KiB
724 lines
32 KiB
import os
|
|
import sys
|
|
import time
|
|
|
|
project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
|
|
sys.path.append(project_dir)
|
|
os.chdir(project_dir)
|
|
|
|
import argparse
|
|
# parse args first and set gpu id
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--points_batch', type=int, default=16384, help='point batch size')
|
|
parser.add_argument('--nepoch', type=int, default=15001, help='number of epochs to train for')
|
|
parser.add_argument('--conf', type=str, default='setup.conf')
|
|
parser.add_argument('--expname', type=str, default='single_shape')
|
|
parser.add_argument('--gpu', type=str, default='0', help='GPU id to use')
|
|
parser.add_argument('--is_continue', default=False, action="store_true", help='continue')
|
|
parser.add_argument('--checkpoint', default='latest', type=str)
|
|
parser.add_argument('--eval', default=False, action="store_true")
|
|
parser.add_argument('--summary', default = False, action="store_true", help = 'write tensorboard summary')
|
|
parser.add_argument('--baseline', default = False, action="store_true", help = 'run baseline')
|
|
parser.add_argument('--th_closeness',type=float, default = 1e-5, help = 'threshold deciding whether two points are the same')
|
|
parser.add_argument('--cpu', default = False, action="store_true", help = 'save for cpu device')
|
|
parser.add_argument('--ab', default='none', type=str, help = 'ablation')
|
|
parser.add_argument('--siren', default = False, action="store_true", help = 'siren normal loss')
|
|
parser.add_argument('--pt', default='ptfile path', type=str)
|
|
parser.add_argument('--feature_sample', action="store_true", help = 'use feature curve samples')
|
|
parser.add_argument('--num_feature_sample', type=int, default=2048, help ='number of bs feature samples')
|
|
parser.add_argument('--all_feature_sample', type=int, default=10000, help ='number of all feature samples')
|
|
|
|
args = parser.parse_args()
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(args.gpu)
|
|
from pyhocon import ConfigFactory
|
|
import numpy as np
|
|
import GPUtil
|
|
import torch
|
|
import utils.general as utils
|
|
from model.sample import Sampler
|
|
from model.network import gradient
|
|
from scipy.spatial import cKDTree
|
|
from utils.plots import plot_surface, plot_cuts_axis
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
import torch.nn.functional as F
|
|
from utils.logger import logger
|
|
logger.info(f"project_dir: {project_dir}")
|
|
|
|
class ReconstructionRunner:
|
|
def run_nhrepnet_training(self):
|
|
print("running")
|
|
self.data = self.data.cuda()
|
|
self.data.requires_grad_()
|
|
feature_mask_cpu = self.feature_mask.numpy()
|
|
self.feature_mask = self.feature_mask.cuda()
|
|
n_branch = int(torch.max(self.feature_mask).item())
|
|
n_batchsize = self.points_batch
|
|
n_patch_batch = n_batchsize // n_branch
|
|
n_patch_last = n_batchsize - n_patch_batch * (n_branch - 1)
|
|
|
|
patch_sup = True
|
|
weight_mnfld_h = 1
|
|
weight_mnfld_cs = 1
|
|
weight_correction = 1
|
|
a_correction = 100
|
|
|
|
patch_id = []
|
|
patch_id_n = []
|
|
for i in range(n_branch):
|
|
patch_id = patch_id + [np.where(feature_mask_cpu == i + 1)[0]]
|
|
patch_id_n = patch_id_n + [patch_id[i].shape[0]]
|
|
if self.eval:
|
|
print("evaluating epoch: {0}".format(self.startepoch))
|
|
my_path = os.path.join(self.cur_exp_dir, 'evaluation', str(self.startepoch))
|
|
utils.mkdir_ifnotexists(os.path.join(self.cur_exp_dir, 'evaluation'))
|
|
utils.mkdir_ifnotexists(my_path)
|
|
for i in range(1):
|
|
self.network.flag_output = i + 1
|
|
self.plot_shapes(epoch=self.startepoch, path=my_path, file_suffix = "_" + str(i), with_cuts = True)
|
|
self.network.flag_output = 0
|
|
return
|
|
|
|
print("training begin")
|
|
if args.summary == True:
|
|
writer = SummaryWriter(os.path.join("summary", self.foldername))
|
|
# branch mask is predefined
|
|
branch_mask = torch.zeros(n_branch, n_batchsize).cuda()
|
|
single_branch_mask_gt = torch.zeros(n_batchsize, n_branch).cuda()
|
|
single_branch_mask_id = torch.zeros([n_batchsize], dtype = torch.long).cuda()
|
|
for i in range(n_branch - 1):
|
|
branch_mask[i, i * n_patch_batch : (i + 1) * n_patch_batch] = 1.0
|
|
single_branch_mask_gt[i * n_patch_batch : (i + 1) * n_patch_batch, i] = 1.0
|
|
single_branch_mask_id[i * n_patch_batch : (i + 1) * n_patch_batch] = i
|
|
|
|
#last patch
|
|
branch_mask[n_branch - 1, (n_branch - 1) * n_patch_batch:] = 1.0
|
|
single_branch_mask_gt[(n_branch - 1) * n_patch_batch:, (n_branch - 1)] = 1.0
|
|
single_branch_mask_id[(n_branch - 1) * n_patch_batch:] = (n_branch - 1)
|
|
|
|
for epoch in range(self.startepoch, self.nepochs + 1):
|
|
indices = torch.empty(0,dtype=torch.int64).cuda()
|
|
for i in range(n_branch - 1):
|
|
indices_nonfeature = torch.tensor(patch_id[i][np.random.choice(patch_id_n[i], n_patch_batch, True)]).cuda()
|
|
indices = torch.cat((indices, indices_nonfeature), 0)
|
|
#last patch
|
|
indices_nonfeature = torch.tensor(patch_id[n_branch - 1][np.random.choice(patch_id_n[n_branch - 1], n_patch_last, True)]).cuda()
|
|
indices = torch.cat((indices, indices_nonfeature), 0)
|
|
|
|
cur_data = self.data[indices]
|
|
mnfld_pnts = cur_data[:, :self.d_in] #n_indices x 3
|
|
mnfld_sigma = self.local_sigma[indices] #noise points
|
|
|
|
if epoch % self.conf.get_int('train.checkpoint_frequency') == 0:
|
|
self.save_checkpoints(epoch)
|
|
|
|
if epoch % self.conf.get_int('train.plot_frequency') == 0:
|
|
print('plot validation epoch: ', epoch)
|
|
for i in range(n_branch + 1):
|
|
self.network.flag_output = i + 1
|
|
self.plot_shapes(epoch, file_suffix = "_" + str(i), with_cuts = False)
|
|
self.network.flag_output = 0
|
|
|
|
self.network.train()
|
|
self.adjust_learning_rate(epoch)
|
|
nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0), mnfld_sigma.unsqueeze(0)).squeeze()
|
|
|
|
# forward pass
|
|
mnfld_pred_all = self.network(mnfld_pnts)
|
|
nonmnfld_pred_all = self.network(nonmnfld_pnts)
|
|
mnfld_pred = mnfld_pred_all[:,0]
|
|
nonmnfld_pred = nonmnfld_pred_all[:,0]
|
|
loss = 0.0
|
|
mnfld_grad = gradient(mnfld_pnts, mnfld_pred)
|
|
|
|
# manifold loss
|
|
mnfld_loss = torch.zeros(1).cuda()
|
|
if not args.ab == 'overall':
|
|
mnfld_loss = (mnfld_pred.abs()).mean()
|
|
loss = loss + weight_mnfld_h * mnfld_loss
|
|
|
|
#feature sample
|
|
if args.feature_sample:
|
|
feature_indices = torch.randperm(args.all_feature_sample)[:args.num_feature_sample].cuda()
|
|
feature_pnts = self.feature_data[feature_indices]
|
|
feature_mask_pair = self.feature_data_mask_pair[feature_indices]
|
|
feature_pred_all = self.network(feature_pnts)
|
|
feature_pred = feature_pred_all[:,0]
|
|
feature_mnfld_loss = feature_pred.abs().mean()
|
|
loss = loss + weight_mnfld_h * feature_mnfld_loss #|h|
|
|
|
|
#patch loss:
|
|
feature_id_left = [list(range(args.num_feature_sample)), feature_mask_pair[:,0].tolist()]
|
|
feature_id_right = [list(range(args.num_feature_sample)), feature_mask_pair[:,1].tolist()]
|
|
feature_fis_left = feature_pred_all[feature_id_left]
|
|
feature_fis_right = feature_pred_all[feature_id_right]
|
|
feature_loss_patch = feature_fis_left.abs().mean() + feature_fis_right.abs().mean()
|
|
loss += feature_loss_patch
|
|
|
|
#consistency loss:
|
|
feature_loss_cons = (feature_fis_left - feature_pred).abs().mean() + (feature_fis_right - feature_pred).abs().mean()
|
|
loss += weight_mnfld_cs * feature_loss_cons
|
|
|
|
all_fi = torch.zeros([n_batchsize, 1], device = 'cuda')
|
|
for i in range(n_branch - 1):
|
|
all_fi[i * n_patch_batch : (i + 1) * n_patch_batch, 0] = mnfld_pred_all[i * n_patch_batch : (i + 1) * n_patch_batch, i + 1]
|
|
#last patch
|
|
all_fi[(n_branch - 1) * n_patch_batch:, 0] = mnfld_pred_all[(n_branch - 1) * n_patch_batch:, n_branch]
|
|
|
|
# manifold loss for patches
|
|
mnfld_loss_patch = torch.zeros(1).cuda()
|
|
if not args.ab == 'patch':
|
|
if patch_sup:
|
|
mnfld_loss_patch = all_fi[:,0].abs().mean()
|
|
loss = loss + mnfld_loss_patch
|
|
|
|
#correction loss
|
|
correction_loss = torch.zeros(1).cuda()
|
|
if not (args.ab == 'cor' or args.ab == 'cc') and epoch > 10000 and not args.baseline:
|
|
mismatch_id = torch.abs(mnfld_pred - all_fi[:,0]) > args.th_closeness
|
|
if mismatch_id.sum() != 0:
|
|
correction_loss = (a_correction * torch.abs(mnfld_pred - all_fi[:,0])[mismatch_id]).mean()
|
|
loss = loss + weight_correction * correction_loss
|
|
|
|
#off surface_loss
|
|
offsurface_loss = torch.zeros(1).cuda()
|
|
if not args.ab == 'off':
|
|
offsurface_loss = torch.exp(-100.0 * torch.abs(nonmnfld_pred[n_batchsize:])).mean()
|
|
loss = loss + offsurface_loss
|
|
|
|
#manifold consistency loss
|
|
mnfld_consistency_loss = torch.zeros(1).cuda()
|
|
if not (args.ab == 'cons' or args.ab == 'cc'):
|
|
mnfld_consistency_loss = (mnfld_pred - all_fi[:,0]).abs().mean()
|
|
loss = loss + weight_mnfld_cs * mnfld_consistency_loss
|
|
|
|
#eikonal loss for h
|
|
grad_loss_h = torch.zeros(1).cuda()
|
|
single_nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred_all[:,0])
|
|
grad_loss_h = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean()
|
|
loss = loss + self.grad_lambda * grad_loss_h
|
|
|
|
# normals loss
|
|
normals_loss_h = torch.zeros(1).cuda()
|
|
normals_loss = torch.zeros(1).cuda()
|
|
normal_consistency_loss = torch.zeros(1).cuda()
|
|
if not args.siren:
|
|
if not args.ab == 'normal' and self.with_normals:
|
|
#all normals
|
|
normals = cur_data[:, -self.d_in:]
|
|
if patch_sup:
|
|
branch_grad = gradient(mnfld_pnts, all_fi[:,0])
|
|
normals_loss = (((branch_grad - normals).abs()).norm(2, dim=1)).mean()
|
|
loss = loss + self.normals_lambda * normals_loss
|
|
|
|
#only supervised, not used for loss computation
|
|
mnfld_grad = gradient(mnfld_pnts, mnfld_pred_all[:, 0])
|
|
normal_consistency_loss = (mnfld_grad - branch_grad).abs().norm(2, dim=1).mean()
|
|
else:
|
|
single_nonmnfld_grad = gradient(mnfld_pnts, all_fi[:,0])
|
|
normals_loss_h = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean()
|
|
loss = loss + self.normals_lambda * normals_loss_h
|
|
else:
|
|
#compute consine normal
|
|
normals = cur_data[:, -self.d_in:]
|
|
normals_loss_h = (1 - F.cosine_similarity(mnfld_grad, normals, dim=-1)).mean()
|
|
loss = loss + self.normals_lambda * normals_loss_h
|
|
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
#tensorboard
|
|
if args.summary == True and epoch % 100 == 0:
|
|
writer.add_scalar('Loss/Total loss', loss.item(), epoch)
|
|
writer.add_scalar('Loss/Manifold loss h', mnfld_loss.item(), epoch)
|
|
writer.add_scalar('Loss/Manifold patch loss', mnfld_loss_patch.item(), epoch)
|
|
writer.add_scalar('Loss/Manifold cons loss', mnfld_consistency_loss.item(), epoch)
|
|
writer.add_scalar('Loss/Grad loss h',self.grad_lambda * grad_loss_h.item(), epoch)
|
|
writer.add_scalar('Loss/Normal loss all', self.normals_lambda * normals_loss.item(), epoch)
|
|
writer.add_scalar('Loss/Normal cs loss', self.normals_lambda * normal_consistency_loss.item(), epoch)
|
|
writer.add_scalar('Loss/Assignment loss', correction_loss.item(), epoch)
|
|
writer.add_scalar('Loss/Offsurface loss', offsurface_loss.item(), epoch)
|
|
|
|
|
|
if epoch % self.conf.get_int('train.status_frequency') == 0:
|
|
print('Train Epoch: [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}\t Manifold loss: {:.6f}'
|
|
'\tManifold patch loss: {:.6f}\t grad loss h: {:.6f}\t normals loss all: {:.6f}\t normals loss h: {:.6f}\t Manifold consistency loss: {:.6f}\tCorrection loss: {:.6f}\t Offsurface loss: {:.6f}'.format(
|
|
epoch, self.nepochs, 100. * epoch / self.nepochs,
|
|
loss.item(), mnfld_loss.item(), mnfld_loss_patch.item(), grad_loss_h.item(), normals_loss.item(), normals_loss_h.item(), mnfld_consistency_loss.item(), correction_loss.item(), offsurface_loss.item()))
|
|
if args.feature_sample:
|
|
print('feature mnfld loss: {} patch loss: {} cons loss: {}'.format(feature_mnfld_loss.item(), feature_loss_patch.item(), feature_loss_cons.item()))
|
|
|
|
self.tracing()
|
|
|
|
def tracing(self):
|
|
#network definition
|
|
device = torch.device('cuda')
|
|
if args.cpu:
|
|
device = torch.device('cpu')
|
|
network = utils.get_class(self.conf.get_string('train.network_class'))(d_in=3, flag_output = 1,
|
|
n_branch = int(torch.max(self.feature_mask).item()),
|
|
csg_tree = self.csg_tree,
|
|
flag_convex = self.csg_flag_convex,
|
|
**self.conf.get_config(
|
|
'network.inputs'))
|
|
network.to(device)
|
|
ckpt_prefix = 'exps/single_shape/'
|
|
save_prefix = '{}/'.format(args.pt)
|
|
if not os.path.exists(save_prefix):
|
|
os.mkdir(save_prefix)
|
|
|
|
if args.cpu:
|
|
saved_model_state = torch.load(ckpt_prefix + self.foldername + '/checkpoints/ModelParameters/latest.pth', map_location=device)
|
|
network.load_state_dict(saved_model_state["model_state_dict"])
|
|
else:
|
|
saved_model_state = torch.load(ckpt_prefix + self.foldername + '/checkpoints/ModelParameters/latest.pth')
|
|
network.load_state_dict(saved_model_state["model_state_dict"])
|
|
#trace
|
|
example = torch.rand(224,3).to(device)
|
|
traced_script_module = torch.jit.trace(network, example)
|
|
traced_script_module.save(save_prefix + self.foldername + "_model_h.pt")
|
|
print('converting to pt finished')
|
|
|
|
def plot_shapes(self, epoch, path=None, with_cuts=False, file_suffix="all"):
|
|
# plot network validation shapes
|
|
with torch.no_grad():
|
|
self.network.eval()
|
|
if not path:
|
|
path = self.plots_dir
|
|
|
|
indices = torch.tensor(np.random.choice(self.data.shape[0], self.points_batch, True)) #modified 0107, with replace
|
|
pnts = self.data[indices, :3]
|
|
|
|
#draw nonmnfld pts
|
|
mnfld_sigma = self.local_sigma[indices]
|
|
nonmnfld_pnts = self.sampler.get_points(pnts.unsqueeze(0), mnfld_sigma.unsqueeze(0)).squeeze()
|
|
pnts = nonmnfld_pnts
|
|
|
|
plot_surface(with_points=True,
|
|
points=pnts,
|
|
decoder=self.network,
|
|
path=path,
|
|
epoch=epoch,
|
|
shapename=self.expname,
|
|
suffix = file_suffix,
|
|
**self.conf.get_config('plot'))
|
|
|
|
if with_cuts:
|
|
plot_cuts_axis(points=pnts,
|
|
decoder=self.network,
|
|
latent = None,
|
|
path=path,
|
|
epoch=epoch,
|
|
near_zero=False,
|
|
axis = 2)
|
|
|
|
def __init__(self, **kwargs):
|
|
try:
|
|
# 1. 基础设置初始化
|
|
self._initialize_basic_settings(kwargs)
|
|
|
|
# 2. 配置文件和实验目录设置
|
|
self._setup_config_and_directories(kwargs)
|
|
|
|
# 3. 数据加载
|
|
self._load_data(kwargs)
|
|
|
|
# 4. CSG树设置
|
|
self._setup_csg_tree()
|
|
|
|
# 5. 本地sigma计算
|
|
self._compute_local_sigma()
|
|
|
|
# 6. 网络和优化器设置
|
|
self._setup_network_and_optimizer(kwargs)
|
|
|
|
print("Initialization completed successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during initialization: {str(e)}")
|
|
raise
|
|
|
|
def _initialize_basic_settings(self, kwargs):
|
|
"""初始化基础设置"""
|
|
try:
|
|
self.home_dir = os.path.abspath(os.getcwd())
|
|
self.flag_list = kwargs.get('flag_list', False)
|
|
self.expname = kwargs['expname']
|
|
self.GPU_INDEX = kwargs['gpu_index']
|
|
self.num_of_gpus = torch.cuda.device_count()
|
|
self.eval = kwargs['eval']
|
|
logger.debug("Basic settings initialized successfully")
|
|
except KeyError as e:
|
|
logger.error(f"Missing required parameter: {str(e)}")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error in basic settings initialization: {str(e)}")
|
|
raise
|
|
|
|
def _setup_config_and_directories(self, kwargs):
|
|
"""设置配置文件和创建必要的目录"""
|
|
try:
|
|
# 配置设置
|
|
if isinstance(kwargs['conf'], str):
|
|
self.conf_filename = './conversion/' + kwargs['conf']
|
|
self.conf = ConfigFactory.parse_file(self.conf_filename)
|
|
else:
|
|
self.conf = kwargs['conf']
|
|
|
|
# 创建实验目录
|
|
self.exps_folder_name = 'exps'
|
|
self.expdir = utils.concat_home_dir(os.path.join(
|
|
self.home_dir, self.exps_folder_name, self.expname))
|
|
utils.mkdir_ifnotexists(utils.concat_home_dir(
|
|
os.path.join(self.home_dir, self.exps_folder_name)))
|
|
utils.mkdir_ifnotexists(self.expdir)
|
|
|
|
logger.debug("Config and directories setup completed")
|
|
except Exception as e:
|
|
logger.error(f"Error in config and directory setup: {str(e)}")
|
|
raise
|
|
|
|
def _load_data(self, kwargs):
|
|
"""加载数据和特征掩码"""
|
|
try:
|
|
if not self.flag_list:
|
|
self._load_single_data()
|
|
else:
|
|
self._load_data_from_list(kwargs)
|
|
|
|
if args.baseline:
|
|
self.feature_mask = torch.ones(self.data.shape[0]).float()
|
|
|
|
logger.info(f"Data loading finished. Data shape: {self.data.shape}")
|
|
except FileNotFoundError as e:
|
|
logger.error(f"Data file not found: {str(e)}")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error in data loading: {str(e)}")
|
|
raise
|
|
|
|
def _load_single_data(self):
|
|
"""加载单个数据文件"""
|
|
try:
|
|
self.input_file = self.conf.get_string('train.input_path')
|
|
self.data = utils.load_point_cloud_by_file_extension(self.input_file)
|
|
self.feature_mask_file = self.conf.get_string('train.feature_mask_path')
|
|
self.feature_mask = utils.load_feature_mask(self.feature_mask_file)
|
|
self.foldername = self.conf.get_string('train.foldername')
|
|
except Exception as e:
|
|
logger.error(f"Error loading single data file: {str(e)}")
|
|
raise
|
|
|
|
def _load_data_from_list(self, kwargs):
|
|
"""从列表加载数据"""
|
|
try:
|
|
self.input_file = os.path.join(
|
|
self.conf.get_string('train.input_path'),
|
|
kwargs['file_prefix']+'.xyz'
|
|
)
|
|
if not os.path.exists(self.input_file):
|
|
self.flag_data_load = False
|
|
raise Exception(f"Data file not found: {self.input_file}, absolute path: {os.path.abspath(self.input_file)}")
|
|
|
|
self.flag_data_load = True
|
|
self.data = utils.load_point_cloud_by_file_extension(self.input_file)
|
|
self.feature_mask_file = os.path.join(
|
|
self.conf.get_string('train.input_path'),
|
|
kwargs['file_prefix']+'_mask.txt'
|
|
)
|
|
|
|
if not args.baseline:
|
|
self.feature_mask = utils.load_feature_mask(self.feature_mask_file)
|
|
|
|
if args.feature_sample:
|
|
self._load_feature_samples(kwargs)
|
|
|
|
self.foldername = kwargs['folder_prefix'] + kwargs['file_prefix']
|
|
except Exception as e:
|
|
logger.error(f"Error loading data from list: {str(e)}")
|
|
raise
|
|
|
|
def _load_feature_samples(self, kwargs):
|
|
"""加载特征样本"""
|
|
try:
|
|
input_fs_file = os.path.join(
|
|
self.conf.get_string('train.input_path'),
|
|
kwargs['file_prefix']+'_feature.xyz'
|
|
)
|
|
self.feature_data = np.loadtxt(input_fs_file)
|
|
self.feature_data = torch.tensor(
|
|
self.feature_data,
|
|
dtype=torch.float32,
|
|
device='cuda'
|
|
)
|
|
|
|
fs_mask_file = os.path.join(
|
|
self.conf.get_string('train.input_path'),
|
|
kwargs['file_prefix']+'_feature_mask.txt'
|
|
)
|
|
self.feature_data_mask_pair = torch.tensor(
|
|
np.loadtxt(fs_mask_file),
|
|
dtype=torch.int64,
|
|
device='cuda'
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error loading feature samples: {str(e)}")
|
|
raise
|
|
|
|
def _setup_csg_tree(self):
|
|
"""设置CSG树"""
|
|
try:
|
|
if args.baseline:
|
|
self.csg_tree = [0]
|
|
self.csg_flag_convex = True
|
|
else:
|
|
csg_conf_file = self.input_file[:-4]+'_csg.conf'
|
|
csg_config = ConfigFactory.parse_file(csg_conf_file)
|
|
self.csg_tree = csg_config.get_list('csg.list')
|
|
self.csg_flag_convex = csg_config.get_int('csg.flag_convex')
|
|
|
|
logger.info(f"CSG tree: {self.csg_tree}")
|
|
logger.info(f"CSG convex flag: {self.csg_flag_convex}")
|
|
except Exception as e:
|
|
logger.error(f"Error in CSG tree setup: {str(e)}")
|
|
raise
|
|
|
|
def _compute_local_sigma(self):
|
|
"""计算局部sigma值"""
|
|
try:
|
|
sigma_set = []
|
|
ptree = cKDTree(self.data)
|
|
logger.debug("KD tree constructed")
|
|
|
|
for p in np.array_split(self.data, 100, axis=0):
|
|
d = ptree.query(p, 50 + 1)
|
|
sigma_set.append(d[0][:, -1])
|
|
|
|
sigmas = np.concatenate(sigma_set)
|
|
self.local_sigma = torch.from_numpy(sigmas).float().cuda()
|
|
except Exception as e:
|
|
logger.error(f"Error computing local sigma: {str(e)}")
|
|
raise
|
|
|
|
def _setup_network_and_optimizer(self, kwargs):
|
|
"""设置网络和优化器"""
|
|
try:
|
|
# 设置目录
|
|
self._setup_checkpoints_directories()
|
|
|
|
# 网络参数设置
|
|
self._setup_network_parameters(kwargs)
|
|
|
|
# 创建网络
|
|
self._create_network()
|
|
|
|
# 设置优化器
|
|
self._setup_optimizer(kwargs)
|
|
|
|
logger.debug("Network and optimizer setup completed")
|
|
except Exception as e:
|
|
logger.error(f"Error in network and optimizer setup: {str(e)}")
|
|
raise
|
|
|
|
def _setup_checkpoints_directories(self):
|
|
"""设置检查点目录"""
|
|
try:
|
|
self.cur_exp_dir = os.path.join(self.expdir, self.foldername)
|
|
utils.mkdir_ifnotexists(self.cur_exp_dir)
|
|
|
|
self.plots_dir = os.path.join(self.cur_exp_dir, 'plots')
|
|
utils.mkdir_ifnotexists(self.plots_dir)
|
|
|
|
self.checkpoints_path = os.path.join(self.cur_exp_dir, 'checkpoints')
|
|
utils.mkdir_ifnotexists(self.checkpoints_path)
|
|
|
|
self.model_params_subdir = "ModelParameters"
|
|
self.optimizer_params_subdir = "OptimizerParameters"
|
|
|
|
utils.mkdir_ifnotexists(os.path.join(
|
|
self.checkpoints_path, self.model_params_subdir))
|
|
utils.mkdir_ifnotexists(os.path.join(
|
|
self.checkpoints_path, self.optimizer_params_subdir))
|
|
except Exception as e:
|
|
logger.error(f"Error setting up checkpoint directories: {str(e)}")
|
|
raise
|
|
|
|
def _setup_network_parameters(self, kwargs):
|
|
"""设置网络参数"""
|
|
try:
|
|
self.nepochs = kwargs['nepochs']
|
|
self.points_batch = kwargs['points_batch']
|
|
self.global_sigma = self.conf.get_float('network.sampler.properties.global_sigma')
|
|
|
|
self.sampler = Sampler.get_sampler(
|
|
self.conf.get_string('network.sampler.sampler_type'))(
|
|
self.global_sigma,
|
|
self.local_sigma
|
|
)
|
|
|
|
self.grad_lambda = self.conf.get_float('network.loss.lambda')
|
|
self.normals_lambda = self.conf.get_float('network.loss.normals_lambda')
|
|
self.with_normals = self.normals_lambda > 0
|
|
self.d_in = self.conf.get_int('train.d_in')
|
|
except Exception as e:
|
|
logger.error(f"Error setting up network parameters: {str(e)}")
|
|
raise
|
|
|
|
def _create_network(self):
|
|
"""创建网络"""
|
|
try:
|
|
self.network = utils.get_class(self.conf.get_string('train.network_class'))(
|
|
d_in=self.d_in,
|
|
n_branch=int(torch.max(self.feature_mask).item()),
|
|
csg_tree=self.csg_tree,
|
|
flag_convex=self.csg_flag_convex,
|
|
**self.conf.get_config('network.inputs')
|
|
)
|
|
|
|
logger.info(f"Network created: {self.network}")
|
|
|
|
if torch.cuda.is_available():
|
|
self.network.cuda()
|
|
except Exception as e:
|
|
logger.error(f"Error creating network: {str(e)}")
|
|
raise
|
|
|
|
def _setup_optimizer(self, kwargs):
|
|
"""设置优化器"""
|
|
try:
|
|
self.lr_schedules = self.get_learning_rate_schedules(
|
|
self.conf.get_list('train.learning_rate_schedule'))
|
|
self.weight_decay = self.conf.get_float('train.weight_decay')
|
|
|
|
self.startepoch = 0
|
|
self.optimizer = torch.optim.Adam([{
|
|
"params": self.network.parameters(),
|
|
"lr": self.lr_schedules[0].get_learning_rate(0),
|
|
"weight_decay": self.weight_decay
|
|
}])
|
|
|
|
# 如果继续训练,加载检查点
|
|
self._load_checkpoints_if_continue(kwargs)
|
|
except Exception as e:
|
|
logger.error(f"Error setting up optimizer: {str(e)}")
|
|
raise
|
|
|
|
def _load_checkpoints_if_continue(self, kwargs):
|
|
"""如果继续训练,加载检查点"""
|
|
try:
|
|
model_params_path = os.path.join(
|
|
self.checkpoints_path, self.model_params_subdir)
|
|
ckpts = os.listdir(model_params_path)
|
|
|
|
if len(ckpts) != 0:
|
|
old_checkpnts_dir = os.path.join(
|
|
self.expdir, self.foldername, 'checkpoints')
|
|
logger.info(f'Loading checkpoint from: {old_checkpnts_dir}')
|
|
|
|
# 加载模型状态
|
|
saved_model_state = torch.load(os.path.join(
|
|
old_checkpnts_dir,
|
|
'ModelParameters',
|
|
f"{kwargs['checkpoint']}.pth"
|
|
))
|
|
self.network.load_state_dict(saved_model_state["model_state_dict"])
|
|
|
|
# 加载优化器状态
|
|
data = torch.load(os.path.join(
|
|
old_checkpnts_dir,
|
|
'OptimizerParameters',
|
|
f"{kwargs['checkpoint']}.pth"
|
|
))
|
|
self.optimizer.load_state_dict(data["optimizer_state_dict"])
|
|
self.startepoch = saved_model_state['epoch']
|
|
except Exception as e:
|
|
logger.error(f"Error loading checkpoints: {str(e)}")
|
|
raise
|
|
|
|
def get_learning_rate_schedules(self, schedule_specs):
|
|
|
|
schedules = []
|
|
|
|
for schedule_specs in schedule_specs:
|
|
|
|
if schedule_specs["Type"] == "Step":
|
|
schedules.append(
|
|
utils.StepLearningRateSchedule(
|
|
schedule_specs["Initial"],
|
|
schedule_specs["Interval"],
|
|
schedule_specs["Factor"],
|
|
)
|
|
)
|
|
|
|
else:
|
|
raise Exception(
|
|
'no known learning rate schedule of type "{}"'.format(
|
|
schedule_specs["Type"]
|
|
)
|
|
)
|
|
|
|
return schedules
|
|
|
|
def adjust_learning_rate(self, epoch):
|
|
for i, param_group in enumerate(self.optimizer.param_groups):
|
|
param_group["lr"] = self.lr_schedules[i].get_learning_rate(epoch)
|
|
|
|
def save_checkpoints(self, epoch):
|
|
|
|
torch.save(
|
|
{"epoch": epoch, "model_state_dict": self.network.state_dict()},
|
|
os.path.join(self.checkpoints_path, self.model_params_subdir, str(epoch) + ".pth"))
|
|
torch.save(
|
|
{"epoch": epoch, "model_state_dict": self.network.state_dict()},
|
|
os.path.join(self.checkpoints_path, self.model_params_subdir, "latest.pth"))
|
|
torch.save(
|
|
{"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()},
|
|
os.path.join(self.checkpoints_path, self.optimizer_params_subdir, str(epoch) + ".pth"))
|
|
torch.save(
|
|
{"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()},
|
|
os.path.join(self.checkpoints_path, self.optimizer_params_subdir, "latest.pth"))
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
if args.gpu == "auto":
|
|
deviceIDs = GPUtil.getAvailable(order='memory', limit=1, maxLoad=0.5, maxMemory=0.5, includeNan=False, excludeID=[],
|
|
excludeUUID=[])
|
|
gpu = deviceIDs[0]
|
|
else:
|
|
gpu = args.gpu
|
|
|
|
nepoch = args.nepoch
|
|
conf = ConfigFactory.parse_file('./conversion/' + args.conf)
|
|
folderprefix = conf.get_string('train.folderprefix')
|
|
fileprefix_list = conf.get_list('train.fileprefix_list')
|
|
|
|
trainrunners = []
|
|
for i in range(len(fileprefix_list)):
|
|
fp = fileprefix_list[i]
|
|
print ('cur model: ', fp)
|
|
|
|
begin = time.time()
|
|
trainrunners.append(ReconstructionRunner(
|
|
conf=args.conf,
|
|
folder_prefix = folderprefix,
|
|
file_prefix = fp,
|
|
points_batch=args.points_batch,
|
|
nepochs=nepoch,
|
|
expname=args.expname,
|
|
gpu_index=gpu,
|
|
is_continue=args.is_continue,
|
|
checkpoint=args.checkpoint,
|
|
eval=args.eval,
|
|
flag_list = True
|
|
)
|
|
)
|
|
|
|
if trainrunners[i].flag_data_load:
|
|
trainrunners[i].run_nhrepnet_training()
|
|
end = time.time()
|
|
dur = end - begin
|
|
if args.baseline:
|
|
fp = fp+"_bl"
|