Browse Source

fix: 数据集可以加载

main
mckay 7 months ago
parent
commit
c03a08a13e
  1. 52
      brep2sdf/data/utils.py

52
brep2sdf/data/utils.py

@ -950,14 +950,8 @@ def construct_brep(surf_wcs, edge_wcs, FaceEdgeAdj, EdgeVertexAdj):
return solid
def process_brep_data(
data: dict,
max_face: int,
max_edge: int,
bbox_scaled: float,
aug: bool = False,
data_class: Optional[int] = None
) -> Tuple[torch.Tensor, ...]:
def process_brep_data(data: dict, max_face: int, max_edge: int, bbox_scaled: float,
aug: bool = False, data_class: Optional[int] = None) -> Tuple[torch.Tensor, ...]:
"""
处理B-rep数据的函数
@ -1087,29 +1081,47 @@ def process_brep_data(
edge_mask = np.stack(edge_mask) # [num_faces, max_edge]
vertex_pos = np.stack(vert_pos_new) # [num_faces, max_edge, 2, 3]
# 面特征打乱
random_indices = np.random.permutation(surf_pos.shape[0])
surf_pos = surf_pos[random_indices] # [num_faces, 6]
edge_pos = edge_pos[random_indices] # [num_faces, max_edge, 6]
# 处理edge_ncs (确保形状为 [num_faces, max_edge, 10, 3])
edge_ncs_list = []
for face_idx in range(len(edge_ncs)):
face_edges = []
for edge_idx in range(len(edge_ncs[face_idx])):
edge_points = edge_ncs[face_idx][edge_idx]
# 确保每条边有10个点
if len(edge_points) > 10:
indices = np.random.choice(len(edge_points), 10, replace=False)
edge_points = edge_points[indices]
elif len(edge_points) < 10:
indices = np.random.choice(len(edge_points), 10-len(edge_points))
edge_points = np.concatenate([edge_points, edge_points[indices]], axis=0)
face_edges.append(edge_points)
# 填充到最大边数
while len(face_edges) < max_edge:
face_edges.append(np.zeros((10, 3), dtype=np.float32))
# 处理surf_ncs (对象数组)
edge_ncs_list.append(np.stack(face_edges))
edge_ncs = np.stack(edge_ncs_list).astype(np.float32) # [num_faces, max_edge, 10, 3]
# 处理surf_ncs (确保形状为 [num_faces, 100, 3])
surf_ncs_list = []
for idx in random_indices:
points = surf_ncs[idx] # 获取当前面的点云
# 确保点云数据形状正确 (N, 3) -> (100, 3)
for points in surf_ncs:
if len(points) > 100:
# 如果点数超过100,随机采样
indices = np.random.choice(len(points), 100, replace=False)
points = points[indices]
elif len(points) < 100:
# 如果点数少于100,重复采样
indices = np.random.choice(len(points), 100-len(points))
points = np.concatenate([points, points[indices]], axis=0)
surf_ncs_list.append(points)
# 将列表转换为numpy数组
surf_ncs = np.stack(surf_ncs_list) # [num_faces, 100, 3]
surf_ncs = np.stack(surf_ncs_list).astype(np.float32) # [num_faces, 100, 3]
# 面特征打乱
random_indices = np.random.permutation(surf_pos.shape[0])
surf_pos = surf_pos[random_indices] # [num_faces, 6]
edge_pos = edge_pos[random_indices] # [num_faces, max_edge, 6]
surf_ncs = surf_ncs[random_indices] # [num_faces, 100, 3]
edge_ncs = edge_ncs[random_indices] # [num_faces, max_edge, 10, 3]
edge_mask = edge_mask[random_indices] # [num_faces, max_edge]
vertex_pos = vertex_pos[random_indices] # [num_faces, max_edge, 2, 3]

Loading…
Cancel
Save