|
@ -107,7 +107,7 @@ def pad_repeat(x, max_len): |
|
|
x_repeat = np.concatenate([sep1, sep2], 0) |
|
|
x_repeat = np.concatenate([sep1, sep2], 0) |
|
|
return x_repeat |
|
|
return x_repeat |
|
|
|
|
|
|
|
|
|
|
|
''' |
|
|
def pad_zero(x, max_len, return_mask=False): |
|
|
def pad_zero(x, max_len, return_mask=False): |
|
|
keys = np.ones(len(x)) |
|
|
keys = np.ones(len(x)) |
|
|
padding = np.zeros((max_len-len(x))).astype(int) |
|
|
padding = np.zeros((max_len-len(x))).astype(int) |
|
@ -118,6 +118,48 @@ def pad_zero(x, max_len, return_mask=False): |
|
|
return x_padded, mask |
|
|
return x_padded, mask |
|
|
else: |
|
|
else: |
|
|
return x_padded |
|
|
return x_padded |
|
|
|
|
|
''' |
|
|
|
|
|
def pad_zero(x, max_len, return_mask=False): |
|
|
|
|
|
"""填充或截断数组到指定长度 |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
x: 输入数组 |
|
|
|
|
|
max_len: 目标长度 |
|
|
|
|
|
return_mask: 是否返回掩码 |
|
|
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
|
x_padded: 处理后的数组 |
|
|
|
|
|
mask: (可选) 掩码,标记实际数据(True)和填充(False) |
|
|
|
|
|
""" |
|
|
|
|
|
# 获取实际长度 |
|
|
|
|
|
actual_len = len(x) |
|
|
|
|
|
|
|
|
|
|
|
# 如果实际长度超过最大长度,进行截断 |
|
|
|
|
|
if actual_len > max_len: |
|
|
|
|
|
x = x[:max_len] |
|
|
|
|
|
mask = np.ones(max_len, dtype=bool) |
|
|
|
|
|
if return_mask: |
|
|
|
|
|
return x, mask |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
# 如果需要填充 |
|
|
|
|
|
if actual_len < max_len: |
|
|
|
|
|
# 创建掩码 |
|
|
|
|
|
keys = np.ones(actual_len) |
|
|
|
|
|
padding_mask = np.zeros(max_len - actual_len) |
|
|
|
|
|
mask = np.concatenate([keys, padding_mask]) == 1 |
|
|
|
|
|
|
|
|
|
|
|
# 填充数据 |
|
|
|
|
|
padding = np.zeros((max_len - actual_len, *x.shape[1:])) |
|
|
|
|
|
x_padded = np.concatenate([x, padding], axis=0) |
|
|
|
|
|
else: |
|
|
|
|
|
# 长度正好 |
|
|
|
|
|
mask = np.ones(max_len, dtype=bool) |
|
|
|
|
|
x_padded = x |
|
|
|
|
|
|
|
|
|
|
|
if return_mask: |
|
|
|
|
|
return x_padded, mask |
|
|
|
|
|
return x_padded |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_3d_bbox(ax, min_corner, max_corner, color='r'): |
|
|
def plot_3d_bbox(ax, min_corner, max_corner, color='r'): |
|
|