1 changed files with 118 additions and 0 deletions
@ -0,0 +1,118 @@ |
|||
import numpy as np |
|||
import argparse |
|||
from brep2sdf.utils.logger import logger |
|||
import matplotlib.pyplot as plt |
|||
from mpl_toolkits.mplot3d import Axes3D |
|||
|
|||
def view_npz_file(file_path: str, save_plot: bool = False): |
|||
"""查看指定的npz文件内容 |
|||
|
|||
Args: |
|||
file_path: npz文件路径 |
|||
save_plot: 是否保存可视化图像 |
|||
""" |
|||
try: |
|||
# 加载npz文件 |
|||
data = np.load(file_path) |
|||
|
|||
# 打印基本信息 |
|||
logger.info(f"\n=== NPZ文件内容分析 ===") |
|||
logger.info(f"文件路径: {file_path}") |
|||
logger.info(f"\n包含的数组:") |
|||
|
|||
# 分析每个数组 |
|||
for key in data.files: |
|||
array = data[key] |
|||
logger.info(f"\n{key}:") |
|||
logger.info(f" 形状: {array.shape}") |
|||
logger.info(f" 类型: {array.dtype}") |
|||
|
|||
if np.issubdtype(array.dtype, np.number): |
|||
logger.info(f" 最小值: {array.min()}") |
|||
logger.info(f" 最大值: {array.max()}") |
|||
logger.info(f" 均值: {array.mean():.4f}") |
|||
logger.info(f" 标准差: {array.std():.4f}") |
|||
logger.info(f" 非零元素数量: {np.count_nonzero(array)}") |
|||
|
|||
# 如果是点云数据(形状为 Nx3 或 Nx4),绘制3D散点图 |
|||
if len(array.shape) == 2 and array.shape[1] in [3, 4]: |
|||
fig = plt.figure(figsize=(10, 10)) |
|||
ax = fig.add_subplot(111, projection='3d') |
|||
|
|||
# 绘制点云 |
|||
points = array[:, :3] # 取前三列作为坐标 |
|||
if array.shape[1] == 4: |
|||
# 如果有第四列,用它来设置颜色 |
|||
colors = array[:, 3] |
|||
scatter = ax.scatter(points[:, 0], points[:, 1], points[:, 2], |
|||
c=colors, cmap='viridis') |
|||
plt.colorbar(scatter) |
|||
else: |
|||
ax.scatter(points[:, 0], points[:, 1], points[:, 2]) |
|||
|
|||
ax.set_xlabel('X') |
|||
ax.set_ylabel('Y') |
|||
ax.set_zlabel('Z') |
|||
ax.set_title(f'{key} 点云可视化') |
|||
|
|||
if save_plot: |
|||
plt.savefig(f'{key}_visualization.png') |
|||
else: |
|||
plt.show() |
|||
plt.close() |
|||
|
|||
# 如果文件包含pos和neg数组,打印它们的统计信息 |
|||
if 'pos' in data.files and 'neg' in data.files: |
|||
logger.info("\n正负样本统计:") |
|||
logger.info(f"正样本数量: {len(data['pos'])}") |
|||
logger.info(f"负样本数量: {len(data['neg'])}") |
|||
|
|||
# 计算正负样本的SDF值分布 |
|||
if data['pos'].shape[1] == 4: |
|||
pos_sdf = data['pos'][:, 3] |
|||
neg_sdf = data['neg'][:, 3] |
|||
|
|||
logger.info("\nSDF值统计:") |
|||
logger.info("正样本:") |
|||
logger.info(f" 最小值: {pos_sdf.min():.4f}") |
|||
logger.info(f" 最大值: {pos_sdf.max():.4f}") |
|||
logger.info(f" 均值: {pos_sdf.mean():.4f}") |
|||
logger.info(f" 标准差: {pos_sdf.std():.4f}") |
|||
|
|||
logger.info("负样本:") |
|||
logger.info(f" 最小值: {neg_sdf.min():.4f}") |
|||
logger.info(f" 最大值: {neg_sdf.max():.4f}") |
|||
logger.info(f" 均值: {neg_sdf.mean():.4f}") |
|||
logger.info(f" 标准差: {neg_sdf.std():.4f}") |
|||
|
|||
# 绘制SDF值分布直方图 |
|||
plt.figure(figsize=(12, 6)) |
|||
plt.hist(pos_sdf, bins=50, alpha=0.5, label='正样本') |
|||
plt.hist(neg_sdf, bins=50, alpha=0.5, label='负样本') |
|||
plt.xlabel('SDF值') |
|||
plt.ylabel('频率') |
|||
plt.title('SDF值分布') |
|||
plt.legend() |
|||
|
|||
if save_plot: |
|||
plt.savefig('sdf_distribution.png') |
|||
else: |
|||
plt.show() |
|||
plt.close() |
|||
|
|||
except Exception as e: |
|||
logger.error(f"查看文件时出错: {str(e)}") |
|||
raise |
|||
|
|||
def main(): |
|||
parser = argparse.ArgumentParser(description='查看npz文件内容') |
|||
parser.add_argument('file', type=str, help='npz文件路径') |
|||
parser.add_argument('--save-plot', action='store_true', help='是否保存可视化图像') |
|||
|
|||
args = parser.parse_args() |
|||
view_npz_file(args.file, args.save_plot) |
|||
|
|||
if __name__ == '__main__': |
|||
# 直接查看指定文件 |
|||
file_path = '/home/wch/brep2sdf/test_data/sdf/train/lamp_0582.npz' |
|||
view_npz_file(file_path) |
Loading…
Reference in new issue