From b5928becf747ba4a217af64c349d25c2ac78a894 Mon Sep 17 00:00:00 2001 From: mckay Date: Thu, 5 Dec 2024 01:15:59 +0800 Subject: [PATCH] fix: neg sdf too small, -> shape error --- brep2sdf/data/data.py | 375 ++++++++++++++++++++++++++++-------------- 1 file changed, 252 insertions(+), 123 deletions(-) diff --git a/brep2sdf/data/data.py b/brep2sdf/data/data.py index 5e6fe84..ac80975 100644 --- a/brep2sdf/data/data.py +++ b/brep2sdf/data/data.py @@ -183,9 +183,16 @@ class BRepSDFDataset(Dataset): max_points = self.config.data.num_query_points # 例如4096 # 确保正负样本均衡 - num_pos = min(max_points // 2, sdf_pos.shape[0]) - num_neg = min(max_points // 2, sdf_neg.shape[0]) + if max_points // 2 > sdf_pos.shape[0]: + logger.warning(f"正样本过少,期望>{max_points // 2},实际:{sdf_pos.shape[0]}") + + if max_points // 2 > sdf_neg.shape[0]: + num_neg = sdf_neg.shape[0] + else: + num_neg = max_points // 2 + num_pos = max_points - num_neg + # 随机采样正样本 if sdf_pos.shape[0] > num_pos: pos_indices = np.random.choice(sdf_pos.shape[0], num_pos, replace=False) @@ -222,139 +229,261 @@ def test_dataset(): try: # 获取配置 config = get_default_config() - brep_dir = config.data.brep_dir - sdf_dir = config.data.sdf_dir - valid_data_dir = config.data.valid_data_dir - split = 'train' - max_face = config.data.max_face - max_edge = config.data.max_edge - num_edge_points = config.model.num_edge_points - num_surf_points = config.model.num_surf_points - num_query_points = config.data.num_query_points - # 定义预期的数据维度,使用配置中的参数 + # 定义预期的数据维度 expected_shapes = { - 'edge_ncs': (max_face, max_edge, num_edge_points, 3), # [max_face, max_edge, sample_points, xyz] - 'edge_pos': (max_face, max_edge, 6), - 'edge_mask': (max_face, max_edge), - 'surf_ncs': (max_face, num_surf_points, 3), - 'surf_pos': (max_face, 6), - 'vertex_pos': (max_face, max_edge, 2, 3), - 'points': (num_query_points, 3), - 'sdf': (num_query_points, 1) + 'edge_ncs': (config.data.max_face, config.data.max_edge, config.model.num_edge_points, 3), + 'edge_pos': (config.data.max_face, config.data.max_edge, 6), + 'edge_mask': (config.data.max_face, config.data.max_edge), + 'surf_ncs': (config.data.max_face, config.model.num_surf_points, 3), + 'surf_pos': (config.data.max_face, 6), + 'vertex_pos': (config.data.max_face, config.data.max_edge, 2, 3), + 'points': (config.data.num_query_points, 3), + 'sdf': (config.data.num_query_points, 1) } logger.info("="*50) - logger.info("Testing dataset") - logger.info(f"B-rep directory: {brep_dir}") - logger.info(f"SDF directory: {sdf_dir}") - logger.info(f"Split: {split}") + logger.info("测试数据集") + logger.info(f"预期形状:") + for key, shape in expected_shapes.items(): + logger.info(f" {key}: {shape}") + + # 初始化数据集 + dataset = BRepSDFDataset( + brep_dir=config.data.brep_dir, + sdf_dir=config.data.sdf_dir, + valid_data_dir=config.data.valid_data_dir, + split='train' + ) + + # 测试数据加载 + logger.info("\n测试数据加载...") + sample = dataset[0] + + # 检查数据类型和形状 + logger.info("\n数据类型和形状检查:") + for key, value in sample.items(): + if isinstance(value, torch.Tensor): + actual_shape = tuple(value.shape) + expected_shape = expected_shapes.get(key) + shape_match = "✓" if actual_shape == expected_shape else "✗" + + logger.info(f"\n{key}:") + logger.info(f" 实际形状: {actual_shape}") + logger.info(f" 预期形状: {expected_shape}") + logger.info(f" 匹配状态: {shape_match}") + logger.info(f" 数据类型: {value.dtype}") + + # 仅对浮点类型计算数值范围、均值和标准差 + if value.dtype.is_floating_point: + logger.info(f" 数值范围: [{value.min():.3f}, {value.max():.3f}]") + logger.info(f" 均值: {value.mean():.3f}") + logger.info(f" 标准差: {value.std():.3f}") + + if shape_match == "✗": + logger.warning(f" 形状不匹配: {key}") + if key in ['points', 'sdf']: + logger.warning(f" 查询点数量不一致,预期 {expected_shape[0]},实际 {actual_shape[0]}") + elif key in ['edge_ncs', 'edge_pos', 'edge_mask']: + logger.warning(f" 边数量不一致,预期 {expected_shape[1]},实际 {actual_shape[1]}") + elif key in ['surf_ncs', 'surf_pos']: + logger.warning(f" 面数量不一致,预期 {expected_shape[0]},实际 {actual_shape[0]}") + + # 测试批处理 + logger.info("\n测试批处理...") + batch_size = 4 + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + num_workers=0 + ) + + batch = next(iter(dataloader)) + logger.info("\n批处理形状检查:") + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch_shape = tuple(value.shape) + expected_batch_shape = (batch_size,) + expected_shapes[key] + shape_match = "✓" if batch_shape == expected_batch_shape else "✗" + + logger.info(f"\n{key}:") + logger.info(f" 实际形状: {batch_shape}") + logger.info(f" 预期形状: {expected_batch_shape}") + logger.info(f" 匹配状态: {shape_match}") + logger.info(f" 数据类型: {value.dtype}") + + # 仅对浮点类型计算数值范围、均值和标准差 + if value.dtype.is_floating_point: + logger.info(f" 数值范围: [{value.min():.3f}, {value.max():.3f}]") + logger.info(f" 均值: {value.mean():.3f}") + logger.info(f" 标准差: {value.std():.3f}") + + if shape_match == "✗": + logger.warning(f" 批处理形状不匹配: {key}") + + logger.info("\n测试完成!") + logger.info("="*50) - # 2. 初始化数据集 + except Exception as e: + logger.error(f"测试过程中出错: {str(e)}") + raise +from collections import defaultdict +from tqdm import tqdm +def validate_dataset(split: str = 'train', num_samples: int = None): + """全面验证数据集 + + Args: + split: 数据集分割 ('train', 'val', 'test') + num_samples: 要检查的样本数量,None表示检查所有样本 + """ + try: + config = get_default_config() + logger.info(f"开始验证{split}数据集...") + + # 初始化数据集 dataset = BRepSDFDataset( - brep_dir=brep_dir, - sdf_dir=sdf_dir, - valid_data_dir=valid_data_dir, - split=split + brep_dir=config.data.brep_dir, + sdf_dir=config.data.sdf_dir, + valid_data_dir=config.data.valid_data_dir, + split='train' ) - logger.info(f"\nDataset size: {len(dataset)}") - # 3. 测试单个样本加载和形状检查 - logger.info("\nTesting single sample loading...") - try: - sample = dataset[0] - logger.info("Sample keys and shapes:") - for key, value in sample.items(): - if isinstance(value, torch.Tensor): - actual_shape = tuple(value.shape) - expected_shape = expected_shapes.get(key) - shape_match = actual_shape == expected_shape if expected_shape else None - - logger.info(f" {key}:") - logger.info(f" Shape: {actual_shape}") - logger.info(f" Expected: {expected_shape}") - logger.info(f" Match: {shape_match}") - logger.info(f" dtype: {value.dtype}") - logger.info(f" grad: {value.requires_grad}") - - if not shape_match: - logger.warning(f" Shape mismatch for {key}!") - else: - logger.info(f" {key}: {type(value)}") - except Exception as e: - logger.error("Error loading single sample") - logger.error(f"Error message: {str(e)}") - raise - - # 4. 测试数据加载器和批处理形状 - logger.info("\nTesting DataLoader...") - try: - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=2, - shuffle=True, - num_workers=0 - ) - - batch = next(iter(dataloader)) - logger.info("Batch keys and shapes:") - for key, value in batch.items(): - if isinstance(value, torch.Tensor): - actual_shape = tuple(value.shape) - expected_shape = expected_shapes.get(key) - if expected_shape: - expected_batch_shape = (2,) + expected_shape - shape_match = actual_shape == expected_batch_shape - else: - expected_batch_shape = None - shape_match = None - - logger.info(f" {key}:") - logger.info(f" Shape: {actual_shape}") - logger.info(f" Expected: {expected_batch_shape}") - logger.info(f" Match: {shape_match}") - logger.info(f" dtype: {value.dtype}") - - if not shape_match: - logger.warning(f" Shape mismatch for {key}!") - elif isinstance(value, list): - logger.info(f" {key}: list of length {len(value)}") - else: - logger.info(f" {key}: {type(value)}") - - except Exception as e: - logger.error("Error in DataLoader") - logger.error(f"Error message: {str(e)}") - raise - - # 5. 验证数据范围 - logger.info("\nValidating data ranges...") - try: - for key, value in batch.items(): - if isinstance(value, torch.Tensor) and value.dtype in [torch.float32, torch.float64]: - logger.info(f" {key}:") - logger.info(f" min: {value.min().item():.4f}") - logger.info(f" max: {value.max().item():.4f}") - logger.info(f" mean: {value.mean().item():.4f}") - logger.info(f" std: {value.std().item():.4f}") - - # 检查是否有NaN或Inf - has_nan = torch.isnan(value).any() - has_inf = torch.isinf(value).any() - if has_nan or has_inf: - logger.warning(f" Found NaN: {has_nan}, Inf: {has_inf}") + total_samples = len(dataset) if num_samples is None else min(num_samples, len(dataset)) + logger.info(f"总样本数: {total_samples}") + + # 初始化统计信息 + stats = { + 'face_counts': [], + 'edge_counts': [], + 'vertex_counts': [], + 'sdf_point_counts': [], + 'invalid_samples': [], + 'shape_mismatches': defaultdict(int), + 'value_ranges': defaultdict(lambda: {'min': float('inf'), 'max': float('-inf')}), + 'nan_counts': defaultdict(int), + 'inf_counts': defaultdict(int) + } + + # 遍历数据集 + for idx in tqdm(range(total_samples), desc="验证数据"): + try: + sample = dataset[idx] + + # 1. 检查数据完整性 + required_keys = ['surf_ncs', 'surf_pos', 'edge_ncs', 'edge_pos', + 'vertex_pos', 'points', 'sdf', 'edge_mask'] + missing_keys = [key for key in required_keys if key not in sample] + if missing_keys: + stats['invalid_samples'].append((idx, f"缺少键: {missing_keys}")) + continue + + # 2. 检查形状 + expected_shapes = { + 'surf_ncs': (config.data.max_face, config.model.num_surf_points, 3), + 'surf_pos': (config.data.max_face, 6), + 'edge_ncs': (config.data.max_face, config.data.max_edge, config.model.num_edge_points, 3), + 'edge_pos': (config.data.max_face, config.data.max_edge, 6), + 'edge_mask': (config.data.max_face, config.data.max_edge), + 'vertex_pos': (config.data.max_face, config.data.max_edge, 2, 3), + 'points': (config.data.num_query_points, 3), + 'sdf': (config.data.num_query_points, 1) + } + + for key, expected_shape in expected_shapes.items(): + if key in sample: + actual_shape = tuple(sample[key].shape) + if actual_shape != expected_shape: + stats['shape_mismatches'][key] += 1 + stats['invalid_samples'].append( + (idx, f"{key} 形状不匹配: 预期 {expected_shape}, 实际 {actual_shape}") + ) + + # 3. 检查数值范围和无效值 + for key, tensor in sample.items(): + if isinstance(tensor, torch.Tensor) and tensor.dtype.is_floating_point: + # 更新值范围 + stats['value_ranges'][key]['min'] = min(stats['value_ranges'][key]['min'], + tensor.min().item()) + stats['value_ranges'][key]['max'] = max(stats['value_ranges'][key]['max'], + tensor.max().item()) - except Exception as e: - logger.error("Error validating data ranges") - logger.error(f"Error message: {str(e)}") - raise - - logger.info("\nAll tests completed successfully!") - logger.info("="*50) + # 检查NaN和Inf + nan_count = torch.isnan(tensor).sum().item() + inf_count = torch.isinf(tensor).sum().item() + if nan_count > 0: + stats['nan_counts'][key] += nan_count + if inf_count > 0: + stats['inf_counts'][key] += inf_count + + # 4. 收集统计信息 + stats['face_counts'].append(sample['surf_ncs'].shape[0]) + stats['edge_counts'].append(sample['edge_ncs'].shape[1]) + stats['vertex_counts'].append(len(torch.unique(sample['vertex_pos'].reshape(-1, 3), dim=0))) + stats['sdf_point_counts'].append(sample['points'].shape[0]) + + except Exception as e: + stats['invalid_samples'].append((idx, str(e))) + + # 输出统计结果 + logger.info("\n=== 数据集验证结果 ===") + + # 1. 基本统计信息 + logger.info("\n基本统计信息:") + logger.info(f"总样本数: {total_samples}") + logger.info(f"有效样本数: {total_samples - len(stats['invalid_samples'])}") + logger.info(f"无效样本数: {len(stats['invalid_samples'])}") + + # 2. 形状不匹配统计 + if stats['shape_mismatches']: + logger.info("\n形状不匹配统计:") + for key, count in stats['shape_mismatches'].items(): + logger.info(f" {key}: {count}个样本不匹配") + + # 3. 数值范围统计 + logger.info("\n数值范围统计:") + for key, ranges in stats['value_ranges'].items(): + logger.info(f" {key}:") + logger.info(f" 最小值: {ranges['min']:.3f}") + logger.info(f" 最大值: {ranges['max']:.3f}") + + # 4. 无效值统计 + if sum(stats['nan_counts'].values()) > 0 or sum(stats['inf_counts'].values()) > 0: + logger.info("\n无效值统计:") + for key in stats['nan_counts'].keys(): + if stats['nan_counts'][key] > 0: + logger.info(f" {key} 包含 {stats['nan_counts'][key]} 个 NaN 值") + for key in stats['inf_counts'].keys(): + if stats['inf_counts'][key] > 0: + logger.info(f" {key} 包含 {stats['inf_counts'][key]} 个 Inf 值") + + # 5. 几何特征统计 + logger.info("\n几何特征统计:") + for name, values in [ + ('面数', stats['face_counts']), + ('边数', stats['edge_counts']), + ('顶点数', stats['vertex_counts']), + ('SDF采样点数', stats['sdf_point_counts']) + ]: + values = np.array(values) + logger.info(f" {name}:") + logger.info(f" 最小值: {np.min(values)}") + logger.info(f" 最大值: {np.max(values)}") + logger.info(f" 平均值: {np.mean(values):.2f}") + logger.info(f" 中位数: {np.median(values):.2f}") + logger.info(f" 标准差: {np.std(values):.2f}") + + # 6. 输出无效样本详情 + if stats['invalid_samples']: + logger.info("\n无效样本详情:") + for idx, error in stats['invalid_samples']: + logger.info(f" 样本 {idx}: {error}") + + return stats except Exception as e: - logger.error(f"Error in test_dataset: {str(e)}") + logger.error(f"验证过程出错: {str(e)}") raise if __name__ == '__main__': - test_dataset() \ No newline at end of file + validate_dataset(split='train', num_samples=None) # 先测试100个样本 \ No newline at end of file