You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

724 lines
36 KiB

5 months ago
import os
import sys
import time
5 months ago
project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
sys.path.append(project_dir)
os.chdir(project_dir)
import argparse
# parse args first and set gpu id
parser = argparse.ArgumentParser()
parser.add_argument('--points_batch', type=int, default=16384, help='point batch size')
parser.add_argument('--nepoch', type=int, default=15001, help='number of epochs to train for')
parser.add_argument('--conf', type=str, default='setup.conf')
parser.add_argument('--expname', type=str, default='single_shape')
parser.add_argument('--gpu', type=str, default='0', help='GPU id to use')
parser.add_argument('--is_continue', default=False, action="store_true", help='continue')
parser.add_argument('--checkpoint', default='latest', type=str)
parser.add_argument('--eval', default=False, action="store_true")
parser.add_argument('--summary', default = False, action="store_true", help = 'write tensorboard summary')
parser.add_argument('--baseline', default = False, action="store_true", help = 'run baseline')
parser.add_argument('--th_closeness',type=float, default = 1e-5, help = 'threshold deciding whether two points are the same')
parser.add_argument('--cpu', default = False, action="store_true", help = 'save for cpu device')
parser.add_argument('--ab', default='none', type=str, help = 'ablation')
parser.add_argument('--siren', default = False, action="store_true", help = 'siren normal loss')
parser.add_argument('--pt', default='ptfile path', type=str)
parser.add_argument('--feature_sample', action="store_true", help = 'use feature curve samples')
parser.add_argument('--num_feature_sample', type=int, default=2048, help ='number of bs feature samples')
parser.add_argument('--all_feature_sample', type=int, default=10000, help ='number of all feature samples')
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(args.gpu)
from pyhocon import ConfigFactory
import numpy as np
import GPUtil
import torch
import utils.general as utils
from model.sample import Sampler
from model.network import gradient
from scipy.spatial import cKDTree
from utils.plots import plot_surface, plot_cuts_axis
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from utils.logger import logger
logger.info(f"project_dir: {project_dir}")
5 months ago
class ReconstructionRunner:
def run_nhrepnet_training(self):
print("running") # 输出训练开始的提示信息
self.data = self.data.cuda() # 将数据移动到 GPU 上
self.data.requires_grad_() # 设置数据以便计算梯度
feature_mask_cpu = self.feature_mask.numpy() # 将特征掩码转换为 NumPy 数组
self.feature_mask = self.feature_mask.cuda() # 将特征掩码移动到 GPU 上
n_branch = int(torch.max(self.feature_mask).item()) # 计算分支数量
n_batchsize = self.points_batch # 设置批次大小
n_patch_batch = n_batchsize // n_branch # 计算每个分支的补丁批次大小
n_patch_last = n_batchsize - n_patch_batch * (n_branch - 1) # 计算最后一个分支的补丁大小
5 months ago
patch_sup = True # 设置补丁支持标志
weight_mnfld_h = 1 # 初始化流形损失权重
weight_mnfld_cs = 1 # 初始化流形一致性损失权重
weight_correction = 1 # 初始化修正损失权重
a_correction = 100 # 初始化修正损失的系数
5 months ago
patch_id = [] # 初始化补丁 ID 列表
patch_id_n = [] # 初始化补丁数量列表
5 months ago
for i in range(n_branch):
patch_id = patch_id + [np.where(feature_mask_cpu == i + 1)[0]] # 找到每个分支的补丁 ID
patch_id_n = patch_id_n + [patch_id[i].shape[0]] # 记录每个补丁的数量
if self.eval: # 检查是否处于评估模式
print("evaluating epoch: {0}".format(self.startepoch)) # 输出当前评估的轮次
my_path = os.path.join(self.cur_exp_dir, 'evaluation', str(self.startepoch)) # 创建评估结果保存路径
utils.mkdir_ifnotexists(os.path.join(self.cur_exp_dir, 'evaluation')) # 确保评估目录存在
utils.mkdir_ifnotexists(my_path) # 确保当前评估路径存在
5 months ago
for i in range(1):
self.network.flag_output = i + 1 # 设置网络输出标志
self.plot_shapes(epoch=self.startepoch, path=my_path, file_suffix = "_" + str(i), with_cuts = True) # 绘制评估结果
self.network.flag_output = 0 # 将输出标志重置为 0
return # 结束方法
print("training begin") # 输出训练开始的提示信息
if args.summary == True: # 如果启用了摘要记录
writer = SummaryWriter(os.path.join("summary", self.foldername)) # 创建一个 SummaryWriter 实例
5 months ago
# branch mask is predefined
branch_mask = torch.zeros(n_branch, n_batchsize).cuda() # 初始化分支掩码
single_branch_mask_gt = torch.zeros(n_batchsize, n_branch).cuda() # 初始化单分支掩码
single_branch_mask_id = torch.zeros([n_batchsize], dtype = torch.long).cuda() # 初始化单分支 ID
5 months ago
for i in range(n_branch - 1):
branch_mask[i, i * n_patch_batch : (i + 1) * n_patch_batch] = 1.0 # 设置当前分支的补丁掩码
single_branch_mask_gt[i * n_patch_batch : (i + 1) * n_patch_batch, i] = 1.0 # 设置单分支的真实掩码
single_branch_mask_id[i * n_patch_batch : (i + 1) * n_patch_batch] = i # 设置单分支 ID
5 months ago
# last patch
branch_mask[n_branch - 1, (n_branch - 1) * n_patch_batch:] = 1.0 # 设置最后一个分支的补丁掩码
single_branch_mask_gt[(n_branch - 1) * n_patch_batch:, (n_branch - 1)] = 1.0 # 设置最后一个分支的真实掩码
single_branch_mask_id[(n_branch - 1) * n_patch_batch:] = (n_branch - 1) # 设置最后一个分支 ID
5 months ago
for epoch in range(self.startepoch, self.nepochs + 1): # 开始训练循环
indices = torch.empty(0,dtype=torch.int64).cuda() # 初始化索引张量
5 months ago
for i in range(n_branch - 1):
indices_nonfeature = torch.tensor(patch_id[i][np.random.choice(patch_id_n[i], n_patch_batch, True)]).cuda() # 随机选择非特征补丁的索引
indices = torch.cat((indices, indices_nonfeature), 0) # 将索引添加到总索引中
# last patch
indices_nonfeature = torch.tensor(patch_id[n_branch - 1][np.random.choice(patch_id_n[n_branch - 1], n_patch_last, True)]).cuda() # 处理最后一个补丁的索引
indices = torch.cat((indices, indices_nonfeature), 0) # 将最后一个补丁的索引添加到总索引中
5 months ago
cur_data = self.data[indices] # 根据索引获取当前数据
mnfld_pnts = cur_data[:, :self.d_in] # 提取流形点
mnfld_sigma = self.local_sigma[indices] # 提取噪声点
5 months ago
if epoch % self.conf.get_int('train.checkpoint_frequency') == 0: # 每隔一定轮次保存检查点
5 months ago
self.save_checkpoints(epoch)
if epoch % self.conf.get_int('train.plot_frequency') == 0: # 每隔一定轮次绘制验证结果
print('plot validation epoch: ', epoch) # 输出当前绘制的轮次
5 months ago
for i in range(n_branch + 1):
self.network.flag_output = i + 1 # 设置网络输出标志
self.plot_shapes(epoch, file_suffix = "_" + str(i), with_cuts = False) # 绘制形状
self.network.flag_output = 0 # 将输出标志重置为 0
5 months ago
self.network.train() # 设置网络为训练模式
self.adjust_learning_rate(epoch) # 调整学习率
nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0), mnfld_sigma.unsqueeze(0)).squeeze() # 生成非流形点
5 months ago
# forward pass
mnfld_pred_all = self.network(mnfld_pnts) # 进行前向传播,计算流形点的预测值
nonmnfld_pred_all = self.network(nonmnfld_pnts) # 进行前向传播,计算非流形点的预测值
mnfld_pred = mnfld_pred_all[:,0] # 提取流形预测结果
nonmnfld_pred = nonmnfld_pred_all[:,0] # 提取非流形预测结果
loss = 0.0 # 初始化损失为 0
mnfld_grad = gradient(mnfld_pnts, mnfld_pred) # 计算流形点的梯度
5 months ago
# manifold loss
mnfld_loss = torch.zeros(1).cuda() # 初始化流形损失
if not args.ab == 'overall': # 检查是否为整体损失
mnfld_loss = (mnfld_pred.abs()).mean() # 计算流形损失
loss = loss + weight_mnfld_h * mnfld_loss # 将流形损失加权到总损失中
5 months ago
# feature sample
if args.feature_sample: # 如果启用了特征采样
feature_indices = torch.randperm(args.all_feature_sample)[:args.num_feature_sample].cuda() # 随机选择特征点
feature_pnts = self.feature_data[feature_indices] # 获取特征点数据
feature_mask_pair = self.feature_data_mask_pair[feature_indices] # 获取特征掩码对
feature_pred_all = self.network(feature_pnts) # 进行前向传播,计算特征点的预测值
feature_pred = feature_pred_all[:,0] # 提取特征预测结果
feature_mnfld_loss = feature_pred.abs().mean() # 计算特征流形损失
loss = loss + weight_mnfld_h * feature_mnfld_loss # 将特征流形损失加权到总损失中
5 months ago
# patch loss:
feature_id_left = [list(range(args.num_feature_sample)), feature_mask_pair[:,0].tolist()] # 获取左侧特征 ID
feature_id_right = [list(range(args.num_feature_sample)), feature_mask_pair[:,1].tolist()] # 获取右侧特征 ID
feature_fis_left = feature_pred_all[feature_id_left] # 获取左侧特征预测值
feature_fis_right = feature_pred_all[feature_id_right] # 获取右侧特征预测值
feature_loss_patch = feature_fis_left.abs().mean() + feature_fis_right.abs().mean() # 计算补丁损失
loss += feature_loss_patch # 将补丁损失加权到总损失中
5 months ago
# consistency loss:
feature_loss_cons = (feature_fis_left - feature_pred).abs().mean() + (feature_fis_right - feature_pred).abs().mean() # 计算一致性损失
loss += weight_mnfld_cs * feature_loss_cons # 将一致性损失加权到总损失中
5 months ago
all_fi = torch.zeros([n_batchsize, 1], device = 'cuda') # 初始化所有流形预测值
5 months ago
for i in range(n_branch - 1):
all_fi[i * n_patch_batch : (i + 1) * n_patch_batch, 0] = mnfld_pred_all[i * n_patch_batch : (i + 1) * n_patch_batch, i + 1] # 填充流形预测值
# last patch
all_fi[(n_branch - 1) * n_patch_batch:, 0] = mnfld_pred_all[(n_branch - 1) * n_patch_batch:, n_branch] # 填充最后一个分支的流形预测值
5 months ago
# manifold loss for patches
mnfld_loss_patch = torch.zeros(1).cuda() # 初始化补丁流形损失
if not args.ab == 'patch': # 检查是否为补丁损失
if patch_sup: # 如果支持补丁
mnfld_loss_patch = all_fi[:,0].abs().mean() # 计算补丁流形损失
loss = loss + mnfld_loss_patch # 将补丁流形损失加权到总损失中
5 months ago
# correction loss
correction_loss = torch.zeros(1).cuda() # 初始化修正损失
if not (args.ab == 'cor' or args.ab == 'cc') and epoch > 10000 and not args.baseline: # 检查修正损失的条件
mismatch_id = torch.abs(mnfld_pred - all_fi[:,0]) > args.th_closeness # 计算不匹配的 ID
if mismatch_id.sum() != 0: # 如果存在不匹配
correction_loss = (a_correction * torch.abs(mnfld_pred - all_fi[:,0])[mismatch_id]).mean() # 计算修正损失
loss = loss + weight_correction * correction_loss # 将修正损失加权到总损失中
5 months ago
# off surface loss
offsurface_loss = torch.zeros(1).cuda() # 初始化离表面损失
if not args.ab == 'off': # 检查是否为离表面损失
offsurface_loss = torch.exp(-100.0 * torch.abs(nonmnfld_pred[n_batchsize:])).mean() # 计算离表面损失
loss = loss + offsurface_loss # 将离表面损失加权到总损失中
5 months ago
# manifold consistency loss
mnfld_consistency_loss = torch.zeros(1).cuda() # 初始化流形一致性损失
if not (args.ab == 'cons' or args.ab == 'cc'): # 检查是否为一致性损失
mnfld_consistency_loss = (mnfld_pred - all_fi[:,0]).abs().mean() # 计算流形一致性损失
loss = loss + weight_mnfld_cs * mnfld_consistency_loss # 将一致性损失加权到总损失中
5 months ago
# eikonal loss for h
grad_loss_h = torch.zeros(1).cuda() # 初始化 Eikonal 损失
single_nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred_all[:,0]) # 计算非流形点的梯度
grad_loss_h = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean() # 计算 Eikonal 损失
loss = loss + self.grad_lambda * grad_loss_h # 将 Eikonal 损失加权到总损失中
5 months ago
# normals loss
normals_loss_h = torch.zeros(1).cuda() # 初始化法线损失
normals_loss = torch.zeros(1).cuda() # 初始化法线损失
normal_consistency_loss = torch.zeros(1).cuda() # 初始化法线一致性损失
if not args.siren: # 检查是否使用 SIREN
if not args.ab == 'normal' and self.with_normals: # 检查法线损失的条件
# all normals
normals = cur_data[:, -self.d_in:] # 提取法线
if patch_sup: # 如果支持补丁
branch_grad = gradient(mnfld_pnts, all_fi[:,0]) # 计算分支梯度
normals_loss = (((branch_grad - normals).abs()).norm(2, dim=1)).mean() # 计算法线损失
loss = loss + self.normals_lambda * normals_loss # 将法线损失加权到总损失中
5 months ago
# only supervised, not used for loss computation
mnfld_grad = gradient(mnfld_pnts, mnfld_pred_all[:, 0]) # 计算流形梯度
normal_consistency_loss = (mnfld_grad - branch_grad).abs().norm(2, dim=1).mean() # 计算法线一致性损失
5 months ago
else:
single_nonmnfld_grad = gradient(mnfld_pnts, all_fi[:,0]) # 计算非流形点的梯度
normals_loss_h = ((single_nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean() # 计算法线损失
loss = loss + self.normals_lambda * normals_loss_h # 将法线损失加权到总损失中
5 months ago
else:
# compute cosine normal
normals = cur_data[:, -self.d_in:] # 提取法线
normals_loss_h = (1 - F.cosine_similarity(mnfld_grad, normals, dim=-1)).mean() # 计算法线的余弦相似度损失
loss = loss + self.normals_lambda * normals_loss_h # 将法线损失加权到总损失中
5 months ago
self.optimizer.zero_grad() # 清零优化器的梯度
loss.backward() # 反向传播计算梯度
self.optimizer.step() # 更新模型参数
5 months ago
# tensorboard
if args.summary == True and epoch % 100 == 0: # 每 100 轮记录损失到 TensorBoard
writer.add_scalar('Loss/Total loss', loss.item(), epoch) # 记录总损失
writer.add_scalar('Loss/Manifold loss h', mnfld_loss.item(), epoch) # 记录流形损失
writer.add_scalar('Loss/Manifold patch loss', mnfld_loss_patch.item(), epoch) # 记录补丁流形损失
writer.add_scalar('Loss/Manifold cons loss', mnfld_consistency_loss.item(), epoch) # 记录流形一致性损失
writer.add_scalar('Loss/Grad loss h',self.grad_lambda * grad_loss_h.item(), epoch) # 记录 Eikonal 损失
writer.add_scalar('Loss/Normal loss all', self.normals_lambda * normals_loss.item(), epoch) # 记录法线损失
writer.add_scalar('Loss/Normal cs loss', self.normals_lambda * normal_consistency_loss.item(), epoch) # 记录法线一致性损失
writer.add_scalar('Loss/Assignment loss', correction_loss.item(), epoch) # 记录修正损失
writer.add_scalar('Loss/Offsurface loss', offsurface_loss.item(), epoch) # 记录离表面损失
5 months ago
if epoch % self.conf.get_int('train.status_frequency') == 0: # 每隔一定轮次记录训练状态
logger.info('Train Epoch: [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}\t Manifold loss: {:.6f}'
5 months ago
'\tManifold patch loss: {:.6f}\t grad loss h: {:.6f}\t normals loss all: {:.6f}\t normals loss h: {:.6f}\t Manifold consistency loss: {:.6f}\tCorrection loss: {:.6f}\t Offsurface loss: {:.6f}'.format(
epoch, self.nepochs, 100. * epoch / self.nepochs,
loss.item(), mnfld_loss.item(), mnfld_loss_patch.item(), grad_loss_h.item(), normals_loss.item(), normals_loss_h.item(), mnfld_consistency_loss.item(), correction_loss.item(), offsurface_loss.item()))
if args.feature_sample: # 如果启用了特征采样
logger.info('feature mnfld loss: {} patch loss: {} cons loss: {}'.format(feature_mnfld_loss.item(), feature_loss_patch.item(), feature_loss_cons.item())) # 记录特征损失信息
5 months ago
self.tracing() # 调用 tracing 方法,可能用于记录或保存训练过程中的某些信息
5 months ago
def tracing(self):
#network definition
device = torch.device('cuda')
if args.cpu:
device = torch.device('cpu')
network = utils.get_class(self.conf.get_string('train.network_class'))(d_in=3, flag_output = 1,
n_branch = int(torch.max(self.feature_mask).item()),
csg_tree = self.csg_tree,
flag_convex = self.csg_flag_convex,
**self.conf.get_config(
'network.inputs'))
network.to(device)
ckpt_prefix = 'exps/single_shape/'
save_prefix = '{}/'.format(args.pt)
if not os.path.exists(save_prefix):
os.mkdir(save_prefix)
if args.cpu:
saved_model_state = torch.load(ckpt_prefix + self.foldername + '/checkpoints/ModelParameters/latest.pth', map_location=device)
network.load_state_dict(saved_model_state["model_state_dict"])
else:
saved_model_state = torch.load(ckpt_prefix + self.foldername + '/checkpoints/ModelParameters/latest.pth')
network.load_state_dict(saved_model_state["model_state_dict"])
#trace
example = torch.rand(224,3).to(device)
traced_script_module = torch.jit.trace(network, example)
traced_script_module.save(save_prefix + self.foldername + "_model_h.pt")
logger.info('converting to pt finished')
5 months ago
def plot_shapes(self, epoch, path=None, with_cuts=False, file_suffix="all"):
# plot network validation shapes
with torch.no_grad():
self.network.eval()
if not path:
path = self.plots_dir
indices = torch.tensor(np.random.choice(self.data.shape[0], self.points_batch, True)) #modified 0107, with replace
pnts = self.data[indices, :3]
#draw nonmnfld pts
mnfld_sigma = self.local_sigma[indices]
nonmnfld_pnts = self.sampler.get_points(pnts.unsqueeze(0), mnfld_sigma.unsqueeze(0)).squeeze()
pnts = nonmnfld_pnts
plot_surface(with_points=True,
points=pnts,
decoder=self.network,
path=path,
epoch=epoch,
shapename=self.expname,
suffix = file_suffix,
**self.conf.get_config('plot'))
if with_cuts:
plot_cuts_axis(points=pnts,
decoder=self.network,
latent = None,
path=path,
epoch=epoch,
near_zero=False,
axis = 2)
def __init__(self, **kwargs):
try:
# 1. 基础设置初始化
self._initialize_basic_settings(kwargs)
# 2. 配置文件和实验目录设置
self._setup_config_and_directories(kwargs)
# 3. 数据加载
self._load_data(kwargs)
# 4. CSG树设置
self._setup_csg_tree()
# 5. 本地sigma计算
self._compute_local_sigma()
# 6. 网络和优化器设置
self._setup_network_and_optimizer(kwargs)
print("Initialization completed successfully")
except Exception as e:
logger.error(f"Error during initialization: {str(e)}")
raise
def _initialize_basic_settings(self, kwargs):
"""初始化基础设置"""
try:
self.home_dir = os.path.abspath(os.getcwd())
self.flag_list = kwargs.get('flag_list', False)
self.expname = kwargs['expname']
self.GPU_INDEX = kwargs['gpu_index']
self.num_of_gpus = torch.cuda.device_count()
self.eval = kwargs['eval']
logger.debug("Basic settings initialized successfully")
except KeyError as e:
logger.error(f"Missing required parameter: {str(e)}")
raise
except Exception as e:
logger.error(f"Error in basic settings initialization: {str(e)}")
raise
def _setup_config_and_directories(self, kwargs):
"""设置配置文件和创建必要的目录"""
try:
# 配置设置
if isinstance(kwargs['conf'], str):
self.conf_filename = './conversion/' + kwargs['conf']
self.conf = ConfigFactory.parse_file(self.conf_filename)
else:
self.conf = kwargs['conf']
# 创建实验目录
self.exps_folder_name = 'exps'
self.expdir = utils.concat_home_dir(os.path.join(
self.home_dir, self.exps_folder_name, self.expname))
utils.mkdir_ifnotexists(utils.concat_home_dir(
os.path.join(self.home_dir, self.exps_folder_name)))
utils.mkdir_ifnotexists(self.expdir)
logger.debug("Config and directories setup completed")
except Exception as e:
logger.error(f"Error in config and directory setup: {str(e)}")
raise
def _load_data(self, kwargs):
"""加载数据和特征掩码"""
try:
if not self.flag_list:
self._load_single_data()
else:
self._load_data_from_list(kwargs)
if args.baseline:
self.feature_mask = torch.ones(self.data.shape[0]).float()
logger.info(f"Data loading finished. Data shape: {self.data.shape}")
except FileNotFoundError as e:
logger.error(f"Data file not found: {str(e)}")
raise
except Exception as e:
logger.error(f"Error in data loading: {str(e)}")
raise
def _load_single_data(self):
"""加载单个数据文件"""
try:
5 months ago
self.input_file = self.conf.get_string('train.input_path')
self.data = utils.load_point_cloud_by_file_extension(self.input_file)
self.feature_mask_file = self.conf.get_string('train.feature_mask_path')
self.feature_mask = utils.load_feature_mask(self.feature_mask_file)
self.foldername = self.conf.get_string('train.foldername')
except Exception as e:
logger.error(f"Error loading single data file: {str(e)}")
raise
def _load_data_from_list(self, kwargs):
"""从列表加载数据"""
try:
self.input_file = os.path.join(
self.conf.get_string('train.input_path'),
kwargs['file_prefix']+'.xyz'
)
5 months ago
if not os.path.exists(self.input_file):
self.flag_data_load = False
raise Exception(f"Data file not found: {self.input_file}, absolute path: {os.path.abspath(self.input_file)}")
5 months ago
self.flag_data_load = True
self.data = utils.load_point_cloud_by_file_extension(self.input_file)
self.feature_mask_file = os.path.join(
self.conf.get_string('train.input_path'),
kwargs['file_prefix']+'_mask.txt'
)
5 months ago
if not args.baseline:
self.feature_mask = utils.load_feature_mask(self.feature_mask_file)
5 months ago
if args.feature_sample:
self._load_feature_samples(kwargs)
5 months ago
self.foldername = kwargs['folder_prefix'] + kwargs['file_prefix']
except Exception as e:
logger.error(f"Error loading data from list: {str(e)}")
raise
def _load_feature_samples(self, kwargs):
"""加载特征样本"""
try:
input_fs_file = os.path.join(
self.conf.get_string('train.input_path'),
kwargs['file_prefix']+'_feature.xyz'
)
self.feature_data = np.loadtxt(input_fs_file)
self.feature_data = torch.tensor(
self.feature_data,
dtype=torch.float32,
device='cuda'
)
fs_mask_file = os.path.join(
self.conf.get_string('train.input_path'),
kwargs['file_prefix']+'_feature_mask.txt'
)
self.feature_data_mask_pair = torch.tensor(
np.loadtxt(fs_mask_file),
dtype=torch.int64,
device='cuda'
)
except Exception as e:
logger.error(f"Error loading feature samples: {str(e)}")
raise
def _setup_csg_tree(self):
"""设置CSG树"""
try:
if args.baseline:
self.csg_tree = [0]
self.csg_flag_convex = True
else:
csg_conf_file = self.input_file[:-4]+'_csg.conf'
csg_config = ConfigFactory.parse_file(csg_conf_file)
self.csg_tree = csg_config.get_list('csg.list')
self.csg_flag_convex = csg_config.get_int('csg.flag_convex')
logger.info(f"CSG tree: {self.csg_tree}")
logger.info(f"CSG convex flag: {self.csg_flag_convex}")
except Exception as e:
logger.error(f"Error in CSG tree setup: {str(e)}")
raise
def _compute_local_sigma(self):
"""计算局部sigma值"""
try:
sigma_set = []
ptree = cKDTree(self.data)
logger.debug("KD tree constructed")
for p in np.array_split(self.data, 100, axis=0):
d = ptree.query(p, 50 + 1)
sigma_set.append(d[0][:, -1])
sigmas = np.concatenate(sigma_set)
self.local_sigma = torch.from_numpy(sigmas).float().cuda()
except Exception as e:
logger.error(f"Error computing local sigma: {str(e)}")
raise
def _setup_network_and_optimizer(self, kwargs):
"""设置网络和优化器"""
try:
# 设置目录
self._setup_checkpoints_directories()
# 网络参数设置
self._setup_network_parameters(kwargs)
# 创建网络
self._create_network()
# 设置优化器
self._setup_optimizer(kwargs)
logger.debug("Network and optimizer setup completed")
except Exception as e:
logger.error(f"Error in network and optimizer setup: {str(e)}")
raise
def _setup_checkpoints_directories(self):
"""设置检查点目录"""
try:
self.cur_exp_dir = os.path.join(self.expdir, self.foldername)
utils.mkdir_ifnotexists(self.cur_exp_dir)
self.plots_dir = os.path.join(self.cur_exp_dir, 'plots')
utils.mkdir_ifnotexists(self.plots_dir)
self.checkpoints_path = os.path.join(self.cur_exp_dir, 'checkpoints')
utils.mkdir_ifnotexists(self.checkpoints_path)
self.model_params_subdir = "ModelParameters"
self.optimizer_params_subdir = "OptimizerParameters"
utils.mkdir_ifnotexists(os.path.join(
self.checkpoints_path, self.model_params_subdir))
utils.mkdir_ifnotexists(os.path.join(
self.checkpoints_path, self.optimizer_params_subdir))
except Exception as e:
logger.error(f"Error setting up checkpoint directories: {str(e)}")
raise
def _setup_network_parameters(self, kwargs):
"""设置网络参数"""
try:
self.nepochs = kwargs['nepochs']
self.points_batch = kwargs['points_batch']
self.global_sigma = self.conf.get_float('network.sampler.properties.global_sigma')
self.sampler = Sampler.get_sampler(
self.conf.get_string('network.sampler.sampler_type'))(
self.global_sigma,
self.local_sigma
)
self.grad_lambda = self.conf.get_float('network.loss.lambda')
self.normals_lambda = self.conf.get_float('network.loss.normals_lambda')
self.with_normals = self.normals_lambda > 0
self.d_in = self.conf.get_int('train.d_in')
except Exception as e:
logger.error(f"Error setting up network parameters: {str(e)}")
raise
def _create_network(self):
"""创建网络"""
try:
self.network = utils.get_class(self.conf.get_string('train.network_class'))(
d_in=self.d_in,
n_branch=int(torch.max(self.feature_mask).item()),
csg_tree=self.csg_tree,
flag_convex=self.csg_flag_convex,
**self.conf.get_config('network.inputs')
)
logger.info(f"Network created: {self.network}")
if torch.cuda.is_available():
self.network.cuda()
except Exception as e:
logger.error(f"Error creating network: {str(e)}")
raise
def _setup_optimizer(self, kwargs):
"""设置优化器"""
try:
self.lr_schedules = self.get_learning_rate_schedules(
self.conf.get_list('train.learning_rate_schedule'))
self.weight_decay = self.conf.get_float('train.weight_decay')
self.startepoch = 0
self.optimizer = torch.optim.Adam([{
"params": self.network.parameters(),
"lr": self.lr_schedules[0].get_learning_rate(0),
"weight_decay": self.weight_decay
}])
# 如果继续训练,加载检查点
self._load_checkpoints_if_continue(kwargs)
except Exception as e:
logger.error(f"Error setting up optimizer: {str(e)}")
raise
def _load_checkpoints_if_continue(self, kwargs):
"""如果继续训练,加载检查点"""
try:
model_params_path = os.path.join(
self.checkpoints_path, self.model_params_subdir)
ckpts = os.listdir(model_params_path)
if len(ckpts) != 0:
old_checkpnts_dir = os.path.join(
self.expdir, self.foldername, 'checkpoints')
logger.info(f'Loading checkpoint from: {old_checkpnts_dir}')
# 加载模型状态
saved_model_state = torch.load(os.path.join(
old_checkpnts_dir,
'ModelParameters',
f"{kwargs['checkpoint']}.pth"
))
self.network.load_state_dict(saved_model_state["model_state_dict"])
# 加载优化器状态
data = torch.load(os.path.join(
old_checkpnts_dir,
'OptimizerParameters',
f"{kwargs['checkpoint']}.pth"
))
self.optimizer.load_state_dict(data["optimizer_state_dict"])
self.startepoch = saved_model_state['epoch']
except Exception as e:
logger.error(f"Error loading checkpoints: {str(e)}")
raise
5 months ago
def get_learning_rate_schedules(self, schedule_specs):
schedules = []
for schedule_specs in schedule_specs:
if schedule_specs["Type"] == "Step":
schedules.append(
utils.StepLearningRateSchedule(
schedule_specs["Initial"],
schedule_specs["Interval"],
schedule_specs["Factor"],
)
)
else:
raise Exception(
'no known learning rate schedule of type "{}"'.format(
schedule_specs["Type"]
)
)
return schedules
def adjust_learning_rate(self, epoch):
for i, param_group in enumerate(self.optimizer.param_groups):
param_group["lr"] = self.lr_schedules[i].get_learning_rate(epoch)
def save_checkpoints(self, epoch):
torch.save(
{"epoch": epoch, "model_state_dict": self.network.state_dict()},
os.path.join(self.checkpoints_path, self.model_params_subdir, str(epoch) + ".pth"))
torch.save(
{"epoch": epoch, "model_state_dict": self.network.state_dict()},
os.path.join(self.checkpoints_path, self.model_params_subdir, "latest.pth"))
torch.save(
{"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()},
os.path.join(self.checkpoints_path, self.optimizer_params_subdir, str(epoch) + ".pth"))
torch.save(
{"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()},
os.path.join(self.checkpoints_path, self.optimizer_params_subdir, "latest.pth"))
5 months ago
if __name__ == '__main__':
if args.gpu == "auto":
deviceIDs = GPUtil.getAvailable(order='memory', limit=1, maxLoad=0.5, maxMemory=0.5, includeNan=False, excludeID=[],
excludeUUID=[])
gpu = deviceIDs[0]
else:
gpu = args.gpu
nepoch = args.nepoch
conf = ConfigFactory.parse_file('./conversion/' + args.conf)
folderprefix = conf.get_string('train.folderprefix')
fileprefix_list = conf.get_list('train.fileprefix_list')
trainrunners = []
for i in range(len(fileprefix_list)):
fp = fileprefix_list[i]
print ('cur model: ', fp)
begin = time.time()
trainrunners.append(ReconstructionRunner(
conf=args.conf,
folder_prefix = folderprefix,
file_prefix = fp,
points_batch=args.points_batch,
nepochs=nepoch,
expname=args.expname,
gpu_index=gpu,
is_continue=args.is_continue,
checkpoint=args.checkpoint,
eval=args.eval,
flag_list = True
)
)
if trainrunners[i].flag_data_load:
trainrunners[i].run_nhrepnet_training()
end = time.time()
dur = end - begin
if args.baseline:
fp = fp+"_bl"