From 3053eda7fd3189190623419ab37f76a89340af6f Mon Sep 17 00:00:00 2001 From: mckay Date: Mon, 18 Nov 2024 23:45:17 +0800 Subject: [PATCH] =?UTF-8?q?test=EF=BC=9A=20=E4=BF=AE=E6=94=B9=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E9=9B=86=E5=8A=A0=E8=BD=BD=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/data/data.py | 127 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 117 insertions(+), 10 deletions(-) diff --git a/brep2sdf/data/data.py b/brep2sdf/data/data.py index 40c7296..05dd07f 100644 --- a/brep2sdf/data/data.py +++ b/brep2sdf/data/data.py @@ -334,26 +334,133 @@ class BRepSDFDataset(Dataset): def test_dataset(): """测试数据集功能""" try: - # 设置测试路径 - brep_dir = '/home/wch/myDeepSDF/test_data/pkl' - sdf_dir = '/home/wch/myDeepSDF/test_data/sdf' + # 1. 设置测试路径和预期的数据维度 + brep_dir = '/home/wch/brep2sdf/test_data/pkl' + sdf_dir = '/home/wch/brep2sdf/test_data/sdf' + valid_data_dir = "/home/wch/brep2sdf/test_data/result/pkl" split = 'train' + # 定义预期的数据维度 + expected_shapes = { + 'edge_ncs': (70, 70, 10, 3), # [max_face, max_edge, sample_points, xyz] + 'edge_pos': (70, 70, 6), # [max_face, max_edge, bbox] + 'edge_mask': (70, 70), # [max_face, max_edge] + 'surf_ncs': (70, 100, 3), # [max_face, sample_points, xyz] + 'surf_pos': (70, 6), # [max_face, bbox] + 'vertex_pos': (70, 70, 2, 3), # [max_face, max_edge, 2_points, xyz] + 'sdf': (2097152, 4) # [num_points, xyz+sdf] + } + logger.info("="*50) logger.info("Testing dataset") logger.info(f"B-rep directory: {brep_dir}") logger.info(f"SDF directory: {sdf_dir}") logger.info(f"Split: {split}") - # ... (其余测试代码保持不变) ... - dataset = BRepSDFDataset(brep_dir='/home/wch/brep2sdf/test_data/pkl', sdf_dir='/home/wch/brep2sdf/test_data/sdf',valid_data_dir="/home/wch/brep2sdf/test_data/result/pkl", split='train') - dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True) - - for batch in dataloader: - print(batch['sdf'].shape) - break + # 2. 初始化数据集 + dataset = BRepSDFDataset( + brep_dir=brep_dir, + sdf_dir=sdf_dir, + valid_data_dir=valid_data_dir, + split=split + ) + logger.info(f"\nDataset size: {len(dataset)}") + + # 3. 测试单个样本加载和形状检查 + logger.info("\nTesting single sample loading...") + try: + sample = dataset[0] + logger.info("Sample keys and shapes:") + for key, value in sample.items(): + if isinstance(value, torch.Tensor): + actual_shape = tuple(value.shape) + expected_shape = expected_shapes.get(key) + shape_match = actual_shape == expected_shape if expected_shape else None + + logger.info(f" {key}:") + logger.info(f" Shape: {actual_shape}") + logger.info(f" Expected: {expected_shape}") + logger.info(f" Match: {shape_match}") + logger.info(f" dtype: {value.dtype}") + + if not shape_match: + logger.warning(f" Shape mismatch for {key}!") + else: + logger.info(f" {key}: {type(value)}") + except Exception as e: + logger.error("Error loading single sample") + logger.error(f"Error message: {str(e)}") + raise + + # 4. 测试数据加载器和批处理形状 + logger.info("\nTesting DataLoader...") + try: + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=2, + shuffle=True, + num_workers=0 + ) + + batch = next(iter(dataloader)) + logger.info("Batch keys and shapes:") + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + actual_shape = tuple(value.shape) + expected_shape = expected_shapes.get(key) + if expected_shape: + expected_batch_shape = (2,) + expected_shape + shape_match = actual_shape == expected_batch_shape + else: + expected_batch_shape = None + shape_match = None + + logger.info(f" {key}:") + logger.info(f" Shape: {actual_shape}") + logger.info(f" Expected: {expected_batch_shape}") + logger.info(f" Match: {shape_match}") + logger.info(f" dtype: {value.dtype}") + + if not shape_match: + logger.warning(f" Shape mismatch for {key}!") + elif isinstance(value, list): + logger.info(f" {key}: list of length {len(value)}") + else: + logger.info(f" {key}: {type(value)}") + + except Exception as e: + logger.error("Error in DataLoader") + logger.error(f"Error message: {str(e)}") + raise + + # 5. 验证数据范围 + logger.info("\nValidating data ranges...") + try: + for key, value in batch.items(): + if isinstance(value, torch.Tensor) and value.dtype in [torch.float32, torch.float64]: + logger.info(f" {key}:") + logger.info(f" min: {value.min().item():.4f}") + logger.info(f" max: {value.max().item():.4f}") + logger.info(f" mean: {value.mean().item():.4f}") + logger.info(f" std: {value.std().item():.4f}") + + # 检查是否有NaN或Inf + has_nan = torch.isnan(value).any() + has_inf = torch.isinf(value).any() + if has_nan or has_inf: + logger.warning(f" Found NaN: {has_nan}, Inf: {has_inf}") + + except Exception as e: + logger.error("Error validating data ranges") + logger.error(f"Error message: {str(e)}") + raise + + logger.info("\nAll tests completed successfully!") + logger.info("="*50) + except Exception as e: logger.error(f"Error in test_dataset: {str(e)}") + raise if __name__ == '__main__': test_dataset() \ No newline at end of file