|
|
@ -183,9 +183,16 @@ class BRepSDFDataset(Dataset): |
|
|
|
max_points = self.config.data.num_query_points # 例如4096 |
|
|
|
|
|
|
|
# 确保正负样本均衡 |
|
|
|
num_pos = min(max_points // 2, sdf_pos.shape[0]) |
|
|
|
num_neg = min(max_points // 2, sdf_neg.shape[0]) |
|
|
|
if max_points // 2 > sdf_pos.shape[0]: |
|
|
|
logger.warning(f"正样本过少,期望>{max_points // 2},实际:{sdf_pos.shape[0]}") |
|
|
|
|
|
|
|
if max_points // 2 > sdf_neg.shape[0]: |
|
|
|
num_neg = sdf_neg.shape[0] |
|
|
|
else: |
|
|
|
num_neg = max_points // 2 |
|
|
|
|
|
|
|
num_pos = max_points - num_neg |
|
|
|
|
|
|
|
# 随机采样正样本 |
|
|
|
if sdf_pos.shape[0] > num_pos: |
|
|
|
pos_indices = np.random.choice(sdf_pos.shape[0], num_pos, replace=False) |
|
|
@ -222,139 +229,261 @@ def test_dataset(): |
|
|
|
try: |
|
|
|
# 获取配置 |
|
|
|
config = get_default_config() |
|
|
|
brep_dir = config.data.brep_dir |
|
|
|
sdf_dir = config.data.sdf_dir |
|
|
|
valid_data_dir = config.data.valid_data_dir |
|
|
|
split = 'train' |
|
|
|
max_face = config.data.max_face |
|
|
|
max_edge = config.data.max_edge |
|
|
|
num_edge_points = config.model.num_edge_points |
|
|
|
num_surf_points = config.model.num_surf_points |
|
|
|
num_query_points = config.data.num_query_points |
|
|
|
|
|
|
|
# 定义预期的数据维度,使用配置中的参数 |
|
|
|
# 定义预期的数据维度 |
|
|
|
expected_shapes = { |
|
|
|
'edge_ncs': (max_face, max_edge, num_edge_points, 3), # [max_face, max_edge, sample_points, xyz] |
|
|
|
'edge_pos': (max_face, max_edge, 6), |
|
|
|
'edge_mask': (max_face, max_edge), |
|
|
|
'surf_ncs': (max_face, num_surf_points, 3), |
|
|
|
'surf_pos': (max_face, 6), |
|
|
|
'vertex_pos': (max_face, max_edge, 2, 3), |
|
|
|
'points': (num_query_points, 3), |
|
|
|
'sdf': (num_query_points, 1) |
|
|
|
'edge_ncs': (config.data.max_face, config.data.max_edge, config.model.num_edge_points, 3), |
|
|
|
'edge_pos': (config.data.max_face, config.data.max_edge, 6), |
|
|
|
'edge_mask': (config.data.max_face, config.data.max_edge), |
|
|
|
'surf_ncs': (config.data.max_face, config.model.num_surf_points, 3), |
|
|
|
'surf_pos': (config.data.max_face, 6), |
|
|
|
'vertex_pos': (config.data.max_face, config.data.max_edge, 2, 3), |
|
|
|
'points': (config.data.num_query_points, 3), |
|
|
|
'sdf': (config.data.num_query_points, 1) |
|
|
|
} |
|
|
|
|
|
|
|
logger.info("="*50) |
|
|
|
logger.info("Testing dataset") |
|
|
|
logger.info(f"B-rep directory: {brep_dir}") |
|
|
|
logger.info(f"SDF directory: {sdf_dir}") |
|
|
|
logger.info(f"Split: {split}") |
|
|
|
logger.info("测试数据集") |
|
|
|
logger.info(f"预期形状:") |
|
|
|
for key, shape in expected_shapes.items(): |
|
|
|
logger.info(f" {key}: {shape}") |
|
|
|
|
|
|
|
# 初始化数据集 |
|
|
|
dataset = BRepSDFDataset( |
|
|
|
brep_dir=config.data.brep_dir, |
|
|
|
sdf_dir=config.data.sdf_dir, |
|
|
|
valid_data_dir=config.data.valid_data_dir, |
|
|
|
split='train' |
|
|
|
) |
|
|
|
|
|
|
|
# 测试数据加载 |
|
|
|
logger.info("\n测试数据加载...") |
|
|
|
sample = dataset[0] |
|
|
|
|
|
|
|
# 检查数据类型和形状 |
|
|
|
logger.info("\n数据类型和形状检查:") |
|
|
|
for key, value in sample.items(): |
|
|
|
if isinstance(value, torch.Tensor): |
|
|
|
actual_shape = tuple(value.shape) |
|
|
|
expected_shape = expected_shapes.get(key) |
|
|
|
shape_match = "✓" if actual_shape == expected_shape else "✗" |
|
|
|
|
|
|
|
logger.info(f"\n{key}:") |
|
|
|
logger.info(f" 实际形状: {actual_shape}") |
|
|
|
logger.info(f" 预期形状: {expected_shape}") |
|
|
|
logger.info(f" 匹配状态: {shape_match}") |
|
|
|
logger.info(f" 数据类型: {value.dtype}") |
|
|
|
|
|
|
|
# 仅对浮点类型计算数值范围、均值和标准差 |
|
|
|
if value.dtype.is_floating_point: |
|
|
|
logger.info(f" 数值范围: [{value.min():.3f}, {value.max():.3f}]") |
|
|
|
logger.info(f" 均值: {value.mean():.3f}") |
|
|
|
logger.info(f" 标准差: {value.std():.3f}") |
|
|
|
|
|
|
|
if shape_match == "✗": |
|
|
|
logger.warning(f" 形状不匹配: {key}") |
|
|
|
if key in ['points', 'sdf']: |
|
|
|
logger.warning(f" 查询点数量不一致,预期 {expected_shape[0]},实际 {actual_shape[0]}") |
|
|
|
elif key in ['edge_ncs', 'edge_pos', 'edge_mask']: |
|
|
|
logger.warning(f" 边数量不一致,预期 {expected_shape[1]},实际 {actual_shape[1]}") |
|
|
|
elif key in ['surf_ncs', 'surf_pos']: |
|
|
|
logger.warning(f" 面数量不一致,预期 {expected_shape[0]},实际 {actual_shape[0]}") |
|
|
|
|
|
|
|
# 测试批处理 |
|
|
|
logger.info("\n测试批处理...") |
|
|
|
batch_size = 4 |
|
|
|
dataloader = torch.utils.data.DataLoader( |
|
|
|
dataset, |
|
|
|
batch_size=batch_size, |
|
|
|
shuffle=True, |
|
|
|
num_workers=0 |
|
|
|
) |
|
|
|
|
|
|
|
batch = next(iter(dataloader)) |
|
|
|
logger.info("\n批处理形状检查:") |
|
|
|
for key, value in batch.items(): |
|
|
|
if isinstance(value, torch.Tensor): |
|
|
|
batch_shape = tuple(value.shape) |
|
|
|
expected_batch_shape = (batch_size,) + expected_shapes[key] |
|
|
|
shape_match = "✓" if batch_shape == expected_batch_shape else "✗" |
|
|
|
|
|
|
|
logger.info(f"\n{key}:") |
|
|
|
logger.info(f" 实际形状: {batch_shape}") |
|
|
|
logger.info(f" 预期形状: {expected_batch_shape}") |
|
|
|
logger.info(f" 匹配状态: {shape_match}") |
|
|
|
logger.info(f" 数据类型: {value.dtype}") |
|
|
|
|
|
|
|
# 仅对浮点类型计算数值范围、均值和标准差 |
|
|
|
if value.dtype.is_floating_point: |
|
|
|
logger.info(f" 数值范围: [{value.min():.3f}, {value.max():.3f}]") |
|
|
|
logger.info(f" 均值: {value.mean():.3f}") |
|
|
|
logger.info(f" 标准差: {value.std():.3f}") |
|
|
|
|
|
|
|
if shape_match == "✗": |
|
|
|
logger.warning(f" 批处理形状不匹配: {key}") |
|
|
|
|
|
|
|
logger.info("\n测试完成!") |
|
|
|
logger.info("="*50) |
|
|
|
|
|
|
|
# 2. 初始化数据集 |
|
|
|
except Exception as e: |
|
|
|
logger.error(f"测试过程中出错: {str(e)}") |
|
|
|
raise |
|
|
|
from collections import defaultdict |
|
|
|
from tqdm import tqdm |
|
|
|
def validate_dataset(split: str = 'train', num_samples: int = None): |
|
|
|
"""全面验证数据集 |
|
|
|
|
|
|
|
Args: |
|
|
|
split: 数据集分割 ('train', 'val', 'test') |
|
|
|
num_samples: 要检查的样本数量,None表示检查所有样本 |
|
|
|
""" |
|
|
|
try: |
|
|
|
config = get_default_config() |
|
|
|
logger.info(f"开始验证{split}数据集...") |
|
|
|
|
|
|
|
# 初始化数据集 |
|
|
|
dataset = BRepSDFDataset( |
|
|
|
brep_dir=brep_dir, |
|
|
|
sdf_dir=sdf_dir, |
|
|
|
valid_data_dir=valid_data_dir, |
|
|
|
split=split |
|
|
|
brep_dir=config.data.brep_dir, |
|
|
|
sdf_dir=config.data.sdf_dir, |
|
|
|
valid_data_dir=config.data.valid_data_dir, |
|
|
|
split='train' |
|
|
|
) |
|
|
|
logger.info(f"\nDataset size: {len(dataset)}") |
|
|
|
|
|
|
|
# 3. 测试单个样本加载和形状检查 |
|
|
|
logger.info("\nTesting single sample loading...") |
|
|
|
try: |
|
|
|
sample = dataset[0] |
|
|
|
logger.info("Sample keys and shapes:") |
|
|
|
for key, value in sample.items(): |
|
|
|
if isinstance(value, torch.Tensor): |
|
|
|
actual_shape = tuple(value.shape) |
|
|
|
expected_shape = expected_shapes.get(key) |
|
|
|
shape_match = actual_shape == expected_shape if expected_shape else None |
|
|
|
|
|
|
|
logger.info(f" {key}:") |
|
|
|
logger.info(f" Shape: {actual_shape}") |
|
|
|
logger.info(f" Expected: {expected_shape}") |
|
|
|
logger.info(f" Match: {shape_match}") |
|
|
|
logger.info(f" dtype: {value.dtype}") |
|
|
|
logger.info(f" grad: {value.requires_grad}") |
|
|
|
|
|
|
|
if not shape_match: |
|
|
|
logger.warning(f" Shape mismatch for {key}!") |
|
|
|
else: |
|
|
|
logger.info(f" {key}: {type(value)}") |
|
|
|
except Exception as e: |
|
|
|
logger.error("Error loading single sample") |
|
|
|
logger.error(f"Error message: {str(e)}") |
|
|
|
raise |
|
|
|
|
|
|
|
# 4. 测试数据加载器和批处理形状 |
|
|
|
logger.info("\nTesting DataLoader...") |
|
|
|
try: |
|
|
|
dataloader = torch.utils.data.DataLoader( |
|
|
|
dataset, |
|
|
|
batch_size=2, |
|
|
|
shuffle=True, |
|
|
|
num_workers=0 |
|
|
|
) |
|
|
|
|
|
|
|
batch = next(iter(dataloader)) |
|
|
|
logger.info("Batch keys and shapes:") |
|
|
|
for key, value in batch.items(): |
|
|
|
if isinstance(value, torch.Tensor): |
|
|
|
actual_shape = tuple(value.shape) |
|
|
|
expected_shape = expected_shapes.get(key) |
|
|
|
if expected_shape: |
|
|
|
expected_batch_shape = (2,) + expected_shape |
|
|
|
shape_match = actual_shape == expected_batch_shape |
|
|
|
else: |
|
|
|
expected_batch_shape = None |
|
|
|
shape_match = None |
|
|
|
|
|
|
|
logger.info(f" {key}:") |
|
|
|
logger.info(f" Shape: {actual_shape}") |
|
|
|
logger.info(f" Expected: {expected_batch_shape}") |
|
|
|
logger.info(f" Match: {shape_match}") |
|
|
|
logger.info(f" dtype: {value.dtype}") |
|
|
|
|
|
|
|
if not shape_match: |
|
|
|
logger.warning(f" Shape mismatch for {key}!") |
|
|
|
elif isinstance(value, list): |
|
|
|
logger.info(f" {key}: list of length {len(value)}") |
|
|
|
else: |
|
|
|
logger.info(f" {key}: {type(value)}") |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
logger.error("Error in DataLoader") |
|
|
|
logger.error(f"Error message: {str(e)}") |
|
|
|
raise |
|
|
|
|
|
|
|
# 5. 验证数据范围 |
|
|
|
logger.info("\nValidating data ranges...") |
|
|
|
try: |
|
|
|
for key, value in batch.items(): |
|
|
|
if isinstance(value, torch.Tensor) and value.dtype in [torch.float32, torch.float64]: |
|
|
|
logger.info(f" {key}:") |
|
|
|
logger.info(f" min: {value.min().item():.4f}") |
|
|
|
logger.info(f" max: {value.max().item():.4f}") |
|
|
|
logger.info(f" mean: {value.mean().item():.4f}") |
|
|
|
logger.info(f" std: {value.std().item():.4f}") |
|
|
|
|
|
|
|
# 检查是否有NaN或Inf |
|
|
|
has_nan = torch.isnan(value).any() |
|
|
|
has_inf = torch.isinf(value).any() |
|
|
|
if has_nan or has_inf: |
|
|
|
logger.warning(f" Found NaN: {has_nan}, Inf: {has_inf}") |
|
|
|
total_samples = len(dataset) if num_samples is None else min(num_samples, len(dataset)) |
|
|
|
logger.info(f"总样本数: {total_samples}") |
|
|
|
|
|
|
|
# 初始化统计信息 |
|
|
|
stats = { |
|
|
|
'face_counts': [], |
|
|
|
'edge_counts': [], |
|
|
|
'vertex_counts': [], |
|
|
|
'sdf_point_counts': [], |
|
|
|
'invalid_samples': [], |
|
|
|
'shape_mismatches': defaultdict(int), |
|
|
|
'value_ranges': defaultdict(lambda: {'min': float('inf'), 'max': float('-inf')}), |
|
|
|
'nan_counts': defaultdict(int), |
|
|
|
'inf_counts': defaultdict(int) |
|
|
|
} |
|
|
|
|
|
|
|
# 遍历数据集 |
|
|
|
for idx in tqdm(range(total_samples), desc="验证数据"): |
|
|
|
try: |
|
|
|
sample = dataset[idx] |
|
|
|
|
|
|
|
# 1. 检查数据完整性 |
|
|
|
required_keys = ['surf_ncs', 'surf_pos', 'edge_ncs', 'edge_pos', |
|
|
|
'vertex_pos', 'points', 'sdf', 'edge_mask'] |
|
|
|
missing_keys = [key for key in required_keys if key not in sample] |
|
|
|
if missing_keys: |
|
|
|
stats['invalid_samples'].append((idx, f"缺少键: {missing_keys}")) |
|
|
|
continue |
|
|
|
|
|
|
|
# 2. 检查形状 |
|
|
|
expected_shapes = { |
|
|
|
'surf_ncs': (config.data.max_face, config.model.num_surf_points, 3), |
|
|
|
'surf_pos': (config.data.max_face, 6), |
|
|
|
'edge_ncs': (config.data.max_face, config.data.max_edge, config.model.num_edge_points, 3), |
|
|
|
'edge_pos': (config.data.max_face, config.data.max_edge, 6), |
|
|
|
'edge_mask': (config.data.max_face, config.data.max_edge), |
|
|
|
'vertex_pos': (config.data.max_face, config.data.max_edge, 2, 3), |
|
|
|
'points': (config.data.num_query_points, 3), |
|
|
|
'sdf': (config.data.num_query_points, 1) |
|
|
|
} |
|
|
|
|
|
|
|
for key, expected_shape in expected_shapes.items(): |
|
|
|
if key in sample: |
|
|
|
actual_shape = tuple(sample[key].shape) |
|
|
|
if actual_shape != expected_shape: |
|
|
|
stats['shape_mismatches'][key] += 1 |
|
|
|
stats['invalid_samples'].append( |
|
|
|
(idx, f"{key} 形状不匹配: 预期 {expected_shape}, 实际 {actual_shape}") |
|
|
|
) |
|
|
|
|
|
|
|
# 3. 检查数值范围和无效值 |
|
|
|
for key, tensor in sample.items(): |
|
|
|
if isinstance(tensor, torch.Tensor) and tensor.dtype.is_floating_point: |
|
|
|
# 更新值范围 |
|
|
|
stats['value_ranges'][key]['min'] = min(stats['value_ranges'][key]['min'], |
|
|
|
tensor.min().item()) |
|
|
|
stats['value_ranges'][key]['max'] = max(stats['value_ranges'][key]['max'], |
|
|
|
tensor.max().item()) |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
logger.error("Error validating data ranges") |
|
|
|
logger.error(f"Error message: {str(e)}") |
|
|
|
raise |
|
|
|
|
|
|
|
logger.info("\nAll tests completed successfully!") |
|
|
|
logger.info("="*50) |
|
|
|
# 检查NaN和Inf |
|
|
|
nan_count = torch.isnan(tensor).sum().item() |
|
|
|
inf_count = torch.isinf(tensor).sum().item() |
|
|
|
if nan_count > 0: |
|
|
|
stats['nan_counts'][key] += nan_count |
|
|
|
if inf_count > 0: |
|
|
|
stats['inf_counts'][key] += inf_count |
|
|
|
|
|
|
|
# 4. 收集统计信息 |
|
|
|
stats['face_counts'].append(sample['surf_ncs'].shape[0]) |
|
|
|
stats['edge_counts'].append(sample['edge_ncs'].shape[1]) |
|
|
|
stats['vertex_counts'].append(len(torch.unique(sample['vertex_pos'].reshape(-1, 3), dim=0))) |
|
|
|
stats['sdf_point_counts'].append(sample['points'].shape[0]) |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
stats['invalid_samples'].append((idx, str(e))) |
|
|
|
|
|
|
|
# 输出统计结果 |
|
|
|
logger.info("\n=== 数据集验证结果 ===") |
|
|
|
|
|
|
|
# 1. 基本统计信息 |
|
|
|
logger.info("\n基本统计信息:") |
|
|
|
logger.info(f"总样本数: {total_samples}") |
|
|
|
logger.info(f"有效样本数: {total_samples - len(stats['invalid_samples'])}") |
|
|
|
logger.info(f"无效样本数: {len(stats['invalid_samples'])}") |
|
|
|
|
|
|
|
# 2. 形状不匹配统计 |
|
|
|
if stats['shape_mismatches']: |
|
|
|
logger.info("\n形状不匹配统计:") |
|
|
|
for key, count in stats['shape_mismatches'].items(): |
|
|
|
logger.info(f" {key}: {count}个样本不匹配") |
|
|
|
|
|
|
|
# 3. 数值范围统计 |
|
|
|
logger.info("\n数值范围统计:") |
|
|
|
for key, ranges in stats['value_ranges'].items(): |
|
|
|
logger.info(f" {key}:") |
|
|
|
logger.info(f" 最小值: {ranges['min']:.3f}") |
|
|
|
logger.info(f" 最大值: {ranges['max']:.3f}") |
|
|
|
|
|
|
|
# 4. 无效值统计 |
|
|
|
if sum(stats['nan_counts'].values()) > 0 or sum(stats['inf_counts'].values()) > 0: |
|
|
|
logger.info("\n无效值统计:") |
|
|
|
for key in stats['nan_counts'].keys(): |
|
|
|
if stats['nan_counts'][key] > 0: |
|
|
|
logger.info(f" {key} 包含 {stats['nan_counts'][key]} 个 NaN 值") |
|
|
|
for key in stats['inf_counts'].keys(): |
|
|
|
if stats['inf_counts'][key] > 0: |
|
|
|
logger.info(f" {key} 包含 {stats['inf_counts'][key]} 个 Inf 值") |
|
|
|
|
|
|
|
# 5. 几何特征统计 |
|
|
|
logger.info("\n几何特征统计:") |
|
|
|
for name, values in [ |
|
|
|
('面数', stats['face_counts']), |
|
|
|
('边数', stats['edge_counts']), |
|
|
|
('顶点数', stats['vertex_counts']), |
|
|
|
('SDF采样点数', stats['sdf_point_counts']) |
|
|
|
]: |
|
|
|
values = np.array(values) |
|
|
|
logger.info(f" {name}:") |
|
|
|
logger.info(f" 最小值: {np.min(values)}") |
|
|
|
logger.info(f" 最大值: {np.max(values)}") |
|
|
|
logger.info(f" 平均值: {np.mean(values):.2f}") |
|
|
|
logger.info(f" 中位数: {np.median(values):.2f}") |
|
|
|
logger.info(f" 标准差: {np.std(values):.2f}") |
|
|
|
|
|
|
|
# 6. 输出无效样本详情 |
|
|
|
if stats['invalid_samples']: |
|
|
|
logger.info("\n无效样本详情:") |
|
|
|
for idx, error in stats['invalid_samples']: |
|
|
|
logger.info(f" 样本 {idx}: {error}") |
|
|
|
|
|
|
|
return stats |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
logger.error(f"Error in test_dataset: {str(e)}") |
|
|
|
logger.error(f"验证过程出错: {str(e)}") |
|
|
|
raise |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
test_dataset() |
|
|
|
validate_dataset(split='train', num_samples=None) # 先测试100个样本 |