From 1f92faad7289767d0cf2e6f4678eec243b12db73 Mon Sep 17 00:00:00 2001 From: mckay Date: Mon, 5 May 2025 16:05:08 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=8A=A8=E6=80=81=E5=A1=AB?= =?UTF-8?q?=E5=85=85=EF=BC=8C=E9=81=BF=E5=85=8Dwarning?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/data/utils.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/brep2sdf/data/utils.py b/brep2sdf/data/utils.py index 9b82050..d609f98 100644 --- a/brep2sdf/data/utils.py +++ b/brep2sdf/data/utils.py @@ -43,13 +43,19 @@ def process_surf_ncs_with_dynamic_padding(surf_ncs: np.ndarray) -> torch.Tensor: 使用 pad_sequence 动态填充 surf_ncs。 参数: - surf_ncs: 形状为 (N,) 的 np.ndarray(dtype=object),每个元素是形状为 (M, 3) 的 float32 数组。 + surf_ncs: 形状为 (N,) 的 np.ndarray(dtype=object),每个元素是形状为 (M, 3) 的 float32 数组或张量。 返回: padded_tensor: 形状为 (N, M_max, 3) 的张量,其中 M_max 是最长子数组的长度。 """ - # 转换为张量列表 - tensor_list = [torch.tensor(arr, dtype=torch.float32) for arr in surf_ncs] + tensor_list = [] + for arr in surf_ncs: + if isinstance(arr, np.ndarray): + tensor_list.append(torch.from_numpy(arr).clone().detach()) + elif isinstance(arr, torch.Tensor): + tensor_list.append(arr.clone().detach()) + else: + raise ValueError(f"不支持的类型: {type(arr)},期望 np.ndarray 或 torch.Tensor。") # 动态填充 padded_tensor = pad_sequence(tensor_list, batch_first=True, padding_value=float('inf'))