You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

202 lines
7.8 KiB

import pickle
import numpy as np
from scipy.stats import entropy
import os
import concurrent.futures
from tqdm import tqdm
import pandas as pd
# 查看保存的数据结构:
def inspect_data(pkl_file):
"""检查并显示pickle文件中的数据结构"""
with open(pkl_file, 'rb') as f:
data = pickle.load(f)
print("数据结构概览:")
print("=" * 50)
# 遍历所有键值对
for key, value in data.items():
print(f"\n键名: {key}")
print("-" * 30)
if isinstance(value, np.ndarray):
print(f"类型: numpy.ndarray")
print(f"形状: {value.shape}")
print(f"数据类型: {value.dtype}")
if value.size > 0:
if value.dtype == object:
print("第一个元素形状:", value[0].shape if hasattr(value[0], 'shape') else len(value[0]))
else:
print("数值范围:", f"[{value.min()}, {value.max()}]")
print("样本:")
if value.dtype == object:
print(value[0][:3] if value.size > 0 else "")
else:
print(value[:3] if value.size > 0 else "")
else:
print(f"类型: {type(value)}")
print(f"值: {value}")
def satistic(pkl_file, n):
'''
统计 surf_wcs 分布
dict: 包含以下键值对的字典:
# 几何数据
'surf_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示面的点云坐标
'edge_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示边的采样点坐标
'surf_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示归一化后的面点云
'edge_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示归一化后的边采样点
'corner_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 2, 3)的数组,表示每条边的两个端点坐标
'corner_unique': np.ndarray(dtype=float32) # 形状为(num_vertices, 3)的数组,表示所有顶点的唯一坐标,num_vertices <= num_edges * 2
# 拓扑关系
'edgeFace_adj': np.ndarray(dtype=int32) # 形状为(num_edges, num_faces)的数组,表示边-面邻接关系
'edgeCorner_adj': np.ndarray(dtype=int32) # 形状为(num_edges, 2)的数组,表示边-顶点邻接关系
'faceEdge_adj': np.ndarray(dtype=int32) # 形状为(num_faces, num_edges)的数组,表示面-边邻接关系
# 包围盒数据
'surf_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_faces, 6)的数组,表示每个面的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
'edge_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 6)的数组,表示每条边的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
'''
# 加载pkl文件
with open(pkl_file, 'rb') as f:
data = pickle.load(f)
# 获取surf_wcs数据
surf_wcs = data['surf_wcs']
all_points = np.concatenate(surf_wcs) # 将所有面的点云合并
# 计算全局边界框
min_coords = np.min(all_points, axis=0)
max_coords = np.max(all_points, axis=0)
# 将空间划分为2^3^n个网格
num_bins = 2 ** n
bins = np.linspace(min_coords, max_coords, num_bins + 1, axis=0)
# 初始化分布矩阵
distribution = np.zeros((num_bins, num_bins, num_bins), dtype=int)
# 统计每个网格中的点云数量
for point in all_points:
# 计算点所在的网格索引
x_idx = np.searchsorted(bins[:, 0], point[0]) - 1
y_idx = np.searchsorted(bins[:, 1], point[1]) - 1
z_idx = np.searchsorted(bins[:, 2], point[2]) - 1
# 确保索引在有效范围内
x_idx = np.clip(x_idx, 0, num_bins - 1)
y_idx = np.clip(y_idx, 0, num_bins - 1)
z_idx = np.clip(z_idx, 0, num_bins - 1)
# 更新分布矩阵
distribution[x_idx, y_idx, z_idx] += 1
# 计算熵
prob_distribution = distribution / np.sum(distribution)
entropy_value = entropy(prob_distribution.flatten())
# 计算空格占比
empty_cells = np.sum(distribution == 0)
total_cells = num_bins ** 3
empty_ratio = empty_cells / total_cells
# 打印统计信息
'''
print(f"Total points: {len(all_points)}")
print(f"Min coordinates: {min_coords}")
print(f"Max coordinates: {max_coords}")
print(f"Distribution in {num_bins}x{num_bins}x{num_bins} grid:")
print(distribution)
print(f"Entropy: {entropy_value}")
print(f"Empty cell ratio: {empty_ratio:.2%}")
'''
# 返回统计结果
return {
'total_points': len(all_points),
'min_coordinates': min_coords,
'max_coordinates': max_coords,
'distribution': distribution,
'entropy': entropy_value,
'empty_ratio': empty_ratio
}
def process_files(pkl_files, n, output_csv='results.csv'):
"""
并行处理多个文件并将结果保存为CSV
参数:
pkl_files: pkl文件路径列表
n: 空间划分的维度
output_csv: 输出CSV文件路径
"""
results = []
# 使用线程池并行处理
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(satistic, pkl_file, n) for pkl_file in pkl_files]
for future in concurrent.futures.as_completed(futures):
try:
result = future.result()
results.append(result)
except Exception as e:
print(f"Error processing file: {e}")
# 将结果转换为DataFrame并保存为CSV
df = pd.DataFrame(results)
df.to_csv(output_csv, index=False)
print(f"Results saved to {output_csv}")
def process_files_with_n_values(pkl_files, n_values, output_dir='results'):
"""
测试不同n值的效果,并将结果保存到不同的CSV文件中
参数:
pkl_files: pkl文件路径列表
n_values: 需要测试的n值列表
output_dir: 输出结果目录
"""
import os
# 创建输出目录
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# 初始化统计结果
n_stats = []
# 遍历不同的n值
for n in tqdm(n_values, desc="Processing n values"):
output_csv = os.path.join(output_dir, f'statistics_n{n}.csv')
print(f"Processing with n={n}...")
process_files(pkl_files, n, output_csv)
# 读取CSV文件并计算平均值
df = pd.read_csv(output_csv)
avg_entropy = df['entropy'].mean()
avg_empty_ratio = df['empty_ratio'].mean()
# 保存当前n值的统计结果
n_stats.append({
'n': n,
'avg_entropy': avg_entropy,
'avg_empty_ratio': avg_empty_ratio
})
print(f"Results for n={n} saved to {output_csv}")
# 将n值的统计结果保存为CSV
n_stats_df = pd.DataFrame(n_stats)
n_stats_csv = os.path.join(output_dir, 'n_averages.csv')
n_stats_df.to_csv(n_stats_csv, index=False)
print(f"Average statistics for n values saved to {n_stats_csv}")
if __name__ == "__main__":
# pkl_file = "/home/wch/brep2sdf/test_data/pkl/train/bathtub_0004.pkl" # 替换为你的文件路径
directory = "/home/wch/brep2sdf/test_data/pkl/test"
# 获取目录下所有.pkl文件
pkl_files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.pkl')]
# 并行处理文件并保存结果
n_values = [2, 3, 4, 5,6,7] # 可以根据需要调整
process_files_with_n_values(pkl_files, n_values, output_dir='n_test_results')