You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
258 lines
12 KiB
258 lines
12 KiB
import os
|
|
import sys
|
|
import time
|
|
|
|
project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
|
|
sys.path.append(project_dir)
|
|
os.chdir(project_dir)
|
|
|
|
import torch
|
|
import numpy as np
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from tqdm import tqdm
|
|
from pyhocon import ConfigFactory
|
|
from scipy.spatial import cKDTree
|
|
|
|
from utils.logger import logger
|
|
from utils.general import get_class
|
|
from data_loader import NHREP_Dataset
|
|
from loss import LossManager
|
|
from learning_rate import LearningRateScheduler
|
|
from model.network import NHRepNet # 导入 NHRepNet
|
|
from model.sample import Sampler
|
|
|
|
class NHREPNet_Training:
|
|
def __init__(self, name_prefix: str, conf, if_baseline: bool = False, if_feature_sample: bool = False):
|
|
self.conf = conf
|
|
self.sampler = Sampler.get_sampler(
|
|
self.conf.get_string('network.sampler.sampler_type'))(
|
|
global_sigma=self.conf.get_float('network.sampler.properties.global_sigma'),
|
|
local_sigma=self.conf.get_float('network.sampler.properties.local_sigma')
|
|
)
|
|
data_dir = self.conf.get_string('train.input_path')
|
|
self.dataset = NHREP_Dataset(data_dir, name_prefix, if_baseline, if_feature_sample)
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# 初始化模型
|
|
self.d_in = 3 # 输入维度,x, y, z.
|
|
self.dims_sdf = [256, 256, 256] # 隐藏层维度
|
|
|
|
self.nepochs = 15000 # 训练轮数
|
|
folder = self.conf.get_string('train.folderprefix')
|
|
self.writer = SummaryWriter(os.path.join("summary",folder, name_prefix)) # TensorBoard 记录器
|
|
|
|
# checkpoint
|
|
self.init_checkpoints()
|
|
|
|
|
|
def run_nhrepnet_training(self):
|
|
# 数据准备
|
|
logger.info("数据准备")
|
|
self.data = self.dataset.get_data().to(self.device).requires_grad_() # x, y, z, nx, ny, nz
|
|
feature_mask_cpu = self.dataset.get_feature_mask().numpy() # 特征掩码
|
|
self.feature_mask = torch.from_numpy(feature_mask_cpu).to(self.device) # 特征掩码 # 特征掩码
|
|
self.points_batch = 16384 # 批次大小
|
|
self.compute_local_sigma()
|
|
|
|
n_branch = int(torch.max(self.feature_mask).item()) # 计算分支数量
|
|
n_batchsize = self.points_batch # 设置批次大小
|
|
n_patch_batch = n_batchsize // n_branch # 计算每个分支的补丁批次大小
|
|
n_patch_last = n_batchsize - n_patch_batch * (n_branch - 1) # 计算最后一个分支的补丁大小
|
|
# 1,准备训练数据
|
|
# 1.1,计算每个分支的补丁数量
|
|
patch_id, patch_id_n = self.compute_patch(n_branch, n_patch_batch, n_patch_last, feature_mask_cpu)
|
|
|
|
# 1.2,获取分支掩码
|
|
branch_mask, single_branch_mask_gt, single_branch_mask_id = self.get_branch_mask(n_branch, n_patch_batch, n_patch_last)
|
|
|
|
|
|
# 1.3,初始化模型
|
|
csg_tree, flag_convex = self.dataset.get_csg_tree()
|
|
self.model = get_class(self.conf.get_string('train.network_class'))(
|
|
d_in=self.d_in,
|
|
n_branch=n_branch,
|
|
csg_tree=csg_tree,
|
|
flag_convex=flag_convex,
|
|
**self.conf.get_config('network.inputs')
|
|
).to(self.device)
|
|
self.scheduler = LearningRateScheduler(self.conf.get_list('train.learning_rate_schedule'), self.conf.get_float('train.weight_decay'), self.model.parameters())
|
|
self.loss_manager = LossManager(ablation="none")
|
|
|
|
|
|
logger.info("开始训练")
|
|
self.model.train() # 设置模型为训练模式
|
|
|
|
for epoch in tqdm(range(self.nepochs), desc="训练进度", unit="epoch"): # 开始训练循环
|
|
try:
|
|
self.train_one_epoch(epoch, patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch, n_batchsize)
|
|
except Exception as e:
|
|
logger.error(f"训练过程中发生错误: {str(e)}")
|
|
break
|
|
self.tracing()
|
|
|
|
def train_one_epoch(self, epoch, patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch, n_batchsize):
|
|
#logger.info(f"Epoch {epoch}/{self.nepochs} 开始")
|
|
# 1.3,获取索引
|
|
indices = self.get_indices(patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch)
|
|
|
|
# 1.4,获取数据
|
|
cur_data = self.data[indices] # x, y, z, nx, ny, nz
|
|
mnfld_pnts = cur_data[:, :self.d_in] # 提取流形点
|
|
mnfld_sigma = self.local_sigma[indices] # 提取噪声点
|
|
|
|
nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0), mnfld_sigma.unsqueeze(0)).squeeze() # 生成非流形点
|
|
|
|
#TODO 记录了log
|
|
|
|
# 2,前向传播
|
|
#logger.info(f"mnfld_pnts: {mnfld_pnts.shape}")
|
|
#logger.info(f"nonmnfld_pnts: {nonmnfld_pnts.shape}")
|
|
|
|
# 前向传播
|
|
mnfld_pred_all = self.model(mnfld_pnts) # 使用模型进行前向传播
|
|
nonmnfld_pred_all = self.model(nonmnfld_pnts) # 使用模型进行前向传播
|
|
|
|
#logger.info(f"mnfld_pred_all: {mnfld_pred_all.shape}")
|
|
#logger.info(f"nonmnfld_pred_all: {nonmnfld_pred_all.shape}")
|
|
|
|
normals = cur_data[:, -self.d_in:]
|
|
# 计算损失
|
|
loss,loss_details = self.loss_manager.compute_loss(
|
|
mnfld_pnts = mnfld_pnts,
|
|
normals = normals,
|
|
mnfld_pred_all = mnfld_pred_all,
|
|
nonmnfld_pnts = nonmnfld_pnts,
|
|
nonmnfld_pred_all = nonmnfld_pred_all,
|
|
n_batchsize = n_batchsize,
|
|
n_branch = n_branch,
|
|
n_patch_batch = n_patch_batch,
|
|
n_patch_last = n_patch_last,
|
|
) # 计算损失
|
|
|
|
self.scheduler.step(loss,epoch)
|
|
|
|
# 反向传播
|
|
self.scheduler.optimizer.zero_grad() # 清空梯度
|
|
loss.backward() # 反向传播
|
|
self.scheduler.optimizer.step() # 更新参数
|
|
|
|
avg_loss = loss.item()
|
|
if epoch % 100 == 0:
|
|
#logger.info(f'Epoch [{epoch}/{self.nepochs}]')
|
|
self.writer.add_scalar('Loss/train', avg_loss, epoch) # 记录损失到 TensorBoard
|
|
for k,v in loss_details.items():
|
|
self.writer.add_scalar('Loss/'+k, v.item(), epoch)
|
|
if epoch % self.conf.get_int('train.checkpoint_frequency') == 0: # 每隔一定轮次保存检查点
|
|
self.save_checkpoints(epoch)
|
|
|
|
#============================ 前向传播 数据准备 ============================
|
|
def compute_patch(self, n_branch, n_patch_batch, n_patch_last, feature_mask_cpu):
|
|
'''
|
|
计算每个分支的补丁数量
|
|
'''
|
|
patch_id = []
|
|
patch_id_n = []
|
|
for i in range(n_branch):
|
|
patch_id = patch_id + [np.where(feature_mask_cpu == i + 1)[0]]
|
|
patch_id_n = patch_id_n + [patch_id[i].shape[0]]
|
|
return patch_id, patch_id_n
|
|
|
|
def get_branch_mask(self, n_branch, n_patch_batch, n_patch_last):
|
|
'''
|
|
branch_mask: 分支掩码,用于表示每个分支在每个批次中的掩码。每一行对应一个分支,每一列对应一个样本。用于表示每个分支的补丁是否被选中。
|
|
single_branch_mask_gt: 单分支掩码,用于表示每个补丁属于哪个分支。每一行对应一个样本,每一列对应一个分支。用于表示每个补丁属于哪个分支。
|
|
single_branch_mask_id: 单分支 ID,用于表示每个补丁属于哪个分支。
|
|
作用:
|
|
'''
|
|
branch_mask = torch.zeros(n_branch, n_patch_batch).cuda()
|
|
single_branch_mask_gt = torch.zeros(n_patch_batch, n_branch).cuda()
|
|
single_branch_mask_id = torch.zeros([n_patch_batch], dtype=torch.long).cuda()
|
|
for i in range(n_branch - 1):
|
|
branch_mask[i, i * n_patch_batch : (i + 1) * n_patch_batch] = 1.0
|
|
single_branch_mask_gt[i * n_patch_batch : (i + 1) * n_patch_batch, i] = 1.0
|
|
single_branch_mask_id[i * n_patch_batch : (i + 1) * n_patch_batch] = i
|
|
branch_mask[n_branch - 1, (n_branch - 1) * n_patch_batch:] = 1.0
|
|
single_branch_mask_gt[(n_branch - 1) * n_patch_batch:, (n_branch - 1)] = 1.0
|
|
single_branch_mask_id[(n_branch - 1) * n_patch_batch:] = (n_branch - 1)
|
|
return branch_mask, single_branch_mask_gt, single_branch_mask_id
|
|
|
|
def get_indices(self, patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch):
|
|
indices = torch.empty(0, dtype=torch.int64).cuda()
|
|
for i in range(n_branch - 1):
|
|
indices_nonfeature = torch.tensor(patch_id[i][np.random.choice(patch_id_n[i], n_patch_batch, True)]).cuda()
|
|
indices = torch.cat((indices, indices_nonfeature), 0)
|
|
# last patch
|
|
indices_nonfeature = torch.tensor(patch_id[n_branch - 1][np.random.choice(patch_id_n[n_branch - 1], n_patch_last, True)]).cuda()
|
|
indices = torch.cat((indices, indices_nonfeature), 0)
|
|
return indices
|
|
|
|
def compute_local_sigma(self):
|
|
"""计算局部sigma值"""
|
|
try:
|
|
sigma_set = []
|
|
data_cpu = self.data.detach().cpu().numpy()
|
|
ptree = cKDTree(data_cpu)
|
|
logger.debug("KD tree constructed")
|
|
|
|
for p in np.array_split(data_cpu, 100, axis=0):
|
|
d = ptree.query(p, 50 + 1)
|
|
sigma_set.append(d[0][:, -1])
|
|
|
|
sigmas = np.concatenate(sigma_set)
|
|
self.local_sigma = torch.from_numpy(sigmas).float().cuda()
|
|
except Exception as e:
|
|
logger.error(f"Error computing local sigma: {str(e)}")
|
|
raise
|
|
|
|
|
|
|
|
|
|
#============================ 保存模型 ============================
|
|
def init_checkpoints(self):
|
|
self.checkpoints_path = os.path.join("../exps/single_shape",name_prefix, "checkpoints")
|
|
self.ModelParameters_path = os.path.join(self.checkpoints_path, "ModelParameters")
|
|
self.OptimizerParameters_path = os.path.join(self.checkpoints_path, "OptimizerParameters")
|
|
self.TorchScript_path = os.path.join(self.checkpoints_path, "TorchScript")
|
|
|
|
# 创建目录
|
|
os.makedirs(self.ModelParameters_path, exist_ok=True)
|
|
os.makedirs(self.OptimizerParameters_path, exist_ok=True)
|
|
os.makedirs(self.TorchScript_path, exist_ok=True)
|
|
|
|
def save_checkpoints(self, epoch):
|
|
torch.save(
|
|
{"epoch": epoch, "model_state_dict": self.model.state_dict()},
|
|
os.path.join(self.ModelParameters_path, str(epoch) + ".pth"))
|
|
torch.save(
|
|
{"epoch": epoch, "model_state_dict": self.model.state_dict()},
|
|
os.path.join(self.ModelParameters_path, "latest.pth"))
|
|
torch.save(
|
|
{"epoch": epoch, "optimizer_state_dict": self.scheduler.optimizer.state_dict()},
|
|
os.path.join(self.OptimizerParameters_path, str(epoch) + ".pth"))
|
|
torch.save(
|
|
{"epoch": epoch, "optimizer_state_dict": self.scheduler.optimizer.state_dict()},
|
|
os.path.join(self.OptimizerParameters_path, "latest.pth"))
|
|
|
|
def tracing(self):
|
|
csg_tree, flag_convex = self.dataset.get_csg_tree()
|
|
network = get_class(self.conf.get_string('train.network_class'))(
|
|
d_in=self.d_in,
|
|
n_branch=int(torch.max(self.feature_mask).item()),
|
|
csg_tree=csg_tree,
|
|
flag_convex=flag_convex,
|
|
**self.conf.get_config('network.inputs')
|
|
).to(self.device)
|
|
#trace
|
|
example = torch.rand(224,3).to(self.device)
|
|
traced_script_module = torch.jit.trace(network, example)
|
|
traced_script_module.save(os.path.join(self.TorchScript_path, "model_h.pt"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
name_prefix = 'broken_bullet_50k'
|
|
conf = ConfigFactory.parse_file('./conversion/setup.conf')
|
|
try:
|
|
train = NHREPNet_Training(name_prefix, conf, if_baseline=True, if_feature_sample=False)
|
|
train.run_nhrepnet_training()
|
|
except Exception as e:
|
|
logger.error(str(e))
|