|
@ -13,6 +13,7 @@ from brep2sdf.networks.octree import OctreeNode |
|
|
from brep2sdf.networks.loss import LossManager |
|
|
from brep2sdf.networks.loss import LossManager |
|
|
from brep2sdf.networks.patch_graph import PatchGraph |
|
|
from brep2sdf.networks.patch_graph import PatchGraph |
|
|
from brep2sdf.networks.sample import NormalPerPoint |
|
|
from brep2sdf.networks.sample import NormalPerPoint |
|
|
|
|
|
from brep2sdf.networks.learning_rate import LearningRateScheduler |
|
|
from brep2sdf.utils.logger import logger |
|
|
from brep2sdf.utils.logger import logger |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -140,6 +141,8 @@ class Trainer: |
|
|
weight_decay=config.train.weight_decay |
|
|
weight_decay=config.train.weight_decay |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
#self.scheduler = LearningRateScheduler(self.conf.get_list('train.learning_rate_schedule'), self.conf.get_float('train.weight_decay'), self.model.parameters()) |
|
|
|
|
|
|
|
|
self.loss_manager = LossManager(ablation="none") |
|
|
self.loss_manager = LossManager(ablation="none") |
|
|
logger.gpu_memory_stats("训练器初始化后") |
|
|
logger.gpu_memory_stats("训练器初始化后") |
|
|
|
|
|
|
|
@ -222,6 +225,7 @@ class Trainer: |
|
|
normals = None |
|
|
normals = None |
|
|
if args.use_normal: |
|
|
if args.use_normal: |
|
|
normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device) |
|
|
normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device) |
|
|
|
|
|
logger.debug(normals) |
|
|
|
|
|
|
|
|
# --- 准备模型输入,启用梯度 --- |
|
|
# --- 准备模型输入,启用梯度 --- |
|
|
mnfld_points.requires_grad_(True) # 在检查之后启用梯度 |
|
|
mnfld_points.requires_grad_(True) # 在检查之后启用梯度 |
|
@ -232,6 +236,9 @@ class Trainer: |
|
|
mnfld_pred = self.model.forward_training_volumes(mnfld_points, step) |
|
|
mnfld_pred = self.model.forward_training_volumes(mnfld_points, step) |
|
|
nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step) |
|
|
nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step) |
|
|
|
|
|
|
|
|
|
|
|
logger.print_tensor_stats("mnfld_pred",mnfld_pred) |
|
|
|
|
|
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) |
|
|
|
|
|
|
|
|
if self.debug_mode: |
|
|
if self.debug_mode: |
|
|
# --- 检查前向传播的输出 --- |
|
|
# --- 检查前向传播的输出 --- |
|
|
logger.gpu_memory_stats("前向传播后") |
|
|
logger.gpu_memory_stats("前向传播后") |
|
@ -302,7 +309,7 @@ class Trainer: |
|
|
self.model.train() |
|
|
self.model.train() |
|
|
total_loss = 0.0 |
|
|
total_loss = 0.0 |
|
|
step = 0 # 如果你的训练是分批次的,这里应该用批次索引 |
|
|
step = 0 # 如果你的训练是分批次的,这里应该用批次索引 |
|
|
batch_size = 10240 # 设置合适的batch大小 |
|
|
batch_size = 8192 # 设置合适的batch大小 |
|
|
|
|
|
|
|
|
# 将数据分成多个batch |
|
|
# 将数据分成多个batch |
|
|
num_points = self.sdf_data.shape[0] |
|
|
num_points = self.sdf_data.shape[0] |
|
@ -311,7 +318,8 @@ class Trainer: |
|
|
for batch_idx in range(num_batches): |
|
|
for batch_idx in range(num_batches): |
|
|
start_idx = batch_idx * batch_size |
|
|
start_idx = batch_idx * batch_size |
|
|
end_idx = min((batch_idx + 1) * batch_size, num_points) |
|
|
end_idx = min((batch_idx + 1) * batch_size, num_points) |
|
|
points = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点 |
|
|
mnfld_pnts = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点 |
|
|
|
|
|
nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0)).squeeze(0) # 生成非流形点 |
|
|
gt_sdf = self.sdf_data[start_idx:end_idx, -1].clone().detach() # 取最后一列作为SDF真值 |
|
|
gt_sdf = self.sdf_data[start_idx:end_idx, -1].clone().detach() # 取最后一列作为SDF真值 |
|
|
normals = None |
|
|
normals = None |
|
|
if args.use_normal: |
|
|
if args.use_normal: |
|
@ -322,19 +330,28 @@ class Trainer: |
|
|
|
|
|
|
|
|
# 执行检查 |
|
|
# 执行检查 |
|
|
if self.debug_mode: |
|
|
if self.debug_mode: |
|
|
if check_tensor(points, "Input Points", epoch, step): return float('inf') |
|
|
if check_tensor(mnfld_pnts, "Input Points", epoch, step): return float('inf') |
|
|
if check_tensor(gt_sdf, "Input GT SDF", epoch, step): return float('inf') |
|
|
if check_tensor(gt_sdf, "Input GT SDF", epoch, step): return float('inf') |
|
|
if args.use_normal: |
|
|
if args.use_normal: |
|
|
# 只有在请求法线时才检查 normals |
|
|
# 只有在请求法线时才检查 normals |
|
|
if check_tensor(normals, "Input Normals", epoch, step): return float('inf') |
|
|
if check_tensor(normals, "Input Normals", epoch, step): return float('inf') |
|
|
|
|
|
logger.debug(normals) |
|
|
|
|
|
logger.print_tensor_stats("normals-x",normals[0]) |
|
|
|
|
|
logger.print_tensor_stats("normals-y",normals[1]) |
|
|
|
|
|
logger.print_tensor_stats("normals-z",normals[2]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# --- 准备模型输入,启用梯度 --- |
|
|
# --- 准备模型输入,启用梯度 --- |
|
|
points.requires_grad_(True) # 在检查之后启用梯度 |
|
|
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
|
|
|
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
|
|
|
|
|
|
# --- 前向传播 --- |
|
|
# --- 前向传播 --- |
|
|
self.optimizer.zero_grad() |
|
|
self.optimizer.zero_grad() |
|
|
pred_sdf = self.model(points) |
|
|
mnfld_pred = self.model(mnfld_pnts) |
|
|
|
|
|
nonmnfld_pred = self.model(nonmnfld_pnts) |
|
|
|
|
|
|
|
|
|
|
|
logger.print_tensor_stats("mnfld_pred",mnfld_pred) |
|
|
|
|
|
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) |
|
|
|
|
|
|
|
|
if self.debug_mode: |
|
|
if self.debug_mode: |
|
|
# --- 检查前向传播的输出 --- |
|
|
# --- 检查前向传播的输出 --- |
|
@ -356,10 +373,12 @@ class Trainer: |
|
|
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") |
|
|
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") |
|
|
logger.gpu_memory_stats("计算损失前") |
|
|
logger.gpu_memory_stats("计算损失前") |
|
|
loss, loss_details = self.loss_manager.compute_loss( |
|
|
loss, loss_details = self.loss_manager.compute_loss( |
|
|
points, |
|
|
mnfld_pnts, |
|
|
|
|
|
nonmnfld_pnts, |
|
|
normals, # 传递检查过的 normals |
|
|
normals, # 传递检查过的 normals |
|
|
gt_sdf, |
|
|
gt_sdf, |
|
|
pred_sdf |
|
|
mnfld_pred, |
|
|
|
|
|
nonmnfld_pred |
|
|
) |
|
|
) |
|
|
else: |
|
|
else: |
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) |
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) |
|
@ -448,8 +467,8 @@ class Trainer: |
|
|
|
|
|
|
|
|
for epoch in range(start_epoch, self.config.train.num_epochs + 1): |
|
|
for epoch in range(start_epoch, self.config.train.num_epochs + 1): |
|
|
# 训练一个epoch |
|
|
# 训练一个epoch |
|
|
train_loss = self.train_epoch_stage1(epoch) |
|
|
#train_loss = self.train_epoch_stage1(epoch) |
|
|
#train_loss = self.train_epoch(epoch) |
|
|
train_loss = self.train_epoch(epoch) |
|
|
|
|
|
|
|
|
# 验证 |
|
|
# 验证 |
|
|
''' |
|
|
''' |
|
|