7 changed files with 723 additions and 172 deletions
@ -1,67 +1,152 @@ |
|||
import os |
|||
import subprocess |
|||
import argparse |
|||
import logging |
|||
from concurrent.futures import ProcessPoolExecutor, as_completed |
|||
from tqdm import tqdm |
|||
|
|||
# 配置日志记录 |
|||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|||
logger = logging.getLogger(__name__) |
|||
|
|||
def main(): |
|||
# 定义 STEP 文件目录和名称列表文件路径 |
|||
step_root_dir = "/home/wch/brep2sdf/data/step" |
|||
name_list_path = "/home/wch/brep2sdf/data/name_list.txt" |
|||
def run_training_process(input_step: str, train_script: str, common_args: list) -> tuple[str, bool, str, str]: |
|||
""" |
|||
为单个 STEP 文件运行 train.py 子进程。 |
|||
|
|||
Args: |
|||
input_step: 输入 STEP 文件的路径。 |
|||
train_script: train.py 脚本的路径。 |
|||
common_args: 传递给 train.py 的通用参数列表。 |
|||
|
|||
Returns: |
|||
Tuple: (输入文件路径, 是否成功, stdout, stderr) |
|||
""" |
|||
command = [ |
|||
"python", train_script, |
|||
*common_args, |
|||
"-i", input_step, |
|||
] |
|||
try: |
|||
result = subprocess.run( |
|||
command, |
|||
capture_output=True, |
|||
text=True, |
|||
check=True, # 如果返回非零退出码,则抛出 CalledProcessError |
|||
timeout=600 # 添加超时设置(例如10分钟) |
|||
) |
|||
return input_step, True, result.stdout, result.stderr |
|||
except subprocess.CalledProcessError as e: |
|||
logger.error(f"命令 '{' '.join(command)}' 执行失败。") |
|||
return input_step, False, e.stdout, e.stderr |
|||
except subprocess.TimeoutExpired as e: |
|||
logger.error(f"处理 '{input_step}' 超时。") |
|||
return input_step, False, e.stdout or "", e.stderr or "" |
|||
except Exception as e: |
|||
logger.error(f"处理 '{input_step}' 时发生意外错误: {e}") |
|||
return input_step, False, "", str(e) |
|||
|
|||
def main(args): |
|||
# 读取名称列表 |
|||
try: |
|||
with open(name_list_path, 'r') as f: |
|||
names = [line.strip() for line in f if line.strip()] # 去除空行 |
|||
with open(args.name_list_path, 'r') as f: |
|||
names = [line.strip() for line in f if line.strip()] |
|||
logger.info(f"从 '{args.name_list_path}' 读取了 {len(names)} 个名称。") |
|||
except FileNotFoundError: |
|||
print(f"Error: File '{name_list_path}' not found.") |
|||
logger.error(f"错误: 文件 '{args.name_list_path}' 未找到。") |
|||
return |
|||
except Exception as e: |
|||
print(f"Error reading file '{name_list_path}': {e}") |
|||
logger.error(f"读取文件 '{args.name_list_path}' 时出错: {e}") |
|||
return |
|||
|
|||
# 遍历名称列表并处理每个 STEP 文件 |
|||
for name in tqdm(names, desc="Processing STEP files"): |
|||
step_dir = os.path.join(step_root_dir, name) |
|||
# 准备 train.py 的通用参数 |
|||
# 注意:从命令行参数或其他配置中获取这些参数通常更好 |
|||
common_train_args = [ |
|||
"--use-normal", |
|||
"--only-zero-surface", |
|||
#"--force-reprocess", |
|||
# 可以添加更多通用参数 |
|||
] |
|||
if args.train_args: |
|||
common_train_args.extend(args.train_args) |
|||
|
|||
# 动态生成 STEP 文件路径(假设只有一个文件) |
|||
step_files = [ |
|||
os.path.join(step_dir, f) |
|||
for f in os.listdir(step_dir) |
|||
if f.endswith(".step") and f.startswith(name) |
|||
] |
|||
tasks = [] |
|||
skipped_count = 0 |
|||
# 准备所有任务 |
|||
for name in names: |
|||
step_dir = os.path.join(args.step_root_dir, name) |
|||
if not os.path.isdir(step_dir): |
|||
logger.warning(f"目录 '{step_dir}' 不存在。跳过 '{name}'。") |
|||
skipped_count += 1 |
|||
continue |
|||
|
|||
if not step_files: |
|||
print(f"Warning: No STEP files found in directory '{step_dir}'. Skipping...") |
|||
continue |
|||
step_files = [] |
|||
try: |
|||
step_files = [ |
|||
os.path.join(step_dir, f) |
|||
for f in os.listdir(step_dir) |
|||
if f.lower().endswith((".step", ".stp")) and f.startswith(name) # 支持 .stp 并忽略大小写 |
|||
] |
|||
except OSError as e: |
|||
logger.warning(f"无法访问目录 '{step_dir}': {e}。跳过 '{name}'。") |
|||
skipped_count += 1 |
|||
continue |
|||
|
|||
# 假设我们只处理第一个匹配的文件 |
|||
input_step = step_files[0] |
|||
if len(step_files) == 0: |
|||
logger.warning(f"在目录 '{step_dir}' 中未找到匹配的 STEP 文件。跳过 '{name}'。") |
|||
skipped_count += 1 |
|||
elif len(step_files) > 1: |
|||
logger.warning(f"在目录 '{step_dir}' 中找到多个匹配的 STEP 文件,将使用第一个: {step_files[0]}。") |
|||
tasks.append(step_files[0]) |
|||
else: |
|||
tasks.append(step_files[0]) |
|||
|
|||
# 构造子进程命令 |
|||
command = [ |
|||
"python", "train.py", |
|||
"--use-normal", |
|||
"-i", input_step, # 输入文件路径 |
|||
] |
|||
if not tasks: |
|||
logger.info("没有找到需要处理的有效 STEP 文件。") |
|||
return |
|||
|
|||
# 调用子进程运行命令 |
|||
try: |
|||
result = subprocess.run( |
|||
command, |
|||
capture_output=True, |
|||
text=True, |
|||
check=True # 如果返回非零退出码,则抛出 CalledProcessError |
|||
) |
|||
print(f"Processed '{input_step}' successfully.") |
|||
print("STDOUT:", result.stdout) |
|||
print("STDERR:", result.stderr) |
|||
except subprocess.CalledProcessError as e: |
|||
print(f"Error processing '{input_step}': {e}") |
|||
print("STDOUT:", e.stdout) |
|||
print("STDERR:", e.stderr) |
|||
except Exception as e: |
|||
print(f"Unexpected error processing '{input_step}': {e}") |
|||
logger.info(f"准备处理 {len(tasks)} 个 STEP 文件,跳过了 {skipped_count} 个名称。") |
|||
|
|||
success_count = 0 |
|||
failure_count = 0 |
|||
|
|||
# 使用 ProcessPoolExecutor 进行并行处理 |
|||
with ProcessPoolExecutor(max_workers=args.workers) as executor: |
|||
# 提交所有任务 |
|||
futures = { |
|||
executor.submit(run_training_process, task_path, args.train_script, common_train_args): task_path |
|||
for task_path in tasks |
|||
} |
|||
|
|||
# 使用 tqdm 显示进度并处理结果 |
|||
for future in tqdm(as_completed(futures), total=len(tasks), desc="运行训练"): |
|||
input_path = futures[future] |
|||
try: |
|||
input_step, success, stdout, stderr = future.result() |
|||
if success: |
|||
success_count += 1 |
|||
# 可以选择记录成功的 stdout/stderr,但通常只记录失败的更有用 |
|||
# logger.debug(f"成功处理 '{input_step}'. STDOUT:\n{stdout}\nSTDERR:\n{stderr}") |
|||
else: |
|||
failure_count += 1 |
|||
logger.error(f"处理 '{input_step}' 失败。STDOUT:\n{stdout}\nSTDERR:\n{stderr}") |
|||
except Exception as e: |
|||
failure_count += 1 |
|||
logger.error(f"处理任务 '{input_path}' 时获取结果失败: {e}") |
|||
|
|||
logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}, 跳过: {skipped_count}") |
|||
|
|||
if __name__ == '__main__': |
|||
main() |
|||
parser = argparse.ArgumentParser(description="批量运行 train.py 处理 STEP 文件。") |
|||
parser.add_argument('--step-root-dir', type=str, default="/home/wch/brep2sdf/data/step", |
|||
help="包含 STEP 子目录的根目录。") |
|||
parser.add_argument('--name-list-path', type=str, default="/home/wch/brep2sdf/data/name_list.txt", |
|||
help="包含要处理的名称列表的文件路径。") |
|||
parser.add_argument('--train-script', type=str, default="train.py", |
|||
help="要执行的训练脚本路径。") |
|||
parser.add_argument('--workers', type=int, default=os.cpu_count(), |
|||
help="用于并行处理的工作进程数。") |
|||
parser.add_argument('--train-args', nargs='*', |
|||
help="传递给 train.py 的额外参数 (例如 --epochs 10 --batch-size 32)。") |
|||
|
|||
args = parser.parse_args() |
|||
main(args) |
Loading…
Reference in new issue