|
|
@ -43,13 +43,19 @@ def process_surf_ncs_with_dynamic_padding(surf_ncs: np.ndarray) -> torch.Tensor: |
|
|
|
使用 pad_sequence 动态填充 surf_ncs。 |
|
|
|
|
|
|
|
参数: |
|
|
|
surf_ncs: 形状为 (N,) 的 np.ndarray(dtype=object),每个元素是形状为 (M, 3) 的 float32 数组。 |
|
|
|
surf_ncs: 形状为 (N,) 的 np.ndarray(dtype=object),每个元素是形状为 (M, 3) 的 float32 数组或张量。 |
|
|
|
|
|
|
|
返回: |
|
|
|
padded_tensor: 形状为 (N, M_max, 3) 的张量,其中 M_max 是最长子数组的长度。 |
|
|
|
""" |
|
|
|
# 转换为张量列表 |
|
|
|
tensor_list = [torch.tensor(arr, dtype=torch.float32) for arr in surf_ncs] |
|
|
|
tensor_list = [] |
|
|
|
for arr in surf_ncs: |
|
|
|
if isinstance(arr, np.ndarray): |
|
|
|
tensor_list.append(torch.from_numpy(arr).clone().detach()) |
|
|
|
elif isinstance(arr, torch.Tensor): |
|
|
|
tensor_list.append(arr.clone().detach()) |
|
|
|
else: |
|
|
|
raise ValueError(f"不支持的类型: {type(arr)},期望 np.ndarray 或 torch.Tensor。") |
|
|
|
|
|
|
|
# 动态填充 |
|
|
|
padded_tensor = pad_sequence(tensor_list, batch_first=True, padding_value=float('inf')) |
|
|
|