Browse Source

增加非流型点训练

final
mckay 1 month ago
parent
commit
19dfe8fcc7
  1. 37
      brep2sdf/train.py

37
brep2sdf/train.py

@ -13,6 +13,7 @@ from brep2sdf.networks.octree import OctreeNode
from brep2sdf.networks.loss import LossManager
from brep2sdf.networks.patch_graph import PatchGraph
from brep2sdf.networks.sample import NormalPerPoint
from brep2sdf.networks.learning_rate import LearningRateScheduler
from brep2sdf.utils.logger import logger
@ -140,6 +141,8 @@ class Trainer:
weight_decay=config.train.weight_decay
)
#self.scheduler = LearningRateScheduler(self.conf.get_list('train.learning_rate_schedule'), self.conf.get_float('train.weight_decay'), self.model.parameters())
self.loss_manager = LossManager(ablation="none")
logger.gpu_memory_stats("训练器初始化后")
@ -222,6 +225,7 @@ class Trainer:
normals = None
if args.use_normal:
normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device)
logger.debug(normals)
# --- 准备模型输入,启用梯度 ---
mnfld_points.requires_grad_(True) # 在检查之后启用梯度
@ -232,6 +236,9 @@ class Trainer:
mnfld_pred = self.model.forward_training_volumes(mnfld_points, step)
nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step)
logger.print_tensor_stats("mnfld_pred",mnfld_pred)
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred)
if self.debug_mode:
# --- 检查前向传播的输出 ---
logger.gpu_memory_stats("前向传播后")
@ -302,7 +309,7 @@ class Trainer:
self.model.train()
total_loss = 0.0
step = 0 # 如果你的训练是分批次的,这里应该用批次索引
batch_size = 10240 # 设置合适的batch大小
batch_size = 8192 # 设置合适的batch大小
# 将数据分成多个batch
num_points = self.sdf_data.shape[0]
@ -311,7 +318,8 @@ class Trainer:
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, num_points)
points = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点
mnfld_pnts = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点
nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0)).squeeze(0) # 生成非流形点
gt_sdf = self.sdf_data[start_idx:end_idx, -1].clone().detach() # 取最后一列作为SDF真值
normals = None
if args.use_normal:
@ -322,19 +330,28 @@ class Trainer:
# 执行检查
if self.debug_mode:
if check_tensor(points, "Input Points", epoch, step): return float('inf')
if check_tensor(mnfld_pnts, "Input Points", epoch, step): return float('inf')
if check_tensor(gt_sdf, "Input GT SDF", epoch, step): return float('inf')
if args.use_normal:
# 只有在请求法线时才检查 normals
if check_tensor(normals, "Input Normals", epoch, step): return float('inf')
logger.debug(normals)
logger.print_tensor_stats("normals-x",normals[0])
logger.print_tensor_stats("normals-y",normals[1])
logger.print_tensor_stats("normals-z",normals[2])
# --- 准备模型输入,启用梯度 ---
points.requires_grad_(True) # 在检查之后启用梯度
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
# --- 前向传播 ---
self.optimizer.zero_grad()
pred_sdf = self.model(points)
mnfld_pred = self.model(mnfld_pnts)
nonmnfld_pred = self.model(nonmnfld_pnts)
logger.print_tensor_stats("mnfld_pred",mnfld_pred)
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred)
if self.debug_mode:
# --- 检查前向传播的输出 ---
@ -356,10 +373,12 @@ class Trainer:
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss")
logger.gpu_memory_stats("计算损失前")
loss, loss_details = self.loss_manager.compute_loss(
points,
mnfld_pnts,
nonmnfld_pnts,
normals, # 传递检查过的 normals
gt_sdf,
pred_sdf
mnfld_pred,
nonmnfld_pred
)
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
@ -448,8 +467,8 @@ class Trainer:
for epoch in range(start_epoch, self.config.train.num_epochs + 1):
# 训练一个epoch
train_loss = self.train_epoch_stage1(epoch)
#train_loss = self.train_epoch(epoch)
#train_loss = self.train_epoch_stage1(epoch)
train_loss = self.train_epoch(epoch)
# 验证
'''

Loading…
Cancel
Save