Browse Source

test: 修改数据集加载测试

main
mckay 4 months ago
parent
commit
3053eda7fd
  1. 127
      brep2sdf/data/data.py

127
brep2sdf/data/data.py

@ -334,26 +334,133 @@ class BRepSDFDataset(Dataset):
def test_dataset():
"""测试数据集功能"""
try:
# 设置测试路径
brep_dir = '/home/wch/myDeepSDF/test_data/pkl'
sdf_dir = '/home/wch/myDeepSDF/test_data/sdf'
# 1. 设置测试路径和预期的数据维度
brep_dir = '/home/wch/brep2sdf/test_data/pkl'
sdf_dir = '/home/wch/brep2sdf/test_data/sdf'
valid_data_dir = "/home/wch/brep2sdf/test_data/result/pkl"
split = 'train'
# 定义预期的数据维度
expected_shapes = {
'edge_ncs': (70, 70, 10, 3), # [max_face, max_edge, sample_points, xyz]
'edge_pos': (70, 70, 6), # [max_face, max_edge, bbox]
'edge_mask': (70, 70), # [max_face, max_edge]
'surf_ncs': (70, 100, 3), # [max_face, sample_points, xyz]
'surf_pos': (70, 6), # [max_face, bbox]
'vertex_pos': (70, 70, 2, 3), # [max_face, max_edge, 2_points, xyz]
'sdf': (2097152, 4) # [num_points, xyz+sdf]
}
logger.info("="*50)
logger.info("Testing dataset")
logger.info(f"B-rep directory: {brep_dir}")
logger.info(f"SDF directory: {sdf_dir}")
logger.info(f"Split: {split}")
# ... (其余测试代码保持不变) ...
dataset = BRepSDFDataset(brep_dir='/home/wch/brep2sdf/test_data/pkl', sdf_dir='/home/wch/brep2sdf/test_data/sdf',valid_data_dir="/home/wch/brep2sdf/test_data/result/pkl", split='train')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
for batch in dataloader:
print(batch['sdf'].shape)
break
# 2. 初始化数据集
dataset = BRepSDFDataset(
brep_dir=brep_dir,
sdf_dir=sdf_dir,
valid_data_dir=valid_data_dir,
split=split
)
logger.info(f"\nDataset size: {len(dataset)}")
# 3. 测试单个样本加载和形状检查
logger.info("\nTesting single sample loading...")
try:
sample = dataset[0]
logger.info("Sample keys and shapes:")
for key, value in sample.items():
if isinstance(value, torch.Tensor):
actual_shape = tuple(value.shape)
expected_shape = expected_shapes.get(key)
shape_match = actual_shape == expected_shape if expected_shape else None
logger.info(f" {key}:")
logger.info(f" Shape: {actual_shape}")
logger.info(f" Expected: {expected_shape}")
logger.info(f" Match: {shape_match}")
logger.info(f" dtype: {value.dtype}")
if not shape_match:
logger.warning(f" Shape mismatch for {key}!")
else:
logger.info(f" {key}: {type(value)}")
except Exception as e:
logger.error("Error loading single sample")
logger.error(f"Error message: {str(e)}")
raise
# 4. 测试数据加载器和批处理形状
logger.info("\nTesting DataLoader...")
try:
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=2,
shuffle=True,
num_workers=0
)
batch = next(iter(dataloader))
logger.info("Batch keys and shapes:")
for key, value in batch.items():
if isinstance(value, torch.Tensor):
actual_shape = tuple(value.shape)
expected_shape = expected_shapes.get(key)
if expected_shape:
expected_batch_shape = (2,) + expected_shape
shape_match = actual_shape == expected_batch_shape
else:
expected_batch_shape = None
shape_match = None
logger.info(f" {key}:")
logger.info(f" Shape: {actual_shape}")
logger.info(f" Expected: {expected_batch_shape}")
logger.info(f" Match: {shape_match}")
logger.info(f" dtype: {value.dtype}")
if not shape_match:
logger.warning(f" Shape mismatch for {key}!")
elif isinstance(value, list):
logger.info(f" {key}: list of length {len(value)}")
else:
logger.info(f" {key}: {type(value)}")
except Exception as e:
logger.error("Error in DataLoader")
logger.error(f"Error message: {str(e)}")
raise
# 5. 验证数据范围
logger.info("\nValidating data ranges...")
try:
for key, value in batch.items():
if isinstance(value, torch.Tensor) and value.dtype in [torch.float32, torch.float64]:
logger.info(f" {key}:")
logger.info(f" min: {value.min().item():.4f}")
logger.info(f" max: {value.max().item():.4f}")
logger.info(f" mean: {value.mean().item():.4f}")
logger.info(f" std: {value.std().item():.4f}")
# 检查是否有NaN或Inf
has_nan = torch.isnan(value).any()
has_inf = torch.isinf(value).any()
if has_nan or has_inf:
logger.warning(f" Found NaN: {has_nan}, Inf: {has_inf}")
except Exception as e:
logger.error("Error validating data ranges")
logger.error(f"Error message: {str(e)}")
raise
logger.info("\nAll tests completed successfully!")
logger.info("="*50)
except Exception as e:
logger.error(f"Error in test_dataset: {str(e)}")
raise
if __name__ == '__main__':
test_dataset()
Loading…
Cancel
Save