From c03a08a13ea0950cfc62e81d4acfb2b5fe1c948d Mon Sep 17 00:00:00 2001 From: mckay Date: Mon, 18 Nov 2024 00:59:38 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=95=B0=E6=8D=AE=E9=9B=86=E5=8F=AF?= =?UTF-8?q?=E4=BB=A5=E5=8A=A0=E8=BD=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/data/utils.py | 52 ++++++++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/brep2sdf/data/utils.py b/brep2sdf/data/utils.py index cfc31bd..40948ae 100644 --- a/brep2sdf/data/utils.py +++ b/brep2sdf/data/utils.py @@ -950,14 +950,8 @@ def construct_brep(surf_wcs, edge_wcs, FaceEdgeAdj, EdgeVertexAdj): return solid -def process_brep_data( - data: dict, - max_face: int, - max_edge: int, - bbox_scaled: float, - aug: bool = False, - data_class: Optional[int] = None -) -> Tuple[torch.Tensor, ...]: +def process_brep_data(data: dict, max_face: int, max_edge: int, bbox_scaled: float, + aug: bool = False, data_class: Optional[int] = None) -> Tuple[torch.Tensor, ...]: """ 处理B-rep数据的函数 @@ -1087,29 +1081,47 @@ def process_brep_data( edge_mask = np.stack(edge_mask) # [num_faces, max_edge] vertex_pos = np.stack(vert_pos_new) # [num_faces, max_edge, 2, 3] - # 面特征打乱 - random_indices = np.random.permutation(surf_pos.shape[0]) - surf_pos = surf_pos[random_indices] # [num_faces, 6] - edge_pos = edge_pos[random_indices] # [num_faces, max_edge, 6] + # 处理edge_ncs (确保形状为 [num_faces, max_edge, 10, 3]) + edge_ncs_list = [] + for face_idx in range(len(edge_ncs)): + face_edges = [] + for edge_idx in range(len(edge_ncs[face_idx])): + edge_points = edge_ncs[face_idx][edge_idx] + # 确保每条边有10个点 + if len(edge_points) > 10: + indices = np.random.choice(len(edge_points), 10, replace=False) + edge_points = edge_points[indices] + elif len(edge_points) < 10: + indices = np.random.choice(len(edge_points), 10-len(edge_points)) + edge_points = np.concatenate([edge_points, edge_points[indices]], axis=0) + face_edges.append(edge_points) + + # 填充到最大边数 + while len(face_edges) < max_edge: + face_edges.append(np.zeros((10, 3), dtype=np.float32)) + + edge_ncs_list.append(np.stack(face_edges)) + + edge_ncs = np.stack(edge_ncs_list).astype(np.float32) # [num_faces, max_edge, 10, 3] - # 处理surf_ncs (对象数组) + # 处理surf_ncs (确保形状为 [num_faces, 100, 3]) surf_ncs_list = [] - for idx in random_indices: - points = surf_ncs[idx] # 获取当前面的点云 - # 确保点云数据形状正确 (N, 3) -> (100, 3) + for points in surf_ncs: if len(points) > 100: - # 如果点数超过100,随机采样 indices = np.random.choice(len(points), 100, replace=False) points = points[indices] elif len(points) < 100: - # 如果点数少于100,重复采样 indices = np.random.choice(len(points), 100-len(points)) points = np.concatenate([points, points[indices]], axis=0) surf_ncs_list.append(points) - # 将列表转换为numpy数组 - surf_ncs = np.stack(surf_ncs_list) # [num_faces, 100, 3] + surf_ncs = np.stack(surf_ncs_list).astype(np.float32) # [num_faces, 100, 3] + # 面特征打乱 + random_indices = np.random.permutation(surf_pos.shape[0]) + surf_pos = surf_pos[random_indices] # [num_faces, 6] + edge_pos = edge_pos[random_indices] # [num_faces, max_edge, 6] + surf_ncs = surf_ncs[random_indices] # [num_faces, 100, 3] edge_ncs = edge_ncs[random_indices] # [num_faces, max_edge, 10, 3] edge_mask = edge_mask[random_indices] # [num_faces, max_edge] vertex_pos = vertex_pos[random_indices] # [num_faces, max_edge, 2, 3]