From 6be600a52fd3b392d7e9c30b0384d5f56a48d4a9 Mon Sep 17 00:00:00 2001 From: mckay Date: Wed, 19 Feb 2025 21:54:07 +0800 Subject: [PATCH] feat: Add LossManager for manifold loss computation - Created `loss.py` with a `LossManager` class to handle loss calculation - Integrated `LossManager` into the training pipeline in `train.py` - Implemented a basic manifold loss computation using mean absolute value --- code/conversion/loss.py | 20 ++++++++++++++++++++ code/conversion/train.py | 4 +++- 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 code/conversion/loss.py diff --git a/code/conversion/loss.py b/code/conversion/loss.py new file mode 100644 index 0000000..c70c433 --- /dev/null +++ b/code/conversion/loss.py @@ -0,0 +1,20 @@ + + + + +class LossManager: + def __init__(self): + pass + + def compute_loss(self, outputs): + """ + 计算流型损失的逻辑 + + :param outputs: 模型的输出 + :return: 计算得到的流型损失值 + """ + + # 计算流型损失(这里使用均方误差作为示例) + manifold_loss = (outputs.abs()).mean() # 计算流型损失 + return manifold_loss + diff --git a/code/conversion/train.py b/code/conversion/train.py index b6cac0e..51fb2b0 100644 --- a/code/conversion/train.py +++ b/code/conversion/train.py @@ -13,6 +13,7 @@ from torch.optim.lr_scheduler import StepLR from tqdm import tqdm from utils.logger import logger from data_loader import NHREP_Dataset +from loss import LossManager from model.network import NHRepNet # 导入 NHRepNet class NHREPNet_Training: @@ -24,6 +25,7 @@ class NHREPNet_Training: d_in = 6 # 输入维度,例如 3D 坐标 dims_sdf = [64, 64, 64] # 隐藏层维度 csg_tree, _ = self.dataset.get_csg_tree() + self.loss_manager = LossManager() self.model = NHRepNet(d_in, dims_sdf, csg_tree).to(self.device) # 实例化模型并移动到设备 self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001) # Adam 优化器 @@ -54,7 +56,7 @@ class NHREPNet_Training: outputs = self.model(input_data) # 使用模型进行前向传播 # 计算损失 - loss = self.compute_loss(outputs) # 计算损失 + loss = self.loss_manager.compute_loss(outputs) # 计算损失 total_loss += loss.item() # 反向传播