import os import sys import time project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) sys.path.append(project_dir) os.chdir(project_dir) import numpy as np import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms from pyhocon import ConfigFactory from typing import List, Tuple from utils.logger import logger from utils.general import load_point_cloud_by_file_extension, load_feature_mask ''' 一个模型 对应 三个文件 *_50k.xyz: 50,000 sampled points of the input B-Rep, can be visualized with MeshLab. e.g. x,y,z,nx,ny,nz *_50k_mask.txt: (patch_id + 1) of sampled points. e.g. 1 or 0 each line *_50k_csg.conf: Boolean tree built on the patches, stored in nested lists. 'flag_convex' indicates the convexity of the root node. e.g. csg{ list = [0,1,[2,3,4,],], flag_convex = 1, } ''' class NHREP_Dataset(Dataset): def __init__(self, data_dir, name_prefix: str, if_baseline: bool = False, if_feature_sample: bool = False): """ 初始化数据集 :param data_dir: 数据目录 :param name_prefix: 模型名称 """ self.data_dir = os.path.abspath(data_dir) # 将数据目录转换为绝对路径 self.if_baseline = if_baseline self.if_feature_sample = if_feature_sample self._load_single_data(self.data_dir, name_prefix, if_baseline, if_feature_sample) def _check_data_file_exists(self, file_name: str): if not os.path.exists(file_name): logger.error(f"Data file not found: {file_name}, absolute path: {os.path.abspath(file_name)}") raise Exception(f"Data file not found: {file_name}, absolute path: {os.path.abspath(file_name)}") def _load_feature_samples(self, data_dir: str, file_prefix: str) -> Tuple[torch.Tensor, torch.Tensor]: """加载特征样本""" try: logger.info(f"Loading feature samples for {file_prefix}") # load feature data input_fs_file = os.path.join( data_dir, file_prefix+'_feature.xyz' ) self._check_data_file_exists(input_fs_file) feature_data = torch.tensor( np.loadtxt(input_fs_file), dtype=torch.float32, device='cuda' ) # load feature mask fs_mask_file = os.path.join( data_dir, file_prefix+'_feature_mask.txt' ) self._check_data_file_exists(fs_mask_file) feature_data_mask_pair = torch.tensor( np.loadtxt(fs_mask_file), dtype=torch.int64, device='cuda' ) return feature_data, feature_data_mask_pair except Exception as e: logger.error(f"Error loading feature samples: {str(e)}") raise def _load_single_data(self, data_dir: str, name_prefix: str, if_baseline: bool, if_feature_sample: bool): """从列表加载数据 :param data_dir: 数据目录 :param name_prefix: 模型名称 :param if_baseline: 是否为基准模型 :param if_feature_sample: 是否加载特征样本 """ try: logger.info(f"Loading data for {name_prefix}") # load xyz file # self.data: 2D array of floats, each row represents a point in 3D space xyz_file = os.path.join( data_dir, name_prefix+'.xyz' ) self._check_data_file_exists(xyz_file) self.data = load_point_cloud_by_file_extension(xyz_file) # load mask file # self.feature_mask: 1D array of integers, each integer represents a feature mask mask_file = os.path.join( data_dir, name_prefix+'_mask.txt' ) self._check_data_file_exists(mask_file) self.feature_mask = load_feature_mask(mask_file) # load csg file # self.csg_tree: list of lists, each inner list represents a node in the CSG tree # self.csg_flag_convex: boolean, indicating whether the root node is convex try: if if_baseline: self.csg_tree = [0] self.csg_flag_convex = True else: csg_conf_file = os.path.join( data_dir, name_prefix+'_csg.conf' ) self._check_data_file_exists(csg_conf_file) csg_config = ConfigFactory.parse_file(csg_conf_file) self.csg_tree = csg_config.get_list('csg.list') self.csg_flag_convex = csg_config.get_int('csg.flag_convex') except Exception as e: logger.error(f"Error in CSG tree setup: {str(e)}") raise # load feature samples # self.feature_data: 2D array of floats, each row represents a point in 3D space # self.feature_data_mask_pair: 1D array of integers, each integer represents a feature mask if if_feature_sample: self.feature_data, self.feature_data_mask_pair = self._load_feature_samples(data_dir, name_prefix) except Exception as e: logger.error(f"Error loading data from list: {str(e)}") raise def get_data(self): return self.data def get_feature_mask(self): return self.feature_mask def get_csg_tree(self): return self.csg_tree, self.csg_flag_convex def get_feature_data(self): if self.if_feature_sample: return self.feature_data, self.feature_data_mask_pair else: return None, None class CustomDataLoader: def __init__(self, data_dir, batch_size=32, shuffle=True, num_workers=4, transform=None): """ 初始化数据加载器 :param data_dir: 数据目录 :param batch_size: 批量大小 :param shuffle: 是否打乱数据 :param num_workers: 使用的子进程数 :param transform: 数据增强或转换 """ self.dataset = CustomDataset(data_dir, transform) self.dataloader = DataLoader( self.dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers ) def get_loader(self): """返回数据加载器""" return self.dataloader # 示例用法 if __name__ == "__main__": # 数据目录和模型名称前缀 data_dir = '../data/input_data' # 数据目录 name_prefix = 'broken_bullet_50k' # 数据增强示例 transform = transforms.Compose([ transforms.Normalize(mean=[0.5], std=[0.5]), # 归一化 ]) # 创建数据集实例 dataset = NHREP_Dataset(data_dir, name_prefix)