From 59f17a7199b00b4f16ed56bf46e3164029449106 Mon Sep 17 00:00:00 2001 From: mckay Date: Thu, 5 Dec 2024 01:15:27 +0800 Subject: [PATCH] feat: script read npz --- brep2sdf/scripts/read_npz.py | 118 +++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 brep2sdf/scripts/read_npz.py diff --git a/brep2sdf/scripts/read_npz.py b/brep2sdf/scripts/read_npz.py new file mode 100644 index 0000000..a6e9c28 --- /dev/null +++ b/brep2sdf/scripts/read_npz.py @@ -0,0 +1,118 @@ +import numpy as np +import argparse +from brep2sdf.utils.logger import logger +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D + +def view_npz_file(file_path: str, save_plot: bool = False): + """查看指定的npz文件内容 + + Args: + file_path: npz文件路径 + save_plot: 是否保存可视化图像 + """ + try: + # 加载npz文件 + data = np.load(file_path) + + # 打印基本信息 + logger.info(f"\n=== NPZ文件内容分析 ===") + logger.info(f"文件路径: {file_path}") + logger.info(f"\n包含的数组:") + + # 分析每个数组 + for key in data.files: + array = data[key] + logger.info(f"\n{key}:") + logger.info(f" 形状: {array.shape}") + logger.info(f" 类型: {array.dtype}") + + if np.issubdtype(array.dtype, np.number): + logger.info(f" 最小值: {array.min()}") + logger.info(f" 最大值: {array.max()}") + logger.info(f" 均值: {array.mean():.4f}") + logger.info(f" 标准差: {array.std():.4f}") + logger.info(f" 非零元素数量: {np.count_nonzero(array)}") + + # 如果是点云数据(形状为 Nx3 或 Nx4),绘制3D散点图 + if len(array.shape) == 2 and array.shape[1] in [3, 4]: + fig = plt.figure(figsize=(10, 10)) + ax = fig.add_subplot(111, projection='3d') + + # 绘制点云 + points = array[:, :3] # 取前三列作为坐标 + if array.shape[1] == 4: + # 如果有第四列,用它来设置颜色 + colors = array[:, 3] + scatter = ax.scatter(points[:, 0], points[:, 1], points[:, 2], + c=colors, cmap='viridis') + plt.colorbar(scatter) + else: + ax.scatter(points[:, 0], points[:, 1], points[:, 2]) + + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + ax.set_title(f'{key} 点云可视化') + + if save_plot: + plt.savefig(f'{key}_visualization.png') + else: + plt.show() + plt.close() + + # 如果文件包含pos和neg数组,打印它们的统计信息 + if 'pos' in data.files and 'neg' in data.files: + logger.info("\n正负样本统计:") + logger.info(f"正样本数量: {len(data['pos'])}") + logger.info(f"负样本数量: {len(data['neg'])}") + + # 计算正负样本的SDF值分布 + if data['pos'].shape[1] == 4: + pos_sdf = data['pos'][:, 3] + neg_sdf = data['neg'][:, 3] + + logger.info("\nSDF值统计:") + logger.info("正样本:") + logger.info(f" 最小值: {pos_sdf.min():.4f}") + logger.info(f" 最大值: {pos_sdf.max():.4f}") + logger.info(f" 均值: {pos_sdf.mean():.4f}") + logger.info(f" 标准差: {pos_sdf.std():.4f}") + + logger.info("负样本:") + logger.info(f" 最小值: {neg_sdf.min():.4f}") + logger.info(f" 最大值: {neg_sdf.max():.4f}") + logger.info(f" 均值: {neg_sdf.mean():.4f}") + logger.info(f" 标准差: {neg_sdf.std():.4f}") + + # 绘制SDF值分布直方图 + plt.figure(figsize=(12, 6)) + plt.hist(pos_sdf, bins=50, alpha=0.5, label='正样本') + plt.hist(neg_sdf, bins=50, alpha=0.5, label='负样本') + plt.xlabel('SDF值') + plt.ylabel('频率') + plt.title('SDF值分布') + plt.legend() + + if save_plot: + plt.savefig('sdf_distribution.png') + else: + plt.show() + plt.close() + + except Exception as e: + logger.error(f"查看文件时出错: {str(e)}") + raise + +def main(): + parser = argparse.ArgumentParser(description='查看npz文件内容') + parser.add_argument('file', type=str, help='npz文件路径') + parser.add_argument('--save-plot', action='store_true', help='是否保存可视化图像') + + args = parser.parse_args() + view_npz_file(args.file, args.save_plot) + +if __name__ == '__main__': + # 直接查看指定文件 + file_path = '/home/wch/brep2sdf/test_data/sdf/train/lamp_0582.npz' + view_npz_file(file_path) \ No newline at end of file