You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

360 lines
15 KiB

import os
import torch
from torch.utils.data import Dataset
import numpy as np
7 months ago
import pickle
from brep2sdf.utils.logger import logger
from brep2sdf.data.utils import process_brep_data
from brep2sdf.config.default_config import get_default_config
7 months ago
class BRepSDFDataset(Dataset):
def __init__(self, brep_dir:str, sdf_dir:str, valid_data_dir:str, split:str='train'):
7 months ago
"""
初始化数据集
参数:
brep_dir: pkl文件目录
sdf_dir: npz文件目录
split: 数据集分割('train', 'val', 'test')
"""
super().__init__()
# 使用配置文件
self.config = get_default_config()
7 months ago
self.brep_dir = os.path.join(brep_dir, split)
self.sdf_dir = os.path.join(sdf_dir, split)
self.split = split
7 months ago
# 使用配置文件中的参数替换固定参数
self.max_face = self.config.data.max_face
self.max_edge = self.config.data.max_edge
self.bbox_scaled = self.config.data.bbox_scaled
7 months ago
# 检查目录是否存在
if not os.path.exists(self.brep_dir):
raise ValueError(f"B-rep directory not found: {self.brep_dir}")
if not os.path.exists(self.sdf_dir):
raise ValueError(f"SDF directory not found: {self.sdf_dir}")
# 加载数据列表
# 如果存在valid_data_file,则加载valid_list
valid_data_file = os.path.join(valid_data_dir, f'{split}_success.txt')
if valid_data_file:
valid_data_file = os.path.join(self.brep_dir, valid_data_file)
self.valid_data_list = self._load_valid_list(valid_data_file)
else:
raise ValueError(f"Valid data file not found: {valid_data_file}")
7 months ago
self.brep_data_list = self._load_data_list(self.brep_dir)
self.sdf_data_list = self._load_data_list(self.sdf_dir)
7 months ago
# 检查数据集是否为空
7 months ago
if len(self.brep_data_list) == 0 :
raise ValueError(f"No valid brep data found in {split} set")
if len(self.sdf_data_list) == 0:
raise ValueError(f"No valid sdf data found in {split} set")
7 months ago
7 months ago
logger.info(f"Loaded {split} dataset with {len(self.brep_data_list)} samples")
def _load_valid_list(self,valid_data_file:str):
with open(valid_data_file, 'r') as f:
valid_list = [line.strip() for line in f.readlines()]
return valid_list
7 months ago
# data_dir 为 self.brep_dir or sdf_dir
def _load_data_list(self, data_dir):
data_list = []
7 months ago
for sample_file in os.listdir(data_dir):
if sample_file.split('.')[0] in self.valid_data_list:
path = os.path.join(data_dir, sample_file)
data_list.append(path)
7 months ago
#logger.info(data_list)
return data_list
def __len__(self):
7 months ago
return len(self.brep_data_list)
def __getitem__(self, idx):
7 months ago
"""获取单个数据样本"""
try:
brep_path = self.brep_data_list[idx]
sdf_path = self.sdf_data_list[idx]
7 months ago
name = os.path.splitext(os.path.basename(brep_path))[0]
# 加载B-rep和SDF数据
with open(brep_path, 'rb') as f:
brep_raw = pickle.load(f)
7 months ago
sdf_data = self._load_sdf_file(sdf_path)
try:
# 处理B-rep数据
brep_features = process_brep_data(
data=brep_raw,
max_face=self.max_face,
max_edge=self.max_edge,
bbox_scaled=self.bbox_scaled
)
'''
# 打印数据形状
logger.debug(f"Processed data shapes for {os.path.basename(brep_path)}:")
for value in brep_features:
if isinstance(value, torch.Tensor):
logger.debug(f" {value.shape}")
# 检查返回值的类型和数量
if not isinstance(brep_features, tuple):
logger.error(f"process_brep_data returned {type(brep_features)}, expected tuple")
raise ValueError("Invalid return type from process_brep_data")
if len(brep_features) != 6:
logger.error(f"Expected 6 features, got {len(brep_features)}")
logger.error("Features returned:")
for i, feat in enumerate(brep_features):
if isinstance(feat, torch.Tensor):
logger.error(f" {i}: Tensor of shape {feat.shape}")
else:
logger.error(f" {i}: {type(feat)}")
raise ValueError(f"Incorrect number of features: {len(brep_features)}")
'''
# 解包处理后的特征
edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos = brep_features
sdf_points = sdf_data[:, :3]
sdf_values = sdf_data[:, 3:]
# 构建返回字典
return {
'name': name,
'edge_ncs': edge_ncs, # [max_face, max_edge, 10, 3]
'edge_pos': edge_pos, # [max_face, max_edge, 6]
'edge_mask': edge_mask, # [max_face, max_edge]
'surf_ncs': surf_ncs, # [max_face, 100, 3]
'surf_pos': surf_pos, # [max_face, 6]
'vertex_pos': vertex_pos, # [max_face, max_edge, 6]
'points': sdf_points, # [num_queries, 3] 所有点的xyz坐标
'sdf': sdf_values # [num_queries, 1] 所有点的sdf值
}
except Exception as e:
logger.error(f"\nError processing B-rep data for file: {brep_path}")
logger.error(f"Error type: {type(e).__name__}")
logger.error(f"Error message: {str(e)}")
# 打印原始数据的结构
logger.error("\nRaw data structure:")
for key, value in brep_raw.items():
if isinstance(value, list):
logger.error(f" {key}: list of length {len(value)}")
if value:
logger.error(f" First element type: {type(value[0])}")
if hasattr(value[0], 'shape'):
logger.error(f" First element shape: {value[0].shape}")
elif hasattr(value, 'shape'):
logger.error(f" {key}: shape {value.shape}")
else:
logger.error(f" {key}: {type(value)}")
raise
7 months ago
except Exception as e:
logger.error(f"Error loading sample from {brep_path}: {str(e)}")
logger.error("Data structure:")
raise
7 months ago
7 months ago
def _load_sdf_file(self, sdf_path):
"""加载和处理SDF数据,并进行随机采样"""
7 months ago
try:
# 加载SDF值
sdf_data = np.load(sdf_path)
if 'pos' not in sdf_data or 'neg' not in sdf_data:
raise ValueError("Missing pos/neg data in SDF file")
sdf_pos = sdf_data['pos'] # (N1, 4)
sdf_neg = sdf_data['neg'] # (N2, 4)
# 添加数据验证
if sdf_pos.shape[1] != 4 or sdf_neg.shape[1] != 4:
raise ValueError(f"Invalid SDF data shape: pos={sdf_pos.shape}, neg={sdf_neg.shape}")
# 随机采样
max_points = self.config.data.num_query_points # 例如4096
# 确保正负样本均衡
num_pos = min(max_points // 2, sdf_pos.shape[0])
num_neg = min(max_points // 2, sdf_neg.shape[0])
# 随机采样正样本
if sdf_pos.shape[0] > num_pos:
pos_indices = np.random.choice(sdf_pos.shape[0], num_pos, replace=False)
sdf_pos = sdf_pos[pos_indices]
# 随机采样负样本
if sdf_neg.shape[0] > num_neg:
neg_indices = np.random.choice(sdf_neg.shape[0], num_neg, replace=False)
sdf_neg = sdf_neg[neg_indices]
# 合并数据
7 months ago
sdf_np = np.concatenate([sdf_pos, sdf_neg], axis=0)
# 再次随机打乱
np.random.shuffle(sdf_np)
# 如果总点数仍然超过最大限制,再次采样
if sdf_np.shape[0] > max_points:
indices = np.random.choice(sdf_np.shape[0], max_points, replace=False)
sdf_np = sdf_np[indices]
#logger.debug(f"Sampled SDF points: {sdf_np.shape[0]} (max: {max_points})")
7 months ago
return torch.from_numpy(sdf_np.astype(np.float32))
except Exception as e:
logger.error(f"Error loading SDF from {sdf_path}")
logger.error(f"Error type: {type(e).__name__}")
logger.error(f"Error message: {str(e)}")
raise
7 months ago
def test_dataset():
"""测试数据集功能"""
try:
# 获取配置
config = get_default_config()
brep_dir = config.data.brep_dir
sdf_dir = config.data.sdf_dir
valid_data_dir = config.data.valid_data_dir
7 months ago
split = 'train'
max_face = config.data.max_face
max_edge = config.data.max_edge
num_edge_points = config.model.num_edge_points
num_surf_points = config.model.num_surf_points
num_query_points = config.data.num_query_points
7 months ago
# 定义预期的数据维度,使用配置中的参数
expected_shapes = {
'edge_ncs': (max_face, max_edge, num_edge_points, 3), # [max_face, max_edge, sample_points, xyz]
'edge_pos': (max_face, max_edge, 6),
'edge_mask': (max_face, max_edge),
'surf_ncs': (max_face, num_surf_points, 3),
'surf_pos': (max_face, 6),
'vertex_pos': (max_face, max_edge, 2, 3),
'points': (num_query_points, 3),
'sdf': (num_query_points, 1)
}
7 months ago
logger.info("="*50)
logger.info("Testing dataset")
logger.info(f"B-rep directory: {brep_dir}")
logger.info(f"SDF directory: {sdf_dir}")
logger.info(f"Split: {split}")
# 2. 初始化数据集
dataset = BRepSDFDataset(
brep_dir=brep_dir,
sdf_dir=sdf_dir,
valid_data_dir=valid_data_dir,
split=split
)
logger.info(f"\nDataset size: {len(dataset)}")
# 3. 测试单个样本加载和形状检查
logger.info("\nTesting single sample loading...")
try:
sample = dataset[0]
logger.info("Sample keys and shapes:")
for key, value in sample.items():
if isinstance(value, torch.Tensor):
actual_shape = tuple(value.shape)
expected_shape = expected_shapes.get(key)
shape_match = actual_shape == expected_shape if expected_shape else None
logger.info(f" {key}:")
logger.info(f" Shape: {actual_shape}")
logger.info(f" Expected: {expected_shape}")
logger.info(f" Match: {shape_match}")
logger.info(f" dtype: {value.dtype}")
logger.info(f" grad: {value.requires_grad}")
if not shape_match:
logger.warning(f" Shape mismatch for {key}!")
else:
logger.info(f" {key}: {type(value)}")
except Exception as e:
logger.error("Error loading single sample")
logger.error(f"Error message: {str(e)}")
raise
# 4. 测试数据加载器和批处理形状
logger.info("\nTesting DataLoader...")
try:
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=2,
shuffle=True,
num_workers=0
)
batch = next(iter(dataloader))
logger.info("Batch keys and shapes:")
for key, value in batch.items():
if isinstance(value, torch.Tensor):
actual_shape = tuple(value.shape)
expected_shape = expected_shapes.get(key)
if expected_shape:
expected_batch_shape = (2,) + expected_shape
shape_match = actual_shape == expected_batch_shape
else:
expected_batch_shape = None
shape_match = None
logger.info(f" {key}:")
logger.info(f" Shape: {actual_shape}")
logger.info(f" Expected: {expected_batch_shape}")
logger.info(f" Match: {shape_match}")
logger.info(f" dtype: {value.dtype}")
if not shape_match:
logger.warning(f" Shape mismatch for {key}!")
elif isinstance(value, list):
logger.info(f" {key}: list of length {len(value)}")
else:
logger.info(f" {key}: {type(value)}")
except Exception as e:
logger.error("Error in DataLoader")
logger.error(f"Error message: {str(e)}")
raise
# 5. 验证数据范围
logger.info("\nValidating data ranges...")
try:
for key, value in batch.items():
if isinstance(value, torch.Tensor) and value.dtype in [torch.float32, torch.float64]:
logger.info(f" {key}:")
logger.info(f" min: {value.min().item():.4f}")
logger.info(f" max: {value.max().item():.4f}")
logger.info(f" mean: {value.mean().item():.4f}")
logger.info(f" std: {value.std().item():.4f}")
# 检查是否有NaN或Inf
has_nan = torch.isnan(value).any()
has_inf = torch.isinf(value).any()
if has_nan or has_inf:
logger.warning(f" Found NaN: {has_nan}, Inf: {has_inf}")
except Exception as e:
logger.error("Error validating data ranges")
logger.error(f"Error message: {str(e)}")
raise
logger.info("\nAll tests completed successfully!")
logger.info("="*50)
7 months ago
except Exception as e:
logger.error(f"Error in test_dataset: {str(e)}")
raise
7 months ago
if __name__ == '__main__':
test_dataset()