diff --git a/brep2sdf/data/utils.py b/brep2sdf/data/utils.py index 159e289..354aa83 100644 --- a/brep2sdf/data/utils.py +++ b/brep2sdf/data/utils.py @@ -1169,15 +1169,19 @@ def process_brep_data(data: dict, max_face: int, max_edge: int, bbox_scaled: flo vertex_pos = vertex_pos[random_indices] # [num_faces, max_edge, 2, 3] # 填充到最大面数 - surf_pos = pad_zero(surf_pos, max_face) # [max_face, 6] + surf_pos, face_mask = pad_zero(surf_pos, max_face, return_mask=True) # [max_face, 6] surf_ncs = pad_zero(surf_ncs, max_face) # [max_face, 100, 3] edge_pos = pad_zero(edge_pos, max_face) # [max_face, max_edge, 6] edge_ncs = pad_zero(edge_ncs, max_face) # [max_face, max_edge, 10, 3] vertex_pos = pad_zero(vertex_pos, max_face) # [max_face, max_edge, 2, 3] - # 扩展边掩码 - padding = np.zeros((max_face-len(edge_mask), *edge_mask.shape[1:])) == 0 # [max_face-num_faces, max_edge] - edge_mask = np.concatenate([edge_mask, padding], 0) # [max_face, max_edge] + # 扩展边掩码 - 使用face_mask来创建新的edge_mask + if len(edge_mask) > max_face: + edge_mask = edge_mask[:max_face] + else: + # 创建填充掩码 + padding = np.zeros((max_face-len(edge_mask), max_edge), dtype=bool) + edge_mask = np.concatenate([edge_mask, padding], axis=0) # [max_face, max_edge] # 转换为张量并返回 if data_class is not None: