You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

79 lines
3.0 KiB

import os
import sys
import time
project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
sys.path.append(project_dir)
os.chdir(project_dir)
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
from utils.logger import logger
from data_loader import NHREP_Dataset
from loss import LossManager
from model.network import NHRepNet # 导入 NHRepNet
class NHREPNet_Training:
def __init__(self, data_dir, name_prefix: str, if_baseline: bool = False, if_feature_sample: bool = False):
self.dataset = NHREP_Dataset(data_dir, name_prefix, if_baseline, if_feature_sample)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 初始化模型
d_in = 6 # 输入维度,例如 3D 坐标
dims_sdf = [256, 256, 256] # 隐藏层维度
csg_tree, _ = self.dataset.get_csg_tree()
self.loss_manager = LossManager()
self.model = NHRepNet(d_in, dims_sdf, csg_tree).to(self.device) # 实例化模型并移动到设备
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001) # Adam 优化器
self.scheduler = StepLR(self.optimizer, step_size=1000, gamma=0.1) # 学习率调度器
self.nepochs = 15000 # 训练轮数
self.writer = SummaryWriter() # TensorBoard 记录器
def run_nhrepnet_training(self):
logger.info("开始训练")
self.model.train() # 设置模型为训练模式
for epoch in range(self.nepochs): # 开始训练循环
try:
self.train_one_epoch(epoch)
self.scheduler.step() # 更新学习率
except Exception as e:
logger.error(f"训练过程中发生错误: {str(e)}")
break
def train_one_epoch(self, epoch):
logger.info(f"Epoch {epoch}/{self.nepochs} 开始")
total_loss = 0.0
# 获取输入数据
input_data = self.dataset.get_data().to(self.device) # 获取数据并移动到设备
logger.info(f"输入数据: {input_data.shape}")
# 前向传播
outputs = self.model(input_data) # 使用模型进行前向传播
logger.info(f"输出数据: {outputs.shape}")
# 计算损失
loss = self.loss_manager.compute_loss(outputs) # 计算损失
total_loss += loss.item()
# 反向传播
self.optimizer.zero_grad() # 清空梯度
loss.backward() # 反向传播
self.optimizer.step() # 更新参数
avg_loss = total_loss
logger.info(f'Epoch [{epoch}/{self.nepochs}], Average Loss: {avg_loss:.4f}')
self.writer.add_scalar('Loss/train', avg_loss, epoch) # 记录损失到 TensorBoard
if __name__ == "__main__":
data_dir = '../data/input_data' # 数据目录
name_prefix = 'broken_bullet_50k'
train = NHREPNet_Training(data_dir, name_prefix, if_baseline=True, if_feature_sample=False)
train.run_nhrepnet_training()