Browse Source

优化动态填充,避免warning

final
mckay 1 month ago
parent
commit
1f92faad72
  1. 12
      brep2sdf/data/utils.py

12
brep2sdf/data/utils.py

@ -43,13 +43,19 @@ def process_surf_ncs_with_dynamic_padding(surf_ncs: np.ndarray) -> torch.Tensor:
使用 pad_sequence 动态填充 surf_ncs 使用 pad_sequence 动态填充 surf_ncs
参数: 参数:
surf_ncs: 形状为 (N,) np.ndarray(dtype=object)每个元素是形状为 (M, 3) float32 数组 surf_ncs: 形状为 (N,) np.ndarray(dtype=object)每个元素是形状为 (M, 3) float32 数组或张量
返回: 返回:
padded_tensor: 形状为 (N, M_max, 3) 的张量其中 M_max 是最长子数组的长度 padded_tensor: 形状为 (N, M_max, 3) 的张量其中 M_max 是最长子数组的长度
""" """
# 转换为张量列表 tensor_list = []
tensor_list = [torch.tensor(arr, dtype=torch.float32) for arr in surf_ncs] for arr in surf_ncs:
if isinstance(arr, np.ndarray):
tensor_list.append(torch.from_numpy(arr).clone().detach())
elif isinstance(arr, torch.Tensor):
tensor_list.append(arr.clone().detach())
else:
raise ValueError(f"不支持的类型: {type(arr)},期望 np.ndarray 或 torch.Tensor。")
# 动态填充 # 动态填充
padded_tensor = pad_sequence(tensor_list, batch_first=True, padding_value=float('inf')) padded_tensor = pad_sequence(tensor_list, batch_first=True, padding_value=float('inf'))

Loading…
Cancel
Save