From 6a1e871e7bff66310bd6cd010df4d80a9332c7c6 Mon Sep 17 00:00:00 2001 From: mckay Date: Thu, 17 Apr 2025 15:12:08 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E8=AE=AD=E7=BB=83=E8=84=9A?= =?UTF-8?q?=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/batch_train.py | 142 +++++++++++++++++++++++++++--- brep2sdf/config/default_config.py | 4 +- brep2sdf/train.py | 2 + 3 files changed, 135 insertions(+), 13 deletions(-) diff --git a/brep2sdf/batch_train.py b/brep2sdf/batch_train.py index 645986f..86d2678 100644 --- a/brep2sdf/batch_train.py +++ b/brep2sdf/batch_train.py @@ -9,6 +9,22 @@ from tqdm import tqdm logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) +# ================================= +# utils +def get_namelist(path): + try: + with open(path, 'r') as f: + names = [line.strip() for line in f if line.strip()] + logger.info(f"从 '{path}' 读取了 {len(names)} 个名称。") + return names + except FileNotFoundError: + logger.error(f"错误: 文件 '{path}' 未找到。") + return + except Exception as e: + logger.error(f"读取文件 '{path}' 时出错: {e}") + return + + def run_training_process(input_step: str, train_script: str, common_args: list) -> tuple[str, bool, str, str]: """ 为单个 STEP 文件运行 train.py 子进程。 @@ -45,18 +61,9 @@ def run_training_process(input_step: str, train_script: str, common_args: list) logger.error(f"处理 '{input_step}' 时发生意外错误: {e}") return input_step, False, "", str(e) -def main(args): +def batch_train(args): # 读取名称列表 - try: - with open(args.name_list_path, 'r') as f: - names = [line.strip() for line in f if line.strip()] - logger.info(f"从 '{args.name_list_path}' 读取了 {len(names)} 个名称。") - except FileNotFoundError: - logger.error(f"错误: 文件 '{args.name_list_path}' 未找到。") - return - except Exception as e: - logger.error(f"读取文件 '{args.name_list_path}' 时出错: {e}") - return + names = get_namelist(args.name_list_path) # 准备 train.py 的通用参数 # 注意:从命令行参数或其他配置中获取这些参数通常更好 @@ -135,10 +142,123 @@ def main(args): logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}, 跳过: {skipped_count}") +#=========================== +# Iso +def run_isosurfacing_process(input_path, output_dir, use_gpu=True): + """ + 运行 IsoSurfacing.py 脚本生成等值面。 + + 参数: + - input_path: 输入 .pt 文件路径 + - output_dir: 输出目录 + - use_gpu: 是否使用 GPU + + 返回: + - input_path: 输入文件路径 + - success: 是否成功 + - stdout: 标准输出 + - stderr: 标准错误输出 + """ + try: + # 构造输出文件名 + base_name = os.path.splitext(os.path.basename(input_path))[0] + output_path = os.path.join(output_dir, f"{base_name}.ply") + + # 构造命令 + command = [ + "python", "IsoSurfacing.py", + "-i", input_path, + "-o", output_path + ] + if use_gpu: + command.append("--use-gpu") + + # 执行命令 + result = subprocess.run( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=300 # 设置超时时间(秒) + ) + + # 检查是否成功 + if result.returncode == 0: + return input_path, True, result.stdout, result.stderr + else: + return input_path, False, result.stdout, result.stderr + + except subprocess.TimeoutExpired: + return input_path, False, "", "处理超时。" + except Exception as e: + return input_path, False, "", str(e) + +def batch_Iso(args): + # python IsoSurfacing.py -i /home/wch/brep2sdf/data/output_data/00000054.pt -o /home/wch/brep2sdf/data/output_data/00000054_3.ply --use-gpu + """ + 批量处理 .pt 文件,生成等值面并保存为 .ply 文件。 + """ + names = get_namelist(args.name_list_path) + # 检查输入目录是否存在 + if not os.path.exists(args.pt_dir): + logger.error(f"错误: 输入目录 '{args.pt_dir}' 不存在。") + return + # 获取所有 .pt 文件 + try: + pt_files = [ + os.path.join(args.pt_dir, f"{name}.pt") + for name in names + ] + pt_files = [f for f in pt_files if os.path.exists(f)] + except OSError as e: + logger.error(f"无法访问输入目录 '{args.pt_dir}': {e}") + return + + if not pt_files: + logger.info(f"输入目录 '{args.pt_dir}' 中未找到任何 .pt 文件。") + return + + logger.info(f"在输入目录中找到 {len(pt_files)} 个 .pt 文件。") + + success_count = 0 + failure_count = 0 + + # 使用 ProcessPoolExecutor 进行并行处理 + with ProcessPoolExecutor(max_workers=args.workers) as executor: + # 提交所有任务 + futures = { + executor.submit(run_isosurfacing_process, pt_path, args.pt_dir, True): pt_path + for pt_path in pt_files + } + + # 使用 tqdm 显示进度并处理结果 + for future in tqdm(as_completed(futures), total=len(pt_files), desc="生成等值面"): + input_path = futures[future] + try: + input_pt, success, stdout, stderr = future.result() + if success: + success_count += 1 + # 可以选择记录成功的 stdout/stderr,但通常只记录失败的更有用 + # logger.debug(f"成功处理 '{input_pt}'. STDOUT:\n{stdout}\nSTDERR:\n{stderr}") + else: + failure_count += 1 + logger.error(f"处理 '{input_pt}' 失败。STDOUT:\n{stdout}\nSTDERR:\n{stderr}") + except Exception as e: + failure_count += 1 + logger.error(f"处理任务 '{input_path}' 时获取结果失败: {e}") + + logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}") +def main(args): + #batch_train(args) + batch_Iso(args) + + if __name__ == '__main__': parser = argparse.ArgumentParser(description="批量运行 train.py 处理 STEP 文件。") parser.add_argument('--step-root-dir', type=str, default="/home/wch/brep2sdf/data/step", help="包含 STEP 子目录的根目录。") + parser.add_argument('--pt-dir', type=str, default="/home/wch/brep2sdf/data/output_data", + help="包含 pt 文件的根目录。") parser.add_argument('--name-list-path', type=str, default="/home/wch/brep2sdf/data/name_list.txt", help="包含要处理的名称列表的文件路径。") parser.add_argument('--train-script', type=str, default="train.py", diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py index a4b9d1e..921223e 100644 --- a/brep2sdf/config/default_config.py +++ b/brep2sdf/config/default_config.py @@ -47,7 +47,7 @@ class TrainConfig: # 基本训练参数 batch_size: int = 8 num_workers: int = 4 - num_epochs: int = 1 + num_epochs: int = 1500 learning_rate: float = 0.0001 min_lr: float = 1e-5 weight_decay: float = 0.01 @@ -62,7 +62,7 @@ class TrainConfig: warmup_epochs: int = 5 # 保存和验证 - save_freq: int = 20 # 每多少个epoch保存一次 + save_freq: int = 100 # 每多少个epoch保存一次 val_freq: int = 1 # 每多少个epoch验证一次 # 保存路径 diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 20ef0ea..802056e 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -111,6 +111,8 @@ class Trainer: self.loss_manager = LossManager(ablation="none") + logger.info(f"初始化完成,正在处理模型 {self.model_name}") + def build_tree(self,surf_bbox, max_depth=6): num_faces = surf_bbox.shape[0]