You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
344 lines
14 KiB
344 lines
14 KiB
import os
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
import numpy as np
|
|
import pickle
|
|
from brep2sdf.utils.logger import logger
|
|
from brep2sdf.data.utils import process_brep_data
|
|
|
|
|
|
|
|
|
|
|
|
class BRepSDFDataset(Dataset):
|
|
def __init__(self, brep_dir:str, sdf_dir:str, split:str='train'):
|
|
"""
|
|
初始化数据集
|
|
|
|
参数:
|
|
brep_dir: pkl文件目录
|
|
sdf_dir: npz文件目录
|
|
split: 数据集分割('train', 'val', 'test')
|
|
"""
|
|
super().__init__()
|
|
self.brep_dir = os.path.join(brep_dir, split)
|
|
self.sdf_dir = os.path.join(sdf_dir, split)
|
|
self.split = split
|
|
|
|
# 设置固定参数
|
|
self.max_face = 70
|
|
self.max_edge = 70
|
|
self.bbox_scaled = 1.0
|
|
|
|
# 检查目录是否存在
|
|
if not os.path.exists(self.brep_dir):
|
|
raise ValueError(f"B-rep directory not found: {self.brep_dir}")
|
|
if not os.path.exists(self.sdf_dir):
|
|
raise ValueError(f"SDF directory not found: {self.sdf_dir}")
|
|
|
|
# 加载数据列表
|
|
self.brep_data_list = self._load_data_list(self.brep_dir)
|
|
self.sdf_data_list = self._load_data_list(self.sdf_dir)
|
|
|
|
# 检查数据集是否为空
|
|
if len(self.brep_data_list) == 0 :
|
|
raise ValueError(f"No valid brep data found in {split} set")
|
|
if len(self.sdf_data_list) == 0:
|
|
raise ValueError(f"No valid sdf data found in {split} set")
|
|
|
|
logger.info(f"Loaded {split} dataset with {len(self.brep_data_list)} samples")
|
|
|
|
# data_dir 为 self.brep_dir or sdf_dir
|
|
def _load_data_list(self, data_dir):
|
|
data_list = []
|
|
for sample_file in os.listdir(data_dir):
|
|
path = os.path.join(data_dir, sample_file)
|
|
data_list.append(path)
|
|
|
|
#logger.info(data_list)
|
|
return data_list
|
|
|
|
def __len__(self):
|
|
return len(self.brep_data_list)
|
|
|
|
def __getitem__(self, idx):
|
|
"""获取单个数据样本"""
|
|
try:
|
|
brep_path = self.brep_data_list[idx]
|
|
sdf_path = self.sdf_data_list[idx]
|
|
name = os.path.splitext(os.path.basename(brep_path))[0]
|
|
|
|
# 加载B-rep和SDF数据
|
|
with open(brep_path, 'rb') as f:
|
|
brep_raw = pickle.load(f)
|
|
sdf_data = self._load_sdf_file(sdf_path)
|
|
|
|
try:
|
|
# 处理B-rep数据
|
|
brep_features = process_brep_data(
|
|
data=brep_raw,
|
|
max_face=self.max_face,
|
|
max_edge=self.max_edge,
|
|
bbox_scaled=self.bbox_scaled
|
|
)
|
|
|
|
# 检查返回值的类型和数量
|
|
if not isinstance(brep_features, tuple):
|
|
logger.error(f"process_brep_data returned {type(brep_features)}, expected tuple")
|
|
raise ValueError("Invalid return type from process_brep_data")
|
|
|
|
if len(brep_features) != 6:
|
|
logger.error(f"Expected 6 features, got {len(brep_features)}")
|
|
logger.error("Features returned:")
|
|
for i, feat in enumerate(brep_features):
|
|
if isinstance(feat, torch.Tensor):
|
|
logger.error(f" {i}: Tensor of shape {feat.shape}")
|
|
else:
|
|
logger.error(f" {i}: {type(feat)}")
|
|
raise ValueError(f"Incorrect number of features: {len(brep_features)}")
|
|
|
|
# 解包处理后的特征
|
|
edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos = brep_features
|
|
|
|
# 构建返回字典
|
|
return {
|
|
'name': name,
|
|
'edge_ncs': edge_ncs, # [max_face, max_edge, 10, 3]
|
|
'edge_pos': edge_pos, # [max_face, max_edge, 6]
|
|
'edge_mask': edge_mask, # [max_face, max_edge]
|
|
'surf_ncs': surf_ncs, # [max_face, 100, 3]
|
|
'surf_pos': surf_pos, # [max_face, 6]
|
|
'vertex_pos': vertex_pos, # [max_face, max_edge, 6]
|
|
'sdf': sdf_data # [N, 4]
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"\nError processing B-rep data for file: {brep_path}")
|
|
logger.error(f"Error type: {type(e).__name__}")
|
|
logger.error(f"Error message: {str(e)}")
|
|
|
|
# 打印原始数据的结构
|
|
logger.error("\nRaw data structure:")
|
|
for key, value in brep_raw.items():
|
|
if isinstance(value, list):
|
|
logger.error(f" {key}: list of length {len(value)}")
|
|
if value:
|
|
logger.error(f" First element type: {type(value[0])}")
|
|
if hasattr(value[0], 'shape'):
|
|
logger.error(f" First element shape: {value[0].shape}")
|
|
elif hasattr(value, 'shape'):
|
|
logger.error(f" {key}: shape {value.shape}")
|
|
else:
|
|
logger.error(f" {key}: {type(value)}")
|
|
raise
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading sample from {brep_path}: {str(e)}")
|
|
logger.error("Data structure:")
|
|
raise
|
|
|
|
|
|
def _load_brep_file(self, brep_path):
|
|
"""加载B-rep特征文件"""
|
|
try:
|
|
# 1. 加载原始数据
|
|
with open(brep_path, 'rb') as f:
|
|
raw_data = pickle.load(f)
|
|
|
|
brep_data = {}
|
|
|
|
# 2. 处理几何数据(不等长序列)
|
|
geom_keys = ['surf_wcs', 'surf_ncs', 'edge_wcs', 'edge_ncs']
|
|
for key in geom_keys:
|
|
if key in raw_data:
|
|
try:
|
|
# 确保数据是列表
|
|
if not isinstance(raw_data[key], list):
|
|
raise ValueError(f"{key} is not a list")
|
|
|
|
# 转换每个元素为张量
|
|
tensors = []
|
|
for i, x in enumerate(raw_data[key]):
|
|
try:
|
|
# 先转换为numpy数组
|
|
arr = np.array(x, dtype=np.float32)
|
|
# 再转换为张量
|
|
tensor = torch.from_numpy(arr)
|
|
tensors.append(tensor)
|
|
except Exception as e:
|
|
logger.error(f"Error converting {key}[{i}]:")
|
|
logger.error(f" Data type: {type(x)}")
|
|
if isinstance(x, np.ndarray):
|
|
logger.error(f" Shape: {x.shape}")
|
|
logger.error(f" dtype: {x.dtype}")
|
|
raise ValueError(f"Failed to convert {key}[{i}]: {str(e)}")
|
|
|
|
brep_data[key] = tensors
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing {key}:")
|
|
logger.error(f" Raw data type: {type(raw_data[key])}")
|
|
raise ValueError(f"Failed to process {key}: {str(e)}")
|
|
|
|
# 3. 处理固定形状的数据
|
|
fixed_keys = ['corner_wcs', 'corner_unique', 'surf_bbox_wcs', 'edge_bbox_wcs']
|
|
for key in fixed_keys:
|
|
if key in raw_data:
|
|
try:
|
|
# 直接从原始数据转换
|
|
arr = np.array(raw_data[key], dtype=np.float32)
|
|
brep_data[key] = torch.from_numpy(arr)
|
|
except Exception as e:
|
|
logger.error(f"Error converting fixed shape data {key}:")
|
|
logger.error(f" Raw data type: {type(raw_data[key])}")
|
|
if isinstance(raw_data[key], np.ndarray):
|
|
logger.error(f" Shape: {raw_data[key].shape}")
|
|
logger.error(f" dtype: {raw_data[key].dtype}")
|
|
raise ValueError(f"Failed to convert {key}: {str(e)}")
|
|
|
|
# 4. 处理邻接矩阵
|
|
adj_keys = ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj']
|
|
for key in adj_keys:
|
|
if key in raw_data:
|
|
try:
|
|
# 转换为整型数组
|
|
arr = np.array(raw_data[key], dtype=np.int32)
|
|
brep_data[key] = torch.from_numpy(arr)
|
|
except Exception as e:
|
|
logger.error(f"Error converting adjacency matrix {key}:")
|
|
logger.error(f" Raw data type: {type(raw_data[key])}")
|
|
if isinstance(raw_data[key], np.ndarray):
|
|
logger.error(f" Shape: {raw_data[key].shape}")
|
|
logger.error(f" dtype: {raw_data[key].dtype}")
|
|
raise ValueError(f"Failed to convert {key}: {str(e)}")
|
|
|
|
# 5. 验证必要的键是否存在
|
|
required_keys = {'surf_wcs', 'edge_wcs', 'corner_wcs'}
|
|
missing_keys = required_keys - set(brep_data.keys())
|
|
if missing_keys:
|
|
raise ValueError(f"Missing required keys: {missing_keys}")
|
|
|
|
# 6. 使用process_brep_data处理数据
|
|
try:
|
|
features = process_brep_data(
|
|
data=brep_data,
|
|
max_face=self.max_face,
|
|
max_edge=self.max_edge,
|
|
bbox_scaled=self.bbox_scaled
|
|
)
|
|
return features
|
|
|
|
except Exception as e:
|
|
logger.error("Error in process_brep_data:")
|
|
logger.error(f" Error message: {str(e)}")
|
|
# 打印数据形状信息
|
|
logger.error("\nInput data shapes:")
|
|
for key, value in brep_data.items():
|
|
if isinstance(value, list):
|
|
shapes = [t.shape for t in value]
|
|
logger.error(f" {key}: list of tensors with shapes {shapes}")
|
|
elif isinstance(value, torch.Tensor):
|
|
logger.error(f" {key}: tensor of shape {value.shape}")
|
|
raise
|
|
|
|
except Exception as e:
|
|
logger.error(f"\nError loading B-rep file: {brep_path}")
|
|
logger.error(f"Error message: {str(e)}")
|
|
raise
|
|
|
|
def _load_sdf_file(self, sdf_path):
|
|
"""加载和处理SDF数据"""
|
|
try:
|
|
# 加载SDF值
|
|
sdf_data = np.load(sdf_path)
|
|
if 'pos' not in sdf_data or 'neg' not in sdf_data:
|
|
raise ValueError("Missing pos/neg data in SDF file")
|
|
|
|
sdf_pos = sdf_data['pos'] # (N1, 4)
|
|
sdf_neg = sdf_data['neg'] # (N2, 4)
|
|
|
|
# 添加数据验证
|
|
if sdf_pos.shape[1] != 4 or sdf_neg.shape[1] != 4:
|
|
raise ValueError(f"Invalid SDF data shape: pos={sdf_pos.shape}, neg={sdf_neg.shape}")
|
|
|
|
sdf_np = np.concatenate([sdf_pos, sdf_neg], axis=0)
|
|
return torch.from_numpy(sdf_np.astype(np.float32))
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading SDF from {sdf_path}")
|
|
logger.error(f"Error type: {type(e).__name__}")
|
|
logger.error(f"Error message: {str(e)}")
|
|
raise
|
|
@staticmethod
|
|
def collate_fn(batch):
|
|
"""自定义批处理函数"""
|
|
# 收集所有样本的名称
|
|
names = [item['name'] for item in batch]
|
|
|
|
# 处理固定大小的张量数据
|
|
tensor_keys = ['edge_ncs', 'edge_pos', 'edge_mask',
|
|
'surf_ncs', 'surf_pos', 'vertex_pos']
|
|
tensors = {
|
|
key: torch.stack([item[key] for item in batch])
|
|
for key in tensor_keys
|
|
}
|
|
|
|
# 处理变长的SDF数据
|
|
sdf_data = [item['sdf'] for item in batch]
|
|
max_sdf_len = max(sdf.size(0) for sdf in sdf_data)
|
|
|
|
# 填充SDF数据
|
|
padded_sdfs = []
|
|
sdf_masks = []
|
|
for sdf in sdf_data:
|
|
pad_len = max_sdf_len - sdf.size(0)
|
|
if pad_len > 0:
|
|
padding = torch.zeros(pad_len, sdf.size(1),
|
|
dtype=sdf.dtype, device=sdf.device)
|
|
padded_sdf = torch.cat([sdf, padding], dim=0)
|
|
mask = torch.cat([
|
|
torch.ones(sdf.size(0), dtype=torch.bool),
|
|
torch.zeros(pad_len, dtype=torch.bool)
|
|
])
|
|
else:
|
|
padded_sdf = sdf
|
|
mask = torch.ones(sdf.size(0), dtype=torch.bool)
|
|
padded_sdfs.append(padded_sdf)
|
|
sdf_masks.append(mask)
|
|
|
|
# 合并所有数据
|
|
batch_data = {
|
|
'name': names,
|
|
'sdf': torch.stack(padded_sdfs),
|
|
'sdf_mask': torch.stack(sdf_masks),
|
|
**tensors
|
|
}
|
|
|
|
return batch_data
|
|
|
|
def test_dataset():
|
|
"""测试数据集功能"""
|
|
try:
|
|
# 设置测试路径
|
|
brep_dir = '/home/wch/myDeepSDF/test_data/pkl'
|
|
sdf_dir = '/home/wch/myDeepSDF/test_data/sdf'
|
|
split = 'train'
|
|
|
|
logger.info("="*50)
|
|
logger.info("Testing dataset")
|
|
logger.info(f"B-rep directory: {brep_dir}")
|
|
logger.info(f"SDF directory: {sdf_dir}")
|
|
logger.info(f"Split: {split}")
|
|
|
|
# ... (其余测试代码保持不变) ...
|
|
dataset = BRepSDFDataset(brep_dir='/home/wch/brep2sdf/test_data/pkl', sdf_dir='/home/wch/brep2sdf/test_data/sdf', split='train')
|
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
|
|
|
|
for batch in dataloader:
|
|
print(batch['sdf'].shape)
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error in test_dataset: {str(e)}")
|
|
|
|
if __name__ == '__main__':
|
|
test_dataset()
|