You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
213 lines
9.0 KiB
213 lines
9.0 KiB
import os
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
import numpy as np
|
|
import pickle
|
|
from brep2sdf.utils.logger import setup_logger
|
|
|
|
# 设置日志记录器
|
|
logger = setup_logger('dataset')
|
|
|
|
class BRepSDFDataset(Dataset):
|
|
def __init__(self, brep_dir:str, sdf_dir:str, split:str='train'):
|
|
"""
|
|
初始化数据集
|
|
|
|
参数:
|
|
brep_dir: pkl文件目录
|
|
sdf_dir: npz文件目录
|
|
split: 数据集分割('train', 'val', 'test')
|
|
"""
|
|
self.brep_dir = os.path.join(brep_dir, split)
|
|
self.sdf_dir = os.path.join(sdf_dir, split)
|
|
self.split = split
|
|
|
|
# 检查目录是否存在
|
|
if not os.path.exists(self.brep_dir):
|
|
raise ValueError(f"B-rep directory not found: {self.brep_dir}")
|
|
if not os.path.exists(self.sdf_dir):
|
|
raise ValueError(f"SDF directory not found: {self.sdf_dir}")
|
|
|
|
# 加载数据列表
|
|
self.brep_data_list = self._load_data_list(self.brep_dir)
|
|
self.sdf_data_list = self._load_data_list(self.sdf_dir)
|
|
|
|
# 检查数据集是否为空
|
|
if len(self.brep_data_list) == 0 :
|
|
raise ValueError(f"No valid brep data found in {split} set")
|
|
if len(self.sdf_data_list) == 0:
|
|
raise ValueError(f"No valid sdf data found in {split} set")
|
|
|
|
logger.info(f"Loaded {split} dataset with {len(self.brep_data_list)} samples")
|
|
|
|
# data_dir 为 self.brep_dir or sdf_dir
|
|
def _load_data_list(self, data_dir):
|
|
data_list = []
|
|
for sample_file in os.listdir(data_dir):
|
|
path = os.path.join(data_dir, sample_file)
|
|
data_list.append(path)
|
|
|
|
#logger.info(data_list)
|
|
return data_list
|
|
|
|
def __len__(self):
|
|
return len(self.brep_data_list)
|
|
|
|
def __getitem__(self, idx):
|
|
"""获取单个数据样本"""
|
|
brep_path = self.brep_data_list[idx]
|
|
sdf_path = self.sdf_data_list[idx]
|
|
|
|
try:
|
|
# 获取文件名(不含扩展名)作为sample name
|
|
name = os.path.splitext(os.path.basename(brep_path))[0]
|
|
|
|
# 加载B-rep和SDF数据
|
|
brep_data = self._load_brep_file(brep_path)
|
|
sdf_data = self._load_sdf_file(sdf_path)
|
|
|
|
# 修改返回格式,将sdf_data作为一个键值对添加
|
|
return {
|
|
'name': name,
|
|
**brep_data, # 解包B-rep特征
|
|
'sdf': sdf_data # 添加SDF数据作为一个键
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading sample from {brep_path}: {str(e)}")
|
|
logger.error("Data structure:")
|
|
if 'brep_data' in locals():
|
|
for key, value in brep_data.items():
|
|
if isinstance(value, np.ndarray):
|
|
logger.error(f" {key}: type={type(value)}, dtype={value.dtype}, shape={value.shape}")
|
|
raise
|
|
|
|
def _load_brep_file(self, brep_path):
|
|
"""加载B-rep特征文件"""
|
|
try:
|
|
with open(brep_path, 'rb') as f:
|
|
brep_data = pickle.load(f)
|
|
|
|
features = {}
|
|
|
|
# 1. 处理几何数据(不等长序列)
|
|
for key in ['surf_wcs', 'surf_ncs', 'edge_wcs', 'edge_ncs']:
|
|
if key in brep_data:
|
|
try:
|
|
features[key] = [
|
|
torch.from_numpy(np.array(x, dtype=np.float32))
|
|
for x in brep_data[key]
|
|
]
|
|
except Exception as e:
|
|
logger.error(f"Error converting {key}:")
|
|
logger.error(f" Type: {type(brep_data[key])}")
|
|
if isinstance(brep_data[key], list):
|
|
logger.error(f" List length: {len(brep_data[key])}")
|
|
if len(brep_data[key]) > 0:
|
|
logger.error(f" First element type: {type(brep_data[key][0])}")
|
|
logger.error(f" First element shape: {brep_data[key][0].shape}")
|
|
logger.error(f" First element dtype: {brep_data[key][0].dtype}")
|
|
raise ValueError(f"Failed to convert {key}: {str(e)}")
|
|
|
|
# 2. 处理固定形状的数据
|
|
for key in ['corner_wcs', 'corner_unique', 'surf_bbox_wcs', 'edge_bbox_wcs']:
|
|
if key in brep_data:
|
|
try:
|
|
data = np.array(brep_data[key], dtype=np.float32)
|
|
features[key] = torch.from_numpy(data)
|
|
except Exception as e:
|
|
logger.error(f"Error converting {key}:")
|
|
logger.error(f" Type: {type(brep_data[key])}")
|
|
if isinstance(brep_data[key], np.ndarray):
|
|
logger.error(f" Shape: {brep_data[key].shape}")
|
|
logger.error(f" dtype: {brep_data[key].dtype}")
|
|
raise ValueError(f"Failed to convert {key}: {str(e)}")
|
|
|
|
# 3. 处理邻接矩阵
|
|
for key in ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj']:
|
|
if key in brep_data:
|
|
try:
|
|
data = np.array(brep_data[key], dtype=np.int32)
|
|
features[key] = torch.from_numpy(data)
|
|
except Exception as e:
|
|
logger.error(f"Error converting {key}:")
|
|
logger.error(f" Type: {type(brep_data[key])}")
|
|
if isinstance(brep_data[key], np.ndarray):
|
|
logger.error(f" Shape: {brep_data[key].shape}")
|
|
logger.error(f" dtype: {brep_data[key].dtype}")
|
|
raise ValueError(f"Failed to convert {key}: {str(e)}")
|
|
|
|
return features
|
|
|
|
except Exception as e:
|
|
logger.error(f"\nError loading B-rep file: {brep_path}")
|
|
logger.error(f"Error message: {str(e)}")
|
|
|
|
# 打印完整的数据结构信息
|
|
if 'brep_data' in locals():
|
|
logger.error("\nComplete data structure:")
|
|
for key, value in brep_data.items():
|
|
logger.error(f"\n{key}:")
|
|
logger.error(f" Type: {type(value)}")
|
|
if isinstance(value, np.ndarray):
|
|
logger.error(f" Shape: {value.shape}")
|
|
logger.error(f" dtype: {value.dtype}")
|
|
elif isinstance(value, list):
|
|
logger.error(f" List length: {len(value)}")
|
|
if len(value) > 0:
|
|
logger.error(f" First element type: {type(value[0])}")
|
|
if isinstance(value[0], np.ndarray):
|
|
logger.error(f" First element shape: {value[0].shape}")
|
|
logger.error(f" First element dtype: {value[0].dtype}")
|
|
raise
|
|
|
|
def _load_sdf_file(self, sdf_path):
|
|
"""加载和处理SDF数据"""
|
|
try:
|
|
# 加载SDF值
|
|
sdf_data = np.load(sdf_path)
|
|
if 'pos' not in sdf_data or 'neg' not in sdf_data:
|
|
raise ValueError("Missing pos/neg data in SDF file")
|
|
|
|
sdf_pos = sdf_data['pos'] # (N1, 4)
|
|
sdf_neg = sdf_data['neg'] # (N2, 4)
|
|
|
|
# 添加数据验证
|
|
if sdf_pos.shape[1] != 4 or sdf_neg.shape[1] != 4:
|
|
raise ValueError(f"Invalid SDF data shape: pos={sdf_pos.shape}, neg={sdf_neg.shape}")
|
|
|
|
sdf_np = np.concatenate([sdf_pos, sdf_neg], axis=0)
|
|
return torch.from_numpy(sdf_np.astype(np.float32))
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading SDF from {sdf_path}")
|
|
logger.error(f"Error type: {type(e).__name__}")
|
|
logger.error(f"Error message: {str(e)}")
|
|
raise
|
|
|
|
def test_dataset():
|
|
"""测试数据集功能"""
|
|
try:
|
|
# 设置测试路径
|
|
brep_dir = '/home/wch/myDeepSDF/test_data/pkl'
|
|
sdf_dir = '/home/wch/myDeepSDF/test_data/sdf'
|
|
split = 'train'
|
|
|
|
logger.info("="*50)
|
|
logger.info("Testing dataset")
|
|
logger.info(f"B-rep directory: {brep_dir}")
|
|
logger.info(f"SDF directory: {sdf_dir}")
|
|
logger.info(f"Split: {split}")
|
|
|
|
# ... (其余测试代码保持不变) ...
|
|
dataset = BRepSDFDataset(brep_dir='/home/wch/brep2sdf/test_data/pkl', sdf_dir='/home/wch/brep2sdf/test_data/sdf', split='train')
|
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
|
|
|
|
for batch in dataloader:
|
|
print(batch['sdf'].shape)
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error in test_dataset: {str(e)}")
|
|
|
|
if __name__ == '__main__':
|
|
test_dataset()
|