import os import sys import time project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) sys.path.append(project_dir) os.chdir(project_dir) import torch import numpy as np from torch.utils.tensorboard import SummaryWriter from torch.optim.lr_scheduler import StepLR from tqdm import tqdm from utils.logger import logger from data_loader import NHREP_Dataset from loss import LossManager from model.network import NHRepNet # 导入 NHRepNet class NHREPNet_Training: def __init__(self, data_dir, name_prefix: str, if_baseline: bool = False, if_feature_sample: bool = False): self.dataset = NHREP_Dataset(data_dir, name_prefix, if_baseline, if_feature_sample) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 初始化模型 d_in = 6 # 输入维度,例如 3D 坐标 dims_sdf = [256, 256, 256] # 隐藏层维度 csg_tree, _ = self.dataset.get_csg_tree() self.loss_manager = LossManager() self.model = NHRepNet(d_in, dims_sdf, csg_tree).to(self.device) # 实例化模型并移动到设备 self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001) # Adam 优化器 self.scheduler = StepLR(self.optimizer, step_size=1000, gamma=0.1) # 学习率调度器 self.nepochs = 15000 # 训练轮数 self.writer = SummaryWriter() # TensorBoard 记录器 def run_nhrepnet_training(self): logger.info("开始训练") self.model.train() # 设置模型为训练模式 for epoch in range(self.nepochs): # 开始训练循环 try: self.train_one_epoch(epoch) self.scheduler.step() # 更新学习率 except Exception as e: logger.error(f"训练过程中发生错误: {str(e)}") break def train_one_epoch(self, epoch): logger.info(f"Epoch {epoch}/{self.nepochs} 开始") total_loss = 0.0 # 获取输入数据 input_data = self.dataset.get_data().to(self.device) # 获取数据并移动到设备 logger.info(f"输入数据: {input_data.shape}") # 前向传播 outputs = self.model(input_data) # 使用模型进行前向传播 logger.info(f"输出数据: {outputs.shape}") # 计算损失 loss = self.loss_manager.compute_loss(outputs) # 计算损失 total_loss += loss.item() # 反向传播 self.optimizer.zero_grad() # 清空梯度 loss.backward() # 反向传播 self.optimizer.step() # 更新参数 avg_loss = total_loss logger.info(f'Epoch [{epoch}/{self.nepochs}], Average Loss: {avg_loss:.4f}') self.writer.add_scalar('Loss/train', avg_loss, epoch) # 记录损失到 TensorBoard if __name__ == "__main__": data_dir = '../data/input_data' # 数据目录 name_prefix = 'broken_bullet_50k' train = NHREPNet_Training(data_dir, name_prefix, if_baseline=True, if_feature_sample=False) train.run_nhrepnet_training()