You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
202 lines
7.8 KiB
202 lines
7.8 KiB
import pickle
|
|
import numpy as np
|
|
from scipy.stats import entropy
|
|
import os
|
|
import concurrent.futures
|
|
from tqdm import tqdm
|
|
import pandas as pd
|
|
# 查看保存的数据结构:
|
|
|
|
def inspect_data(pkl_file):
|
|
"""检查并显示pickle文件中的数据结构"""
|
|
with open(pkl_file, 'rb') as f:
|
|
data = pickle.load(f)
|
|
print("数据结构概览:")
|
|
print("=" * 50)
|
|
|
|
# 遍历所有键值对
|
|
for key, value in data.items():
|
|
print(f"\n键名: {key}")
|
|
print("-" * 30)
|
|
|
|
if isinstance(value, np.ndarray):
|
|
print(f"类型: numpy.ndarray")
|
|
print(f"形状: {value.shape}")
|
|
print(f"数据类型: {value.dtype}")
|
|
if value.size > 0:
|
|
if value.dtype == object:
|
|
print("第一个元素形状:", value[0].shape if hasattr(value[0], 'shape') else len(value[0]))
|
|
else:
|
|
print("数值范围:", f"[{value.min()}, {value.max()}]")
|
|
print("样本:")
|
|
if value.dtype == object:
|
|
print(value[0][:3] if value.size > 0 else "空")
|
|
else:
|
|
print(value[:3] if value.size > 0 else "空")
|
|
else:
|
|
print(f"类型: {type(value)}")
|
|
print(f"值: {value}")
|
|
|
|
def satistic(pkl_file, n):
|
|
'''
|
|
统计 surf_wcs 分布
|
|
dict: 包含以下键值对的字典:
|
|
# 几何数据
|
|
'surf_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示面的点云坐标
|
|
'edge_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示边的采样点坐标
|
|
'surf_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示归一化后的面点云
|
|
'edge_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示归一化后的边采样点
|
|
'corner_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 2, 3)的数组,表示每条边的两个端点坐标
|
|
'corner_unique': np.ndarray(dtype=float32) # 形状为(num_vertices, 3)的数组,表示所有顶点的唯一坐标,num_vertices <= num_edges * 2
|
|
|
|
# 拓扑关系
|
|
'edgeFace_adj': np.ndarray(dtype=int32) # 形状为(num_edges, num_faces)的数组,表示边-面邻接关系
|
|
'edgeCorner_adj': np.ndarray(dtype=int32) # 形状为(num_edges, 2)的数组,表示边-顶点邻接关系
|
|
'faceEdge_adj': np.ndarray(dtype=int32) # 形状为(num_faces, num_edges)的数组,表示面-边邻接关系
|
|
|
|
# 包围盒数据
|
|
'surf_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_faces, 6)的数组,表示每个面的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
|
|
'edge_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 6)的数组,表示每条边的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
|
|
'''
|
|
# 加载pkl文件
|
|
with open(pkl_file, 'rb') as f:
|
|
data = pickle.load(f)
|
|
|
|
# 获取surf_wcs数据
|
|
surf_wcs = data['surf_wcs']
|
|
all_points = np.concatenate(surf_wcs) # 将所有面的点云合并
|
|
|
|
# 计算全局边界框
|
|
min_coords = np.min(all_points, axis=0)
|
|
max_coords = np.max(all_points, axis=0)
|
|
|
|
# 将空间划分为2^3^n个网格
|
|
num_bins = 2 ** n
|
|
bins = np.linspace(min_coords, max_coords, num_bins + 1, axis=0)
|
|
|
|
# 初始化分布矩阵
|
|
distribution = np.zeros((num_bins, num_bins, num_bins), dtype=int)
|
|
|
|
# 统计每个网格中的点云数量
|
|
for point in all_points:
|
|
# 计算点所在的网格索引
|
|
x_idx = np.searchsorted(bins[:, 0], point[0]) - 1
|
|
y_idx = np.searchsorted(bins[:, 1], point[1]) - 1
|
|
z_idx = np.searchsorted(bins[:, 2], point[2]) - 1
|
|
# 确保索引在有效范围内
|
|
x_idx = np.clip(x_idx, 0, num_bins - 1)
|
|
y_idx = np.clip(y_idx, 0, num_bins - 1)
|
|
z_idx = np.clip(z_idx, 0, num_bins - 1)
|
|
# 更新分布矩阵
|
|
distribution[x_idx, y_idx, z_idx] += 1
|
|
|
|
# 计算熵
|
|
prob_distribution = distribution / np.sum(distribution)
|
|
entropy_value = entropy(prob_distribution.flatten())
|
|
|
|
# 计算空格占比
|
|
empty_cells = np.sum(distribution == 0)
|
|
total_cells = num_bins ** 3
|
|
empty_ratio = empty_cells / total_cells
|
|
|
|
# 打印统计信息
|
|
'''
|
|
print(f"Total points: {len(all_points)}")
|
|
print(f"Min coordinates: {min_coords}")
|
|
print(f"Max coordinates: {max_coords}")
|
|
print(f"Distribution in {num_bins}x{num_bins}x{num_bins} grid:")
|
|
print(distribution)
|
|
print(f"Entropy: {entropy_value}")
|
|
print(f"Empty cell ratio: {empty_ratio:.2%}")
|
|
'''
|
|
# 返回统计结果
|
|
return {
|
|
'total_points': len(all_points),
|
|
'min_coordinates': min_coords,
|
|
'max_coordinates': max_coords,
|
|
'distribution': distribution,
|
|
'entropy': entropy_value,
|
|
'empty_ratio': empty_ratio
|
|
}
|
|
|
|
def process_files(pkl_files, n, output_csv='results.csv'):
|
|
"""
|
|
并行处理多个文件并将结果保存为CSV
|
|
|
|
参数:
|
|
pkl_files: pkl文件路径列表
|
|
n: 空间划分的维度
|
|
output_csv: 输出CSV文件路径
|
|
"""
|
|
results = []
|
|
|
|
# 使用线程池并行处理
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
futures = [executor.submit(satistic, pkl_file, n) for pkl_file in pkl_files]
|
|
|
|
for future in concurrent.futures.as_completed(futures):
|
|
try:
|
|
result = future.result()
|
|
results.append(result)
|
|
except Exception as e:
|
|
print(f"Error processing file: {e}")
|
|
|
|
# 将结果转换为DataFrame并保存为CSV
|
|
df = pd.DataFrame(results)
|
|
df.to_csv(output_csv, index=False)
|
|
print(f"Results saved to {output_csv}")
|
|
|
|
def process_files_with_n_values(pkl_files, n_values, output_dir='results'):
|
|
"""
|
|
测试不同n值的效果,并将结果保存到不同的CSV文件中
|
|
|
|
参数:
|
|
pkl_files: pkl文件路径列表
|
|
n_values: 需要测试的n值列表
|
|
output_dir: 输出结果目录
|
|
"""
|
|
import os
|
|
|
|
# 创建输出目录
|
|
if not os.path.exists(output_dir):
|
|
os.makedirs(output_dir)
|
|
|
|
# 初始化统计结果
|
|
n_stats = []
|
|
|
|
# 遍历不同的n值
|
|
for n in tqdm(n_values, desc="Processing n values"):
|
|
output_csv = os.path.join(output_dir, f'statistics_n{n}.csv')
|
|
print(f"Processing with n={n}...")
|
|
process_files(pkl_files, n, output_csv)
|
|
|
|
# 读取CSV文件并计算平均值
|
|
df = pd.read_csv(output_csv)
|
|
avg_entropy = df['entropy'].mean()
|
|
avg_empty_ratio = df['empty_ratio'].mean()
|
|
|
|
# 保存当前n值的统计结果
|
|
n_stats.append({
|
|
'n': n,
|
|
'avg_entropy': avg_entropy,
|
|
'avg_empty_ratio': avg_empty_ratio
|
|
})
|
|
print(f"Results for n={n} saved to {output_csv}")
|
|
|
|
# 将n值的统计结果保存为CSV
|
|
n_stats_df = pd.DataFrame(n_stats)
|
|
n_stats_csv = os.path.join(output_dir, 'n_averages.csv')
|
|
n_stats_df.to_csv(n_stats_csv, index=False)
|
|
print(f"Average statistics for n values saved to {n_stats_csv}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# pkl_file = "/home/wch/brep2sdf/test_data/pkl/train/bathtub_0004.pkl" # 替换为你的文件路径
|
|
directory = "/home/wch/brep2sdf/test_data/pkl/test"
|
|
|
|
# 获取目录下所有.pkl文件
|
|
pkl_files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.pkl')]
|
|
|
|
# 并行处理文件并保存结果
|
|
n_values = [2, 3, 4, 5,6,7] # 可以根据需要调整
|
|
process_files_with_n_values(pkl_files, n_values, output_dir='n_test_results')
|