You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

213 lines
9.0 KiB

import os
import torch
from torch.utils.data import Dataset
import numpy as np
import pickle
from brep2sdf.utils.logger import setup_logger
# 设置日志记录器
logger = setup_logger('dataset')
class BRepSDFDataset(Dataset):
def __init__(self, brep_dir:str, sdf_dir:str, split:str='train'):
"""
初始化数据集
参数:
brep_dir: pkl文件目录
sdf_dir: npz文件目录
split: 数据集分割('train', 'val', 'test')
"""
self.brep_dir = os.path.join(brep_dir, split)
self.sdf_dir = os.path.join(sdf_dir, split)
self.split = split
# 检查目录是否存在
if not os.path.exists(self.brep_dir):
raise ValueError(f"B-rep directory not found: {self.brep_dir}")
if not os.path.exists(self.sdf_dir):
raise ValueError(f"SDF directory not found: {self.sdf_dir}")
# 加载数据列表
self.brep_data_list = self._load_data_list(self.brep_dir)
self.sdf_data_list = self._load_data_list(self.sdf_dir)
# 检查数据集是否为空
if len(self.brep_data_list) == 0 :
raise ValueError(f"No valid brep data found in {split} set")
if len(self.sdf_data_list) == 0:
raise ValueError(f"No valid sdf data found in {split} set")
logger.info(f"Loaded {split} dataset with {len(self.brep_data_list)} samples")
# data_dir 为 self.brep_dir or sdf_dir
def _load_data_list(self, data_dir):
data_list = []
for sample_file in os.listdir(data_dir):
path = os.path.join(data_dir, sample_file)
data_list.append(path)
#logger.info(data_list)
return data_list
def __len__(self):
return len(self.brep_data_list)
def __getitem__(self, idx):
"""获取单个数据样本"""
brep_path = self.brep_data_list[idx]
sdf_path = self.sdf_data_list[idx]
try:
# 获取文件名(不含扩展名)作为sample name
name = os.path.splitext(os.path.basename(brep_path))[0]
# 加载B-rep和SDF数据
brep_data = self._load_brep_file(brep_path)
sdf_data = self._load_sdf_file(sdf_path)
# 修改返回格式,将sdf_data作为一个键值对添加
return {
'name': name,
**brep_data, # 解包B-rep特征
'sdf': sdf_data # 添加SDF数据作为一个键
}
except Exception as e:
logger.error(f"Error loading sample from {brep_path}: {str(e)}")
logger.error("Data structure:")
if 'brep_data' in locals():
for key, value in brep_data.items():
if isinstance(value, np.ndarray):
logger.error(f" {key}: type={type(value)}, dtype={value.dtype}, shape={value.shape}")
raise
def _load_brep_file(self, brep_path):
"""加载B-rep特征文件"""
try:
with open(brep_path, 'rb') as f:
brep_data = pickle.load(f)
features = {}
# 1. 处理几何数据(不等长序列)
for key in ['surf_wcs', 'surf_ncs', 'edge_wcs', 'edge_ncs']:
if key in brep_data:
try:
features[key] = [
torch.from_numpy(np.array(x, dtype=np.float32))
for x in brep_data[key]
]
except Exception as e:
logger.error(f"Error converting {key}:")
logger.error(f" Type: {type(brep_data[key])}")
if isinstance(brep_data[key], list):
logger.error(f" List length: {len(brep_data[key])}")
if len(brep_data[key]) > 0:
logger.error(f" First element type: {type(brep_data[key][0])}")
logger.error(f" First element shape: {brep_data[key][0].shape}")
logger.error(f" First element dtype: {brep_data[key][0].dtype}")
raise ValueError(f"Failed to convert {key}: {str(e)}")
# 2. 处理固定形状的数据
for key in ['corner_wcs', 'corner_unique', 'surf_bbox_wcs', 'edge_bbox_wcs']:
if key in brep_data:
try:
data = np.array(brep_data[key], dtype=np.float32)
features[key] = torch.from_numpy(data)
except Exception as e:
logger.error(f"Error converting {key}:")
logger.error(f" Type: {type(brep_data[key])}")
if isinstance(brep_data[key], np.ndarray):
logger.error(f" Shape: {brep_data[key].shape}")
logger.error(f" dtype: {brep_data[key].dtype}")
raise ValueError(f"Failed to convert {key}: {str(e)}")
# 3. 处理邻接矩阵
for key in ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj']:
if key in brep_data:
try:
data = np.array(brep_data[key], dtype=np.int32)
features[key] = torch.from_numpy(data)
except Exception as e:
logger.error(f"Error converting {key}:")
logger.error(f" Type: {type(brep_data[key])}")
if isinstance(brep_data[key], np.ndarray):
logger.error(f" Shape: {brep_data[key].shape}")
logger.error(f" dtype: {brep_data[key].dtype}")
raise ValueError(f"Failed to convert {key}: {str(e)}")
return features
except Exception as e:
logger.error(f"\nError loading B-rep file: {brep_path}")
logger.error(f"Error message: {str(e)}")
# 打印完整的数据结构信息
if 'brep_data' in locals():
logger.error("\nComplete data structure:")
for key, value in brep_data.items():
logger.error(f"\n{key}:")
logger.error(f" Type: {type(value)}")
if isinstance(value, np.ndarray):
logger.error(f" Shape: {value.shape}")
logger.error(f" dtype: {value.dtype}")
elif isinstance(value, list):
logger.error(f" List length: {len(value)}")
if len(value) > 0:
logger.error(f" First element type: {type(value[0])}")
if isinstance(value[0], np.ndarray):
logger.error(f" First element shape: {value[0].shape}")
logger.error(f" First element dtype: {value[0].dtype}")
raise
def _load_sdf_file(self, sdf_path):
"""加载和处理SDF数据"""
try:
# 加载SDF值
sdf_data = np.load(sdf_path)
if 'pos' not in sdf_data or 'neg' not in sdf_data:
raise ValueError("Missing pos/neg data in SDF file")
sdf_pos = sdf_data['pos'] # (N1, 4)
sdf_neg = sdf_data['neg'] # (N2, 4)
# 添加数据验证
if sdf_pos.shape[1] != 4 or sdf_neg.shape[1] != 4:
raise ValueError(f"Invalid SDF data shape: pos={sdf_pos.shape}, neg={sdf_neg.shape}")
sdf_np = np.concatenate([sdf_pos, sdf_neg], axis=0)
return torch.from_numpy(sdf_np.astype(np.float32))
except Exception as e:
logger.error(f"Error loading SDF from {sdf_path}")
logger.error(f"Error type: {type(e).__name__}")
logger.error(f"Error message: {str(e)}")
raise
def test_dataset():
"""测试数据集功能"""
try:
# 设置测试路径
brep_dir = '/home/wch/myDeepSDF/test_data/pkl'
sdf_dir = '/home/wch/myDeepSDF/test_data/sdf'
split = 'train'
logger.info("="*50)
logger.info("Testing dataset")
logger.info(f"B-rep directory: {brep_dir}")
logger.info(f"SDF directory: {sdf_dir}")
logger.info(f"Split: {split}")
# ... (其余测试代码保持不变) ...
dataset = BRepSDFDataset(brep_dir='/home/wch/brep2sdf/test_data/pkl', sdf_dir='/home/wch/brep2sdf/test_data/sdf', split='train')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
for batch in dataloader:
print(batch['sdf'].shape)
break
except Exception as e:
logger.error(f"Error in test_dataset: {str(e)}")
if __name__ == '__main__':
test_dataset()