You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

102 lines
3.3 KiB

import os
import torch
from torch.utils.data import Dataset
import numpy as np
import pickle
from utils.logger import setup_logger
# 设置日志记录器
logger = setup_logger('dataset')
class BRepSDFDataset(Dataset):
def __init__(self, brep_dir:str, sdf_dir:str, split:str='train'):
"""
初始化数据集
参数:
brep_dir: pkl文件目录
sdf_dir: npz文件目录
split: 数据集分割('train', 'val', 'test')
"""
self.brep_dir = os.path.join(brep_dir, split)
self.sdf_dir = os.path.join(sdf_dir, split)
self.split = split
# 检查目录是否存在
if not os.path.exists(self.brep_dir):
raise ValueError(f"B-rep directory not found: {self.brep_dir}")
if not os.path.exists(self.sdf_dir):
raise ValueError(f"SDF directory not found: {self.sdf_dir}")
# 加载数据列表
self.data_list = self._load_data_list()
# 检查数据集是否为空
if len(self.data_list) == 0:
raise ValueError(f"No valid data found in {split} set")
logger.info(f"Loaded {split} dataset with {len(self.data_list)} samples")
def _load_data_list(self):
data_list = []
for sample_dir in os.listdir(self.brep_dir):
sample_path = os.path.join(self.brep_dir, sample_dir)
if os.path.isdir(sample_path):
data_list.append(sample_path)
return data_list
def __len__(self):
return len(self.data_list)
def __getitem__(self, idx):
sample_path = self.data_list[idx]
# 解析 .step 文件
step_file = os.path.join(sample_path, 'model.step')
brep_features = self._parse_step_file(step_file)
# 加载 .sdf 文件
sdf_file = os.path.join(sample_path, 'sdf.npy')
sdf = np.load(sdf_file)
# 转换为 torch 张量
brep_features = torch.tensor(brep_features, dtype=torch.float32)
sdf = torch.tensor(sdf, dtype=torch.float32)
return {
'brep_features': brep_features,
'sdf': sdf
}
def _parse_step_file(self, step_file):
# 解析 .step 文件的逻辑
# 返回 B-rep 特征
pass
def test_dataset():
"""测试数据集功能"""
try:
# 设置测试路径
brep_dir = '/home/wch/myDeepSDF/test_data/pkl'
sdf_dir = '/home/wch/myDeepSDF/test_data/sdf'
split = 'train'
logger.info("="*50)
logger.info("Testing dataset")
logger.info(f"B-rep directory: {brep_dir}")
logger.info(f"SDF directory: {sdf_dir}")
logger.info(f"Split: {split}")
# ... (其余测试代码保持不变) ...
dataset = BRepSDFDataset(brep_data_dir='/home/wch/myDeepSDF/test_data/pkl', sdf_data_dir='/home/wch/myDeepSDF/test_data/sdf', split='train')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)
for batch in dataloader:
print(batch['brep_features'].shape)
print(batch['sdf'].shape)
break
except Exception as e:
logger.error(f"Error in test_dataset: {str(e)}")
if __name__ == '__main__':
test_dataset()