| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -334,26 +334,133 @@ class BRepSDFDataset(Dataset): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def test_dataset(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """测试数据集功能""" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 设置测试路径 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        brep_dir = '/home/wch/myDeepSDF/test_data/pkl' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        sdf_dir = '/home/wch/myDeepSDF/test_data/sdf' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 1. 设置测试路径和预期的数据维度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        brep_dir = '/home/wch/brep2sdf/test_data/pkl' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        sdf_dir = '/home/wch/brep2sdf/test_data/sdf' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        valid_data_dir = "/home/wch/brep2sdf/test_data/result/pkl" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        split = 'train' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 定义预期的数据维度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        expected_shapes = { | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'edge_ncs': (70, 70, 10, 3),    # [max_face, max_edge, sample_points, xyz] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'edge_pos': (70, 70, 6),        # [max_face, max_edge, bbox] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'edge_mask': (70, 70),          # [max_face, max_edge] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'surf_ncs': (70, 100, 3),       # [max_face, sample_points, xyz] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'surf_pos': (70, 6),            # [max_face, bbox] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'vertex_pos': (70, 70, 2, 3),   # [max_face, max_edge, 2_points, xyz] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'sdf': (2097152, 4)             # [num_points, xyz+sdf] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        } | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info("="*50) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info("Testing dataset") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info(f"B-rep directory: {brep_dir}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info(f"SDF directory: {sdf_dir}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info(f"Split: {split}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # ... (其余测试代码保持不变) ... | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        dataset = BRepSDFDataset(brep_dir='/home/wch/brep2sdf/test_data/pkl', sdf_dir='/home/wch/brep2sdf/test_data/sdf',valid_data_dir="/home/wch/brep2sdf/test_data/result/pkl", split='train') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for batch in dataloader: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            print(batch['sdf'].shape) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            break | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 2. 初始化数据集 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        dataset = BRepSDFDataset( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            brep_dir=brep_dir, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            sdf_dir=sdf_dir, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            valid_data_dir=valid_data_dir, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            split=split | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info(f"\nDataset size: {len(dataset)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 3. 测试单个样本加载和形状检查 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info("\nTesting single sample loading...") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            sample = dataset[0] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.info("Sample keys and shapes:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for key, value in sample.items(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if isinstance(value, torch.Tensor): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    actual_shape = tuple(value.shape) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    expected_shape = expected_shapes.get(key) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    shape_match = actual_shape == expected_shape if expected_shape else None | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"  {key}:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"    Shape: {actual_shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"    Expected: {expected_shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"    Match: {shape_match}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"    dtype: {value.dtype}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if not shape_match: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.warning(f"    Shape mismatch for {key}!") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"  {key}: {type(value)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error("Error loading single sample") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"Error message: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 4. 测试数据加载器和批处理形状 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info("\nTesting DataLoader...") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            dataloader = torch.utils.data.DataLoader( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                dataset,  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                batch_size=2, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                shuffle=True, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                num_workers=0 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            batch = next(iter(dataloader)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.info("Batch keys and shapes:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for key, value in batch.items(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if isinstance(value, torch.Tensor): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    actual_shape = tuple(value.shape) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    expected_shape = expected_shapes.get(key) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if expected_shape: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        expected_batch_shape = (2,) + expected_shape | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        shape_match = actual_shape == expected_batch_shape | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        expected_batch_shape = None | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        shape_match = None | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"  {key}:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"    Shape: {actual_shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"    Expected: {expected_batch_shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"    Match: {shape_match}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"    dtype: {value.dtype}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if not shape_match: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.warning(f"    Shape mismatch for {key}!") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                elif isinstance(value, list): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"  {key}: list of length {len(value)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"  {key}: {type(value)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error("Error in DataLoader") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"Error message: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 5. 验证数据范围 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info("\nValidating data ranges...") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for key, value in batch.items(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if isinstance(value, torch.Tensor) and value.dtype in [torch.float32, torch.float64]: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"  {key}:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"    min: {value.min().item():.4f}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"    max: {value.max().item():.4f}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"    mean: {value.mean().item():.4f}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f"    std: {value.std().item():.4f}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    # 检查是否有NaN或Inf | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    has_nan = torch.isnan(value).any() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    has_inf = torch.isinf(value).any() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if has_nan or has_inf: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.warning(f"    Found NaN: {has_nan}, Inf: {has_inf}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error("Error validating data ranges") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"Error message: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info("\nAll tests completed successfully!") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info("="*50) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.error(f"Error in test_dataset: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        raise | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					if __name__ == '__main__': | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    test_dataset() |