You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
192 lines
6.9 KiB
192 lines
6.9 KiB
import os
|
|
import sys
|
|
import time
|
|
|
|
project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
|
|
sys.path.append(project_dir)
|
|
os.chdir(project_dir)
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torchvision import transforms
|
|
from pyhocon import ConfigFactory
|
|
from typing import List, Tuple
|
|
from utils.logger import logger
|
|
from utils.general import load_point_cloud_by_file_extension, load_feature_mask
|
|
'''
|
|
一个模型 对应 三个文件
|
|
*_50k.xyz: 50,000 sampled points of the input B-Rep, can be visualized with MeshLab.
|
|
e.g. x,y,z,nx,ny,nz
|
|
*_50k_mask.txt: (patch_id + 1) of sampled points.
|
|
e.g. 1 or 0 each line
|
|
*_50k_csg.conf: Boolean tree built on the patches, stored in nested lists. 'flag_convex' indicates the convexity of the root node.
|
|
e.g.
|
|
csg{
|
|
list = [0,1,[2,3,4,],],
|
|
flag_convex = 1,
|
|
}
|
|
'''
|
|
|
|
|
|
class NHREP_Dataset(Dataset):
|
|
def __init__(self, data_dir, name_prefix: str, if_baseline: bool = False, if_feature_sample: bool = False):
|
|
"""
|
|
初始化数据集
|
|
:param data_dir: 数据目录
|
|
:param name_prefix: 模型名称
|
|
"""
|
|
self.data_dir = os.path.abspath(data_dir) # 将数据目录转换为绝对路径
|
|
self.if_baseline = if_baseline
|
|
self.if_feature_sample = if_feature_sample
|
|
self._load_single_data(self.data_dir, name_prefix, if_baseline, if_feature_sample)
|
|
|
|
def _check_data_file_exists(self, file_name: str):
|
|
if not os.path.exists(file_name):
|
|
logger.error(f"Data file not found: {file_name}, absolute path: {os.path.abspath(file_name)}")
|
|
raise Exception(f"Data file not found: {file_name}, absolute path: {os.path.abspath(file_name)}")
|
|
|
|
def _load_feature_samples(self, data_dir: str, file_prefix: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""加载特征样本"""
|
|
try:
|
|
logger.info(f"Loading feature samples for {file_prefix}")
|
|
# load feature data
|
|
input_fs_file = os.path.join(
|
|
data_dir,
|
|
file_prefix+'_feature.xyz'
|
|
)
|
|
self._check_data_file_exists(input_fs_file)
|
|
feature_data = torch.tensor(
|
|
np.loadtxt(input_fs_file),
|
|
dtype=torch.float32,
|
|
device='cuda'
|
|
)
|
|
|
|
# load feature mask
|
|
fs_mask_file = os.path.join(
|
|
data_dir,
|
|
file_prefix+'_feature_mask.txt'
|
|
)
|
|
self._check_data_file_exists(fs_mask_file)
|
|
feature_data_mask_pair = torch.tensor(
|
|
np.loadtxt(fs_mask_file),
|
|
dtype=torch.int64,
|
|
device='cuda'
|
|
)
|
|
|
|
return feature_data, feature_data_mask_pair
|
|
except Exception as e:
|
|
logger.error(f"Error loading feature samples: {str(e)}")
|
|
raise
|
|
|
|
def _load_single_data(self, data_dir: str, name_prefix: str, if_baseline: bool, if_feature_sample: bool):
|
|
"""从列表加载数据
|
|
:param data_dir: 数据目录
|
|
:param name_prefix: 模型名称
|
|
:param if_baseline: 是否为基准模型
|
|
:param if_feature_sample: 是否加载特征样本
|
|
"""
|
|
try:
|
|
logger.info(f"Loading data for {name_prefix}")
|
|
# load xyz file
|
|
# self.data: 2D array of floats, each row represents a point in 3D space
|
|
xyz_file = os.path.join(
|
|
data_dir,
|
|
name_prefix+'.xyz'
|
|
)
|
|
self._check_data_file_exists(xyz_file)
|
|
self.data = load_point_cloud_by_file_extension(xyz_file)
|
|
|
|
# load mask file
|
|
# self.feature_mask: 1D array of integers, each integer represents a feature mask
|
|
mask_file = os.path.join(
|
|
data_dir,
|
|
name_prefix+'_mask.txt'
|
|
)
|
|
self._check_data_file_exists(mask_file)
|
|
self.feature_mask = load_feature_mask(mask_file)
|
|
|
|
# load csg file
|
|
# self.csg_tree: list of lists, each inner list represents a node in the CSG tree
|
|
# self.csg_flag_convex: boolean, indicating whether the root node is convex
|
|
try:
|
|
if if_baseline:
|
|
self.csg_tree = [0]
|
|
self.csg_flag_convex = True
|
|
else:
|
|
csg_conf_file = os.path.join(
|
|
data_dir,
|
|
name_prefix+'_csg.conf'
|
|
)
|
|
self._check_data_file_exists(csg_conf_file)
|
|
csg_config = ConfigFactory.parse_file(csg_conf_file)
|
|
self.csg_tree = csg_config.get_list('csg.list')
|
|
self.csg_flag_convex = csg_config.get_int('csg.flag_convex')
|
|
except Exception as e:
|
|
logger.error(f"Error in CSG tree setup: {str(e)}")
|
|
raise
|
|
|
|
# load feature samples
|
|
# self.feature_data: 2D array of floats, each row represents a point in 3D space
|
|
# self.feature_data_mask_pair: 1D array of integers, each integer represents a feature mask
|
|
if if_feature_sample:
|
|
self.feature_data, self.feature_data_mask_pair = self._load_feature_samples(data_dir, name_prefix)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading data from list: {str(e)}")
|
|
raise
|
|
|
|
def get_data(self):
|
|
return self.data
|
|
|
|
def get_feature_mask(self):
|
|
return self.feature_mask
|
|
|
|
def get_csg_tree(self):
|
|
return self.csg_tree, self.csg_flag_convex
|
|
|
|
def get_feature_data(self):
|
|
if self.if_feature_sample:
|
|
return self.feature_data, self.feature_data_mask_pair
|
|
else:
|
|
return None, None
|
|
|
|
class CustomDataLoader:
|
|
def __init__(self, data_dir, batch_size=32, shuffle=True, num_workers=4, transform=None):
|
|
"""
|
|
初始化数据加载器
|
|
:param data_dir: 数据目录
|
|
:param batch_size: 批量大小
|
|
:param shuffle: 是否打乱数据
|
|
:param num_workers: 使用的子进程数
|
|
:param transform: 数据增强或转换
|
|
"""
|
|
self.dataset = CustomDataset(data_dir, transform)
|
|
self.dataloader = DataLoader(
|
|
self.dataset,
|
|
batch_size=batch_size,
|
|
shuffle=shuffle,
|
|
num_workers=num_workers
|
|
)
|
|
|
|
def get_loader(self):
|
|
"""返回数据加载器"""
|
|
return self.dataloader
|
|
|
|
# 示例用法
|
|
if __name__ == "__main__":
|
|
# 数据目录和模型名称前缀
|
|
data_dir = '../data/input_data' # 数据目录
|
|
name_prefix = 'broken_bullet_50k'
|
|
|
|
# 数据增强示例
|
|
transform = transforms.Compose([
|
|
transforms.Normalize(mean=[0.5], std=[0.5]), # 归一化
|
|
])
|
|
|
|
# 创建数据集实例
|
|
dataset = NHREP_Dataset(data_dir, name_prefix)
|
|
|
|
|
|
|