| 
						
						
							
								
							
						
						
					 | 
					@ -183,9 +183,16 @@ class BRepSDFDataset(Dataset): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            max_points = self.config.data.num_query_points  # 例如4096 | 
					 | 
					 | 
					            max_points = self.config.data.num_query_points  # 例如4096 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 确保正负样本均衡 | 
					 | 
					 | 
					            # 确保正负样本均衡 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            num_pos = min(max_points // 2, sdf_pos.shape[0]) | 
					 | 
					 | 
					            if max_points // 2 > sdf_pos.shape[0]: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            num_neg = min(max_points // 2, sdf_neg.shape[0]) | 
					 | 
					 | 
					                logger.warning(f"正样本过少,期望>{max_points // 2},实际:{sdf_pos.shape[0]}") | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            if max_points // 2 > sdf_neg.shape[0]: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                num_neg = sdf_neg.shape[0] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                num_neg = max_points // 2 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            num_pos = max_points - num_neg | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 随机采样正样本 | 
					 | 
					 | 
					            # 随机采样正样本 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            if sdf_pos.shape[0] > num_pos: | 
					 | 
					 | 
					            if sdf_pos.shape[0] > num_pos: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                pos_indices = np.random.choice(sdf_pos.shape[0], num_pos, replace=False) | 
					 | 
					 | 
					                pos_indices = np.random.choice(sdf_pos.shape[0], num_pos, replace=False) | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -222,139 +229,261 @@ def test_dataset(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    try: | 
					 | 
					 | 
					    try: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 获取配置 | 
					 | 
					 | 
					        # 获取配置 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        config = get_default_config() | 
					 | 
					 | 
					        config = get_default_config() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        brep_dir = config.data.brep_dir | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        sdf_dir = config.data.sdf_dir | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        valid_data_dir = config.data.valid_data_dir | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        split = 'train' | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        max_face = config.data.max_face | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        max_edge = config.data.max_edge | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        num_edge_points = config.model.num_edge_points | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        num_surf_points = config.model.num_surf_points | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        num_query_points = config.data.num_query_points | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 定义预期的数据维度,使用配置中的参数 | 
					 | 
					 | 
					        # 定义预期的数据维度 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        expected_shapes = { | 
					 | 
					 | 
					        expected_shapes = { | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            'edge_ncs': (max_face, max_edge, num_edge_points, 3),    # [max_face, max_edge, sample_points, xyz] | 
					 | 
					 | 
					            'edge_ncs': (config.data.max_face, config.data.max_edge, config.model.num_edge_points, 3), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            'edge_pos': (max_face, max_edge, 6),         | 
					 | 
					 | 
					            'edge_pos': (config.data.max_face, config.data.max_edge, 6), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            'edge_mask': (max_face, max_edge),           | 
					 | 
					 | 
					            'edge_mask': (config.data.max_face, config.data.max_edge), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            'surf_ncs': (max_face, num_surf_points, 3),   | 
					 | 
					 | 
					            'surf_ncs': (config.data.max_face, config.model.num_surf_points, 3), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            'surf_pos': (max_face, 6),             | 
					 | 
					 | 
					            'surf_pos': (config.data.max_face, 6), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            'vertex_pos': (max_face, max_edge, 2, 3),    | 
					 | 
					 | 
					            'vertex_pos': (config.data.max_face, config.data.max_edge, 2, 3), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            'points': (num_query_points, 3), | 
					 | 
					 | 
					            'points': (config.data.num_query_points, 3), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            'sdf': (num_query_points, 1)              | 
					 | 
					 | 
					            'sdf': (config.data.num_query_points, 1) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        } | 
					 | 
					 | 
					        } | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        logger.info("="*50) | 
					 | 
					 | 
					        logger.info("="*50) | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        logger.info("Testing dataset") | 
					 | 
					 | 
					        logger.info("测试数据集") | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        logger.info(f"B-rep directory: {brep_dir}") | 
					 | 
					 | 
					        logger.info(f"预期形状:") | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        logger.info(f"SDF directory: {sdf_dir}") | 
					 | 
					 | 
					        for key, shape in expected_shapes.items(): | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        logger.info(f"Split: {split}") | 
					 | 
					 | 
					            logger.info(f"  {key}: {shape}") | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 初始化数据集 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        dataset = BRepSDFDataset( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            brep_dir=config.data.brep_dir, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            sdf_dir=config.data.sdf_dir, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            valid_data_dir=config.data.valid_data_dir, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            split='train' | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 测试数据加载 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info("\n测试数据加载...") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        sample = dataset[0] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 检查数据类型和形状 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info("\n数据类型和形状检查:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        for key, value in sample.items(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            if isinstance(value, torch.Tensor): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                actual_shape = tuple(value.shape) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                expected_shape = expected_shapes.get(key) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                shape_match = "✓" if actual_shape == expected_shape else "✗" | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.info(f"\n{key}:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.info(f"  实际形状: {actual_shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.info(f"  预期形状: {expected_shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.info(f"  匹配状态: {shape_match}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.info(f"  数据类型: {value.dtype}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                # 仅对浮点类型计算数值范围、均值和标准差 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                if value.dtype.is_floating_point: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    logger.info(f"  数值范围: [{value.min():.3f}, {value.max():.3f}]") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    logger.info(f"  均值: {value.mean():.3f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    logger.info(f"  标准差: {value.std():.3f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                if shape_match == "✗": | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    logger.warning(f"  形状不匹配: {key}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    if key in ['points', 'sdf']: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        logger.warning(f"  查询点数量不一致,预期 {expected_shape[0]},实际 {actual_shape[0]}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    elif key in ['edge_ncs', 'edge_pos', 'edge_mask']: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        logger.warning(f"  边数量不一致,预期 {expected_shape[1]},实际 {actual_shape[1]}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    elif key in ['surf_ncs', 'surf_pos']: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        logger.warning(f"  面数量不一致,预期 {expected_shape[0]},实际 {actual_shape[0]}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 测试批处理 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info("\n测试批处理...") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        batch_size = 4 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        dataloader = torch.utils.data.DataLoader( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            dataset, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            batch_size=batch_size, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            shuffle=True, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            num_workers=0 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        batch = next(iter(dataloader)) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info("\n批处理形状检查:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        for key, value in batch.items(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            if isinstance(value, torch.Tensor): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                batch_shape = tuple(value.shape) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                expected_batch_shape = (batch_size,) + expected_shapes[key] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                shape_match = "✓" if batch_shape == expected_batch_shape else "✗" | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.info(f"\n{key}:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.info(f"  实际形状: {batch_shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.info(f"  预期形状: {expected_batch_shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.info(f"  匹配状态: {shape_match}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.info(f"  数据类型: {value.dtype}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                # 仅对浮点类型计算数值范围、均值和标准差 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                if value.dtype.is_floating_point: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    logger.info(f"  数值范围: [{value.min():.3f}, {value.max():.3f}]") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    logger.info(f"  均值: {value.mean():.3f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    logger.info(f"  标准差: {value.std():.3f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                if shape_match == "✗": | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    logger.warning(f"  批处理形状不匹配: {key}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info("\n测试完成!") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info("="*50) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 2. 初始化数据集 | 
					 | 
					 | 
					    except Exception as e: | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.error(f"测试过程中出错: {str(e)}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        raise | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					from collections import defaultdict | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					from tqdm import tqdm | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					def validate_dataset(split: str = 'train', num_samples: int = None): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    """全面验证数据集 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    Args: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        split: 数据集分割 ('train', 'val', 'test') | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        num_samples: 要检查的样本数量,None表示检查所有样本 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    """ | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    try: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        config = get_default_config() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"开始验证{split}数据集...") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 初始化数据集 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        dataset = BRepSDFDataset( | 
					 | 
					 | 
					        dataset = BRepSDFDataset( | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            brep_dir=brep_dir, | 
					 | 
					 | 
					            brep_dir=config.data.brep_dir, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            sdf_dir=sdf_dir, | 
					 | 
					 | 
					            sdf_dir=config.data.sdf_dir, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            valid_data_dir=valid_data_dir, | 
					 | 
					 | 
					            valid_data_dir=config.data.valid_data_dir, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            split=split | 
					 | 
					 | 
					            split='train' | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        ) | 
					 | 
					 | 
					        ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        logger.info(f"\nDataset size: {len(dataset)}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 3. 测试单个样本加载和形状检查 | 
					 | 
					 | 
					        total_samples = len(dataset) if num_samples is None else min(num_samples, len(dataset)) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        logger.info("\nTesting single sample loading...") | 
					 | 
					 | 
					        logger.info(f"总样本数: {total_samples}") | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        try: | 
					 | 
					 | 
					         | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            sample = dataset[0] | 
					 | 
					 | 
					        # 初始化统计信息 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            logger.info("Sample keys and shapes:") | 
					 | 
					 | 
					        stats = { | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            for key, value in sample.items(): | 
					 | 
					 | 
					            'face_counts': [], | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                if isinstance(value, torch.Tensor): | 
					 | 
					 | 
					            'edge_counts': [], | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    actual_shape = tuple(value.shape) | 
					 | 
					 | 
					            'vertex_counts': [], | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    expected_shape = expected_shapes.get(key) | 
					 | 
					 | 
					            'sdf_point_counts': [], | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    shape_match = actual_shape == expected_shape if expected_shape else None | 
					 | 
					 | 
					            'invalid_samples': [], | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                     | 
					 | 
					 | 
					            'shape_mismatches': defaultdict(int), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    logger.info(f"  {key}:") | 
					 | 
					 | 
					            'value_ranges': defaultdict(lambda: {'min': float('inf'), 'max': float('-inf')}), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    logger.info(f"    Shape: {actual_shape}") | 
					 | 
					 | 
					            'nan_counts': defaultdict(int), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    logger.info(f"    Expected: {expected_shape}") | 
					 | 
					 | 
					            'inf_counts': defaultdict(int) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    logger.info(f"    Match: {shape_match}") | 
					 | 
					 | 
					        } | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    logger.info(f"    dtype: {value.dtype}") | 
					 | 
					 | 
					         | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    logger.info(f"    grad: {value.requires_grad}") | 
					 | 
					 | 
					        # 遍历数据集 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                     | 
					 | 
					 | 
					        for idx in tqdm(range(total_samples), desc="验证数据"): | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    if not shape_match: | 
					 | 
					 | 
					            try: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                        logger.warning(f"    Shape mismatch for {key}!") | 
					 | 
					 | 
					                sample = dataset[idx] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                else: | 
					 | 
					 | 
					                 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    logger.info(f"  {key}: {type(value)}") | 
					 | 
					 | 
					                # 1. 检查数据完整性 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        except Exception as e: | 
					 | 
					 | 
					                required_keys = ['surf_ncs', 'surf_pos', 'edge_ncs', 'edge_pos',  | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            logger.error("Error loading single sample") | 
					 | 
					 | 
					                               'vertex_pos', 'points', 'sdf', 'edge_mask'] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            logger.error(f"Error message: {str(e)}") | 
					 | 
					 | 
					                missing_keys = [key for key in required_keys if key not in sample] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            raise | 
					 | 
					 | 
					                if missing_keys: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					             | 
					 | 
					 | 
					                    stats['invalid_samples'].append((idx, f"缺少键: {missing_keys}")) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 4. 测试数据加载器和批处理形状 | 
					 | 
					 | 
					                    continue | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        logger.info("\nTesting DataLoader...") | 
					 | 
					 | 
					                 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        try: | 
					 | 
					 | 
					                # 2. 检查形状 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            dataloader = torch.utils.data.DataLoader( | 
					 | 
					 | 
					                expected_shapes = { | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                dataset,  | 
					 | 
					 | 
					                    'surf_ncs': (config.data.max_face, config.model.num_surf_points, 3), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                batch_size=2, | 
					 | 
					 | 
					                    'surf_pos': (config.data.max_face, 6), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                shuffle=True, | 
					 | 
					 | 
					                    'edge_ncs': (config.data.max_face, config.data.max_edge, config.model.num_edge_points, 3), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                num_workers=0 | 
					 | 
					 | 
					                    'edge_pos': (config.data.max_face, config.data.max_edge, 6), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            ) | 
					 | 
					 | 
					                    'edge_mask': (config.data.max_face, config.data.max_edge), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					             | 
					 | 
					 | 
					                    'vertex_pos': (config.data.max_face, config.data.max_edge, 2, 3), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            batch = next(iter(dataloader)) | 
					 | 
					 | 
					                    'points': (config.data.num_query_points, 3), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            logger.info("Batch keys and shapes:") | 
					 | 
					 | 
					                    'sdf': (config.data.num_query_points, 1) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            for key, value in batch.items(): | 
					 | 
					 | 
					                } | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                if isinstance(value, torch.Tensor): | 
					 | 
					 | 
					                 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    actual_shape = tuple(value.shape) | 
					 | 
					 | 
					                for key, expected_shape in expected_shapes.items(): | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    expected_shape = expected_shapes.get(key) | 
					 | 
					 | 
					                    if key in sample: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    if expected_shape: | 
					 | 
					 | 
					                        actual_shape = tuple(sample[key].shape) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                        expected_batch_shape = (2,) + expected_shape | 
					 | 
					 | 
					                        if actual_shape != expected_shape: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                        shape_match = actual_shape == expected_batch_shape | 
					 | 
					 | 
					                            stats['shape_mismatches'][key] += 1 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    else: | 
					 | 
					 | 
					                            stats['invalid_samples'].append( | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                        expected_batch_shape = None | 
					 | 
					 | 
					                                (idx, f"{key} 形状不匹配: 预期 {expected_shape}, 实际 {actual_shape}") | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                        shape_match = None | 
					 | 
					 | 
					                            ) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                     | 
					 | 
					 | 
					                 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    logger.info(f"  {key}:") | 
					 | 
					 | 
					                # 3. 检查数值范围和无效值 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    logger.info(f"    Shape: {actual_shape}") | 
					 | 
					 | 
					                for key, tensor in sample.items(): | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    logger.info(f"    Expected: {expected_batch_shape}") | 
					 | 
					 | 
					                    if isinstance(tensor, torch.Tensor) and tensor.dtype.is_floating_point: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    logger.info(f"    Match: {shape_match}") | 
					 | 
					 | 
					                        # 更新值范围 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    logger.info(f"    dtype: {value.dtype}") | 
					 | 
					 | 
					                        stats['value_ranges'][key]['min'] = min(stats['value_ranges'][key]['min'],  | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                     | 
					 | 
					 | 
					                                                              tensor.min().item()) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    if not shape_match: | 
					 | 
					 | 
					                        stats['value_ranges'][key]['max'] = max(stats['value_ranges'][key]['max'],  | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                        logger.warning(f"    Shape mismatch for {key}!") | 
					 | 
					 | 
					                                                              tensor.max().item()) | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					                elif isinstance(value, list): | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    logger.info(f"  {key}: list of length {len(value)}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                else: | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    logger.info(f"  {key}: {type(value)}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                     | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        except Exception as e: | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            logger.error("Error in DataLoader") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            logger.error(f"Error message: {str(e)}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            raise | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 5. 验证数据范围 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        logger.info("\nValidating data ranges...") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        try: | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            for key, value in batch.items(): | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                if isinstance(value, torch.Tensor) and value.dtype in [torch.float32, torch.float64]: | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    logger.info(f"  {key}:") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    logger.info(f"    min: {value.min().item():.4f}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    logger.info(f"    max: {value.max().item():.4f}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    logger.info(f"    mean: {value.mean().item():.4f}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    logger.info(f"    std: {value.std().item():.4f}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                     | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    # 检查是否有NaN或Inf | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    has_nan = torch.isnan(value).any() | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    has_inf = torch.isinf(value).any() | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    if has_nan or has_inf: | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                        logger.warning(f"    Found NaN: {has_nan}, Inf: {has_inf}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					                         | 
					 | 
					 | 
					                         | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        except Exception as e: | 
					 | 
					 | 
					                        # 检查NaN和Inf | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            logger.error("Error validating data ranges") | 
					 | 
					 | 
					                        nan_count = torch.isnan(tensor).sum().item() | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            logger.error(f"Error message: {str(e)}") | 
					 | 
					 | 
					                        inf_count = torch.isinf(tensor).sum().item() | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            raise | 
					 | 
					 | 
					                        if nan_count > 0: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					             | 
					 | 
					 | 
					                            stats['nan_counts'][key] += nan_count | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        logger.info("\nAll tests completed successfully!") | 
					 | 
					 | 
					                        if inf_count > 0: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        logger.info("="*50) | 
					 | 
					 | 
					                            stats['inf_counts'][key] += inf_count | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                # 4. 收集统计信息 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                stats['face_counts'].append(sample['surf_ncs'].shape[0]) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                stats['edge_counts'].append(sample['edge_ncs'].shape[1]) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                stats['vertex_counts'].append(len(torch.unique(sample['vertex_pos'].reshape(-1, 3), dim=0))) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                stats['sdf_point_counts'].append(sample['points'].shape[0]) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            except Exception as e: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                stats['invalid_samples'].append((idx, str(e))) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 输出统计结果 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info("\n=== 数据集验证结果 ===") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 1. 基本统计信息 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info("\n基本统计信息:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"总样本数: {total_samples}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"有效样本数: {total_samples - len(stats['invalid_samples'])}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"无效样本数: {len(stats['invalid_samples'])}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 2. 形状不匹配统计 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        if stats['shape_mismatches']: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.info("\n形状不匹配统计:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            for key, count in stats['shape_mismatches'].items(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.info(f"  {key}: {count}个样本不匹配") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 3. 数值范围统计 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info("\n数值范围统计:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        for key, ranges in stats['value_ranges'].items(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.info(f"  {key}:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.info(f"    最小值: {ranges['min']:.3f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.info(f"    最大值: {ranges['max']:.3f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 4. 无效值统计 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        if sum(stats['nan_counts'].values()) > 0 or sum(stats['inf_counts'].values()) > 0: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.info("\n无效值统计:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            for key in stats['nan_counts'].keys(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                if stats['nan_counts'][key] > 0: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    logger.info(f"  {key} 包含 {stats['nan_counts'][key]} 个 NaN 值") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            for key in stats['inf_counts'].keys(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                if stats['inf_counts'][key] > 0: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    logger.info(f"  {key} 包含 {stats['inf_counts'][key]} 个 Inf 值") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 5. 几何特征统计 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info("\n几何特征统计:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        for name, values in [ | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            ('面数', stats['face_counts']), | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            ('边数', stats['edge_counts']), | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            ('顶点数', stats['vertex_counts']), | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            ('SDF采样点数', stats['sdf_point_counts']) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        ]: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            values = np.array(values) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.info(f"  {name}:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.info(f"    最小值: {np.min(values)}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.info(f"    最大值: {np.max(values)}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.info(f"    平均值: {np.mean(values):.2f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.info(f"    中位数: {np.median(values):.2f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.info(f"    标准差: {np.std(values):.2f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 6. 输出无效样本详情 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        if stats['invalid_samples']: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.info("\n无效样本详情:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            for idx, error in stats['invalid_samples']: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.info(f"  样本 {idx}: {error}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        return stats | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    except Exception as e: | 
					 | 
					 | 
					    except Exception as e: | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        logger.error(f"Error in test_dataset: {str(e)}") | 
					 | 
					 | 
					        logger.error(f"验证过程出错: {str(e)}") | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        raise | 
					 | 
					 | 
					        raise | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					if __name__ == '__main__': | 
					 | 
					 | 
					if __name__ == '__main__': | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    test_dataset() | 
					 | 
					 | 
					    validate_dataset(split='train', num_samples=None)  # 先测试100个样本 |