9 changed files with 586 additions and 31 deletions
@ -0,0 +1,133 @@ |
|||||
|
|
||||
|
import os |
||||
|
from concurrent.futures import ProcessPoolExecutor, as_completed |
||||
|
from tqdm import tqdm |
||||
|
import logging |
||||
|
|
||||
|
# 假设 logger 是通过 logging 模块配置的 |
||||
|
logger = logging.getLogger(__name__) |
||||
|
|
||||
|
# utils |
||||
|
def get_namelist(path): |
||||
|
try: |
||||
|
with open(path, 'r') as f: |
||||
|
names = [line.strip() for line in f if line.strip()] |
||||
|
logger.info(f"从 '{path}' 读取了 {len(names)} 个名称。") |
||||
|
return names |
||||
|
except FileNotFoundError: |
||||
|
logger.error(f"错误: 文件 '{path}' 未找到。") |
||||
|
return |
||||
|
except Exception as e: |
||||
|
logger.error(f"读取文件 '{path}' 时出错: {e}") |
||||
|
return |
||||
|
|
||||
|
|
||||
|
def get_step_paths(names, step_root_dir, file_extensions, name_filter=None): |
||||
|
""" |
||||
|
根据名称列表文件路径,获取所有匹配的 STEP 文件路径。 |
||||
|
|
||||
|
Args: |
||||
|
namelist_path (str): 名称列表文件的路径,该文件包含要处理的名称。 |
||||
|
step_root_dir (str): 步骤文件的根目录,每个名称对应一个子目录。 |
||||
|
file_extensions (list): 要匹配的文件扩展名列表,例如 ['.step', '.stp']。 |
||||
|
name_filter (callable, optional): 文件名过滤函数,接受文件名和名称作为参数,返回布尔值。 |
||||
|
|
||||
|
Returns: |
||||
|
list: 匹配的 STEP 文件路径列表。 |
||||
|
""" |
||||
|
# 获取名称列表 |
||||
|
if names is None: |
||||
|
logger.error("无法获取名称列表,终止任务。") |
||||
|
return [] |
||||
|
|
||||
|
step_file_paths = [] |
||||
|
skipped_count = 0 |
||||
|
|
||||
|
# 遍历每个名称,查找匹配的 STEP 文件 |
||||
|
for name in names: |
||||
|
step_dir = os.path.join(step_root_dir, name) |
||||
|
if not os.path.isdir(step_dir): |
||||
|
logger.warning(f"目录 '{step_dir}' 不存在。跳过 '{name}'。") |
||||
|
skipped_count += 1 |
||||
|
continue |
||||
|
|
||||
|
step_files = [] |
||||
|
try: |
||||
|
# 查找匹配的文件 |
||||
|
step_files = [ |
||||
|
os.path.join(step_dir, f) |
||||
|
for f in os.listdir(step_dir) |
||||
|
if f.lower().endswith(tuple(file_extensions)) and (not name_filter or name_filter(f, name)) |
||||
|
] |
||||
|
except OSError as e: |
||||
|
logger.warning(f"无法访问目录 '{step_dir}': {e}。跳过 '{name}'。") |
||||
|
skipped_count += 1 |
||||
|
continue |
||||
|
|
||||
|
if len(step_files) == 0: |
||||
|
logger.warning(f"在目录 '{step_dir}' 中未找到匹配的文件。跳过 '{name}'。") |
||||
|
skipped_count += 1 |
||||
|
elif len(step_files) > 1: |
||||
|
logger.warning(f"在目录 '{step_dir}' 中找到多个匹配的文件,将使用第一个: {step_files[0]}。") |
||||
|
step_file_paths.append(step_files[0]) |
||||
|
else: |
||||
|
step_file_paths.append(step_files[0]) |
||||
|
|
||||
|
logger.info(f"成功获取 {len(step_file_paths)} 个文件路径,跳过了 {skipped_count} 个名称。") |
||||
|
return step_file_paths |
||||
|
|
||||
|
def run_batch_task(task_function, args, common_args_func, file_extensions, name_filter=None): |
||||
|
""" |
||||
|
通用批量任务处理函数。 |
||||
|
|
||||
|
Args: |
||||
|
task_function: 要执行的任务函数,接受文件路径、脚本路径和通用参数作为参数。 |
||||
|
args: 命令行参数对象。 |
||||
|
common_args_func: 生成通用参数的函数。 |
||||
|
file_extensions: 要匹配的文件扩展名列表。 |
||||
|
name_filter: 文件名过滤函数,可选。 |
||||
|
|
||||
|
Returns: |
||||
|
None |
||||
|
""" |
||||
|
# 获取任务文件路径 |
||||
|
tasks = get_step_paths(args.name_list_path, args.step_root_dir, file_extensions, name_filter) |
||||
|
if not tasks: |
||||
|
logger.info("没有找到需要处理的有效文件。") |
||||
|
return |
||||
|
|
||||
|
# 准备通用参数 |
||||
|
common_args = common_args_func(args) |
||||
|
|
||||
|
success_count = 0 |
||||
|
failure_count = 0 |
||||
|
skipped_count = len(get_namelist(args.name_list_path) or []) - len(tasks) |
||||
|
|
||||
|
# 使用 ProcessPoolExecutor 进行并行处理 |
||||
|
with ProcessPoolExecutor(max_workers=args.workers) as executor: |
||||
|
# 提交所有任务 |
||||
|
futures = { |
||||
|
executor.submit(task_function, task_path, args.train_script, common_args): task_path |
||||
|
for task_path in tasks |
||||
|
} |
||||
|
|
||||
|
# 使用 tqdm 显示进度并处理结果 |
||||
|
for future in tqdm(as_completed(futures), total=len(tasks), desc="运行任务"): |
||||
|
input_path = futures[future] |
||||
|
try: |
||||
|
input_file, success, stdout, stderr = future.result() |
||||
|
if success: |
||||
|
success_count += 1 |
||||
|
# 可以选择记录成功的 stdout/stderr,但通常只记录失败的更有用 |
||||
|
# logger.debug(f"成功处理 '{input_file}'. STDOUT:\n{stdout}\nSTDERR:\n{stderr}") |
||||
|
else: |
||||
|
failure_count += 1 |
||||
|
logger.error(f"处理 '{input_file}' 失败。STDOUT:\n{stdout}\nSTDERR:\n{stderr}") |
||||
|
except Exception as e: |
||||
|
failure_count += 1 |
||||
|
logger.error(f"处理任务 '{input_path}' 时获取结果失败: {e}") |
||||
|
|
||||
|
logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}, 跳过: {skipped_count}") |
||||
|
|
||||
|
|
||||
|
|
Loading…
Reference in new issue