|
|
@ -334,26 +334,133 @@ class BRepSDFDataset(Dataset): |
|
|
|
def test_dataset(): |
|
|
|
"""测试数据集功能""" |
|
|
|
try: |
|
|
|
# 设置测试路径 |
|
|
|
brep_dir = '/home/wch/myDeepSDF/test_data/pkl' |
|
|
|
sdf_dir = '/home/wch/myDeepSDF/test_data/sdf' |
|
|
|
# 1. 设置测试路径和预期的数据维度 |
|
|
|
brep_dir = '/home/wch/brep2sdf/test_data/pkl' |
|
|
|
sdf_dir = '/home/wch/brep2sdf/test_data/sdf' |
|
|
|
valid_data_dir = "/home/wch/brep2sdf/test_data/result/pkl" |
|
|
|
split = 'train' |
|
|
|
|
|
|
|
# 定义预期的数据维度 |
|
|
|
expected_shapes = { |
|
|
|
'edge_ncs': (70, 70, 10, 3), # [max_face, max_edge, sample_points, xyz] |
|
|
|
'edge_pos': (70, 70, 6), # [max_face, max_edge, bbox] |
|
|
|
'edge_mask': (70, 70), # [max_face, max_edge] |
|
|
|
'surf_ncs': (70, 100, 3), # [max_face, sample_points, xyz] |
|
|
|
'surf_pos': (70, 6), # [max_face, bbox] |
|
|
|
'vertex_pos': (70, 70, 2, 3), # [max_face, max_edge, 2_points, xyz] |
|
|
|
'sdf': (2097152, 4) # [num_points, xyz+sdf] |
|
|
|
} |
|
|
|
|
|
|
|
logger.info("="*50) |
|
|
|
logger.info("Testing dataset") |
|
|
|
logger.info(f"B-rep directory: {brep_dir}") |
|
|
|
logger.info(f"SDF directory: {sdf_dir}") |
|
|
|
logger.info(f"Split: {split}") |
|
|
|
|
|
|
|
# ... (其余测试代码保持不变) ... |
|
|
|
dataset = BRepSDFDataset(brep_dir='/home/wch/brep2sdf/test_data/pkl', sdf_dir='/home/wch/brep2sdf/test_data/sdf',valid_data_dir="/home/wch/brep2sdf/test_data/result/pkl", split='train') |
|
|
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True) |
|
|
|
|
|
|
|
for batch in dataloader: |
|
|
|
print(batch['sdf'].shape) |
|
|
|
break |
|
|
|
# 2. 初始化数据集 |
|
|
|
dataset = BRepSDFDataset( |
|
|
|
brep_dir=brep_dir, |
|
|
|
sdf_dir=sdf_dir, |
|
|
|
valid_data_dir=valid_data_dir, |
|
|
|
split=split |
|
|
|
) |
|
|
|
logger.info(f"\nDataset size: {len(dataset)}") |
|
|
|
|
|
|
|
# 3. 测试单个样本加载和形状检查 |
|
|
|
logger.info("\nTesting single sample loading...") |
|
|
|
try: |
|
|
|
sample = dataset[0] |
|
|
|
logger.info("Sample keys and shapes:") |
|
|
|
for key, value in sample.items(): |
|
|
|
if isinstance(value, torch.Tensor): |
|
|
|
actual_shape = tuple(value.shape) |
|
|
|
expected_shape = expected_shapes.get(key) |
|
|
|
shape_match = actual_shape == expected_shape if expected_shape else None |
|
|
|
|
|
|
|
logger.info(f" {key}:") |
|
|
|
logger.info(f" Shape: {actual_shape}") |
|
|
|
logger.info(f" Expected: {expected_shape}") |
|
|
|
logger.info(f" Match: {shape_match}") |
|
|
|
logger.info(f" dtype: {value.dtype}") |
|
|
|
|
|
|
|
if not shape_match: |
|
|
|
logger.warning(f" Shape mismatch for {key}!") |
|
|
|
else: |
|
|
|
logger.info(f" {key}: {type(value)}") |
|
|
|
except Exception as e: |
|
|
|
logger.error("Error loading single sample") |
|
|
|
logger.error(f"Error message: {str(e)}") |
|
|
|
raise |
|
|
|
|
|
|
|
# 4. 测试数据加载器和批处理形状 |
|
|
|
logger.info("\nTesting DataLoader...") |
|
|
|
try: |
|
|
|
dataloader = torch.utils.data.DataLoader( |
|
|
|
dataset, |
|
|
|
batch_size=2, |
|
|
|
shuffle=True, |
|
|
|
num_workers=0 |
|
|
|
) |
|
|
|
|
|
|
|
batch = next(iter(dataloader)) |
|
|
|
logger.info("Batch keys and shapes:") |
|
|
|
for key, value in batch.items(): |
|
|
|
if isinstance(value, torch.Tensor): |
|
|
|
actual_shape = tuple(value.shape) |
|
|
|
expected_shape = expected_shapes.get(key) |
|
|
|
if expected_shape: |
|
|
|
expected_batch_shape = (2,) + expected_shape |
|
|
|
shape_match = actual_shape == expected_batch_shape |
|
|
|
else: |
|
|
|
expected_batch_shape = None |
|
|
|
shape_match = None |
|
|
|
|
|
|
|
logger.info(f" {key}:") |
|
|
|
logger.info(f" Shape: {actual_shape}") |
|
|
|
logger.info(f" Expected: {expected_batch_shape}") |
|
|
|
logger.info(f" Match: {shape_match}") |
|
|
|
logger.info(f" dtype: {value.dtype}") |
|
|
|
|
|
|
|
if not shape_match: |
|
|
|
logger.warning(f" Shape mismatch for {key}!") |
|
|
|
elif isinstance(value, list): |
|
|
|
logger.info(f" {key}: list of length {len(value)}") |
|
|
|
else: |
|
|
|
logger.info(f" {key}: {type(value)}") |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
logger.error("Error in DataLoader") |
|
|
|
logger.error(f"Error message: {str(e)}") |
|
|
|
raise |
|
|
|
|
|
|
|
# 5. 验证数据范围 |
|
|
|
logger.info("\nValidating data ranges...") |
|
|
|
try: |
|
|
|
for key, value in batch.items(): |
|
|
|
if isinstance(value, torch.Tensor) and value.dtype in [torch.float32, torch.float64]: |
|
|
|
logger.info(f" {key}:") |
|
|
|
logger.info(f" min: {value.min().item():.4f}") |
|
|
|
logger.info(f" max: {value.max().item():.4f}") |
|
|
|
logger.info(f" mean: {value.mean().item():.4f}") |
|
|
|
logger.info(f" std: {value.std().item():.4f}") |
|
|
|
|
|
|
|
# 检查是否有NaN或Inf |
|
|
|
has_nan = torch.isnan(value).any() |
|
|
|
has_inf = torch.isinf(value).any() |
|
|
|
if has_nan or has_inf: |
|
|
|
logger.warning(f" Found NaN: {has_nan}, Inf: {has_inf}") |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
logger.error("Error validating data ranges") |
|
|
|
logger.error(f"Error message: {str(e)}") |
|
|
|
raise |
|
|
|
|
|
|
|
logger.info("\nAll tests completed successfully!") |
|
|
|
logger.info("="*50) |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
logger.error(f"Error in test_dataset: {str(e)}") |
|
|
|
raise |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
test_dataset() |