Browse Source

feat: script read npz

main
mckay 3 months ago
parent
commit
59f17a7199
  1. 118
      brep2sdf/scripts/read_npz.py

118
brep2sdf/scripts/read_npz.py

@ -0,0 +1,118 @@
import numpy as np
import argparse
from brep2sdf.utils.logger import logger
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
def view_npz_file(file_path: str, save_plot: bool = False):
"""查看指定的npz文件内容
Args:
file_path: npz文件路径
save_plot: 是否保存可视化图像
"""
try:
# 加载npz文件
data = np.load(file_path)
# 打印基本信息
logger.info(f"\n=== NPZ文件内容分析 ===")
logger.info(f"文件路径: {file_path}")
logger.info(f"\n包含的数组:")
# 分析每个数组
for key in data.files:
array = data[key]
logger.info(f"\n{key}:")
logger.info(f" 形状: {array.shape}")
logger.info(f" 类型: {array.dtype}")
if np.issubdtype(array.dtype, np.number):
logger.info(f" 最小值: {array.min()}")
logger.info(f" 最大值: {array.max()}")
logger.info(f" 均值: {array.mean():.4f}")
logger.info(f" 标准差: {array.std():.4f}")
logger.info(f" 非零元素数量: {np.count_nonzero(array)}")
# 如果是点云数据(形状为 Nx3 或 Nx4),绘制3D散点图
if len(array.shape) == 2 and array.shape[1] in [3, 4]:
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
# 绘制点云
points = array[:, :3] # 取前三列作为坐标
if array.shape[1] == 4:
# 如果有第四列,用它来设置颜色
colors = array[:, 3]
scatter = ax.scatter(points[:, 0], points[:, 1], points[:, 2],
c=colors, cmap='viridis')
plt.colorbar(scatter)
else:
ax.scatter(points[:, 0], points[:, 1], points[:, 2])
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title(f'{key} 点云可视化')
if save_plot:
plt.savefig(f'{key}_visualization.png')
else:
plt.show()
plt.close()
# 如果文件包含pos和neg数组,打印它们的统计信息
if 'pos' in data.files and 'neg' in data.files:
logger.info("\n正负样本统计:")
logger.info(f"正样本数量: {len(data['pos'])}")
logger.info(f"负样本数量: {len(data['neg'])}")
# 计算正负样本的SDF值分布
if data['pos'].shape[1] == 4:
pos_sdf = data['pos'][:, 3]
neg_sdf = data['neg'][:, 3]
logger.info("\nSDF值统计:")
logger.info("正样本:")
logger.info(f" 最小值: {pos_sdf.min():.4f}")
logger.info(f" 最大值: {pos_sdf.max():.4f}")
logger.info(f" 均值: {pos_sdf.mean():.4f}")
logger.info(f" 标准差: {pos_sdf.std():.4f}")
logger.info("负样本:")
logger.info(f" 最小值: {neg_sdf.min():.4f}")
logger.info(f" 最大值: {neg_sdf.max():.4f}")
logger.info(f" 均值: {neg_sdf.mean():.4f}")
logger.info(f" 标准差: {neg_sdf.std():.4f}")
# 绘制SDF值分布直方图
plt.figure(figsize=(12, 6))
plt.hist(pos_sdf, bins=50, alpha=0.5, label='正样本')
plt.hist(neg_sdf, bins=50, alpha=0.5, label='负样本')
plt.xlabel('SDF值')
plt.ylabel('频率')
plt.title('SDF值分布')
plt.legend()
if save_plot:
plt.savefig('sdf_distribution.png')
else:
plt.show()
plt.close()
except Exception as e:
logger.error(f"查看文件时出错: {str(e)}")
raise
def main():
parser = argparse.ArgumentParser(description='查看npz文件内容')
parser.add_argument('file', type=str, help='npz文件路径')
parser.add_argument('--save-plot', action='store_true', help='是否保存可视化图像')
args = parser.parse_args()
view_npz_file(args.file, args.save_plot)
if __name__ == '__main__':
# 直接查看指定文件
file_path = '/home/wch/brep2sdf/test_data/sdf/train/lamp_0582.npz'
view_npz_file(file_path)
Loading…
Cancel
Save