diff --git a/brep2sdf/data/data.py b/brep2sdf/data/data.py index 5065e9b..40c7296 100644 --- a/brep2sdf/data/data.py +++ b/brep2sdf/data/data.py @@ -11,7 +11,7 @@ from brep2sdf.data.utils import process_brep_data class BRepSDFDataset(Dataset): - def __init__(self, brep_dir:str, sdf_dir:str, split:str='train'): + def __init__(self, brep_dir:str, sdf_dir:str, valid_data_dir:str, split:str='train'): """ 初始化数据集 @@ -37,6 +37,14 @@ class BRepSDFDataset(Dataset): raise ValueError(f"SDF directory not found: {self.sdf_dir}") # 加载数据列表 + # 如果存在valid_data_file,则加载valid_list + valid_data_file = os.path.join(valid_data_dir, f'{split}_success.txt') + if valid_data_file: + valid_data_file = os.path.join(self.brep_dir, valid_data_file) + self.valid_data_list = self._load_valid_list(valid_data_file) + else: + raise ValueError(f"Valid data file not found: {valid_data_file}") + self.brep_data_list = self._load_data_list(self.brep_dir) self.sdf_data_list = self._load_data_list(self.sdf_dir) @@ -47,13 +55,19 @@ class BRepSDFDataset(Dataset): raise ValueError(f"No valid sdf data found in {split} set") logger.info(f"Loaded {split} dataset with {len(self.brep_data_list)} samples") + + def _load_valid_list(self,valid_data_file:str): + with open(valid_data_file, 'r') as f: + valid_list = [line.strip() for line in f.readlines()] + return valid_list # data_dir 为 self.brep_dir or sdf_dir def _load_data_list(self, data_dir): data_list = [] for sample_file in os.listdir(data_dir): - path = os.path.join(data_dir, sample_file) - data_list.append(path) + if sample_file.split('.')[0] in self.valid_data_list: + path = os.path.join(data_dir, sample_file) + data_list.append(path) #logger.info(data_list) return data_list @@ -138,6 +152,7 @@ class BRepSDFDataset(Dataset): raise + def _load_brep_file(self, brep_path): """加载B-rep特征文件""" try: @@ -331,7 +346,7 @@ def test_dataset(): logger.info(f"Split: {split}") # ... (其余测试代码保持不变) ... - dataset = BRepSDFDataset(brep_dir='/home/wch/brep2sdf/test_data/pkl', sdf_dir='/home/wch/brep2sdf/test_data/sdf', split='train') + dataset = BRepSDFDataset(brep_dir='/home/wch/brep2sdf/test_data/pkl', sdf_dir='/home/wch/brep2sdf/test_data/sdf',valid_data_dir="/home/wch/brep2sdf/test_data/result/pkl", split='train') dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True) for batch in dataloader: diff --git a/brep2sdf/data/utils.py b/brep2sdf/data/utils.py index 164c68c..f847303 100644 --- a/brep2sdf/data/utils.py +++ b/brep2sdf/data/utils.py @@ -966,33 +966,48 @@ def process_brep_data( { 'surf_ncs': np.ndarray, # 面归一化点云 [num_faces, 100, 3] 'edge_ncs': np.ndarray, # 边归一化点云 [num_edges, 10, 3] - 'corner_wcs': np.ndarray, # 顶点坐标 [num_edges, 2, 3] + 'corner_wcs': np.ndarray, # 顶点坐标 [num_edges, 2, 3] - 每条边的两个端点 'faceEdge_adj': np.ndarray, # 面-边邻接矩阵 [num_faces, num_edges] - 'surf_pos': np.ndarray, # 面位置 [num_faces, 6] - 'edge_pos': np.ndarray, # 边位置 [num_edges, 6] + 'surf_pos': np.ndarray, # 面位置(包围盒) [num_faces, 6] + 'edge_pos': np.ndarray, # 边位置(包围盒) [num_edges, 6] } - max_face (int): 最大面数 - max_edge (int): 最大边数 + max_face (int): 最大面数,用于填充 + max_edge (int): 最大边数,用于填充 bbox_scaled (float): 边界框缩放因子 aug (bool): 是否使用数据增强 data_class (Optional[int]): 数据类别标签 Returns: Tuple[torch.Tensor, ...]: 包含以下张量的元组: - - edge_ncs: 边归一化特征 [num_faces, max_edge, 10, 3] - - edge_pos: 边位置 [num_faces, max_edge, 6] - - edge_mask: 边掩码 [num_faces, max_edge] + - edge_ncs: 边归一化特征 [max_face, max_edge, 10, 3] + - edge_pos: 边位置 [max_face, max_edge, 6] + - edge_mask: 边掩码 [max_face, max_edge] - surf_ncs: 面归一化特征 [max_face, 100, 3] - surf_pos: 面位置 [max_face, 6] - - vertex_pos: 顶点位置 [max_face, max_edge, 6] + - vertex_pos: 顶点位置 [max_face, max_edge, 2, 3] - 每个面的每条边的两个端点 - data_class: (可选) 类别标签 [1] + + 数据处理流程: + 1. 数据增强(可选): + - 对几何元素进行随机旋转 + - 重新计算包围盒 + 2. 特征复制: + - 根据面-边邻接关系复制边和顶点特征 + - 保持顶点对的结构 [2, 3] + 3. 特征打乱: + - 随机打乱每个面的边顺序 + - 随机打乱面的顺序 + 4. 填充处理: + - 填充到最大边数 + - 填充到最大面数 + 5. 转换为张量 """ # 解包数据 #_, _, surf_ncs, edge_ncs, corner_wcs, _, _, faceEdge_adj, surf_pos, edge_pos, _ = data.values() # 直接获取需要的数据 surf_ncs = data['surf_ncs'] # (num_faces,) -> 每个元素形状 (N, 3) edge_ncs = data['edge_ncs'] # (num_edges, 100, 3) - corner_wcs = data['corner_wcs'] # (num_corners, 3) + corner_wcs = data['corner_wcs'] # (num_edges, 2, 3) faceEdge_adj = data['faceEdge_adj'] # (num_faces, num_edges) edgeCorner_adj = data['edgeCorner_adj'] # (num_edges, 2) 每条边连接2个顶点 surf_pos = data['surf_bbox_wcs'] # (num_faces, 6) @@ -1011,12 +1026,7 @@ def process_brep_data( # 旋转所有几何元素,保持形状不变 surfpos_corners = rotate_axis(surfpos_corners, angle, axis, normalized=True) edgepos_corners = rotate_axis(edgepos_corners, angle, axis, normalized=True) - corner_wcs = rotate_axis(corner_wcs, angle, axis, normalized=True) - - # 对每个面的点云进行旋转 - for i in range(len(surf_ncs)): - surf_ncs[i] = rotate_axis(surf_ncs[i], angle, axis, normalized=False) - edge_ncs = rotate_axis(edge_ncs, angle, axis, normalized=False) + corner_wcs = rotate_axis(corner_wcs, angle, axis, normalized=True) # 直接旋转,保持形状 # 重新计算边界框 @@ -1031,25 +1041,20 @@ def process_brep_data( corner_wcs = corner_wcs * bbox_scaled # [num_edges, 2, 3] # 特征复制 - edge_pos_duplicated: List[np.ndarray] = [] # 每个元素形状: [num_edges, 6] - vertex_pos_duplicated: List[np.ndarray] = [] # 每个元素形状: [num_edges, 6] - edge_ncs_duplicated: List[np.ndarray] = [] # 每个元素形状: [num_edges, 10, 3] + edge_pos_duplicated = [] # [num_edges_per_face, 6] + vertex_pos_duplicated = [] # [num_edges_per_face, 2, 3] + edge_ncs_duplicated = [] # [num_edges_per_face, 10, 3] for adj in faceEdge_adj: # [num_faces, num_edges] - edge_ncs_duplicated.append(edge_ncs[adj]) # [num_edges, 10, 3] - edge_pos_duplicated.append(edge_pos[adj]) # [num_edges, 6] - #corners = corner_wcs[adj] # [num_vertax, 3] FIXME - edge_indices = np.where(adj)[0] # 获取当前面的边索引 - corner_indices = edgeCorner_adj[edge_indices] # 获取这些边对应的顶点索引 - corners = corner_wcs[corner_indices] # 获取顶点坐标 - logger.debug(corners) - - corners_sorted = [] - for corner in corners: # [2, 3] - sorted_indices = np.lexsort((corner[:, 2], corner[:, 1], corner[:, 0])) - corners_sorted.append(corner[sorted_indices].flatten()) # [6] - corners = np.stack(corners_sorted) # [num_edges_per_face, 6] - vertex_pos_duplicated.append(corners) + edge_indices = np.where(adj)[0] # 获取当前面的边索引 + + # 复制边的特征 + edge_ncs_duplicated.append(edge_ncs[edge_indices]) # [num_edges_per_face, 10, 3] + edge_pos_duplicated.append(edge_pos[edge_indices]) # [num_edges_per_face, 6] + + # 直接获取对应边的顶点对 + vertex_pairs = corner_wcs[edge_indices] # [num_edges_per_face, 2, 3] + vertex_pos_duplicated.append(vertex_pairs) # 边特征打乱和填充 edge_pos_new = [] # 最终形状: [num_faces, max_edge, 6] @@ -1058,14 +1063,19 @@ def process_brep_data( edge_mask = [] # 最终形状: [num_faces, max_edge] for pos, ncs, vert in zip(edge_pos_duplicated, edge_ncs_duplicated, vertex_pos_duplicated): - random_indices = np.random.permutation(pos.shape[0]) - pos = pos[random_indices] # [num_edges_per_face, 6] - ncs = ncs[random_indices] # [num_edges_per_face, 10, 3] - vert = vert[random_indices] # [num_edges_per_face, 6] + # 生成随机排列 + num_edges = pos.shape[0] + random_indices = np.random.permutation(num_edges) + + # 同时打乱所有特征 + pos = pos[random_indices] # [num_edges_per_face, 6] + ncs = ncs[random_indices] # [num_edges_per_face, 10, 3] + vert = vert[random_indices] # [num_edges_per_face, 2, 3] + # 填充到最大边数 pos, mask = pad_zero(pos, max_edge, return_mask=True) # [max_edge, 6], [max_edge] ncs = pad_zero(ncs, max_edge) # [max_edge, 10, 3] - vert = pad_zero(vert, max_edge) # [max_edge, 6] + vert = pad_zero(vert, max_edge) # [max_edge, 2, 3] edge_pos_new.append(pos) edge_ncs_new.append(ncs) @@ -1075,7 +1085,7 @@ def process_brep_data( edge_pos = np.stack(edge_pos_new) # [num_faces, max_edge, 6] edge_ncs = np.stack(edge_ncs_new) # [num_faces, max_edge, 10, 3] edge_mask = np.stack(edge_mask) # [num_faces, max_edge] - vertex_pos = np.stack(vert_pos_new) # [num_faces, max_edge, 6] + vertex_pos = np.stack(vert_pos_new) # [num_faces, max_edge, 2, 3] # 面特征打乱 random_indices = np.random.permutation(surf_pos.shape[0]) @@ -1087,14 +1097,14 @@ def process_brep_data( #surf_ncs = surf_ncs[random_indices] # [num_faces, 100, 3] edge_ncs = edge_ncs[random_indices] # [num_faces, max_edge, 10, 3] edge_mask = edge_mask[random_indices] # [num_faces, max_edge] - vertex_pos = vertex_pos[random_indices] # [num_faces, max_edge, 6] + vertex_pos = vertex_pos[random_indices] # [num_faces, max_edge, 2, 3] # 填充到最大面数 surf_pos = pad_zero(surf_pos, max_face) # [max_face, 6] surf_ncs = pad_zero(surf_ncs, max_face) # [max_face, 100, 3] edge_pos = pad_zero(edge_pos, max_face) # [max_face, max_edge, 6] edge_ncs = pad_zero(edge_ncs, max_face) # [max_face, max_edge, 10, 3] - vertex_pos = pad_zero(vertex_pos, max_face) # [max_face, max_edge, 6] + vertex_pos = pad_zero(vertex_pos, max_face) # [max_face, max_edge, 2, 3] # 扩展边掩码 padding = np.zeros((max_face-len(edge_mask), *edge_mask.shape[1:])) == 0 # [max_face-num_faces, max_edge] @@ -1108,7 +1118,7 @@ def process_brep_data( torch.BoolTensor(edge_mask), # [max_face, max_edge] torch.FloatTensor(surf_ncs), # [max_face, 100, 3] torch.FloatTensor(surf_pos), # [max_face, 6] - torch.FloatTensor(vertex_pos), # [max_face, max_edge, 6] + torch.FloatTensor(vertex_pos), # [max_face, max_edge, 2, 3] torch.LongTensor([data_class+1]) # [1] ) else: @@ -1118,5 +1128,5 @@ def process_brep_data( torch.BoolTensor(edge_mask), # [max_face, max_edge] torch.FloatTensor(surf_ncs), # [max_face, 100, 3] torch.FloatTensor(surf_pos), # [max_face, 6] - torch.FloatTensor(vertex_pos) # [max_face, max_edge, 6] + torch.FloatTensor(vertex_pos) # [max_face, max_edge, 2, 3] ) \ No newline at end of file diff --git a/brep2sdf/scripts/process_brep.py b/brep2sdf/scripts/process_brep.py index 6a9472f..d70f309 100644 --- a/brep2sdf/scripts/process_brep.py +++ b/brep2sdf/scripts/process_brep.py @@ -43,20 +43,20 @@ def normalize(surfs, edges, corners): 参数: surfs: 面的点集列表 edges: 边的点集列表 - corners: 顶点坐标列表 + corners: 顶点坐标数组 [num_edges, 2, 3] 返回: surfs_wcs: 原始坐标系下的面点集 edges_wcs: 原始坐标系下的边点集 surfs_ncs: 归一化坐标系下的面点集 edges_ncs: 归一化坐标系下的边点集 - corner_wcs: 归一化后的顶点坐标 + corner_wcs: 归一化后的顶点坐标 [num_edges, 2, 3] """ if len(corners) == 0: return None, None, None, None, None # 计算包围盒和缩放因子 - corners_array = np.array(corners) + corners_array = corners.reshape(-1, 3) # [num_edges*2, 3] center = (corners_array.max(0) + corners_array.min(0)) / 2 # 计算中心点 scale = 1.0 / (corners_array.max(0) - corners_array.min(0)).max() # 计算缩放系数 @@ -78,8 +78,8 @@ def normalize(surfs, edges, corners): edges_wcs.append(edge_wcs) edges_ncs.append(edge_ncs) - # 归一化顶点坐标 - corner_wcs = (corners_array - center) * scale + # 归一化顶点坐标 - 保持[num_edges, 2, 3]的形状 + corner_wcs = (corners - center) * scale # 广播操作会保持原有维度 return (np.array(surfs_wcs, dtype=object), np.array(edges_wcs, dtype=object), @@ -194,8 +194,8 @@ def parse_solid(step_path): 'edge_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(100, 3)的float32数组,表示边的采样点坐标 'surf_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示归一化后的面点云 'edge_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(100, 3)的float32数组,表示归一化后的边采样点 - 'corner_wcs': np.ndarray(dtype=float32) # 形状为(K, 3)的数组,表示所有顶点的坐标 - 'corner_unique': np.ndarray(dtype=float32) # 形状为(L, 3)的数组,表示去重后的顶点坐标 + 'corner_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 2, 3)的数组,表示每条边的两个端点坐标 + 'corner_unique': np.ndarray(dtype=float32) # 形状为(num_vertices, 3)的数组,表示所有顶点的唯一坐标,num_vertices <= num_edges * 2 # 拓扑关系 'edgeFace_adj': np.ndarray(dtype=int32) # 形状为(num_edges, num_faces)的数组,表示边-面邻接关系 @@ -282,15 +282,29 @@ def parse_solid(step_path): # 获取邻接信息 edgeFace_adj, faceEdge_adj, edgeCorner_adj = get_adjacency_info(shape) - # 转换为numpy数组,但保持列表形式 - face_pnts = list(face_pnts) # 确保是列表 - edge_pnts = list(edge_pnts) # 确保是列表 - corner_pnts = np.array(corner_pnts, dtype=np.float32) + # 转换为numpy数组 + face_pnts = list(face_pnts) + edge_pnts = list(edge_pnts) + corner_pnts = np.array(corner_pnts, dtype=np.float32) # [num_vertices, 3] + + # 重组顶点数据为每条边两个端点的形式 + corner_pairs = [] + for edge_idx in range(len(edge_pnts)): + v1_idx, v2_idx = edgeCorner_adj[edge_idx] + v1_pos = corner_pnts[v1_idx] + v2_pos = corner_pnts[v2_idx] + # 按坐标排序确保顺序一致 + if (v1_pos > v2_pos).any(): + v1_pos, v2_pos = v2_pos, v1_pos + corner_pairs.append(np.stack([v1_pos, v2_pos])) + + corner_pairs = np.stack(corner_pairs) # [num_edges, 2, 3] surf_bbox_wcs = np.array(surf_bbox_wcs, dtype=np.float32) edge_bbox_wcs = np.array(edge_bbox_wcs, dtype=np.float32) # Normalize the CAD model - surfs_wcs, edges_wcs, surfs_ncs, edges_ncs, corner_wcs = normalize(face_pnts, edge_pnts, corner_pnts) + surfs_wcs, edges_wcs, surfs_ncs, edges_ncs, corner_wcs = normalize( + face_pnts, edge_pnts, corner_pairs) # Create result dictionary data = { @@ -298,13 +312,13 @@ def parse_solid(step_path): 'edge_wcs': edges_wcs, 'surf_ncs': surfs_ncs, 'edge_ncs': edges_ncs, - 'corner_wcs': corner_wcs.astype(np.float32), + 'corner_wcs': corner_wcs, # [num_edges, 2, 3] 'edgeFace_adj': edgeFace_adj, 'edgeCorner_adj': edgeCorner_adj, 'faceEdge_adj': faceEdge_adj, 'surf_bbox_wcs': surf_bbox_wcs, 'edge_bbox_wcs': edge_bbox_wcs, - 'corner_unique': np.unique(corner_wcs, axis=0).astype(np.float32) + 'corner_unique': np.unique(corner_wcs.reshape(-1, 3), axis=0).astype(np.float32) # 先展平再去重 } return data @@ -316,18 +330,91 @@ def load_step(step_path): reader.TransferRoots() return [reader.OneShape()] -def process_single_step( - step_path:str, - output_path:str=None, - timeout:int=300 -) -> dict: - """Process single STEP file""" +def check_data_format(data, step_file): + """检查数据格式和维度是否符合要求""" + try: + # 检查必需的键 + required_keys = [ + # 几何数据 + 'surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs', + 'corner_wcs', 'corner_unique', + # 拓扑关系 + 'edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj', + # 包围盒数据 + 'surf_bbox_wcs', 'edge_bbox_wcs' + ] + + # 检查键是否存在 + for key in required_keys: + if key not in data: + return False, f"Missing required key: {key}" + + # 检查几何数据 + geometry_arrays = ['surf_wcs', 'edge_wcs', 'surf_ncs', 'edge_ncs'] + for key in geometry_arrays: + if not isinstance(data[key], np.ndarray) or data[key].dtype != object: + return False, f"{key} should be a numpy array with dtype=object" + + # 检查顶点数据 + if not isinstance(data['corner_wcs'], np.ndarray) or data['corner_wcs'].dtype != np.float32: + return False, "corner_wcs should be a numpy array with dtype=float32" + if len(data['corner_wcs'].shape) != 3 or data['corner_wcs'].shape[1:] != (2, 3): + return False, f"corner_wcs should have shape (num_edges, 2, 3), got {data['corner_wcs'].shape}" + + if not isinstance(data['corner_unique'], np.ndarray) or data['corner_unique'].dtype != np.float32: + return False, "corner_unique should be a numpy array with dtype=float32" + if len(data['corner_unique'].shape) != 2 or data['corner_unique'].shape[1] != 3: + return False, f"corner_unique should have shape (N, 3), got {data['corner_unique'].shape}" + + # 检查拓扑关系 + num_faces = len(data['surf_wcs']) + num_edges = len(data['edge_wcs']) + + # 检查邻接矩阵 + adj_checks = [ + ('edgeFace_adj', (num_edges, num_faces)), + ('faceEdge_adj', (num_faces, num_edges)), + ('edgeCorner_adj', (num_edges, 2)) + ] + + for key, expected_shape in adj_checks: + if not isinstance(data[key], np.ndarray) or data[key].dtype != np.int32: + return False, f"{key} should be a numpy array with dtype=int32" + if data[key].shape != expected_shape: + return False, f"{key} shape mismatch: expected {expected_shape}, got {data[key].shape}" + + # 检查包围盒数据 + bbox_checks = [ + ('surf_bbox_wcs', (num_faces, 6)), + ('edge_bbox_wcs', (num_edges, 6)) + ] + + for key, expected_shape in bbox_checks: + if not isinstance(data[key], np.ndarray) or data[key].dtype != np.float32: + return False, f"{key} should be a numpy array with dtype=float32" + if data[key].shape != expected_shape: + return False, f"{key} shape mismatch: expected {expected_shape}, got {data[key].shape}" + + return True, "" + + except Exception as e: + return False, f"Format check failed: {str(e)}" + +def process_single_step(step_path:str, output_path:str=None, timeout:int=300) -> dict: + """处理单个STEP文件""" try: # 解析STEP文件 data = parse_solid(step_path) if data is None: - logger.error("Failed to parse STEP file") + logger.error(f"Failed to parse STEP file: {step_path}") + return None + + # 检查数据格式 + is_valid, msg = check_data_format(data, step_path) + if not is_valid: + logger.error(f"Data format check failed for {step_path}: {msg}") return None + # 保存结果 if output_path: try: @@ -335,14 +422,17 @@ def process_single_step( os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, 'wb') as f: pickle.dump(data, f) - logger.info("Results saved successfully") + logger.info(f"Results saved successfully: {output_path}") + return data except Exception as e: - logger.error(f'Not saving due to error: {str(e)}') + logger.error(f'Failed to save {output_path}: {str(e)}') + return None return data + except Exception as e: - logger.error(f'Not saving due to error: {str(e)}') - return 0 + logger.error(f'Error processing {step_path}: {str(e)}') + return None def test(step_file_path, output_path=None): """ @@ -354,27 +444,37 @@ def test(step_file_path, output_path=None): # 解析STEP文件 data = parse_solid(step_file_path) if data is None: - logger.error("Failed to parse STEP file") + logger.error(f"Failed to parse STEP file: {step_file_path}") + return None + + # 检查数据格式 + is_valid, msg = check_data_format(data, step_file_path) + if not is_valid: + logger.error(f"Data format check failed for {step_file_path}: {msg}") return None # 打印统计信息 logger.info("\nStatistics:") logger.info(f"Number of surfaces: {len(data['surf_wcs'])}") logger.info(f"Number of edges: {len(data['edge_wcs'])}") - logger.info(f"Number of corners: {len(data['corner_unique'])}") + logger.info(f"Number of corners: {len(data['corner_wcs'])}") # 修正为corner_wcs # 保存结果 if output_path: - logger.info(f"Saving results to: {output_path}") - os.makedirs(os.path.dirname(output_path), exist_ok=True) - with open(output_path, 'wb') as f: - pickle.dump(data, f) - logger.info("Results saved successfully") + try: + logger.info(f"Saving results to: {output_path}") + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'wb') as f: + pickle.dump(data, f) + logger.info(f"Results saved successfully: {output_path}") + except Exception as e: + logger.error(f"Failed to save {output_path}: {str(e)}") + return None return data except Exception as e: - logger.error(f"Error processing STEP file: {str(e)}") + logger.error(f"Error processing {step_file_path}: {str(e)}") return None def process_furniture_step(data_path): @@ -407,13 +507,26 @@ def process_furniture_step(data_path): def main(): - """ - 主函数:处理多个STEP文件 - """ + """主函数:处理多个STEP文件""" # 定义路径常量 INPUT = '/mnt/disk2/dataset/furniture/step/furniture_dataset_step/' OUTPUT = '../test_data/pkl/' - RESULT = '../test_data/result/' # 用于存储成功/失败文件记录 + RESULT = '../test_data/result/pkl/' # 用于存储成功/失败文件记录 + + # 清理输出目录 + def clean_directory(directory): + if os.path.exists(directory): + logger.info(f"Cleaning directory: {directory}") + for root, dirs, files in os.walk(directory, topdown=False): + for name in files: + os.remove(os.path.join(root, name)) + for name in dirs: + os.rmdir(os.path.join(root, name)) + logger.info(f"Directory cleaned: {directory}") + + # 清理之前的输出 + clean_directory(OUTPUT) + clean_directory(RESULT) # 确保输出目录存在 os.makedirs(OUTPUT, exist_ok=True) @@ -424,6 +537,9 @@ def main(): total_processed = 0 total_success = 0 + # 记录开始时间 + start_time = datetime.now() + # 按数据集分割处理文件 for split in ['train', 'val', 'test']: current_step_dirs = step_dirs_dict[split] @@ -431,14 +547,12 @@ def main(): logger.warning(f"No files found in {split} split") continue - # 确保分割目录存在 + # 确保输出目录存在 split_output_dir = os.path.join(OUTPUT, split) - split_result_dir = os.path.join(RESULT, split) os.makedirs(split_output_dir, exist_ok=True) - os.makedirs(split_result_dir, exist_ok=True) - success_files = [] # 存储成功处理的文件名 - failed_files = [] # 存储失败的文件名及原因 + success_files = [] # 只存储基础文件名(不含扩展名) + failed_files = [] # 只存储基础文件名(不含扩展名) # 并行处理文件 with ProcessPoolExecutor(max_workers=os.cpu_count() // 2) as executor: @@ -450,41 +564,50 @@ def main(): futures[future] = step_file # 处理结果 - for future in tqdm(as_completed(futures), total=len(current_step_dirs), + for future in tqdm(as_completed(futures), total=len(current_step_dirs), desc=f"Processing {split} set"): + step_file = futures[future] + base_name = step_file.replace('.step', '') # 获取不含扩展名的文件名 try: status = future.result(timeout=300) if status is not None: - success_files.append(futures[future]) + success_files.append(base_name) total_success += 1 - except TimeoutError: - logger.error(f"Timeout occurred while processing {futures[future]}") - failed_files.append((futures[future], "Timeout")) - except Exception as e: - logger.error(f"Error processing {futures[future]}: {str(e)}") - failed_files.append((futures[future], str(e))) + else: + failed_files.append(base_name) + except (TimeoutError, Exception): + failed_files.append(base_name) finally: total_processed += 1 - # 保存成功文件列表 - success_file_path = os.path.join(split_result_dir, 'success.txt') - with open(success_file_path, 'w', encoding='utf-8') as f: + # 保存处理结果 + os.makedirs(RESULT, exist_ok=True) + + # 保存成功文件列表 (只保存文件名) + success_path = os.path.join(RESULT, f'{split}_success.txt') + with open(success_path, 'w') as f: f.write('\n'.join(success_files)) - logger.info(f"Saved {len(success_files)} successful files to {success_file_path}") - # 保存失败文件列表(包含错误信息) - failed_file_path = os.path.join(split_result_dir, 'failed.txt') - with open(failed_file_path, 'w', encoding='utf-8') as f: - for file, error in failed_files: - f.write(f"{file}: {error}\n") - logger.info(f"Saved {len(failed_files)} failed files to {failed_file_path}") + # 保存失败文件列表 (只保存文件名) + failed_path = os.path.join(RESULT, f'{split}_failed.txt') + with open(failed_path, 'w') as f: + f.write('\n'.join(failed_files)) + + logger.info(f"{split} set - Success: {len(success_files)}, Failed: {len(failed_files)}") # 打印最终统计信息 + end_time = datetime.now() + duration = end_time - start_time + if total_processed > 0: success_rate = (total_success / total_processed) * 100 - logger.info(f"Processing completed:") + logger.info("\nProcessing Summary:") + logger.info(f"Start time: {start_time}") + logger.info(f"End time: {end_time}") + logger.info(f"Duration: {duration}") logger.info(f"Total files processed: {total_processed}") logger.info(f"Successfully processed: {total_success}") + logger.info(f"Failed: {total_processed - total_success}") logger.info(f"Success rate: {success_rate:.2f}%") else: logger.warning("No files were processed")