You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

583 lines
25 KiB

import os
import torch
from torch.utils.data import Dataset
import numpy as np
import pickle
from brep2sdf.utils.logger import logger
from brep2sdf.data.utils import process_brep_data
from brep2sdf.config.default_config import get_default_config
class BRepSDFDataset(Dataset):
def __init__(self, brep_dir:str, sdf_dir:str, valid_data_dir:str, use_filter: bool=True, split:str='train'):
"""
初始化数据集
参数:
brep_dir: pkl文件目录
sdf_dir: npz文件目录
split: 数据集分割('train', 'val', 'test')
"""
super().__init__()
# 使用配置文件
self.config = get_default_config()
self.brep_dir = os.path.join(brep_dir, split)
self.sdf_dir = os.path.join(sdf_dir, split)
self.split = split
# 使用配置文件中的参数替换固定参数
self.max_face = self.config.data.max_face
self.max_edge = self.config.data.max_edge
self.bbox_scaled = self.config.data.bbox_scaled
# 检查目录是否存在
if not os.path.exists(self.brep_dir):
raise ValueError(f"B-rep directory not found: {self.brep_dir}")
if not os.path.exists(self.sdf_dir):
raise ValueError(f"SDF directory not found: {self.sdf_dir}")
# 加载数据列表
# 如果存在valid_data_file,则加载valid_list
valid_data_file = os.path.join(valid_data_dir, f'{split}_success.txt')
if valid_data_file:
valid_data_file = os.path.join(self.brep_dir, valid_data_file)
self.valid_data_list = self._load_valid_list(valid_data_file)
else:
raise ValueError(f"Valid data file not found: {valid_data_file}")
self.brep_data_list = self._load_data_list(self.brep_dir)
self.sdf_data_list = self._load_data_list(self.sdf_dir)
if use_filter:
self._filter_num_faces_and_num_edges()
# 检查数据集是否为空
if len(self.brep_data_list) == 0 :
raise ValueError(f"No valid brep data found in {split} set")
if len(self.sdf_data_list) == 0:
raise ValueError(f"No valid sdf data found in {split} set")
logger.info(f"Loaded {split} dataset with {len(self.brep_data_list)} samples")
def _load_valid_list(self,valid_data_file:str):
with open(valid_data_file, 'r') as f:
valid_list = [line.strip() for line in f.readlines()]
return valid_list
# data_dir 为 self.brep_dir or sdf_dir
def _load_data_list(self, data_dir):
data_list = []
for sample_file in os.listdir(data_dir):
if sample_file.split('.')[0] in self.valid_data_list:
path = os.path.join(data_dir, sample_file)
data_list.append(path)
#logger.info(data_list)
return data_list
def _filter_num_faces_and_num_edges(self):
'''
Filter the data if their face_num or edge_num > max_face or max_edge.
'''
# Collect indices of elements that satisfy the condition
filtered_indices = [
idx for idx in range(len(self.brep_data_list))
if (self._get_brep_face_and_edge(self.brep_data_list[idx]) <= (self.max_face, self.max_edge))
]
#filtered_indices = filtered_indices[0:8] # TODO rm
# Use filtered_indices to update brep_data_list and sdf_data_list
self.brep_data_list = [self.brep_data_list[idx] for idx in filtered_indices]
self.sdf_data_list = [self.sdf_data_list[idx] for idx in filtered_indices]
def __len__(self):
return len(self.brep_data_list)
def __getitem__(self, idx):
"""获取单个数据样本"""
try:
brep_path = self.brep_data_list[idx]
sdf_path = self.sdf_data_list[idx]
name = os.path.splitext(os.path.basename(brep_path))[0]
# 加载B-rep和SDF数据
brep_raw = self._load_brep_file(brep_path)
sdf_data = self._load_sdf_file(sdf_path)
try:
# 处理B-rep数据
brep_features = process_brep_data(
data=brep_raw,
max_face=self.max_face,
max_edge=self.max_edge,
bbox_scaled=self.bbox_scaled
)
'''
# 打印数据形状
logger.debug(f"Processed data shapes for {os.path.basename(brep_path)}:")
for value in brep_features:
if isinstance(value, torch.Tensor):
logger.debug(f" {value.shape}")
# 检查返回值的类型和数量
if not isinstance(brep_features, tuple):
logger.error(f"process_brep_data returned {type(brep_features)}, expected tuple")
raise ValueError("Invalid return type from process_brep_data")
if len(brep_features) != 6:
logger.error(f"Expected 6 features, got {len(brep_features)}")
logger.error("Features returned:")
for i, feat in enumerate(brep_features):
if isinstance(feat, torch.Tensor):
logger.error(f" {i}: Tensor of shape {feat.shape}")
else:
logger.error(f" {i}: {type(feat)}")
raise ValueError(f"Incorrect number of features: {len(brep_features)}")
'''
# 解包处理后的特征
edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos = brep_features
sdf_points = sdf_data[:, :3]
sdf_values = sdf_data[:, 3:]
# 构建返回字典
return {
'name': name,
'edge_ncs': edge_ncs, # [max_face, max_edge, 10, 3]
'edge_pos': edge_pos, # [max_face, max_edge, 6]
'edge_mask': edge_mask, # [max_face, max_edge]
'surf_ncs': surf_ncs, # [max_face, 100, 3]
'surf_pos': surf_pos, # [max_face, 6]
'vertex_pos': vertex_pos, # [max_face, max_edge, 6]
'points': sdf_points, # [num_queries, 3] 所有点的xyz坐标
'sdf': sdf_values # [num_queries, 1] 所有点的sdf值
}
except Exception as e:
logger.error(f"\nError processing B-rep data for file: {brep_path}")
logger.error(f"Error type: {type(e).__name__}")
logger.error(f"Error message: {str(e)}")
# 打印原始数据的结构
logger.error("\nRaw data structure:")
for key, value in brep_raw.items():
if isinstance(value, list):
logger.error(f" {key}: list of length {len(value)}")
if value:
logger.error(f" First element type: {type(value[0])}")
if hasattr(value[0], 'shape'):
logger.error(f" First element shape: {value[0].shape}")
elif hasattr(value, 'shape'):
logger.error(f" {key}: shape {value.shape}")
else:
logger.error(f" {key}: {type(value)}")
raise
except Exception as e:
logger.error(f"Error loading sample from {brep_path}: {str(e)}")
logger.error("Data structure:")
raise
def _load_brep_file(self, brep_path):
with open(brep_path, 'rb') as f:
brep_raw = pickle.load(f)
return brep_raw
def _load_sdf_file(self, sdf_path):
"""加载和处理SDF数据,并进行随机采样"""
try:
# 加载SDF值
sdf_data = np.load(sdf_path)
if 'pos' not in sdf_data or 'neg' not in sdf_data:
raise ValueError("Missing pos/neg data in SDF file")
sdf_pos = sdf_data['pos'] # (N1, 4)
sdf_neg = sdf_data['neg'] # (N2, 4)
# 添加数据验证
if sdf_pos.shape[1] != 4 or sdf_neg.shape[1] != 4:
raise ValueError(f"Invalid SDF data shape: pos={sdf_pos.shape}, neg={sdf_neg.shape}")
# 随机采样
max_points = self.config.data.num_query_points # 例如4096
# 确保正负样本均衡
if max_points // 2 > sdf_pos.shape[0]:
logger.warning(f"正样本过少,期望>{max_points // 2},实际:{sdf_pos.shape[0]}")
if max_points // 2 > sdf_neg.shape[0]:
num_neg = sdf_neg.shape[0]
else:
num_neg = max_points // 2
num_pos = max_points - num_neg
# 随机采样正样本
if sdf_pos.shape[0] > num_pos:
pos_indices = np.random.choice(sdf_pos.shape[0], num_pos, replace=False)
sdf_pos = sdf_pos[pos_indices]
# 随机采样负样本
if sdf_neg.shape[0] > num_neg:
neg_indices = np.random.choice(sdf_neg.shape[0], num_neg, replace=False)
sdf_neg = sdf_neg[neg_indices]
# 合并数据
sdf_np = np.concatenate([sdf_pos, sdf_neg], axis=0)
# 再次随机打乱
np.random.shuffle(sdf_np)
# 如果总点数仍然超过最大限制,再次采样
if sdf_np.shape[0] > max_points:
indices = np.random.choice(sdf_np.shape[0], max_points, replace=False)
sdf_np = sdf_np[indices]
#logger.debug(f"Sampled SDF points: {sdf_np.shape[0]} (max: {max_points})")
return torch.from_numpy(sdf_np.astype(np.float32))
except Exception as e:
logger.error(f"Error loading SDF from {sdf_path}")
logger.error(f"Error type: {type(e).__name__}")
logger.error(f"Error message: {str(e)}")
raise
def _get_brep_face_and_edge(self, brep_path: str) -> tuple[int,int]:
brep: dict = self._load_brep_file(brep_path)
face_edge_adj = brep["faceEdge_adj"]
num_faces, num_edges = face_edge_adj.shape
return num_faces, num_edges
def load_brep_file(brep_path):
with open(brep_path, 'rb') as f:
brep_raw = pickle.load(f)
return brep_raw
def load_sdf_file(sdf_path: str, num_query_points: int = 4096) -> torch.Tensor:
"""
加载和处理SDF数据,并进行随机采样
参数:
sdf_path: SDF文件路径
num_query_points: 最大采样点数,默认为4096
返回:
sdf_tensor: 处理后的SDF数据张量
"""
try:
# 加载SDF值
sdf_data = np.load(sdf_path)
if 'pos' not in sdf_data or 'neg' not in sdf_data:
raise ValueError("Missing pos/neg data in SDF file")
sdf_pos = sdf_data['pos'] # (N1, 4)
sdf_neg = sdf_data['neg'] # (N2, 4)
# 添加数据验证
if sdf_pos.shape[1] != 4 or sdf_neg.shape[1] != 4:
raise ValueError(f"Invalid SDF data shape: pos={sdf_pos.shape}, neg={sdf_neg.shape}")
# 确保正负样本均衡
if num_query_points // 2 > sdf_pos.shape[0]:
logger.warning(f"正样本过少,期望>{num_query_points // 2},实际:{sdf_pos.shape[0]}")
num_neg = min(num_query_points // 2, sdf_neg.shape[0])
num_pos = num_query_points - num_neg
# 随机采样正样本
if sdf_pos.shape[0] > num_pos:
pos_indices = np.random.choice(sdf_pos.shape[0], num_pos, replace=False)
sdf_pos = sdf_pos[pos_indices]
# 随机采样负样本
if sdf_neg.shape[0] > num_neg:
neg_indices = np.random.choice(sdf_neg.shape[0], num_neg, replace=False)
sdf_neg = sdf_neg[neg_indices]
# 合并数据
sdf_np = np.concatenate([sdf_pos, sdf_neg], axis=0)
# 再次随机打乱
np.random.shuffle(sdf_np)
# 如果总点数仍然超过最大限制,再次采样
if sdf_np.shape[0] > num_query_points:
indices = np.random.choice(sdf_np.shape[0], num_query_points, replace=False)
sdf_np = sdf_np[indices]
return torch.from_numpy(sdf_np.astype(np.float32))
except Exception as e:
logger.error(f"Error loading SDF from {sdf_path}")
logger.error(f"Error type: {type(e).__name__}")
logger.error(f"Error message: {str(e)}")
raise
def test_dataset():
"""测试数据集功能"""
try:
# 获取配置
config = get_default_config()
# 定义预期的数据维度
expected_shapes = {
'edge_ncs': (config.data.max_face, config.data.max_edge, config.model.num_edge_points, 3),
'edge_pos': (config.data.max_face, config.data.max_edge, 6),
'edge_mask': (config.data.max_face, config.data.max_edge),
'surf_ncs': (config.data.max_face, config.model.num_surf_points, 3),
'surf_pos': (config.data.max_face, 6),
'vertex_pos': (config.data.max_face, config.data.max_edge, 2, 3),
'points': (config.data.num_query_points, 3),
'sdf': (config.data.num_query_points, 1)
}
logger.info("="*50)
logger.info("测试数据集")
logger.info(f"预期形状:")
for key, shape in expected_shapes.items():
logger.info(f" {key}: {shape}")
# 初始化数据集
dataset = BRepSDFDataset(
brep_dir=config.data.brep_dir,
sdf_dir=config.data.sdf_dir,
valid_data_dir=config.data.valid_data_dir,
split='train'
)
# 测试数据加载
logger.info("\n测试数据加载...")
sample = dataset[0]
# 检查数据类型和形状
logger.info("\n数据类型和形状检查:")
for key, value in sample.items():
if isinstance(value, torch.Tensor):
actual_shape = tuple(value.shape)
expected_shape = expected_shapes.get(key)
shape_match = "" if actual_shape == expected_shape else ""
logger.info(f"\n{key}:")
logger.info(f" 实际形状: {actual_shape}")
logger.info(f" 预期形状: {expected_shape}")
logger.info(f" 匹配状态: {shape_match}")
logger.info(f" 数据类型: {value.dtype}")
# 仅对浮点类型计算数值范围、均值和标准差
if value.dtype.is_floating_point:
logger.info(f" 数值范围: [{value.min():.3f}, {value.max():.3f}]")
logger.info(f" 均值: {value.mean():.3f}")
logger.info(f" 标准差: {value.std():.3f}")
if shape_match == "":
logger.warning(f" 形状不匹配: {key}")
if key in ['points', 'sdf']:
logger.warning(f" 查询点数量不一致,预期 {expected_shape[0]},实际 {actual_shape[0]}")
elif key in ['edge_ncs', 'edge_pos', 'edge_mask']:
logger.warning(f" 边数量不一致,预期 {expected_shape[1]},实际 {actual_shape[1]}")
elif key in ['surf_ncs', 'surf_pos']:
logger.warning(f" 面数量不一致,预期 {expected_shape[0]},实际 {actual_shape[0]}")
# 测试批处理
logger.info("\n测试批处理...")
batch_size = 4
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0
)
batch = next(iter(dataloader))
logger.info("\n批处理形状检查:")
for key, value in batch.items():
if isinstance(value, torch.Tensor):
batch_shape = tuple(value.shape)
expected_batch_shape = (batch_size,) + expected_shapes[key]
shape_match = "" if batch_shape == expected_batch_shape else ""
logger.info(f"\n{key}:")
logger.info(f" 实际形状: {batch_shape}")
logger.info(f" 预期形状: {expected_batch_shape}")
logger.info(f" 匹配状态: {shape_match}")
logger.info(f" 数据类型: {value.dtype}")
# 仅对浮点类型计算数值范围、均值和标准差
if value.dtype.is_floating_point:
logger.info(f" 数值范围: [{value.min():.3f}, {value.max():.3f}]")
logger.info(f" 均值: {value.mean():.3f}")
logger.info(f" 标准差: {value.std():.3f}")
if shape_match == "":
logger.warning(f" 批处理形状不匹配: {key}")
logger.info("\n测试完成!")
logger.info("="*50)
except Exception as e:
logger.error(f"测试过程中出错: {str(e)}")
raise
from collections import defaultdict
from tqdm import tqdm
def validate_dataset(split: str = 'train', num_samples: int = None):
"""全面验证数据集
Args:
split: 数据集分割 ('train', 'val', 'test')
num_samples: 要检查的样本数量,None表示检查所有样本
"""
try:
config = get_default_config()
logger.info(f"开始验证{split}数据集...")
# 初始化数据集
dataset = BRepSDFDataset(
brep_dir=config.data.brep_dir,
sdf_dir=config.data.sdf_dir,
valid_data_dir=config.data.valid_data_dir,
split='train'
)
total_samples = len(dataset) if num_samples is None else min(num_samples, len(dataset))
logger.info(f"总样本数: {total_samples}")
# 初始化统计信息
stats = {
'face_counts': [],
'edge_counts': [],
'vertex_counts': [],
'sdf_point_counts': [],
'invalid_samples': [],
'shape_mismatches': defaultdict(int),
'value_ranges': defaultdict(lambda: {'min': float('inf'), 'max': float('-inf')}),
'nan_counts': defaultdict(int),
'inf_counts': defaultdict(int)
}
# 遍历数据集
for idx in tqdm(range(total_samples), desc="验证数据"):
try:
sample = dataset[idx]
# 1. 检查数据完整性
required_keys = ['surf_ncs', 'surf_pos', 'edge_ncs', 'edge_pos',
'vertex_pos', 'points', 'sdf', 'edge_mask']
missing_keys = [key for key in required_keys if key not in sample]
if missing_keys:
stats['invalid_samples'].append((idx, f"缺少键: {missing_keys}"))
continue
# 2. 检查形状
expected_shapes = {
'surf_ncs': (config.data.max_face, config.model.num_surf_points, 3),
'surf_pos': (config.data.max_face, 6),
'edge_ncs': (config.data.max_face, config.data.max_edge, config.model.num_edge_points, 3),
'edge_pos': (config.data.max_face, config.data.max_edge, 6),
'edge_mask': (config.data.max_face, config.data.max_edge),
'vertex_pos': (config.data.max_face, config.data.max_edge, 2, 3),
'points': (config.data.num_query_points, 3),
'sdf': (config.data.num_query_points, 1)
}
for key, expected_shape in expected_shapes.items():
if key in sample:
actual_shape = tuple(sample[key].shape)
if actual_shape != expected_shape:
stats['shape_mismatches'][key] += 1
stats['invalid_samples'].append(
(idx, f"{key} 形状不匹配: 预期 {expected_shape}, 实际 {actual_shape}")
)
# 3. 检查数值范围和无效值
for key, tensor in sample.items():
if isinstance(tensor, torch.Tensor) and tensor.dtype.is_floating_point:
# 更新值范围
stats['value_ranges'][key]['min'] = min(stats['value_ranges'][key]['min'],
tensor.min().item())
stats['value_ranges'][key]['max'] = max(stats['value_ranges'][key]['max'],
tensor.max().item())
# 检查NaN和Inf
nan_count = torch.isnan(tensor).sum().item()
inf_count = torch.isinf(tensor).sum().item()
if nan_count > 0:
stats['nan_counts'][key] += nan_count
if inf_count > 0:
stats['inf_counts'][key] += inf_count
# 4. 收集统计信息
stats['face_counts'].append(sample['surf_ncs'].shape[0])
stats['edge_counts'].append(sample['edge_ncs'].shape[1])
stats['vertex_counts'].append(len(torch.unique(sample['vertex_pos'].reshape(-1, 3), dim=0)))
stats['sdf_point_counts'].append(sample['points'].shape[0])
except Exception as e:
stats['invalid_samples'].append((idx, str(e)))
# 输出统计结果
logger.info("\n=== 数据集验证结果 ===")
# 1. 基本统计信息
logger.info("\n基本统计信息:")
logger.info(f"总样本数: {total_samples}")
logger.info(f"有效样本数: {total_samples - len(stats['invalid_samples'])}")
logger.info(f"无效样本数: {len(stats['invalid_samples'])}")
# 2. 形状不匹配统计
if stats['shape_mismatches']:
logger.info("\n形状不匹配统计:")
for key, count in stats['shape_mismatches'].items():
logger.info(f" {key}: {count}个样本不匹配")
# 3. 数值范围统计
logger.info("\n数值范围统计:")
for key, ranges in stats['value_ranges'].items():
logger.info(f" {key}:")
logger.info(f" 最小值: {ranges['min']:.3f}")
logger.info(f" 最大值: {ranges['max']:.3f}")
# 4. 无效值统计
if sum(stats['nan_counts'].values()) > 0 or sum(stats['inf_counts'].values()) > 0:
logger.info("\n无效值统计:")
for key in stats['nan_counts'].keys():
if stats['nan_counts'][key] > 0:
logger.info(f" {key} 包含 {stats['nan_counts'][key]} 个 NaN 值")
for key in stats['inf_counts'].keys():
if stats['inf_counts'][key] > 0:
logger.info(f" {key} 包含 {stats['inf_counts'][key]} 个 Inf 值")
# 5. 几何特征统计
logger.info("\n几何特征统计:")
for name, values in [
('面数', stats['face_counts']),
('边数', stats['edge_counts']),
('顶点数', stats['vertex_counts']),
('SDF采样点数', stats['sdf_point_counts'])
]:
values = np.array(values)
logger.info(f" {name}:")
logger.info(f" 最小值: {np.min(values)}")
logger.info(f" 最大值: {np.max(values)}")
logger.info(f" 平均值: {np.mean(values):.2f}")
logger.info(f" 中位数: {np.median(values):.2f}")
logger.info(f" 标准差: {np.std(values):.2f}")
# 6. 输出无效样本详情
if stats['invalid_samples']:
logger.info("\n无效样本详情:")
for idx, error in stats['invalid_samples']:
logger.info(f" 样本 {idx}: {error}")
return stats
except Exception as e:
logger.error(f"验证过程出错: {str(e)}")
raise
if __name__ == '__main__':
validate_dataset(split='train', num_samples=None) # 先测试100个样本