From 9e6020adafeb98b504ed1ca75c62daedf7efd3c7 Mon Sep 17 00:00:00 2001 From: mckay Date: Mon, 18 Nov 2024 22:58:12 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D:=20=E6=94=B9=E8=BF=9B?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E9=A2=84=E5=A4=84=E7=90=86=E4=B8=AD=E7=9A=84?= =?UTF-8?q?=E6=95=B0=E7=BB=84=E5=A1=AB=E5=85=85=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 修复了pad_zero函数中的负数维度问题 - 添加了对超长输入的截断处理 - 优化了掩码生成逻辑 - 保持多维数组的维度一致性 2. 主要改动: - 处理数组长度 > max_len:进行截断 - 处理数组长度 < max_len:进行填充 - 处理数组长度 = max_len:直接返回 - 改进布尔掩码的生成方式 --- brep2sdf/data/utils.py | 46 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/brep2sdf/data/utils.py b/brep2sdf/data/utils.py index 40948ae..159e289 100644 --- a/brep2sdf/data/utils.py +++ b/brep2sdf/data/utils.py @@ -107,7 +107,7 @@ def pad_repeat(x, max_len): x_repeat = np.concatenate([sep1, sep2], 0) return x_repeat - +''' def pad_zero(x, max_len, return_mask=False): keys = np.ones(len(x)) padding = np.zeros((max_len-len(x))).astype(int) @@ -118,7 +118,49 @@ def pad_zero(x, max_len, return_mask=False): return x_padded, mask else: return x_padded - +''' +def pad_zero(x, max_len, return_mask=False): + """填充或截断数组到指定长度 + + Args: + x: 输入数组 + max_len: 目标长度 + return_mask: 是否返回掩码 + + Returns: + x_padded: 处理后的数组 + mask: (可选) 掩码,标记实际数据(True)和填充(False) + """ + # 获取实际长度 + actual_len = len(x) + + # 如果实际长度超过最大长度,进行截断 + if actual_len > max_len: + x = x[:max_len] + mask = np.ones(max_len, dtype=bool) + if return_mask: + return x, mask + return x + + # 如果需要填充 + if actual_len < max_len: + # 创建掩码 + keys = np.ones(actual_len) + padding_mask = np.zeros(max_len - actual_len) + mask = np.concatenate([keys, padding_mask]) == 1 + + # 填充数据 + padding = np.zeros((max_len - actual_len, *x.shape[1:])) + x_padded = np.concatenate([x, padding], axis=0) + else: + # 长度正好 + mask = np.ones(max_len, dtype=bool) + x_padded = x + + if return_mask: + return x_padded, mask + return x_padded + def plot_3d_bbox(ax, min_corner, max_corner, color='r'): """