|
|
@ -12,6 +12,7 @@ from brep2sdf.networks.network import Net |
|
|
|
from brep2sdf.networks.octree import OctreeNode |
|
|
|
from brep2sdf.networks.loss import LossManager |
|
|
|
from brep2sdf.networks.patch_graph import PatchGraph |
|
|
|
from brep2sdf.networks.sample import NormalPerPoint |
|
|
|
from brep2sdf.utils.logger import logger |
|
|
|
|
|
|
|
|
|
|
@ -142,6 +143,11 @@ class Trainer: |
|
|
|
self.loss_manager = LossManager(ablation="none") |
|
|
|
logger.gpu_memory_stats("训练器初始化后") |
|
|
|
|
|
|
|
self.sampler = NormalPerPoint( |
|
|
|
global_sigma=0.1, # 全局采样标准差 |
|
|
|
local_sigma=0.01 # 局部采样标准差 |
|
|
|
) |
|
|
|
|
|
|
|
logger.info(f"初始化完成,正在处理模型 {self.model_name}") |
|
|
|
|
|
|
|
|
|
|
@ -200,7 +206,9 @@ class Trainer: |
|
|
|
total_loss = 0.0 |
|
|
|
total_loss_details = { |
|
|
|
"manifold": 0.0, |
|
|
|
"normals": 0.0 |
|
|
|
"normals": 0.0, |
|
|
|
"eikonal": 0.0, |
|
|
|
"offsurface": 0.0 |
|
|
|
} |
|
|
|
accumulated_loss = 0.0 # 新增:用于累积多个step的loss |
|
|
|
|
|
|
@ -208,19 +216,21 @@ class Trainer: |
|
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
|
|
for step, surf_points in enumerate(self.data['surf_ncs']): |
|
|
|
points = torch.tensor(surf_points, device=self.device) |
|
|
|
gt_sdf = torch.zeros(points.shape[0], device=self.device) |
|
|
|
mnfld_points = torch.tensor(surf_points, device=self.device) |
|
|
|
nonmnfld_pnts = self.sampler.get_points(mnfld_points.unsqueeze(0)).squeeze(0) # 生成非流形点 |
|
|
|
gt_sdf = torch.zeros(mnfld_points.shape[0], device=self.device) |
|
|
|
normals = None |
|
|
|
if args.use_normal: |
|
|
|
normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device) |
|
|
|
|
|
|
|
# --- 准备模型输入,启用梯度 --- |
|
|
|
points.requires_grad_(True) # 在检查之后启用梯度 |
|
|
|
mnfld_points.requires_grad_(True) # 在检查之后启用梯度 |
|
|
|
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
|
|
|
|
|
# --- 前向传播 --- |
|
|
|
self.optimizer.zero_grad() |
|
|
|
pred_sdf = self.model.forward_training_volumes(points, step) |
|
|
|
logger.debug(f"pred_sdf:{pred_sdf}") |
|
|
|
mnfld_pred = self.model.forward_training_volumes(mnfld_points, step) |
|
|
|
nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step) |
|
|
|
|
|
|
|
if self.debug_mode: |
|
|
|
# --- 检查前向传播的输出 --- |
|
|
@ -230,10 +240,12 @@ class Trainer: |
|
|
|
try: |
|
|
|
if args.use_normal: |
|
|
|
loss, loss_details = self.loss_manager.compute_loss( |
|
|
|
points, |
|
|
|
mnfld_points, |
|
|
|
nonmnfld_pnts, |
|
|
|
normals, |
|
|
|
gt_sdf, |
|
|
|
pred_sdf |
|
|
|
mnfld_pred, |
|
|
|
nonmnfld_pred |
|
|
|
) |
|
|
|
else: |
|
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) |
|
|
@ -278,6 +290,7 @@ class Trainer: |
|
|
|
return total_loss # 返回平均损失而非累计值 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_epoch(self, epoch: int) -> float: |
|
|
|
# --- 1. 检查输入数据 --- |
|
|
|
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) |
|
|
|