From 1ef73946da9028f12a528fdb61f8825fa9aa9c55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=90=9B=E6=B6=B5?= Date: Wed, 13 Nov 2024 17:11:01 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=9D=E6=AD=A5=E6=95=B0=E6=8D=AE=E9=9B=86da?= =?UTF-8?q?taset?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data/data.py | 68 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 data/data.py diff --git a/data/data.py b/data/data.py new file mode 100644 index 0000000..78153ad --- /dev/null +++ b/data/data.py @@ -0,0 +1,68 @@ +import os +import torch +from torch.utils.data import Dataset +import numpy as np + +class BRepSDFDataset(Dataset): + def __init__(self, brep_data_dir:str, sdf_data_dir:str, split:str='train'): + self.brep_data_dir = brep_data_dir + self.sdf_data_dir = sdf_data_dir + self.split = split + self.brep_data_list = self._load_brep_data_list() + self.sdf_data_list = self._load_sdf_data_list() + + def _load_brep_data_list(self): + data_list = [] + split_dir = os.path.join(self.brep_data_dir, self.split) + for sample_dir in os.listdir(split_dir): + sample_path = os.path.join(split_dir, sample_dir) + if os.path.isdir(sample_path): + data_list.append(sample_path) + return data_list + + def _load_sdf_data_list(self): + data_list = [] + split_dir = os.path.join(self.sdf_data_dir, self.split) + for sample_dir in os.listdir(split_dir): + sample_path = os.path.join(split_dir, sample_dir) + if os.path.isdir(sample_path): + data_list.append(sample_path) + return data_list + + def __len__(self): + return len(self.brep_data_list) + + def __getitem__(self, idx): + sample_path = self.brep_data_list[idx] + + # 解析 .step 文件 + step_file = os.path.join(sample_path, 'model.step') + brep_features = self._parse_step_file(step_file) + + # 加载 .sdf 文件 + sdf_file = os.path.join(sample_path, 'sdf.npy') + sdf = np.load(sdf_file) + + # 转换为 torch 张量 + brep_features = torch.tensor(brep_features, dtype=torch.float32) + sdf = torch.tensor(sdf, dtype=torch.float32) + + return { + 'brep_features': brep_features, + 'sdf': sdf + } + + def _parse_step_file(self, step_file): + # 解析 .step 文件的逻辑 + # 返回 B-rep 特征 + pass + +# 使用示例 +if __name__ == '__main__': + dataset = BRepSDFDataset(brep_data_dir='path/to/brep_data', sdf_data_dir='path/to/sdf_data', split='train') + dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True) + + for batch in dataloader: + print(batch['brep_features'].shape) + print(batch['sdf'].shape) + break \ No newline at end of file