You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

360 lines
15 KiB

import os
import torch
from torch.utils.data import Dataset
import numpy as np
import pickle
from brep2sdf.utils.logger import logger
from brep2sdf.data.utils import process_brep_data
from brep2sdf.config.default_config import get_default_config
class BRepSDFDataset(Dataset):
def __init__(self, brep_dir:str, sdf_dir:str, valid_data_dir:str, split:str='train'):
"""
初始化数据集
参数:
brep_dir: pkl文件目录
sdf_dir: npz文件目录
split: 数据集分割('train', 'val', 'test')
"""
super().__init__()
# 使用配置文件
self.config = get_default_config()
self.brep_dir = os.path.join(brep_dir, split)
self.sdf_dir = os.path.join(sdf_dir, split)
self.split = split
# 使用配置文件中的参数替换固定参数
self.max_face = self.config.data.max_face
self.max_edge = self.config.data.max_edge
self.bbox_scaled = self.config.data.bbox_scaled
# 检查目录是否存在
if not os.path.exists(self.brep_dir):
raise ValueError(f"B-rep directory not found: {self.brep_dir}")
if not os.path.exists(self.sdf_dir):
raise ValueError(f"SDF directory not found: {self.sdf_dir}")
# 加载数据列表
# 如果存在valid_data_file,则加载valid_list
valid_data_file = os.path.join(valid_data_dir, f'{split}_success.txt')
if valid_data_file:
valid_data_file = os.path.join(self.brep_dir, valid_data_file)
self.valid_data_list = self._load_valid_list(valid_data_file)
else:
raise ValueError(f"Valid data file not found: {valid_data_file}")
self.brep_data_list = self._load_data_list(self.brep_dir)
self.sdf_data_list = self._load_data_list(self.sdf_dir)
# 检查数据集是否为空
if len(self.brep_data_list) == 0 :
raise ValueError(f"No valid brep data found in {split} set")
if len(self.sdf_data_list) == 0:
raise ValueError(f"No valid sdf data found in {split} set")
logger.info(f"Loaded {split} dataset with {len(self.brep_data_list)} samples")
def _load_valid_list(self,valid_data_file:str):
with open(valid_data_file, 'r') as f:
valid_list = [line.strip() for line in f.readlines()]
return valid_list
# data_dir 为 self.brep_dir or sdf_dir
def _load_data_list(self, data_dir):
data_list = []
for sample_file in os.listdir(data_dir):
if sample_file.split('.')[0] in self.valid_data_list:
path = os.path.join(data_dir, sample_file)
data_list.append(path)
#logger.info(data_list)
return data_list
def __len__(self):
return len(self.brep_data_list)
def __getitem__(self, idx):
"""获取单个数据样本"""
try:
brep_path = self.brep_data_list[idx]
sdf_path = self.sdf_data_list[idx]
name = os.path.splitext(os.path.basename(brep_path))[0]
# 加载B-rep和SDF数据
with open(brep_path, 'rb') as f:
brep_raw = pickle.load(f)
sdf_data = self._load_sdf_file(sdf_path)
try:
# 处理B-rep数据
brep_features = process_brep_data(
data=brep_raw,
max_face=self.max_face,
max_edge=self.max_edge,
bbox_scaled=self.bbox_scaled
)
'''
# 打印数据形状
logger.debug(f"Processed data shapes for {os.path.basename(brep_path)}:")
for value in brep_features:
if isinstance(value, torch.Tensor):
logger.debug(f" {value.shape}")
# 检查返回值的类型和数量
if not isinstance(brep_features, tuple):
logger.error(f"process_brep_data returned {type(brep_features)}, expected tuple")
raise ValueError("Invalid return type from process_brep_data")
if len(brep_features) != 6:
logger.error(f"Expected 6 features, got {len(brep_features)}")
logger.error("Features returned:")
for i, feat in enumerate(brep_features):
if isinstance(feat, torch.Tensor):
logger.error(f" {i}: Tensor of shape {feat.shape}")
else:
logger.error(f" {i}: {type(feat)}")
raise ValueError(f"Incorrect number of features: {len(brep_features)}")
'''
# 解包处理后的特征
edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos = brep_features
sdf_points = sdf_data[:, :3]
sdf_values = sdf_data[:, 3:]
# 构建返回字典
return {
'name': name,
'edge_ncs': edge_ncs, # [max_face, max_edge, 10, 3]
'edge_pos': edge_pos, # [max_face, max_edge, 6]
'edge_mask': edge_mask, # [max_face, max_edge]
'surf_ncs': surf_ncs, # [max_face, 100, 3]
'surf_pos': surf_pos, # [max_face, 6]
'vertex_pos': vertex_pos, # [max_face, max_edge, 6]
'points': sdf_points, # [num_queries, 3] 所有点的xyz坐标
'sdf': sdf_values # [num_queries, 1] 所有点的sdf值
}
except Exception as e:
logger.error(f"\nError processing B-rep data for file: {brep_path}")
logger.error(f"Error type: {type(e).__name__}")
logger.error(f"Error message: {str(e)}")
# 打印原始数据的结构
logger.error("\nRaw data structure:")
for key, value in brep_raw.items():
if isinstance(value, list):
logger.error(f" {key}: list of length {len(value)}")
if value:
logger.error(f" First element type: {type(value[0])}")
if hasattr(value[0], 'shape'):
logger.error(f" First element shape: {value[0].shape}")
elif hasattr(value, 'shape'):
logger.error(f" {key}: shape {value.shape}")
else:
logger.error(f" {key}: {type(value)}")
raise
except Exception as e:
logger.error(f"Error loading sample from {brep_path}: {str(e)}")
logger.error("Data structure:")
raise
def _load_sdf_file(self, sdf_path):
"""加载和处理SDF数据,并进行随机采样"""
try:
# 加载SDF值
sdf_data = np.load(sdf_path)
if 'pos' not in sdf_data or 'neg' not in sdf_data:
raise ValueError("Missing pos/neg data in SDF file")
sdf_pos = sdf_data['pos'] # (N1, 4)
sdf_neg = sdf_data['neg'] # (N2, 4)
# 添加数据验证
if sdf_pos.shape[1] != 4 or sdf_neg.shape[1] != 4:
raise ValueError(f"Invalid SDF data shape: pos={sdf_pos.shape}, neg={sdf_neg.shape}")
# 随机采样
max_points = self.config.data.num_query_points # 例如4096
# 确保正负样本均衡
num_pos = min(max_points // 2, sdf_pos.shape[0])
num_neg = min(max_points // 2, sdf_neg.shape[0])
# 随机采样正样本
if sdf_pos.shape[0] > num_pos:
pos_indices = np.random.choice(sdf_pos.shape[0], num_pos, replace=False)
sdf_pos = sdf_pos[pos_indices]
# 随机采样负样本
if sdf_neg.shape[0] > num_neg:
neg_indices = np.random.choice(sdf_neg.shape[0], num_neg, replace=False)
sdf_neg = sdf_neg[neg_indices]
# 合并数据
sdf_np = np.concatenate([sdf_pos, sdf_neg], axis=0)
# 再次随机打乱
np.random.shuffle(sdf_np)
# 如果总点数仍然超过最大限制,再次采样
if sdf_np.shape[0] > max_points:
indices = np.random.choice(sdf_np.shape[0], max_points, replace=False)
sdf_np = sdf_np[indices]
#logger.debug(f"Sampled SDF points: {sdf_np.shape[0]} (max: {max_points})")
return torch.from_numpy(sdf_np.astype(np.float32))
except Exception as e:
logger.error(f"Error loading SDF from {sdf_path}")
logger.error(f"Error type: {type(e).__name__}")
logger.error(f"Error message: {str(e)}")
raise
def test_dataset():
"""测试数据集功能"""
try:
# 获取配置
config = get_default_config()
brep_dir = config.data.brep_dir
sdf_dir = config.data.sdf_dir
valid_data_dir = config.data.valid_data_dir
split = 'train'
max_face = config.data.max_face
max_edge = config.data.max_edge
num_edge_points = config.model.num_edge_points
num_surf_points = config.model.num_surf_points
num_query_points = config.data.num_query_points
# 定义预期的数据维度,使用配置中的参数
expected_shapes = {
'edge_ncs': (max_face, max_edge, num_edge_points, 3), # [max_face, max_edge, sample_points, xyz]
'edge_pos': (max_face, max_edge, 6),
'edge_mask': (max_face, max_edge),
'surf_ncs': (max_face, num_surf_points, 3),
'surf_pos': (max_face, 6),
'vertex_pos': (max_face, max_edge, 2, 3),
'points': (num_query_points, 3),
'sdf': (num_query_points, 1)
}
logger.info("="*50)
logger.info("Testing dataset")
logger.info(f"B-rep directory: {brep_dir}")
logger.info(f"SDF directory: {sdf_dir}")
logger.info(f"Split: {split}")
# 2. 初始化数据集
dataset = BRepSDFDataset(
brep_dir=brep_dir,
sdf_dir=sdf_dir,
valid_data_dir=valid_data_dir,
split=split
)
logger.info(f"\nDataset size: {len(dataset)}")
# 3. 测试单个样本加载和形状检查
logger.info("\nTesting single sample loading...")
try:
sample = dataset[0]
logger.info("Sample keys and shapes:")
for key, value in sample.items():
if isinstance(value, torch.Tensor):
actual_shape = tuple(value.shape)
expected_shape = expected_shapes.get(key)
shape_match = actual_shape == expected_shape if expected_shape else None
logger.info(f" {key}:")
logger.info(f" Shape: {actual_shape}")
logger.info(f" Expected: {expected_shape}")
logger.info(f" Match: {shape_match}")
logger.info(f" dtype: {value.dtype}")
logger.info(f" grad: {value.requires_grad}")
if not shape_match:
logger.warning(f" Shape mismatch for {key}!")
else:
logger.info(f" {key}: {type(value)}")
except Exception as e:
logger.error("Error loading single sample")
logger.error(f"Error message: {str(e)}")
raise
# 4. 测试数据加载器和批处理形状
logger.info("\nTesting DataLoader...")
try:
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=2,
shuffle=True,
num_workers=0
)
batch = next(iter(dataloader))
logger.info("Batch keys and shapes:")
for key, value in batch.items():
if isinstance(value, torch.Tensor):
actual_shape = tuple(value.shape)
expected_shape = expected_shapes.get(key)
if expected_shape:
expected_batch_shape = (2,) + expected_shape
shape_match = actual_shape == expected_batch_shape
else:
expected_batch_shape = None
shape_match = None
logger.info(f" {key}:")
logger.info(f" Shape: {actual_shape}")
logger.info(f" Expected: {expected_batch_shape}")
logger.info(f" Match: {shape_match}")
logger.info(f" dtype: {value.dtype}")
if not shape_match:
logger.warning(f" Shape mismatch for {key}!")
elif isinstance(value, list):
logger.info(f" {key}: list of length {len(value)}")
else:
logger.info(f" {key}: {type(value)}")
except Exception as e:
logger.error("Error in DataLoader")
logger.error(f"Error message: {str(e)}")
raise
# 5. 验证数据范围
logger.info("\nValidating data ranges...")
try:
for key, value in batch.items():
if isinstance(value, torch.Tensor) and value.dtype in [torch.float32, torch.float64]:
logger.info(f" {key}:")
logger.info(f" min: {value.min().item():.4f}")
logger.info(f" max: {value.max().item():.4f}")
logger.info(f" mean: {value.mean().item():.4f}")
logger.info(f" std: {value.std().item():.4f}")
# 检查是否有NaN或Inf
has_nan = torch.isnan(value).any()
has_inf = torch.isinf(value).any()
if has_nan or has_inf:
logger.warning(f" Found NaN: {has_nan}, Inf: {has_inf}")
except Exception as e:
logger.error("Error validating data ranges")
logger.error(f"Error message: {str(e)}")
raise
logger.info("\nAll tests completed successfully!")
logger.info("="*50)
except Exception as e:
logger.error(f"Error in test_dataset: {str(e)}")
raise
if __name__ == '__main__':
test_dataset()