|
|
|
import os
|
|
|
|
import torch
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
import numpy as np
|
|
|
|
import pickle
|
|
|
|
from brep2sdf.utils.logger import logger
|
|
|
|
from brep2sdf.data.utils import process_brep_data
|
|
|
|
from brep2sdf.config.default_config import get_default_config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BRepSDFDataset(Dataset):
|
|
|
|
def __init__(self, brep_dir:str, sdf_dir:str, valid_data_dir:str, use_filter: bool=True, split:str='train'):
|
|
|
|
"""
|
|
|
|
初始化数据集
|
|
|
|
|
|
|
|
参数:
|
|
|
|
brep_dir: pkl文件目录
|
|
|
|
sdf_dir: npz文件目录
|
|
|
|
split: 数据集分割('train', 'val', 'test')
|
|
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
# 使用配置文件
|
|
|
|
self.config = get_default_config()
|
|
|
|
|
|
|
|
self.brep_dir = os.path.join(brep_dir, split)
|
|
|
|
self.sdf_dir = os.path.join(sdf_dir, split)
|
|
|
|
self.split = split
|
|
|
|
|
|
|
|
# 使用配置文件中的参数替换固定参数
|
|
|
|
self.max_face = self.config.data.max_face
|
|
|
|
self.max_edge = self.config.data.max_edge
|
|
|
|
self.bbox_scaled = self.config.data.bbox_scaled
|
|
|
|
|
|
|
|
# 检查目录是否存在
|
|
|
|
if not os.path.exists(self.brep_dir):
|
|
|
|
raise ValueError(f"B-rep directory not found: {self.brep_dir}")
|
|
|
|
if not os.path.exists(self.sdf_dir):
|
|
|
|
raise ValueError(f"SDF directory not found: {self.sdf_dir}")
|
|
|
|
|
|
|
|
# 加载数据列表
|
|
|
|
# 如果存在valid_data_file,则加载valid_list
|
|
|
|
valid_data_file = os.path.join(valid_data_dir, f'{split}_success.txt')
|
|
|
|
if valid_data_file:
|
|
|
|
valid_data_file = os.path.join(self.brep_dir, valid_data_file)
|
|
|
|
self.valid_data_list = self._load_valid_list(valid_data_file)
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Valid data file not found: {valid_data_file}")
|
|
|
|
|
|
|
|
self.brep_data_list = self._load_data_list(self.brep_dir)
|
|
|
|
self.sdf_data_list = self._load_data_list(self.sdf_dir)
|
|
|
|
|
|
|
|
if use_filter:
|
|
|
|
self._filter_num_faces_and_num_edges()
|
|
|
|
|
|
|
|
# 检查数据集是否为空
|
|
|
|
if len(self.brep_data_list) == 0 :
|
|
|
|
raise ValueError(f"No valid brep data found in {split} set")
|
|
|
|
if len(self.sdf_data_list) == 0:
|
|
|
|
raise ValueError(f"No valid sdf data found in {split} set")
|
|
|
|
|
|
|
|
logger.info(f"Loaded {split} dataset with {len(self.brep_data_list)} samples")
|
|
|
|
|
|
|
|
def _load_valid_list(self,valid_data_file:str):
|
|
|
|
with open(valid_data_file, 'r') as f:
|
|
|
|
valid_list = [line.strip() for line in f.readlines()]
|
|
|
|
return valid_list
|
|
|
|
|
|
|
|
# data_dir 为 self.brep_dir or sdf_dir
|
|
|
|
def _load_data_list(self, data_dir):
|
|
|
|
data_list = []
|
|
|
|
for sample_file in os.listdir(data_dir):
|
|
|
|
if sample_file.split('.')[0] in self.valid_data_list:
|
|
|
|
path = os.path.join(data_dir, sample_file)
|
|
|
|
data_list.append(path)
|
|
|
|
|
|
|
|
#logger.info(data_list)
|
|
|
|
return data_list
|
|
|
|
|
|
|
|
def _filter_num_faces_and_num_edges(self):
|
|
|
|
'''
|
|
|
|
Filter the data if their face_num or edge_num > max_face or max_edge.
|
|
|
|
'''
|
|
|
|
# Collect indices of elements that satisfy the condition
|
|
|
|
filtered_indices = [
|
|
|
|
idx for idx in range(len(self.brep_data_list))
|
|
|
|
if (self._get_brep_face_and_edge(self.brep_data_list[idx]) <= (self.max_face, self.max_edge))
|
|
|
|
]
|
|
|
|
|
|
|
|
#filtered_indices = filtered_indices[0:8] # TODO rm
|
|
|
|
|
|
|
|
# Use filtered_indices to update brep_data_list and sdf_data_list
|
|
|
|
self.brep_data_list = [self.brep_data_list[idx] for idx in filtered_indices]
|
|
|
|
self.sdf_data_list = [self.sdf_data_list[idx] for idx in filtered_indices]
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.brep_data_list)
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
"""获取单个数据样本"""
|
|
|
|
try:
|
|
|
|
brep_path = self.brep_data_list[idx]
|
|
|
|
sdf_path = self.sdf_data_list[idx]
|
|
|
|
name = os.path.splitext(os.path.basename(brep_path))[0]
|
|
|
|
|
|
|
|
# 加载B-rep和SDF数据
|
|
|
|
brep_raw = self._load_brep_file(brep_path)
|
|
|
|
sdf_data = self._load_sdf_file(sdf_path)
|
|
|
|
|
|
|
|
try:
|
|
|
|
# 处理B-rep数据
|
|
|
|
brep_features = process_brep_data(
|
|
|
|
data=brep_raw,
|
|
|
|
max_face=self.max_face,
|
|
|
|
max_edge=self.max_edge,
|
|
|
|
bbox_scaled=self.bbox_scaled
|
|
|
|
)
|
|
|
|
'''
|
|
|
|
# 打印数据形状
|
|
|
|
logger.debug(f"Processed data shapes for {os.path.basename(brep_path)}:")
|
|
|
|
for value in brep_features:
|
|
|
|
if isinstance(value, torch.Tensor):
|
|
|
|
logger.debug(f" {value.shape}")
|
|
|
|
# 检查返回值的类型和数量
|
|
|
|
if not isinstance(brep_features, tuple):
|
|
|
|
logger.error(f"process_brep_data returned {type(brep_features)}, expected tuple")
|
|
|
|
raise ValueError("Invalid return type from process_brep_data")
|
|
|
|
|
|
|
|
if len(brep_features) != 6:
|
|
|
|
logger.error(f"Expected 6 features, got {len(brep_features)}")
|
|
|
|
logger.error("Features returned:")
|
|
|
|
for i, feat in enumerate(brep_features):
|
|
|
|
if isinstance(feat, torch.Tensor):
|
|
|
|
logger.error(f" {i}: Tensor of shape {feat.shape}")
|
|
|
|
else:
|
|
|
|
logger.error(f" {i}: {type(feat)}")
|
|
|
|
raise ValueError(f"Incorrect number of features: {len(brep_features)}")
|
|
|
|
'''
|
|
|
|
# 解包处理后的特征
|
|
|
|
edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos = brep_features
|
|
|
|
sdf_points = sdf_data[:, :3]
|
|
|
|
sdf_values = sdf_data[:, 3:]
|
|
|
|
|
|
|
|
# 构建返回字典
|
|
|
|
return {
|
|
|
|
'name': name,
|
|
|
|
'edge_ncs': edge_ncs, # [max_face, max_edge, 10, 3]
|
|
|
|
'edge_pos': edge_pos, # [max_face, max_edge, 6]
|
|
|
|
'edge_mask': edge_mask, # [max_face, max_edge]
|
|
|
|
'surf_ncs': surf_ncs, # [max_face, 100, 3]
|
|
|
|
'surf_pos': surf_pos, # [max_face, 6]
|
|
|
|
'vertex_pos': vertex_pos, # [max_face, max_edge, 6]
|
|
|
|
'points': sdf_points, # [num_queries, 3] 所有点的xyz坐标
|
|
|
|
'sdf': sdf_values # [num_queries, 1] 所有点的sdf值
|
|
|
|
}
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"\nError processing B-rep data for file: {brep_path}")
|
|
|
|
logger.error(f"Error type: {type(e).__name__}")
|
|
|
|
logger.error(f"Error message: {str(e)}")
|
|
|
|
|
|
|
|
# 打印原始数据的结构
|
|
|
|
logger.error("\nRaw data structure:")
|
|
|
|
for key, value in brep_raw.items():
|
|
|
|
if isinstance(value, list):
|
|
|
|
logger.error(f" {key}: list of length {len(value)}")
|
|
|
|
if value:
|
|
|
|
logger.error(f" First element type: {type(value[0])}")
|
|
|
|
if hasattr(value[0], 'shape'):
|
|
|
|
logger.error(f" First element shape: {value[0].shape}")
|
|
|
|
elif hasattr(value, 'shape'):
|
|
|
|
logger.error(f" {key}: shape {value.shape}")
|
|
|
|
else:
|
|
|
|
logger.error(f" {key}: {type(value)}")
|
|
|
|
raise
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error loading sample from {brep_path}: {str(e)}")
|
|
|
|
logger.error("Data structure:")
|
|
|
|
raise
|
|
|
|
def _load_brep_file(self, brep_path):
|
|
|
|
with open(brep_path, 'rb') as f:
|
|
|
|
brep_raw = pickle.load(f)
|
|
|
|
return brep_raw
|
|
|
|
|
|
|
|
def _load_sdf_file(self, sdf_path):
|
|
|
|
"""加载和处理SDF数据,并进行随机采样"""
|
|
|
|
try:
|
|
|
|
# 加载SDF值
|
|
|
|
sdf_data = np.load(sdf_path)
|
|
|
|
if 'pos' not in sdf_data or 'neg' not in sdf_data:
|
|
|
|
raise ValueError("Missing pos/neg data in SDF file")
|
|
|
|
|
|
|
|
sdf_pos = sdf_data['pos'] # (N1, 4)
|
|
|
|
sdf_neg = sdf_data['neg'] # (N2, 4)
|
|
|
|
|
|
|
|
# 添加数据验证
|
|
|
|
if sdf_pos.shape[1] != 4 or sdf_neg.shape[1] != 4:
|
|
|
|
raise ValueError(f"Invalid SDF data shape: pos={sdf_pos.shape}, neg={sdf_neg.shape}")
|
|
|
|
|
|
|
|
# 随机采样
|
|
|
|
max_points = self.config.data.num_query_points # 例如4096
|
|
|
|
|
|
|
|
# 确保正负样本均衡
|
|
|
|
if max_points // 2 > sdf_pos.shape[0]:
|
|
|
|
logger.warning(f"正样本过少,期望>{max_points // 2},实际:{sdf_pos.shape[0]}")
|
|
|
|
|
|
|
|
if max_points // 2 > sdf_neg.shape[0]:
|
|
|
|
num_neg = sdf_neg.shape[0]
|
|
|
|
else:
|
|
|
|
num_neg = max_points // 2
|
|
|
|
|
|
|
|
num_pos = max_points - num_neg
|
|
|
|
|
|
|
|
# 随机采样正样本
|
|
|
|
if sdf_pos.shape[0] > num_pos:
|
|
|
|
pos_indices = np.random.choice(sdf_pos.shape[0], num_pos, replace=False)
|
|
|
|
sdf_pos = sdf_pos[pos_indices]
|
|
|
|
|
|
|
|
# 随机采样负样本
|
|
|
|
if sdf_neg.shape[0] > num_neg:
|
|
|
|
neg_indices = np.random.choice(sdf_neg.shape[0], num_neg, replace=False)
|
|
|
|
sdf_neg = sdf_neg[neg_indices]
|
|
|
|
|
|
|
|
# 合并数据
|
|
|
|
sdf_np = np.concatenate([sdf_pos, sdf_neg], axis=0)
|
|
|
|
|
|
|
|
# 再次随机打乱
|
|
|
|
np.random.shuffle(sdf_np)
|
|
|
|
|
|
|
|
# 如果总点数仍然超过最大限制,再次采样
|
|
|
|
if sdf_np.shape[0] > max_points:
|
|
|
|
indices = np.random.choice(sdf_np.shape[0], max_points, replace=False)
|
|
|
|
sdf_np = sdf_np[indices]
|
|
|
|
|
|
|
|
#logger.debug(f"Sampled SDF points: {sdf_np.shape[0]} (max: {max_points})")
|
|
|
|
return torch.from_numpy(sdf_np.astype(np.float32))
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error loading SDF from {sdf_path}")
|
|
|
|
logger.error(f"Error type: {type(e).__name__}")
|
|
|
|
logger.error(f"Error message: {str(e)}")
|
|
|
|
raise
|
|
|
|
|
|
|
|
def _get_brep_face_and_edge(self, brep_path: str) -> tuple[int,int]:
|
|
|
|
brep: dict = self._load_brep_file(brep_path)
|
|
|
|
face_edge_adj = brep["faceEdge_adj"]
|
|
|
|
num_faces, num_edges = face_edge_adj.shape
|
|
|
|
return num_faces, num_edges
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_brep_file(brep_path):
|
|
|
|
with open(brep_path, 'rb') as f:
|
|
|
|
brep_raw = pickle.load(f)
|
|
|
|
return brep_raw
|
|
|
|
|
|
|
|
def load_sdf_file(sdf_path: str, num_query_points: int = 4096) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
加载和处理SDF数据,并进行随机采样
|
|
|
|
|
|
|
|
参数:
|
|
|
|
sdf_path: SDF文件路径
|
|
|
|
num_query_points: 最大采样点数,默认为4096
|
|
|
|
|
|
|
|
返回:
|
|
|
|
sdf_tensor: 处理后的SDF数据张量
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
# 加载SDF值
|
|
|
|
sdf_data = np.load(sdf_path)
|
|
|
|
if 'pos' not in sdf_data or 'neg' not in sdf_data:
|
|
|
|
raise ValueError("Missing pos/neg data in SDF file")
|
|
|
|
|
|
|
|
sdf_pos = sdf_data['pos'] # (N1, 4)
|
|
|
|
sdf_neg = sdf_data['neg'] # (N2, 4)
|
|
|
|
|
|
|
|
# 添加数据验证
|
|
|
|
if sdf_pos.shape[1] != 4 or sdf_neg.shape[1] != 4:
|
|
|
|
raise ValueError(f"Invalid SDF data shape: pos={sdf_pos.shape}, neg={sdf_neg.shape}")
|
|
|
|
|
|
|
|
# 确保正负样本均衡
|
|
|
|
if num_query_points // 2 > sdf_pos.shape[0]:
|
|
|
|
logger.warning(f"正样本过少,期望>{num_query_points // 2},实际:{sdf_pos.shape[0]}")
|
|
|
|
|
|
|
|
num_neg = min(num_query_points // 2, sdf_neg.shape[0])
|
|
|
|
num_pos = num_query_points - num_neg
|
|
|
|
|
|
|
|
# 随机采样正样本
|
|
|
|
if sdf_pos.shape[0] > num_pos:
|
|
|
|
pos_indices = np.random.choice(sdf_pos.shape[0], num_pos, replace=False)
|
|
|
|
sdf_pos = sdf_pos[pos_indices]
|
|
|
|
|
|
|
|
# 随机采样负样本
|
|
|
|
if sdf_neg.shape[0] > num_neg:
|
|
|
|
neg_indices = np.random.choice(sdf_neg.shape[0], num_neg, replace=False)
|
|
|
|
sdf_neg = sdf_neg[neg_indices]
|
|
|
|
|
|
|
|
# 合并数据
|
|
|
|
sdf_np = np.concatenate([sdf_pos, sdf_neg], axis=0)
|
|
|
|
|
|
|
|
# 再次随机打乱
|
|
|
|
np.random.shuffle(sdf_np)
|
|
|
|
|
|
|
|
# 如果总点数仍然超过最大限制,再次采样
|
|
|
|
if sdf_np.shape[0] > num_query_points:
|
|
|
|
indices = np.random.choice(sdf_np.shape[0], num_query_points, replace=False)
|
|
|
|
sdf_np = sdf_np[indices]
|
|
|
|
|
|
|
|
return torch.from_numpy(sdf_np.astype(np.float32))
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error loading SDF from {sdf_path}")
|
|
|
|
logger.error(f"Error type: {type(e).__name__}")
|
|
|
|
logger.error(f"Error message: {str(e)}")
|
|
|
|
raise
|
|
|
|
|
|
|
|
def test_dataset():
|
|
|
|
"""测试数据集功能"""
|
|
|
|
try:
|
|
|
|
# 获取配置
|
|
|
|
config = get_default_config()
|
|
|
|
|
|
|
|
# 定义预期的数据维度
|
|
|
|
expected_shapes = {
|
|
|
|
'edge_ncs': (config.data.max_face, config.data.max_edge, config.model.num_edge_points, 3),
|
|
|
|
'edge_pos': (config.data.max_face, config.data.max_edge, 6),
|
|
|
|
'edge_mask': (config.data.max_face, config.data.max_edge),
|
|
|
|
'surf_ncs': (config.data.max_face, config.model.num_surf_points, 3),
|
|
|
|
'surf_pos': (config.data.max_face, 6),
|
|
|
|
'vertex_pos': (config.data.max_face, config.data.max_edge, 2, 3),
|
|
|
|
'points': (config.data.num_query_points, 3),
|
|
|
|
'sdf': (config.data.num_query_points, 1)
|
|
|
|
}
|
|
|
|
|
|
|
|
logger.info("="*50)
|
|
|
|
logger.info("测试数据集")
|
|
|
|
logger.info(f"预期形状:")
|
|
|
|
for key, shape in expected_shapes.items():
|
|
|
|
logger.info(f" {key}: {shape}")
|
|
|
|
|
|
|
|
# 初始化数据集
|
|
|
|
dataset = BRepSDFDataset(
|
|
|
|
brep_dir=config.data.brep_dir,
|
|
|
|
sdf_dir=config.data.sdf_dir,
|
|
|
|
valid_data_dir=config.data.valid_data_dir,
|
|
|
|
split='train'
|
|
|
|
)
|
|
|
|
|
|
|
|
# 测试数据加载
|
|
|
|
logger.info("\n测试数据加载...")
|
|
|
|
sample = dataset[0]
|
|
|
|
|
|
|
|
# 检查数据类型和形状
|
|
|
|
logger.info("\n数据类型和形状检查:")
|
|
|
|
for key, value in sample.items():
|
|
|
|
if isinstance(value, torch.Tensor):
|
|
|
|
actual_shape = tuple(value.shape)
|
|
|
|
expected_shape = expected_shapes.get(key)
|
|
|
|
shape_match = "✓" if actual_shape == expected_shape else "✗"
|
|
|
|
|
|
|
|
logger.info(f"\n{key}:")
|
|
|
|
logger.info(f" 实际形状: {actual_shape}")
|
|
|
|
logger.info(f" 预期形状: {expected_shape}")
|
|
|
|
logger.info(f" 匹配状态: {shape_match}")
|
|
|
|
logger.info(f" 数据类型: {value.dtype}")
|
|
|
|
|
|
|
|
# 仅对浮点类型计算数值范围、均值和标准差
|
|
|
|
if value.dtype.is_floating_point:
|
|
|
|
logger.info(f" 数值范围: [{value.min():.3f}, {value.max():.3f}]")
|
|
|
|
logger.info(f" 均值: {value.mean():.3f}")
|
|
|
|
logger.info(f" 标准差: {value.std():.3f}")
|
|
|
|
|
|
|
|
if shape_match == "✗":
|
|
|
|
logger.warning(f" 形状不匹配: {key}")
|
|
|
|
if key in ['points', 'sdf']:
|
|
|
|
logger.warning(f" 查询点数量不一致,预期 {expected_shape[0]},实际 {actual_shape[0]}")
|
|
|
|
elif key in ['edge_ncs', 'edge_pos', 'edge_mask']:
|
|
|
|
logger.warning(f" 边数量不一致,预期 {expected_shape[1]},实际 {actual_shape[1]}")
|
|
|
|
elif key in ['surf_ncs', 'surf_pos']:
|
|
|
|
logger.warning(f" 面数量不一致,预期 {expected_shape[0]},实际 {actual_shape[0]}")
|
|
|
|
|
|
|
|
# 测试批处理
|
|
|
|
logger.info("\n测试批处理...")
|
|
|
|
batch_size = 4
|
|
|
|
dataloader = torch.utils.data.DataLoader(
|
|
|
|
dataset,
|
|
|
|
batch_size=batch_size,
|
|
|
|
shuffle=True,
|
|
|
|
num_workers=0
|
|
|
|
)
|
|
|
|
|
|
|
|
batch = next(iter(dataloader))
|
|
|
|
logger.info("\n批处理形状检查:")
|
|
|
|
for key, value in batch.items():
|
|
|
|
if isinstance(value, torch.Tensor):
|
|
|
|
batch_shape = tuple(value.shape)
|
|
|
|
expected_batch_shape = (batch_size,) + expected_shapes[key]
|
|
|
|
shape_match = "✓" if batch_shape == expected_batch_shape else "✗"
|
|
|
|
|
|
|
|
logger.info(f"\n{key}:")
|
|
|
|
logger.info(f" 实际形状: {batch_shape}")
|
|
|
|
logger.info(f" 预期形状: {expected_batch_shape}")
|
|
|
|
logger.info(f" 匹配状态: {shape_match}")
|
|
|
|
logger.info(f" 数据类型: {value.dtype}")
|
|
|
|
|
|
|
|
# 仅对浮点类型计算数值范围、均值和标准差
|
|
|
|
if value.dtype.is_floating_point:
|
|
|
|
logger.info(f" 数值范围: [{value.min():.3f}, {value.max():.3f}]")
|
|
|
|
logger.info(f" 均值: {value.mean():.3f}")
|
|
|
|
logger.info(f" 标准差: {value.std():.3f}")
|
|
|
|
|
|
|
|
if shape_match == "✗":
|
|
|
|
logger.warning(f" 批处理形状不匹配: {key}")
|
|
|
|
|
|
|
|
logger.info("\n测试完成!")
|
|
|
|
logger.info("="*50)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"测试过程中出错: {str(e)}")
|
|
|
|
raise
|
|
|
|
from collections import defaultdict
|
|
|
|
from tqdm import tqdm
|
|
|
|
def validate_dataset(split: str = 'train', num_samples: int = None):
|
|
|
|
"""全面验证数据集
|
|
|
|
|
|
|
|
Args:
|
|
|
|
split: 数据集分割 ('train', 'val', 'test')
|
|
|
|
num_samples: 要检查的样本数量,None表示检查所有样本
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
config = get_default_config()
|
|
|
|
logger.info(f"开始验证{split}数据集...")
|
|
|
|
|
|
|
|
# 初始化数据集
|
|
|
|
dataset = BRepSDFDataset(
|
|
|
|
brep_dir=config.data.brep_dir,
|
|
|
|
sdf_dir=config.data.sdf_dir,
|
|
|
|
valid_data_dir=config.data.valid_data_dir,
|
|
|
|
split='train'
|
|
|
|
)
|
|
|
|
|
|
|
|
total_samples = len(dataset) if num_samples is None else min(num_samples, len(dataset))
|
|
|
|
logger.info(f"总样本数: {total_samples}")
|
|
|
|
|
|
|
|
# 初始化统计信息
|
|
|
|
stats = {
|
|
|
|
'face_counts': [],
|
|
|
|
'edge_counts': [],
|
|
|
|
'vertex_counts': [],
|
|
|
|
'sdf_point_counts': [],
|
|
|
|
'invalid_samples': [],
|
|
|
|
'shape_mismatches': defaultdict(int),
|
|
|
|
'value_ranges': defaultdict(lambda: {'min': float('inf'), 'max': float('-inf')}),
|
|
|
|
'nan_counts': defaultdict(int),
|
|
|
|
'inf_counts': defaultdict(int)
|
|
|
|
}
|
|
|
|
|
|
|
|
# 遍历数据集
|
|
|
|
for idx in tqdm(range(total_samples), desc="验证数据"):
|
|
|
|
try:
|
|
|
|
sample = dataset[idx]
|
|
|
|
|
|
|
|
# 1. 检查数据完整性
|
|
|
|
required_keys = ['surf_ncs', 'surf_pos', 'edge_ncs', 'edge_pos',
|
|
|
|
'vertex_pos', 'points', 'sdf', 'edge_mask']
|
|
|
|
missing_keys = [key for key in required_keys if key not in sample]
|
|
|
|
if missing_keys:
|
|
|
|
stats['invalid_samples'].append((idx, f"缺少键: {missing_keys}"))
|
|
|
|
continue
|
|
|
|
|
|
|
|
# 2. 检查形状
|
|
|
|
expected_shapes = {
|
|
|
|
'surf_ncs': (config.data.max_face, config.model.num_surf_points, 3),
|
|
|
|
'surf_pos': (config.data.max_face, 6),
|
|
|
|
'edge_ncs': (config.data.max_face, config.data.max_edge, config.model.num_edge_points, 3),
|
|
|
|
'edge_pos': (config.data.max_face, config.data.max_edge, 6),
|
|
|
|
'edge_mask': (config.data.max_face, config.data.max_edge),
|
|
|
|
'vertex_pos': (config.data.max_face, config.data.max_edge, 2, 3),
|
|
|
|
'points': (config.data.num_query_points, 3),
|
|
|
|
'sdf': (config.data.num_query_points, 1)
|
|
|
|
}
|
|
|
|
|
|
|
|
for key, expected_shape in expected_shapes.items():
|
|
|
|
if key in sample:
|
|
|
|
actual_shape = tuple(sample[key].shape)
|
|
|
|
if actual_shape != expected_shape:
|
|
|
|
stats['shape_mismatches'][key] += 1
|
|
|
|
stats['invalid_samples'].append(
|
|
|
|
(idx, f"{key} 形状不匹配: 预期 {expected_shape}, 实际 {actual_shape}")
|
|
|
|
)
|
|
|
|
|
|
|
|
# 3. 检查数值范围和无效值
|
|
|
|
for key, tensor in sample.items():
|
|
|
|
if isinstance(tensor, torch.Tensor) and tensor.dtype.is_floating_point:
|
|
|
|
# 更新值范围
|
|
|
|
stats['value_ranges'][key]['min'] = min(stats['value_ranges'][key]['min'],
|
|
|
|
tensor.min().item())
|
|
|
|
stats['value_ranges'][key]['max'] = max(stats['value_ranges'][key]['max'],
|
|
|
|
tensor.max().item())
|
|
|
|
|
|
|
|
# 检查NaN和Inf
|
|
|
|
nan_count = torch.isnan(tensor).sum().item()
|
|
|
|
inf_count = torch.isinf(tensor).sum().item()
|
|
|
|
if nan_count > 0:
|
|
|
|
stats['nan_counts'][key] += nan_count
|
|
|
|
if inf_count > 0:
|
|
|
|
stats['inf_counts'][key] += inf_count
|
|
|
|
|
|
|
|
# 4. 收集统计信息
|
|
|
|
stats['face_counts'].append(sample['surf_ncs'].shape[0])
|
|
|
|
stats['edge_counts'].append(sample['edge_ncs'].shape[1])
|
|
|
|
stats['vertex_counts'].append(len(torch.unique(sample['vertex_pos'].reshape(-1, 3), dim=0)))
|
|
|
|
stats['sdf_point_counts'].append(sample['points'].shape[0])
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
stats['invalid_samples'].append((idx, str(e)))
|
|
|
|
|
|
|
|
# 输出统计结果
|
|
|
|
logger.info("\n=== 数据集验证结果 ===")
|
|
|
|
|
|
|
|
# 1. 基本统计信息
|
|
|
|
logger.info("\n基本统计信息:")
|
|
|
|
logger.info(f"总样本数: {total_samples}")
|
|
|
|
logger.info(f"有效样本数: {total_samples - len(stats['invalid_samples'])}")
|
|
|
|
logger.info(f"无效样本数: {len(stats['invalid_samples'])}")
|
|
|
|
|
|
|
|
# 2. 形状不匹配统计
|
|
|
|
if stats['shape_mismatches']:
|
|
|
|
logger.info("\n形状不匹配统计:")
|
|
|
|
for key, count in stats['shape_mismatches'].items():
|
|
|
|
logger.info(f" {key}: {count}个样本不匹配")
|
|
|
|
|
|
|
|
# 3. 数值范围统计
|
|
|
|
logger.info("\n数值范围统计:")
|
|
|
|
for key, ranges in stats['value_ranges'].items():
|
|
|
|
logger.info(f" {key}:")
|
|
|
|
logger.info(f" 最小值: {ranges['min']:.3f}")
|
|
|
|
logger.info(f" 最大值: {ranges['max']:.3f}")
|
|
|
|
|
|
|
|
# 4. 无效值统计
|
|
|
|
if sum(stats['nan_counts'].values()) > 0 or sum(stats['inf_counts'].values()) > 0:
|
|
|
|
logger.info("\n无效值统计:")
|
|
|
|
for key in stats['nan_counts'].keys():
|
|
|
|
if stats['nan_counts'][key] > 0:
|
|
|
|
logger.info(f" {key} 包含 {stats['nan_counts'][key]} 个 NaN 值")
|
|
|
|
for key in stats['inf_counts'].keys():
|
|
|
|
if stats['inf_counts'][key] > 0:
|
|
|
|
logger.info(f" {key} 包含 {stats['inf_counts'][key]} 个 Inf 值")
|
|
|
|
|
|
|
|
# 5. 几何特征统计
|
|
|
|
logger.info("\n几何特征统计:")
|
|
|
|
for name, values in [
|
|
|
|
('面数', stats['face_counts']),
|
|
|
|
('边数', stats['edge_counts']),
|
|
|
|
('顶点数', stats['vertex_counts']),
|
|
|
|
('SDF采样点数', stats['sdf_point_counts'])
|
|
|
|
]:
|
|
|
|
values = np.array(values)
|
|
|
|
logger.info(f" {name}:")
|
|
|
|
logger.info(f" 最小值: {np.min(values)}")
|
|
|
|
logger.info(f" 最大值: {np.max(values)}")
|
|
|
|
logger.info(f" 平均值: {np.mean(values):.2f}")
|
|
|
|
logger.info(f" 中位数: {np.median(values):.2f}")
|
|
|
|
logger.info(f" 标准差: {np.std(values):.2f}")
|
|
|
|
|
|
|
|
# 6. 输出无效样本详情
|
|
|
|
if stats['invalid_samples']:
|
|
|
|
logger.info("\n无效样本详情:")
|
|
|
|
for idx, error in stats['invalid_samples']:
|
|
|
|
logger.info(f" 样本 {idx}: {error}")
|
|
|
|
|
|
|
|
return stats
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"验证过程出错: {str(e)}")
|
|
|
|
raise
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
validate_dataset(split='train', num_samples=None) # 先测试100个样本
|