You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

193 lines
6.9 KiB

import os
import sys
import time
project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
sys.path.append(project_dir)
os.chdir(project_dir)
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from pyhocon import ConfigFactory
from typing import List, Tuple
from utils.logger import logger
from utils.general import load_point_cloud_by_file_extension, load_feature_mask
'''
一个模型 对应 三个文件
*_50k.xyz: 50,000 sampled points of the input B-Rep, can be visualized with MeshLab.
e.g. x,y,z,nx,ny,nz
*_50k_mask.txt: (patch_id + 1) of sampled points.
e.g. 1 or 0 each line
*_50k_csg.conf: Boolean tree built on the patches, stored in nested lists. 'flag_convex' indicates the convexity of the root node.
e.g.
csg{
list = [0,1,[2,3,4,],],
flag_convex = 1,
}
'''
class NHREP_Dataset(Dataset):
def __init__(self, data_dir, name_prefix: str, if_baseline: bool = False, if_feature_sample: bool = False):
"""
初始化数据集
:param data_dir: 数据目录
:param name_prefix: 模型名称
"""
self.data_dir = os.path.abspath(data_dir) # 将数据目录转换为绝对路径
self.if_baseline = if_baseline
self.if_feature_sample = if_feature_sample
self._load_single_data(self.data_dir, name_prefix, if_baseline, if_feature_sample)
def _check_data_file_exists(self, file_name: str):
if not os.path.exists(file_name):
logger.error(f"Data file not found: {file_name}, absolute path: {os.path.abspath(file_name)}")
raise Exception(f"Data file not found: {file_name}, absolute path: {os.path.abspath(file_name)}")
def _load_feature_samples(self, data_dir: str, file_prefix: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""加载特征样本"""
try:
logger.info(f"Loading feature samples for {file_prefix}")
# load feature data
input_fs_file = os.path.join(
data_dir,
file_prefix+'_feature.xyz'
)
self._check_data_file_exists(input_fs_file)
feature_data = torch.tensor(
np.loadtxt(input_fs_file),
dtype=torch.float32,
device='cuda'
)
# load feature mask
fs_mask_file = os.path.join(
data_dir,
file_prefix+'_feature_mask.txt'
)
self._check_data_file_exists(fs_mask_file)
feature_data_mask_pair = torch.tensor(
np.loadtxt(fs_mask_file),
dtype=torch.int64,
device='cuda'
)
return feature_data, feature_data_mask_pair
except Exception as e:
logger.error(f"Error loading feature samples: {str(e)}")
raise
def _load_single_data(self, data_dir: str, name_prefix: str, if_baseline: bool, if_feature_sample: bool):
"""从列表加载数据
:param data_dir: 数据目录
:param name_prefix: 模型名称
:param if_baseline: 是否为基准模型
:param if_feature_sample: 是否加载特征样本
"""
try:
logger.info(f"Loading data for {name_prefix}")
# load xyz file
# self.data: 2D array of floats, each row represents a point in 3D space
xyz_file = os.path.join(
data_dir,
name_prefix+'.xyz'
)
self._check_data_file_exists(xyz_file)
self.data = load_point_cloud_by_file_extension(xyz_file)
# load mask file
# self.feature_mask: 1D array of integers, each integer represents a feature mask
mask_file = os.path.join(
data_dir,
name_prefix+'_mask.txt'
)
self._check_data_file_exists(mask_file)
self.feature_mask = load_feature_mask(mask_file)
# load csg file
# self.csg_tree: list of lists, each inner list represents a node in the CSG tree
# self.csg_flag_convex: boolean, indicating whether the root node is convex
try:
if if_baseline:
self.csg_tree = [0]
self.csg_flag_convex = True
else:
csg_conf_file = os.path.join(
data_dir,
name_prefix+'_csg.conf'
)
self._check_data_file_exists(csg_conf_file)
csg_config = ConfigFactory.parse_file(csg_conf_file)
self.csg_tree = csg_config.get_list('csg.list')
self.csg_flag_convex = csg_config.get_int('csg.flag_convex')
except Exception as e:
logger.error(f"Error in CSG tree setup: {str(e)}")
raise
# load feature samples
# self.feature_data: 2D array of floats, each row represents a point in 3D space
# self.feature_data_mask_pair: 1D array of integers, each integer represents a feature mask
if if_feature_sample:
self.feature_data, self.feature_data_mask_pair = self._load_feature_samples(data_dir, name_prefix)
except Exception as e:
logger.error(f"Error loading data from list: {str(e)}")
raise
def get_data(self):
return self.data
def get_feature_mask(self):
return self.feature_mask
def get_csg_tree(self):
return self.csg_tree, self.csg_flag_convex
def get_feature_data(self):
if self.if_feature_sample:
return self.feature_data, self.feature_data_mask_pair
else:
return None, None
class CustomDataLoader:
def __init__(self, data_dir, batch_size=32, shuffle=True, num_workers=4, transform=None):
"""
初始化数据加载器
:param data_dir: 数据目录
:param batch_size: 批量大小
:param shuffle: 是否打乱数据
:param num_workers: 使用的子进程数
:param transform: 数据增强或转换
"""
self.dataset = CustomDataset(data_dir, transform)
self.dataloader = DataLoader(
self.dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers
)
def get_loader(self):
"""返回数据加载器"""
return self.dataloader
# 示例用法
if __name__ == "__main__":
# 数据目录和模型名称前缀
data_dir = '../data/input_data' # 数据目录
name_prefix = 'broken_bullet_50k'
# 数据增强示例
transform = transforms.Compose([
transforms.Normalize(mean=[0.5], std=[0.5]), # 归一化
])
# 创建数据集实例
dataset = NHREP_Dataset(data_dir, name_prefix)