You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
68 lines
2.2 KiB
68 lines
2.2 KiB
7 months ago
|
import os
|
||
|
import torch
|
||
|
from torch.utils.data import Dataset
|
||
|
import numpy as np
|
||
|
|
||
|
class BRepSDFDataset(Dataset):
|
||
|
def __init__(self, brep_data_dir:str, sdf_data_dir:str, split:str='train'):
|
||
|
self.brep_data_dir = brep_data_dir
|
||
|
self.sdf_data_dir = sdf_data_dir
|
||
|
self.split = split
|
||
|
self.brep_data_list = self._load_brep_data_list()
|
||
|
self.sdf_data_list = self._load_sdf_data_list()
|
||
|
|
||
|
def _load_brep_data_list(self):
|
||
|
data_list = []
|
||
|
split_dir = os.path.join(self.brep_data_dir, self.split)
|
||
|
for sample_dir in os.listdir(split_dir):
|
||
|
sample_path = os.path.join(split_dir, sample_dir)
|
||
|
if os.path.isdir(sample_path):
|
||
|
data_list.append(sample_path)
|
||
|
return data_list
|
||
|
|
||
|
def _load_sdf_data_list(self):
|
||
|
data_list = []
|
||
|
split_dir = os.path.join(self.sdf_data_dir, self.split)
|
||
|
for sample_dir in os.listdir(split_dir):
|
||
|
sample_path = os.path.join(split_dir, sample_dir)
|
||
|
if os.path.isdir(sample_path):
|
||
|
data_list.append(sample_path)
|
||
|
return data_list
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.brep_data_list)
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
sample_path = self.brep_data_list[idx]
|
||
|
|
||
|
# 解析 .step 文件
|
||
|
step_file = os.path.join(sample_path, 'model.step')
|
||
|
brep_features = self._parse_step_file(step_file)
|
||
|
|
||
|
# 加载 .sdf 文件
|
||
|
sdf_file = os.path.join(sample_path, 'sdf.npy')
|
||
|
sdf = np.load(sdf_file)
|
||
|
|
||
|
# 转换为 torch 张量
|
||
|
brep_features = torch.tensor(brep_features, dtype=torch.float32)
|
||
|
sdf = torch.tensor(sdf, dtype=torch.float32)
|
||
|
|
||
|
return {
|
||
|
'brep_features': brep_features,
|
||
|
'sdf': sdf
|
||
|
}
|
||
|
|
||
|
def _parse_step_file(self, step_file):
|
||
|
# 解析 .step 文件的逻辑
|
||
|
# 返回 B-rep 特征
|
||
|
pass
|
||
|
|
||
|
# 使用示例
|
||
|
if __name__ == '__main__':
|
||
|
dataset = BRepSDFDataset(brep_data_dir='path/to/brep_data', sdf_data_dir='path/to/sdf_data', split='train')
|
||
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)
|
||
|
|
||
|
for batch in dataloader:
|
||
|
print(batch['brep_features'].shape)
|
||
|
print(batch['sdf'].shape)
|
||
|
break
|