You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

344 lines
14 KiB

import os
import torch
from torch.utils.data import Dataset
import numpy as np
import pickle
from brep2sdf.utils.logger import logger
from brep2sdf.data.utils import process_brep_data
class BRepSDFDataset(Dataset):
def __init__(self, brep_dir:str, sdf_dir:str, split:str='train'):
"""
初始化数据集
参数:
brep_dir: pkl文件目录
sdf_dir: npz文件目录
split: 数据集分割('train', 'val', 'test')
"""
super().__init__()
self.brep_dir = os.path.join(brep_dir, split)
self.sdf_dir = os.path.join(sdf_dir, split)
self.split = split
# 设置固定参数
self.max_face = 70
self.max_edge = 70
self.bbox_scaled = 1.0
# 检查目录是否存在
if not os.path.exists(self.brep_dir):
raise ValueError(f"B-rep directory not found: {self.brep_dir}")
if not os.path.exists(self.sdf_dir):
raise ValueError(f"SDF directory not found: {self.sdf_dir}")
# 加载数据列表
self.brep_data_list = self._load_data_list(self.brep_dir)
self.sdf_data_list = self._load_data_list(self.sdf_dir)
# 检查数据集是否为空
if len(self.brep_data_list) == 0 :
raise ValueError(f"No valid brep data found in {split} set")
if len(self.sdf_data_list) == 0:
raise ValueError(f"No valid sdf data found in {split} set")
logger.info(f"Loaded {split} dataset with {len(self.brep_data_list)} samples")
# data_dir 为 self.brep_dir or sdf_dir
def _load_data_list(self, data_dir):
data_list = []
for sample_file in os.listdir(data_dir):
path = os.path.join(data_dir, sample_file)
data_list.append(path)
#logger.info(data_list)
return data_list
def __len__(self):
return len(self.brep_data_list)
def __getitem__(self, idx):
"""获取单个数据样本"""
try:
brep_path = self.brep_data_list[idx]
sdf_path = self.sdf_data_list[idx]
name = os.path.splitext(os.path.basename(brep_path))[0]
# 加载B-rep和SDF数据
with open(brep_path, 'rb') as f:
brep_raw = pickle.load(f)
sdf_data = self._load_sdf_file(sdf_path)
try:
# 处理B-rep数据
brep_features = process_brep_data(
data=brep_raw,
max_face=self.max_face,
max_edge=self.max_edge,
bbox_scaled=self.bbox_scaled
)
# 检查返回值的类型和数量
if not isinstance(brep_features, tuple):
logger.error(f"process_brep_data returned {type(brep_features)}, expected tuple")
raise ValueError("Invalid return type from process_brep_data")
if len(brep_features) != 6:
logger.error(f"Expected 6 features, got {len(brep_features)}")
logger.error("Features returned:")
for i, feat in enumerate(brep_features):
if isinstance(feat, torch.Tensor):
logger.error(f" {i}: Tensor of shape {feat.shape}")
else:
logger.error(f" {i}: {type(feat)}")
raise ValueError(f"Incorrect number of features: {len(brep_features)}")
# 解包处理后的特征
edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos = brep_features
# 构建返回字典
return {
'name': name,
'edge_ncs': edge_ncs, # [max_face, max_edge, 10, 3]
'edge_pos': edge_pos, # [max_face, max_edge, 6]
'edge_mask': edge_mask, # [max_face, max_edge]
'surf_ncs': surf_ncs, # [max_face, 100, 3]
'surf_pos': surf_pos, # [max_face, 6]
'vertex_pos': vertex_pos, # [max_face, max_edge, 6]
'sdf': sdf_data # [N, 4]
}
except Exception as e:
logger.error(f"\nError processing B-rep data for file: {brep_path}")
logger.error(f"Error type: {type(e).__name__}")
logger.error(f"Error message: {str(e)}")
# 打印原始数据的结构
logger.error("\nRaw data structure:")
for key, value in brep_raw.items():
if isinstance(value, list):
logger.error(f" {key}: list of length {len(value)}")
if value:
logger.error(f" First element type: {type(value[0])}")
if hasattr(value[0], 'shape'):
logger.error(f" First element shape: {value[0].shape}")
elif hasattr(value, 'shape'):
logger.error(f" {key}: shape {value.shape}")
else:
logger.error(f" {key}: {type(value)}")
raise
except Exception as e:
logger.error(f"Error loading sample from {brep_path}: {str(e)}")
logger.error("Data structure:")
raise
def _load_brep_file(self, brep_path):
"""加载B-rep特征文件"""
try:
# 1. 加载原始数据
with open(brep_path, 'rb') as f:
raw_data = pickle.load(f)
brep_data = {}
# 2. 处理几何数据(不等长序列)
geom_keys = ['surf_wcs', 'surf_ncs', 'edge_wcs', 'edge_ncs']
for key in geom_keys:
if key in raw_data:
try:
# 确保数据是列表
if not isinstance(raw_data[key], list):
raise ValueError(f"{key} is not a list")
# 转换每个元素为张量
tensors = []
for i, x in enumerate(raw_data[key]):
try:
# 先转换为numpy数组
arr = np.array(x, dtype=np.float32)
# 再转换为张量
tensor = torch.from_numpy(arr)
tensors.append(tensor)
except Exception as e:
logger.error(f"Error converting {key}[{i}]:")
logger.error(f" Data type: {type(x)}")
if isinstance(x, np.ndarray):
logger.error(f" Shape: {x.shape}")
logger.error(f" dtype: {x.dtype}")
raise ValueError(f"Failed to convert {key}[{i}]: {str(e)}")
brep_data[key] = tensors
except Exception as e:
logger.error(f"Error processing {key}:")
logger.error(f" Raw data type: {type(raw_data[key])}")
raise ValueError(f"Failed to process {key}: {str(e)}")
# 3. 处理固定形状的数据
fixed_keys = ['corner_wcs', 'corner_unique', 'surf_bbox_wcs', 'edge_bbox_wcs']
for key in fixed_keys:
if key in raw_data:
try:
# 直接从原始数据转换
arr = np.array(raw_data[key], dtype=np.float32)
brep_data[key] = torch.from_numpy(arr)
except Exception as e:
logger.error(f"Error converting fixed shape data {key}:")
logger.error(f" Raw data type: {type(raw_data[key])}")
if isinstance(raw_data[key], np.ndarray):
logger.error(f" Shape: {raw_data[key].shape}")
logger.error(f" dtype: {raw_data[key].dtype}")
raise ValueError(f"Failed to convert {key}: {str(e)}")
# 4. 处理邻接矩阵
adj_keys = ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj']
for key in adj_keys:
if key in raw_data:
try:
# 转换为整型数组
arr = np.array(raw_data[key], dtype=np.int32)
brep_data[key] = torch.from_numpy(arr)
except Exception as e:
logger.error(f"Error converting adjacency matrix {key}:")
logger.error(f" Raw data type: {type(raw_data[key])}")
if isinstance(raw_data[key], np.ndarray):
logger.error(f" Shape: {raw_data[key].shape}")
logger.error(f" dtype: {raw_data[key].dtype}")
raise ValueError(f"Failed to convert {key}: {str(e)}")
# 5. 验证必要的键是否存在
required_keys = {'surf_wcs', 'edge_wcs', 'corner_wcs'}
missing_keys = required_keys - set(brep_data.keys())
if missing_keys:
raise ValueError(f"Missing required keys: {missing_keys}")
# 6. 使用process_brep_data处理数据
try:
features = process_brep_data(
data=brep_data,
max_face=self.max_face,
max_edge=self.max_edge,
bbox_scaled=self.bbox_scaled
)
return features
except Exception as e:
logger.error("Error in process_brep_data:")
logger.error(f" Error message: {str(e)}")
# 打印数据形状信息
logger.error("\nInput data shapes:")
for key, value in brep_data.items():
if isinstance(value, list):
shapes = [t.shape for t in value]
logger.error(f" {key}: list of tensors with shapes {shapes}")
elif isinstance(value, torch.Tensor):
logger.error(f" {key}: tensor of shape {value.shape}")
raise
except Exception as e:
logger.error(f"\nError loading B-rep file: {brep_path}")
logger.error(f"Error message: {str(e)}")
raise
def _load_sdf_file(self, sdf_path):
"""加载和处理SDF数据"""
try:
# 加载SDF值
sdf_data = np.load(sdf_path)
if 'pos' not in sdf_data or 'neg' not in sdf_data:
raise ValueError("Missing pos/neg data in SDF file")
sdf_pos = sdf_data['pos'] # (N1, 4)
sdf_neg = sdf_data['neg'] # (N2, 4)
# 添加数据验证
if sdf_pos.shape[1] != 4 or sdf_neg.shape[1] != 4:
raise ValueError(f"Invalid SDF data shape: pos={sdf_pos.shape}, neg={sdf_neg.shape}")
sdf_np = np.concatenate([sdf_pos, sdf_neg], axis=0)
return torch.from_numpy(sdf_np.astype(np.float32))
except Exception as e:
logger.error(f"Error loading SDF from {sdf_path}")
logger.error(f"Error type: {type(e).__name__}")
logger.error(f"Error message: {str(e)}")
raise
@staticmethod
def collate_fn(batch):
"""自定义批处理函数"""
# 收集所有样本的名称
names = [item['name'] for item in batch]
# 处理固定大小的张量数据
tensor_keys = ['edge_ncs', 'edge_pos', 'edge_mask',
'surf_ncs', 'surf_pos', 'vertex_pos']
tensors = {
key: torch.stack([item[key] for item in batch])
for key in tensor_keys
}
# 处理变长的SDF数据
sdf_data = [item['sdf'] for item in batch]
max_sdf_len = max(sdf.size(0) for sdf in sdf_data)
# 填充SDF数据
padded_sdfs = []
sdf_masks = []
for sdf in sdf_data:
pad_len = max_sdf_len - sdf.size(0)
if pad_len > 0:
padding = torch.zeros(pad_len, sdf.size(1),
dtype=sdf.dtype, device=sdf.device)
padded_sdf = torch.cat([sdf, padding], dim=0)
mask = torch.cat([
torch.ones(sdf.size(0), dtype=torch.bool),
torch.zeros(pad_len, dtype=torch.bool)
])
else:
padded_sdf = sdf
mask = torch.ones(sdf.size(0), dtype=torch.bool)
padded_sdfs.append(padded_sdf)
sdf_masks.append(mask)
# 合并所有数据
batch_data = {
'name': names,
'sdf': torch.stack(padded_sdfs),
'sdf_mask': torch.stack(sdf_masks),
**tensors
}
return batch_data
def test_dataset():
"""测试数据集功能"""
try:
# 设置测试路径
brep_dir = '/home/wch/myDeepSDF/test_data/pkl'
sdf_dir = '/home/wch/myDeepSDF/test_data/sdf'
split = 'train'
logger.info("="*50)
logger.info("Testing dataset")
logger.info(f"B-rep directory: {brep_dir}")
logger.info(f"SDF directory: {sdf_dir}")
logger.info(f"Split: {split}")
# ... (其余测试代码保持不变) ...
dataset = BRepSDFDataset(brep_dir='/home/wch/brep2sdf/test_data/pkl', sdf_dir='/home/wch/brep2sdf/test_data/sdf', split='train')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
for batch in dataloader:
print(batch['sdf'].shape)
break
except Exception as e:
logger.error(f"Error in test_dataset: {str(e)}")
if __name__ == '__main__':
test_dataset()