Browse Source

优化训练脚本

final
mckay 2 months ago
parent
commit
6a1e871e7b
  1. 142
      brep2sdf/batch_train.py
  2. 4
      brep2sdf/config/default_config.py
  3. 2
      brep2sdf/train.py

142
brep2sdf/batch_train.py

@ -9,6 +9,22 @@ from tqdm import tqdm
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# =================================
# utils
def get_namelist(path):
try:
with open(path, 'r') as f:
names = [line.strip() for line in f if line.strip()]
logger.info(f"'{path}' 读取了 {len(names)} 个名称。")
return names
except FileNotFoundError:
logger.error(f"错误: 文件 '{path}' 未找到。")
return
except Exception as e:
logger.error(f"读取文件 '{path}' 时出错: {e}")
return
def run_training_process(input_step: str, train_script: str, common_args: list) -> tuple[str, bool, str, str]: def run_training_process(input_step: str, train_script: str, common_args: list) -> tuple[str, bool, str, str]:
""" """
为单个 STEP 文件运行 train.py 子进程 为单个 STEP 文件运行 train.py 子进程
@ -45,18 +61,9 @@ def run_training_process(input_step: str, train_script: str, common_args: list)
logger.error(f"处理 '{input_step}' 时发生意外错误: {e}") logger.error(f"处理 '{input_step}' 时发生意外错误: {e}")
return input_step, False, "", str(e) return input_step, False, "", str(e)
def main(args): def batch_train(args):
# 读取名称列表 # 读取名称列表
try: names = get_namelist(args.name_list_path)
with open(args.name_list_path, 'r') as f:
names = [line.strip() for line in f if line.strip()]
logger.info(f"'{args.name_list_path}' 读取了 {len(names)} 个名称。")
except FileNotFoundError:
logger.error(f"错误: 文件 '{args.name_list_path}' 未找到。")
return
except Exception as e:
logger.error(f"读取文件 '{args.name_list_path}' 时出错: {e}")
return
# 准备 train.py 的通用参数 # 准备 train.py 的通用参数
# 注意:从命令行参数或其他配置中获取这些参数通常更好 # 注意:从命令行参数或其他配置中获取这些参数通常更好
@ -135,10 +142,123 @@ def main(args):
logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}, 跳过: {skipped_count}") logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}, 跳过: {skipped_count}")
#===========================
# Iso
def run_isosurfacing_process(input_path, output_dir, use_gpu=True):
"""
运行 IsoSurfacing.py 脚本生成等值面
参数:
- input_path: 输入 .pt 文件路径
- output_dir: 输出目录
- use_gpu: 是否使用 GPU
返回:
- input_path: 输入文件路径
- success: 是否成功
- stdout: 标准输出
- stderr: 标准错误输出
"""
try:
# 构造输出文件名
base_name = os.path.splitext(os.path.basename(input_path))[0]
output_path = os.path.join(output_dir, f"{base_name}.ply")
# 构造命令
command = [
"python", "IsoSurfacing.py",
"-i", input_path,
"-o", output_path
]
if use_gpu:
command.append("--use-gpu")
# 执行命令
result = subprocess.run(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
timeout=300 # 设置超时时间(秒)
)
# 检查是否成功
if result.returncode == 0:
return input_path, True, result.stdout, result.stderr
else:
return input_path, False, result.stdout, result.stderr
except subprocess.TimeoutExpired:
return input_path, False, "", "处理超时。"
except Exception as e:
return input_path, False, "", str(e)
def batch_Iso(args):
# python IsoSurfacing.py -i /home/wch/brep2sdf/data/output_data/00000054.pt -o /home/wch/brep2sdf/data/output_data/00000054_3.ply --use-gpu
"""
批量处理 .pt 文件生成等值面并保存为 .ply 文件
"""
names = get_namelist(args.name_list_path)
# 检查输入目录是否存在
if not os.path.exists(args.pt_dir):
logger.error(f"错误: 输入目录 '{args.pt_dir}' 不存在。")
return
# 获取所有 .pt 文件
try:
pt_files = [
os.path.join(args.pt_dir, f"{name}.pt")
for name in names
]
pt_files = [f for f in pt_files if os.path.exists(f)]
except OSError as e:
logger.error(f"无法访问输入目录 '{args.pt_dir}': {e}")
return
if not pt_files:
logger.info(f"输入目录 '{args.pt_dir}' 中未找到任何 .pt 文件。")
return
logger.info(f"在输入目录中找到 {len(pt_files)} 个 .pt 文件。")
success_count = 0
failure_count = 0
# 使用 ProcessPoolExecutor 进行并行处理
with ProcessPoolExecutor(max_workers=args.workers) as executor:
# 提交所有任务
futures = {
executor.submit(run_isosurfacing_process, pt_path, args.pt_dir, True): pt_path
for pt_path in pt_files
}
# 使用 tqdm 显示进度并处理结果
for future in tqdm(as_completed(futures), total=len(pt_files), desc="生成等值面"):
input_path = futures[future]
try:
input_pt, success, stdout, stderr = future.result()
if success:
success_count += 1
# 可以选择记录成功的 stdout/stderr,但通常只记录失败的更有用
# logger.debug(f"成功处理 '{input_pt}'. STDOUT:\n{stdout}\nSTDERR:\n{stderr}")
else:
failure_count += 1
logger.error(f"处理 '{input_pt}' 失败。STDOUT:\n{stdout}\nSTDERR:\n{stderr}")
except Exception as e:
failure_count += 1
logger.error(f"处理任务 '{input_path}' 时获取结果失败: {e}")
logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}")
def main(args):
#batch_train(args)
batch_Iso(args)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description="批量运行 train.py 处理 STEP 文件。") parser = argparse.ArgumentParser(description="批量运行 train.py 处理 STEP 文件。")
parser.add_argument('--step-root-dir', type=str, default="/home/wch/brep2sdf/data/step", parser.add_argument('--step-root-dir', type=str, default="/home/wch/brep2sdf/data/step",
help="包含 STEP 子目录的根目录。") help="包含 STEP 子目录的根目录。")
parser.add_argument('--pt-dir', type=str, default="/home/wch/brep2sdf/data/output_data",
help="包含 pt 文件的根目录。")
parser.add_argument('--name-list-path', type=str, default="/home/wch/brep2sdf/data/name_list.txt", parser.add_argument('--name-list-path', type=str, default="/home/wch/brep2sdf/data/name_list.txt",
help="包含要处理的名称列表的文件路径。") help="包含要处理的名称列表的文件路径。")
parser.add_argument('--train-script', type=str, default="train.py", parser.add_argument('--train-script', type=str, default="train.py",

4
brep2sdf/config/default_config.py

@ -47,7 +47,7 @@ class TrainConfig:
# 基本训练参数 # 基本训练参数
batch_size: int = 8 batch_size: int = 8
num_workers: int = 4 num_workers: int = 4
num_epochs: int = 1 num_epochs: int = 1500
learning_rate: float = 0.0001 learning_rate: float = 0.0001
min_lr: float = 1e-5 min_lr: float = 1e-5
weight_decay: float = 0.01 weight_decay: float = 0.01
@ -62,7 +62,7 @@ class TrainConfig:
warmup_epochs: int = 5 warmup_epochs: int = 5
# 保存和验证 # 保存和验证
save_freq: int = 20 # 每多少个epoch保存一次 save_freq: int = 100 # 每多少个epoch保存一次
val_freq: int = 1 # 每多少个epoch验证一次 val_freq: int = 1 # 每多少个epoch验证一次
# 保存路径 # 保存路径

2
brep2sdf/train.py

@ -111,6 +111,8 @@ class Trainer:
self.loss_manager = LossManager(ablation="none") self.loss_manager = LossManager(ablation="none")
logger.info(f"初始化完成,正在处理模型 {self.model_name}")
def build_tree(self,surf_bbox, max_depth=6): def build_tree(self,surf_bbox, max_depth=6):
num_faces = surf_bbox.shape[0] num_faces = surf_bbox.shape[0]

Loading…
Cancel
Save