Browse Source

修复: 改进数据预处理中的数组填充逻辑

1. 修复了pad_zero函数中的负数维度问题
- 添加了对超长输入的截断处理
- 优化了掩码生成逻辑
- 保持多维数组的维度一致性

2. 主要改动:
- 处理数组长度 > max_len:进行截断
- 处理数组长度 < max_len:进行填充
- 处理数组长度 = max_len:直接返回
- 改进布尔掩码的生成方式
final
mckay 7 months ago
parent
commit
bee51c9dc0
  1. 46
      brep2sdf/data/utils.py

46
brep2sdf/data/utils.py

@ -107,7 +107,7 @@ def pad_repeat(x, max_len):
x_repeat = np.concatenate([sep1, sep2], 0)
return x_repeat
'''
def pad_zero(x, max_len, return_mask=False):
keys = np.ones(len(x))
padding = np.zeros((max_len-len(x))).astype(int)
@ -118,7 +118,49 @@ def pad_zero(x, max_len, return_mask=False):
return x_padded, mask
else:
return x_padded
'''
def pad_zero(x, max_len, return_mask=False):
"""填充或截断数组到指定长度
Args:
x: 输入数组
max_len: 目标长度
return_mask: 是否返回掩码
Returns:
x_padded: 处理后的数组
mask: (可选) 掩码标记实际数据(True)和填充(False)
"""
# 获取实际长度
actual_len = len(x)
# 如果实际长度超过最大长度,进行截断
if actual_len > max_len:
x = x[:max_len]
mask = np.ones(max_len, dtype=bool)
if return_mask:
return x, mask
return x
# 如果需要填充
if actual_len < max_len:
# 创建掩码
keys = np.ones(actual_len)
padding_mask = np.zeros(max_len - actual_len)
mask = np.concatenate([keys, padding_mask]) == 1
# 填充数据
padding = np.zeros((max_len - actual_len, *x.shape[1:]))
x_padded = np.concatenate([x, padding], axis=0)
else:
# 长度正好
mask = np.ones(max_len, dtype=bool)
x_padded = x
if return_mask:
return x_padded, mask
return x_padded
def plot_3d_bbox(ax, min_corner, max_corner, color='r'):
"""

Loading…
Cancel
Save