Browse Source

feat: Add LossManager for manifold loss computation

- Created `loss.py` with a `LossManager` class to handle loss calculation
- Integrated `LossManager` into the training pipeline in `train.py`
- Implemented a basic manifold loss computation using mean absolute value
NH-Rep
mckay 4 months ago
parent
commit
6be600a52f
  1. 20
      code/conversion/loss.py
  2. 4
      code/conversion/train.py

20
code/conversion/loss.py

@ -0,0 +1,20 @@
class LossManager:
def __init__(self):
pass
def compute_loss(self, outputs):
"""
计算流型损失的逻辑
:param outputs: 模型的输出
:return: 计算得到的流型损失值
"""
# 计算流型损失(这里使用均方误差作为示例)
manifold_loss = (outputs.abs()).mean() # 计算流型损失
return manifold_loss

4
code/conversion/train.py

@ -13,6 +13,7 @@ from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
from utils.logger import logger
from data_loader import NHREP_Dataset
from loss import LossManager
from model.network import NHRepNet # 导入 NHRepNet
class NHREPNet_Training:
@ -24,6 +25,7 @@ class NHREPNet_Training:
d_in = 6 # 输入维度,例如 3D 坐标
dims_sdf = [64, 64, 64] # 隐藏层维度
csg_tree, _ = self.dataset.get_csg_tree()
self.loss_manager = LossManager()
self.model = NHRepNet(d_in, dims_sdf, csg_tree).to(self.device) # 实例化模型并移动到设备
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001) # Adam 优化器
@ -54,7 +56,7 @@ class NHREPNet_Training:
outputs = self.model(input_data) # 使用模型进行前向传播
# 计算损失
loss = self.compute_loss(outputs) # 计算损失
loss = self.loss_manager.compute_loss(outputs) # 计算损失
total_loss += loss.item()
# 反向传播

Loading…
Cancel
Save