7 changed files with 723 additions and 172 deletions
@ -1,67 +1,152 @@ |
|||||
import os |
import os |
||||
import subprocess |
import subprocess |
||||
|
import argparse |
||||
|
import logging |
||||
|
from concurrent.futures import ProcessPoolExecutor, as_completed |
||||
from tqdm import tqdm |
from tqdm import tqdm |
||||
|
|
||||
|
# 配置日志记录 |
||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
||||
|
logger = logging.getLogger(__name__) |
||||
|
|
||||
def main(): |
def run_training_process(input_step: str, train_script: str, common_args: list) -> tuple[str, bool, str, str]: |
||||
# 定义 STEP 文件目录和名称列表文件路径 |
""" |
||||
step_root_dir = "/home/wch/brep2sdf/data/step" |
为单个 STEP 文件运行 train.py 子进程。 |
||||
name_list_path = "/home/wch/brep2sdf/data/name_list.txt" |
|
||||
|
|
||||
|
Args: |
||||
|
input_step: 输入 STEP 文件的路径。 |
||||
|
train_script: train.py 脚本的路径。 |
||||
|
common_args: 传递给 train.py 的通用参数列表。 |
||||
|
|
||||
|
Returns: |
||||
|
Tuple: (输入文件路径, 是否成功, stdout, stderr) |
||||
|
""" |
||||
|
command = [ |
||||
|
"python", train_script, |
||||
|
*common_args, |
||||
|
"-i", input_step, |
||||
|
] |
||||
|
try: |
||||
|
result = subprocess.run( |
||||
|
command, |
||||
|
capture_output=True, |
||||
|
text=True, |
||||
|
check=True, # 如果返回非零退出码,则抛出 CalledProcessError |
||||
|
timeout=600 # 添加超时设置(例如10分钟) |
||||
|
) |
||||
|
return input_step, True, result.stdout, result.stderr |
||||
|
except subprocess.CalledProcessError as e: |
||||
|
logger.error(f"命令 '{' '.join(command)}' 执行失败。") |
||||
|
return input_step, False, e.stdout, e.stderr |
||||
|
except subprocess.TimeoutExpired as e: |
||||
|
logger.error(f"处理 '{input_step}' 超时。") |
||||
|
return input_step, False, e.stdout or "", e.stderr or "" |
||||
|
except Exception as e: |
||||
|
logger.error(f"处理 '{input_step}' 时发生意外错误: {e}") |
||||
|
return input_step, False, "", str(e) |
||||
|
|
||||
|
def main(args): |
||||
# 读取名称列表 |
# 读取名称列表 |
||||
try: |
try: |
||||
with open(name_list_path, 'r') as f: |
with open(args.name_list_path, 'r') as f: |
||||
names = [line.strip() for line in f if line.strip()] # 去除空行 |
names = [line.strip() for line in f if line.strip()] |
||||
|
logger.info(f"从 '{args.name_list_path}' 读取了 {len(names)} 个名称。") |
||||
except FileNotFoundError: |
except FileNotFoundError: |
||||
print(f"Error: File '{name_list_path}' not found.") |
logger.error(f"错误: 文件 '{args.name_list_path}' 未找到。") |
||||
return |
return |
||||
except Exception as e: |
except Exception as e: |
||||
print(f"Error reading file '{name_list_path}': {e}") |
logger.error(f"读取文件 '{args.name_list_path}' 时出错: {e}") |
||||
return |
return |
||||
|
|
||||
# 遍历名称列表并处理每个 STEP 文件 |
# 准备 train.py 的通用参数 |
||||
for name in tqdm(names, desc="Processing STEP files"): |
# 注意:从命令行参数或其他配置中获取这些参数通常更好 |
||||
step_dir = os.path.join(step_root_dir, name) |
common_train_args = [ |
||||
|
"--use-normal", |
||||
|
"--only-zero-surface", |
||||
|
#"--force-reprocess", |
||||
|
# 可以添加更多通用参数 |
||||
|
] |
||||
|
if args.train_args: |
||||
|
common_train_args.extend(args.train_args) |
||||
|
|
||||
# 动态生成 STEP 文件路径(假设只有一个文件) |
tasks = [] |
||||
step_files = [ |
skipped_count = 0 |
||||
os.path.join(step_dir, f) |
# 准备所有任务 |
||||
for f in os.listdir(step_dir) |
for name in names: |
||||
if f.endswith(".step") and f.startswith(name) |
step_dir = os.path.join(args.step_root_dir, name) |
||||
] |
if not os.path.isdir(step_dir): |
||||
|
logger.warning(f"目录 '{step_dir}' 不存在。跳过 '{name}'。") |
||||
|
skipped_count += 1 |
||||
|
continue |
||||
|
|
||||
if not step_files: |
step_files = [] |
||||
print(f"Warning: No STEP files found in directory '{step_dir}'. Skipping...") |
try: |
||||
continue |
step_files = [ |
||||
|
os.path.join(step_dir, f) |
||||
|
for f in os.listdir(step_dir) |
||||
|
if f.lower().endswith((".step", ".stp")) and f.startswith(name) # 支持 .stp 并忽略大小写 |
||||
|
] |
||||
|
except OSError as e: |
||||
|
logger.warning(f"无法访问目录 '{step_dir}': {e}。跳过 '{name}'。") |
||||
|
skipped_count += 1 |
||||
|
continue |
||||
|
|
||||
# 假设我们只处理第一个匹配的文件 |
if len(step_files) == 0: |
||||
input_step = step_files[0] |
logger.warning(f"在目录 '{step_dir}' 中未找到匹配的 STEP 文件。跳过 '{name}'。") |
||||
|
skipped_count += 1 |
||||
|
elif len(step_files) > 1: |
||||
|
logger.warning(f"在目录 '{step_dir}' 中找到多个匹配的 STEP 文件,将使用第一个: {step_files[0]}。") |
||||
|
tasks.append(step_files[0]) |
||||
|
else: |
||||
|
tasks.append(step_files[0]) |
||||
|
|
||||
# 构造子进程命令 |
if not tasks: |
||||
command = [ |
logger.info("没有找到需要处理的有效 STEP 文件。") |
||||
"python", "train.py", |
return |
||||
"--use-normal", |
|
||||
"-i", input_step, # 输入文件路径 |
|
||||
] |
|
||||
|
|
||||
# 调用子进程运行命令 |
logger.info(f"准备处理 {len(tasks)} 个 STEP 文件,跳过了 {skipped_count} 个名称。") |
||||
try: |
|
||||
result = subprocess.run( |
success_count = 0 |
||||
command, |
failure_count = 0 |
||||
capture_output=True, |
|
||||
text=True, |
# 使用 ProcessPoolExecutor 进行并行处理 |
||||
check=True # 如果返回非零退出码,则抛出 CalledProcessError |
with ProcessPoolExecutor(max_workers=args.workers) as executor: |
||||
) |
# 提交所有任务 |
||||
print(f"Processed '{input_step}' successfully.") |
futures = { |
||||
print("STDOUT:", result.stdout) |
executor.submit(run_training_process, task_path, args.train_script, common_train_args): task_path |
||||
print("STDERR:", result.stderr) |
for task_path in tasks |
||||
except subprocess.CalledProcessError as e: |
} |
||||
print(f"Error processing '{input_step}': {e}") |
|
||||
print("STDOUT:", e.stdout) |
|
||||
print("STDERR:", e.stderr) |
|
||||
except Exception as e: |
|
||||
print(f"Unexpected error processing '{input_step}': {e}") |
|
||||
|
|
||||
|
# 使用 tqdm 显示进度并处理结果 |
||||
|
for future in tqdm(as_completed(futures), total=len(tasks), desc="运行训练"): |
||||
|
input_path = futures[future] |
||||
|
try: |
||||
|
input_step, success, stdout, stderr = future.result() |
||||
|
if success: |
||||
|
success_count += 1 |
||||
|
# 可以选择记录成功的 stdout/stderr,但通常只记录失败的更有用 |
||||
|
# logger.debug(f"成功处理 '{input_step}'. STDOUT:\n{stdout}\nSTDERR:\n{stderr}") |
||||
|
else: |
||||
|
failure_count += 1 |
||||
|
logger.error(f"处理 '{input_step}' 失败。STDOUT:\n{stdout}\nSTDERR:\n{stderr}") |
||||
|
except Exception as e: |
||||
|
failure_count += 1 |
||||
|
logger.error(f"处理任务 '{input_path}' 时获取结果失败: {e}") |
||||
|
|
||||
|
logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}, 跳过: {skipped_count}") |
||||
|
|
||||
if __name__ == '__main__': |
if __name__ == '__main__': |
||||
main() |
parser = argparse.ArgumentParser(description="批量运行 train.py 处理 STEP 文件。") |
||||
|
parser.add_argument('--step-root-dir', type=str, default="/home/wch/brep2sdf/data/step", |
||||
|
help="包含 STEP 子目录的根目录。") |
||||
|
parser.add_argument('--name-list-path', type=str, default="/home/wch/brep2sdf/data/name_list.txt", |
||||
|
help="包含要处理的名称列表的文件路径。") |
||||
|
parser.add_argument('--train-script', type=str, default="train.py", |
||||
|
help="要执行的训练脚本路径。") |
||||
|
parser.add_argument('--workers', type=int, default=os.cpu_count(), |
||||
|
help="用于并行处理的工作进程数。") |
||||
|
parser.add_argument('--train-args', nargs='*', |
||||
|
help="传递给 train.py 的额外参数 (例如 --epochs 10 --batch-size 32)。") |
||||
|
|
||||
|
args = parser.parse_args() |
||||
|
main(args) |
||||
Loading…
Reference in new issue