From a40fb611d5481093bb7ad20e45c548a6df2c3315 Mon Sep 17 00:00:00 2001 From: mckay Date: Tue, 18 Feb 2025 21:45:47 +0800 Subject: [PATCH] feat: Implement NHREP Dataset and Training Pipeline - Added `data_loader.py` with `NHREP_Dataset` class for loading point cloud, feature mask, and CSG tree data - Implemented `CustomDataLoader` for flexible data loading with configurable parameters - Refactored `train.py` to create a structured training pipeline for NHRepNet - Added support for feature sampling, device selection, and TensorBoard logging - Introduced modular training methods with error handling and logging --- code/conversion/data_loader.py | 192 +++++++++++++++++++++++++++++++++ code/conversion/train.py | 100 ++++++++++++----- 2 files changed, 264 insertions(+), 28 deletions(-) create mode 100644 code/conversion/data_loader.py diff --git a/code/conversion/data_loader.py b/code/conversion/data_loader.py new file mode 100644 index 0000000..2c7ca12 --- /dev/null +++ b/code/conversion/data_loader.py @@ -0,0 +1,192 @@ +import os +import sys +import time + +project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) +sys.path.append(project_dir) +os.chdir(project_dir) + +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from pyhocon import ConfigFactory +from typing import List, Tuple +from utils.logger import logger +from utils.general import load_point_cloud_by_file_extension, load_feature_mask +''' +一个模型 对应 三个文件 +*_50k.xyz: 50,000 sampled points of the input B-Rep, can be visualized with MeshLab. + e.g. x,y,z,nx,ny,nz +*_50k_mask.txt: (patch_id + 1) of sampled points. + e.g. 1 or 0 each line +*_50k_csg.conf: Boolean tree built on the patches, stored in nested lists. 'flag_convex' indicates the convexity of the root node. + e.g. + csg{ + list = [0,1,[2,3,4,],], + flag_convex = 1, + } +''' + + +class NHREP_Dataset(Dataset): + def __init__(self, data_dir, name_prefix: str, if_baseline: bool = False, if_feature_sample: bool = False): + """ + 初始化数据集 + :param data_dir: 数据目录 + :param name_prefix: 模型名称 + """ + self.data_dir = os.path.abspath(data_dir) # 将数据目录转换为绝对路径 + self.if_baseline = if_baseline + self.if_feature_sample = if_feature_sample + self._load_single_data(self.data_dir, name_prefix, if_baseline, if_feature_sample) + + def _check_data_file_exists(self, file_name: str): + if not os.path.exists(file_name): + logger.error(f"Data file not found: {file_name}, absolute path: {os.path.abspath(file_name)}") + raise Exception(f"Data file not found: {file_name}, absolute path: {os.path.abspath(file_name)}") + + def _load_feature_samples(self, data_dir: str, file_prefix: str) -> Tuple[torch.Tensor, torch.Tensor]: + """加载特征样本""" + try: + logger.info(f"Loading feature samples for {file_prefix}") + # load feature data + input_fs_file = os.path.join( + data_dir, + file_prefix+'_feature.xyz' + ) + self._check_data_file_exists(input_fs_file) + feature_data = torch.tensor( + np.loadtxt(input_fs_file), + dtype=torch.float32, + device='cuda' + ) + + # load feature mask + fs_mask_file = os.path.join( + data_dir, + file_prefix+'_feature_mask.txt' + ) + self._check_data_file_exists(fs_mask_file) + feature_data_mask_pair = torch.tensor( + np.loadtxt(fs_mask_file), + dtype=torch.int64, + device='cuda' + ) + + return feature_data, feature_data_mask_pair + except Exception as e: + logger.error(f"Error loading feature samples: {str(e)}") + raise + + def _load_single_data(self, data_dir: str, name_prefix: str, if_baseline: bool, if_feature_sample: bool): + """从列表加载数据 + :param data_dir: 数据目录 + :param name_prefix: 模型名称 + :param if_baseline: 是否为基准模型 + :param if_feature_sample: 是否加载特征样本 + """ + try: + logger.info(f"Loading data for {name_prefix}") + # load xyz file + # self.data: 2D array of floats, each row represents a point in 3D space + xyz_file = os.path.join( + data_dir, + name_prefix+'.xyz' + ) + self._check_data_file_exists(xyz_file) + self.data = load_point_cloud_by_file_extension(xyz_file) + + # load mask file + # self.feature_mask: 1D array of integers, each integer represents a feature mask + mask_file = os.path.join( + data_dir, + name_prefix+'_mask.txt' + ) + self._check_data_file_exists(mask_file) + self.feature_mask = load_feature_mask(mask_file) + + # load csg file + # self.csg_tree: list of lists, each inner list represents a node in the CSG tree + # self.csg_flag_convex: boolean, indicating whether the root node is convex + try: + if if_baseline: + self.csg_tree = [0] + self.csg_flag_convex = True + else: + csg_conf_file = os.path.join( + data_dir, + name_prefix+'_csg.conf' + ) + self._check_data_file_exists(csg_conf_file) + csg_config = ConfigFactory.parse_file(csg_conf_file) + self.csg_tree = csg_config.get_list('csg.list') + self.csg_flag_convex = csg_config.get_int('csg.flag_convex') + except Exception as e: + logger.error(f"Error in CSG tree setup: {str(e)}") + raise + + # load feature samples + # self.feature_data: 2D array of floats, each row represents a point in 3D space + # self.feature_data_mask_pair: 1D array of integers, each integer represents a feature mask + if if_feature_sample: + self.feature_data, self.feature_data_mask_pair = self._load_feature_samples(data_dir, name_prefix) + + + + except Exception as e: + logger.error(f"Error loading data from list: {str(e)}") + raise + + def get_data(self): + return self.data + + def get_feature_mask(self): + return self.feature_mask + + def get_csg_tree(self): + return self.csg_tree, self.csg_flag_convex + + def get_feature_data(self): + if self.if_feature_sample: + return self.feature_data, self.feature_data_mask_pair + else: + return None, None + +class CustomDataLoader: + def __init__(self, data_dir, batch_size=32, shuffle=True, num_workers=4, transform=None): + """ + 初始化数据加载器 + :param data_dir: 数据目录 + :param batch_size: 批量大小 + :param shuffle: 是否打乱数据 + :param num_workers: 使用的子进程数 + :param transform: 数据增强或转换 + """ + self.dataset = CustomDataset(data_dir, transform) + self.dataloader = DataLoader( + self.dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers + ) + + def get_loader(self): + """返回数据加载器""" + return self.dataloader + +# 示例用法 +if __name__ == "__main__": + # 数据目录和模型名称前缀 + data_dir = '../data/input_data' # 数据目录 + name_prefix = 'broken_bullet_50k' + + # 数据增强示例 + transform = transforms.Compose([ + transforms.Normalize(mean=[0.5], std=[0.5]), # 归一化 + ]) + + # 创建数据集实例 + dataset = NHREP_Dataset(data_dir, name_prefix) + + diff --git a/code/conversion/train.py b/code/conversion/train.py index fddd09e..b6cac0e 100644 --- a/code/conversion/train.py +++ b/code/conversion/train.py @@ -1,41 +1,85 @@ +import os +import sys +import time + +project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) +sys.path.append(project_dir) +os.chdir(project_dir) + import torch import numpy as np -import os from torch.utils.tensorboard import SummaryWriter +from torch.optim.lr_scheduler import StepLR +from tqdm import tqdm from utils.logger import logger -from utils.general import gradient - +from data_loader import NHREP_Dataset +from model.network import NHRepNet # 导入 NHRepNet class NHREPNet_Training: + def __init__(self, data_dir, name_prefix: str, if_baseline: bool = False, if_feature_sample: bool = False): + self.dataset = NHREP_Dataset(data_dir, name_prefix, if_baseline, if_feature_sample) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # 初始化模型 + d_in = 6 # 输入维度,例如 3D 坐标 + dims_sdf = [64, 64, 64] # 隐藏层维度 + csg_tree, _ = self.dataset.get_csg_tree() + self.model = NHRepNet(d_in, dims_sdf, csg_tree).to(self.device) # 实例化模型并移动到设备 + + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001) # Adam 优化器 + self.scheduler = StepLR(self.optimizer, step_size=1000, gamma=0.1) # 学习率调度器 + self.nepochs = 15000 # 训练轮数 + self.writer = SummaryWriter() # TensorBoard 记录器 + def run_nhrepnet_training(self): - print("running") # 输出训练开始的提示信息 - self.data = self.data.cuda() # 将数据移动到 GPU 上 - self.data.requires_grad_() # 设置数据以便计算梯度 - self.feature_mask = self.feature_mask.cuda() # 将特征掩码移动到 GPU 上 - n_branch = int(torch.max(self.feature_mask).item()) # 计算分支数量 - n_batchsize = self.points_batch # 设置批次大小 - n_patch_batch = n_batchsize // n_branch # 计算每个分支的补丁批次大小 - - # 初始化补丁 ID 列表 - patch_id = [np.where(self.feature_mask.cpu().numpy() == i + 1)[0] for i in range(n_branch)] + logger.info("开始训练") + self.model.train() # 设置模型为训练模式 + + for epoch in range(self.nepochs): # 开始训练循环 + try: + self.train_one_epoch(epoch) + self.scheduler.step() # 更新学习率 + except Exception as e: + logger.error(f"训练过程中发生错误: {str(e)}") + break + + def train_one_epoch(self, epoch): + logger.info(f"Epoch {epoch}/{self.nepochs} 开始") + total_loss = 0.0 + + # 获取输入数据 + input_data = self.dataset.get_data().to(self.device) # 获取数据并移动到设备 + + # 前向传播 + outputs = self.model(input_data) # 使用模型进行前向传播 - for epoch in range(15000): # 开始训练循环 - indices = torch.cat([torch.tensor(patch_id[i][np.random.choice(len(patch_id[i]), n_patch_batch, replace=True)]).cuda() for i in range(n_branch)]).cuda() # 随机选择补丁的索引 - cur_data = self.data[indices] # 根据索引获取当前数据 - mnfld_pnts = cur_data[:, :self.d_in] # 提取流形点 + # 计算损失 + loss = self.compute_loss(outputs) # 计算损失 + total_loss += loss.item() + + # 反向传播 + self.optimizer.zero_grad() # 清空梯度 + loss.backward() # 反向传播 + self.optimizer.step() # 更新参数 - # 前向传播 - mnfld_pred_all = self.network(mnfld_pnts) # 进行前向传播,计算流形点的预测值 - mnfld_pred = mnfld_pred_all[:, 0] # 提取流形预测结果 + avg_loss = total_loss + logger.info(f'Epoch [{epoch}/{self.nepochs}], Average Loss: {avg_loss:.4f}') + self.writer.add_scalar('Loss/train', avg_loss, epoch) # 记录损失到 TensorBoard - # 计算损失 - loss = (mnfld_pred.abs()).mean() # 计算流形损失 + def compute_loss(self, outputs): + """ + 计算流型损失的逻辑 - self.optimizer.zero_grad() # 清零优化器的梯度 - loss.backward() # 反向传播计算梯度 - self.optimizer.step() # 更新模型参数 + :param outputs: 模型的输出 + :return: 计算得到的流型损失值 + """ - if epoch % 100 == 0: # 每 100 轮记录损失 - print(f'Epoch [{epoch}/{self.nepochs}], Loss: {loss.item():.4f}') # 输出当前轮次的损失 + # 计算流型损失(这里使用均方误差作为示例) + manifold_loss = (outputs.abs()).mean() # 计算流型损失 + return manifold_loss - self.tracing() # \ No newline at end of file +if __name__ == "__main__": + data_dir = '../data/input_data' # 数据目录 + name_prefix = 'broken_bullet_50k' + train = NHREPNet_Training(data_dir, name_prefix, if_baseline=True, if_feature_sample=False) + train.run_nhrepnet_training() \ No newline at end of file