1 changed files with 68 additions and 0 deletions
@ -0,0 +1,68 @@ |
|||
import os |
|||
import torch |
|||
from torch.utils.data import Dataset |
|||
import numpy as np |
|||
|
|||
class BRepSDFDataset(Dataset): |
|||
def __init__(self, brep_data_dir:str, sdf_data_dir:str, split:str='train'): |
|||
self.brep_data_dir = brep_data_dir |
|||
self.sdf_data_dir = sdf_data_dir |
|||
self.split = split |
|||
self.brep_data_list = self._load_brep_data_list() |
|||
self.sdf_data_list = self._load_sdf_data_list() |
|||
|
|||
def _load_brep_data_list(self): |
|||
data_list = [] |
|||
split_dir = os.path.join(self.brep_data_dir, self.split) |
|||
for sample_dir in os.listdir(split_dir): |
|||
sample_path = os.path.join(split_dir, sample_dir) |
|||
if os.path.isdir(sample_path): |
|||
data_list.append(sample_path) |
|||
return data_list |
|||
|
|||
def _load_sdf_data_list(self): |
|||
data_list = [] |
|||
split_dir = os.path.join(self.sdf_data_dir, self.split) |
|||
for sample_dir in os.listdir(split_dir): |
|||
sample_path = os.path.join(split_dir, sample_dir) |
|||
if os.path.isdir(sample_path): |
|||
data_list.append(sample_path) |
|||
return data_list |
|||
|
|||
def __len__(self): |
|||
return len(self.brep_data_list) |
|||
|
|||
def __getitem__(self, idx): |
|||
sample_path = self.brep_data_list[idx] |
|||
|
|||
# 解析 .step 文件 |
|||
step_file = os.path.join(sample_path, 'model.step') |
|||
brep_features = self._parse_step_file(step_file) |
|||
|
|||
# 加载 .sdf 文件 |
|||
sdf_file = os.path.join(sample_path, 'sdf.npy') |
|||
sdf = np.load(sdf_file) |
|||
|
|||
# 转换为 torch 张量 |
|||
brep_features = torch.tensor(brep_features, dtype=torch.float32) |
|||
sdf = torch.tensor(sdf, dtype=torch.float32) |
|||
|
|||
return { |
|||
'brep_features': brep_features, |
|||
'sdf': sdf |
|||
} |
|||
|
|||
def _parse_step_file(self, step_file): |
|||
# 解析 .step 文件的逻辑 |
|||
# 返回 B-rep 特征 |
|||
pass |
|||
|
|||
# 使用示例 |
|||
if __name__ == '__main__': |
|||
dataset = BRepSDFDataset(brep_data_dir='path/to/brep_data', sdf_data_dir='path/to/sdf_data', split='train') |
|||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True) |
|||
|
|||
for batch in dataloader: |
|||
print(batch['brep_features'].shape) |
|||
print(batch['sdf'].shape) |
|||
break |
Loading…
Reference in new issue