diff --git a/brep2sdf/data/utils.py b/brep2sdf/data/utils.py index 40948ae..159e289 100644 --- a/brep2sdf/data/utils.py +++ b/brep2sdf/data/utils.py @@ -107,7 +107,7 @@ def pad_repeat(x, max_len): x_repeat = np.concatenate([sep1, sep2], 0) return x_repeat - +''' def pad_zero(x, max_len, return_mask=False): keys = np.ones(len(x)) padding = np.zeros((max_len-len(x))).astype(int) @@ -118,7 +118,49 @@ def pad_zero(x, max_len, return_mask=False): return x_padded, mask else: return x_padded - +''' +def pad_zero(x, max_len, return_mask=False): + """填充或截断数组到指定长度 + + Args: + x: 输入数组 + max_len: 目标长度 + return_mask: 是否返回掩码 + + Returns: + x_padded: 处理后的数组 + mask: (可选) 掩码,标记实际数据(True)和填充(False) + """ + # 获取实际长度 + actual_len = len(x) + + # 如果实际长度超过最大长度,进行截断 + if actual_len > max_len: + x = x[:max_len] + mask = np.ones(max_len, dtype=bool) + if return_mask: + return x, mask + return x + + # 如果需要填充 + if actual_len < max_len: + # 创建掩码 + keys = np.ones(actual_len) + padding_mask = np.zeros(max_len - actual_len) + mask = np.concatenate([keys, padding_mask]) == 1 + + # 填充数据 + padding = np.zeros((max_len - actual_len, *x.shape[1:])) + x_padded = np.concatenate([x, padding], axis=0) + else: + # 长度正好 + mask = np.ones(max_len, dtype=bool) + x_padded = x + + if return_mask: + return x_padded, mask + return x_padded + def plot_3d_bbox(ax, min_corner, max_corner, color='r'): """