|
|
@ -5,9 +5,11 @@ import os |
|
|
|
import numpy as np |
|
|
|
import argparse |
|
|
|
|
|
|
|
|
|
|
|
from brep2sdf.config.default_config import get_default_config |
|
|
|
from brep2sdf.data.data import load_brep_file,prepare_sdf_data, print_data_distribution, check_tensor |
|
|
|
from brep2sdf.data.pre_process_by_mesh import process_single_step |
|
|
|
from brep2sdf.data.utils import points_in_box |
|
|
|
from brep2sdf.networks.network import Net |
|
|
|
from brep2sdf.networks.octree import OctreeNode |
|
|
|
from brep2sdf.networks.loss import LossManager |
|
|
@ -203,97 +205,7 @@ class Trainer: |
|
|
|
# # 返回合并后的边界框 |
|
|
|
# return torch.cat([global_min, global_max]) |
|
|
|
# return [-0.5,] # 这个是错误的 |
|
|
|
def train_epoch_stage1_(self, epoch: int): |
|
|
|
total_loss = 0.0 |
|
|
|
total_loss_details = { |
|
|
|
"manifold": 0.0, |
|
|
|
"normals": 0.0, |
|
|
|
"eikonal": 0.0, |
|
|
|
"offsurface": 0.0 |
|
|
|
} |
|
|
|
accumulated_loss = 0.0 # 新增:用于累积多个step的loss |
|
|
|
|
|
|
|
# 新增:在每个epoch开始时清零梯度 |
|
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
|
|
for step, surf_points in enumerate(self.data['surf_ncs']): |
|
|
|
mnfld_points = torch.tensor(surf_points, device=self.device) |
|
|
|
nonmnfld_pnts = self.sampler.get_points(mnfld_points.unsqueeze(0)).squeeze(0) # 生成非流形点 |
|
|
|
gt_sdf = torch.zeros(mnfld_points.shape[0], device=self.device) |
|
|
|
normals = None |
|
|
|
if args.use_normal: |
|
|
|
normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device) |
|
|
|
logger.debug(normals) |
|
|
|
|
|
|
|
# --- 准备模型输入,启用梯度 --- |
|
|
|
mnfld_points.requires_grad_(True) # 在检查之后启用梯度 |
|
|
|
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
|
|
|
|
|
# --- 前向传播 --- |
|
|
|
self.optimizer.zero_grad() |
|
|
|
mnfld_pred = self.model.forward_training_volumes(mnfld_points, step) |
|
|
|
nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step) |
|
|
|
|
|
|
|
if self.debug_mode: |
|
|
|
# --- 检查前向传播的输出 --- |
|
|
|
logger.print_tensor_stats("mnfld_pred",mnfld_pred) |
|
|
|
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) |
|
|
|
logger.gpu_memory_stats("前向传播后") |
|
|
|
|
|
|
|
# --- 计算损失 --- |
|
|
|
try: |
|
|
|
if args.use_normal: |
|
|
|
loss, loss_details = self.loss_manager.compute_loss( |
|
|
|
mnfld_points, |
|
|
|
nonmnfld_pnts, |
|
|
|
normals, |
|
|
|
gt_sdf, |
|
|
|
mnfld_pred, |
|
|
|
nonmnfld_pred |
|
|
|
) |
|
|
|
else: |
|
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) |
|
|
|
loss_details = {} # 确保变量初始化 |
|
|
|
|
|
|
|
# 修改:累积loss而不是立即backward |
|
|
|
accumulated_loss += loss / self.config.train.accumulation_steps # 假设配置中有accumulation_steps |
|
|
|
current_loss = loss.item() |
|
|
|
total_loss += current_loss |
|
|
|
for key in total_loss_details: |
|
|
|
if key in loss_details: |
|
|
|
total_loss_details[key] += loss_details[key].item() |
|
|
|
|
|
|
|
# 新增:达到累积步数时执行反向传播 |
|
|
|
if (step + 1) % self.config.train.accumulation_steps == 0: |
|
|
|
# 反向传播 |
|
|
|
self.scheduler.optimizer.zero_grad() # 清空梯度 |
|
|
|
loss.backward() # 反向传播 |
|
|
|
self.scheduler.optimizer.step() # 更新参数 |
|
|
|
self.scheduler.step(accumulated_loss,epoch) |
|
|
|
|
|
|
|
# 记录日志保持不变 ... |
|
|
|
|
|
|
|
except Exception as loss_e: |
|
|
|
logger.error(f"Error in step {step}: {loss_e}") |
|
|
|
continue |
|
|
|
|
|
|
|
# --- 内存管理 --- |
|
|
|
del loss |
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
# 新增:处理最后未达到累积步数的剩余loss |
|
|
|
if accumulated_loss != 0: |
|
|
|
# 反向传播 |
|
|
|
self.scheduler.optimizer.zero_grad() # 清空梯度 |
|
|
|
loss.backward() # 反向传播 |
|
|
|
self.scheduler.optimizer.step() # 更新参数 |
|
|
|
self.scheduler.step(accumulated_loss,epoch) |
|
|
|
|
|
|
|
# 计算并记录epoch损失 |
|
|
|
logger.info(f'Train Epoch: {epoch:4d}]\t' |
|
|
|
f'Loss: {total_loss:.6f}') |
|
|
|
logger.info(f"Loss Details: {total_loss_details}") |
|
|
|
return total_loss # 返回平均损失而非累计值 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_epoch_stage1(self, epoch: int) -> float: |
|
|
@ -437,7 +349,202 @@ class Trainer: |
|
|
|
|
|
|
|
return total_loss # 对于单批次训练,直接返回当前损失 |
|
|
|
|
|
|
|
def train_epoch_stage2(self, epoch: int) -> float: |
|
|
|
def train_stage2(self, num_epoch): |
|
|
|
self.model.freeze_stage2() |
|
|
|
self.cached_train_data = None |
|
|
|
|
|
|
|
num_volumes = self.data['surf_bbox_ncs'].shape[0] |
|
|
|
surf_bbox=torch.tensor( |
|
|
|
self.data['surf_bbox_ncs'], |
|
|
|
dtype=torch.float32, |
|
|
|
device=self.device |
|
|
|
) |
|
|
|
logger.info(f"Start Stage 2 Training: {num_epoch} epochs") |
|
|
|
total_loss = 0.0 |
|
|
|
for patch_id in range(num_volumes): |
|
|
|
points = points_in_box(self.train_surf_ncs, surf_bbox[patch_id]) |
|
|
|
loss = self.train_stage2_by_volume(num_epoch, patch_id, points) |
|
|
|
logger.debug(f"Patch [{patch_id:2d}] | Loss: {loss:.6f}") |
|
|
|
total_loss += loss |
|
|
|
|
|
|
|
return total_loss |
|
|
|
|
|
|
|
|
|
|
|
def train_stage2_by_volume(self, num_epoch, patch_id, points): |
|
|
|
logger.debug(f"Patch [{patch_id:2d}] | train pnt number {points.shape[0]}") |
|
|
|
points.to(self.device) |
|
|
|
mnfld_pnts = points[:,0:3] |
|
|
|
logger.debug(mnfld_pnts) |
|
|
|
gt_sdf = torch.zeros(mnfld_pnts.shape[0], device=self.device) |
|
|
|
if not args.use_normal: |
|
|
|
logger.warning(f"need args.use_normal,skip stage2") |
|
|
|
return float('inf') |
|
|
|
normals = points[:,3:6] |
|
|
|
logger.debug(normals) |
|
|
|
nonmnfld_pnts, psdf = self.sampler.get_norm_points(mnfld_pnts, normals) # 生成非流形点 |
|
|
|
|
|
|
|
# --- 准备模型输入,启用梯度 --- |
|
|
|
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
|
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
|
|
|
|
|
for epoch in range(num_epoch): |
|
|
|
# --- 前向传播 --- |
|
|
|
mnfld_pred = self.model.forward_training_volumes(mnfld_pnts, patch_id) |
|
|
|
nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, patch_id) |
|
|
|
|
|
|
|
|
|
|
|
# --- 计算损失 --- |
|
|
|
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败 |
|
|
|
loss_details = {} |
|
|
|
try: |
|
|
|
logger.gpu_memory_stats("计算损失前") |
|
|
|
loss, loss_details = self.loss_manager.compute_loss( |
|
|
|
mnfld_pnts, |
|
|
|
nonmnfld_pnts, |
|
|
|
normals, # 传递检查过的 normals |
|
|
|
gt_sdf, |
|
|
|
mnfld_pred, |
|
|
|
nonmnfld_pred, |
|
|
|
psdf |
|
|
|
) |
|
|
|
|
|
|
|
# --- 4. 检查损失计算结果 --- |
|
|
|
if self.debug_mode: |
|
|
|
logger.print_tensor_stats("psdf",psdf) |
|
|
|
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) |
|
|
|
if check_tensor(loss, "Calculated Loss", epoch, step): |
|
|
|
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") |
|
|
|
if loss_details: logger.error(f"Loss Details: {loss_details}") |
|
|
|
return float('inf') # 如果损失无效,停止这个epoch |
|
|
|
|
|
|
|
except Exception as loss_e: |
|
|
|
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True) |
|
|
|
return float('inf') # 如果计算出错,停止这个epoch |
|
|
|
logger.gpu_memory_stats("损失计算后") |
|
|
|
|
|
|
|
# --- 反向传播和优化 --- |
|
|
|
try: |
|
|
|
# 反向传播 |
|
|
|
self.scheduler.optimizer.zero_grad() # 清空梯度 |
|
|
|
loss.backward() # 反向传播 |
|
|
|
self.scheduler.optimizer.step() # 更新参数 |
|
|
|
self.scheduler.step(loss,epoch) |
|
|
|
except Exception as backward_e: |
|
|
|
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) |
|
|
|
# 如果你想看是哪个操作导致的,可以启用 anomaly detection |
|
|
|
# torch.autograd.set_detect_anomaly(True) # 放在训练开始前 |
|
|
|
return float('inf') # 如果反向传播或优化出错,停止这个epoch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
if epoch % 100 == 0: |
|
|
|
# 记录训练进度 (只记录有效的损失) |
|
|
|
logger.info(f'Train Epoch: {epoch:4d}]\t' |
|
|
|
f'Loss: {loss:.6f}') |
|
|
|
if loss_details: logger.info(f"Loss Details: {loss_details}") |
|
|
|
|
|
|
|
return loss # last loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_epoch_stage2_(self, epoch: int): |
|
|
|
total_loss = 0.0 |
|
|
|
total_loss_details = { |
|
|
|
"manifold": 0.0, |
|
|
|
"normals": 0.0, |
|
|
|
"eikonal": 0.0, |
|
|
|
"offsurface": 0.0 |
|
|
|
} |
|
|
|
accumulated_loss = 0.0 # 新增:用于累积多个step的loss |
|
|
|
|
|
|
|
# 新增:在每个epoch开始时清零梯度 |
|
|
|
self.scheduler.optimizer.zero_grad() |
|
|
|
|
|
|
|
for step, surf_points in enumerate(self.data['surf_ncs']): |
|
|
|
mnfld_points = torch.tensor(surf_points, device=self.device) |
|
|
|
nonmnfld_pnts = self.sampler.get_points(mnfld_points.unsqueeze(0)).squeeze(0) # 生成非流形点 |
|
|
|
gt_sdf = torch.zeros(mnfld_points.shape[0], device=self.device) |
|
|
|
normals = None |
|
|
|
if args.use_normal: |
|
|
|
normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device) |
|
|
|
logger.debug(normals) |
|
|
|
|
|
|
|
# --- 准备模型输入,启用梯度 --- |
|
|
|
mnfld_points.requires_grad_(True) # 在检查之后启用梯度 |
|
|
|
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
|
|
|
|
|
# --- 前向传播 --- |
|
|
|
self.scheduler.optimizer.zero_grad() |
|
|
|
mnfld_pred = self.model.forward_training_volumes(mnfld_points, step) |
|
|
|
nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step) |
|
|
|
|
|
|
|
if self.debug_mode: |
|
|
|
# --- 检查前向传播的输出 --- |
|
|
|
logger.print_tensor_stats("mnfld_pred",mnfld_pred) |
|
|
|
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) |
|
|
|
logger.gpu_memory_stats("前向传播后") |
|
|
|
|
|
|
|
# --- 计算损失 --- |
|
|
|
try: |
|
|
|
if args.use_normal: |
|
|
|
loss, loss_details = self.loss_manager.compute_loss_volume( |
|
|
|
mnfld_points, |
|
|
|
nonmnfld_pnts, |
|
|
|
normals, |
|
|
|
gt_sdf, |
|
|
|
mnfld_pred, |
|
|
|
nonmnfld_pred |
|
|
|
) |
|
|
|
else: |
|
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) |
|
|
|
loss_details = {} # 确保变量初始化 |
|
|
|
|
|
|
|
# 修改:累积loss而不是立即backward |
|
|
|
accumulated_loss += loss / self.config.train.accumulation_steps # 假设配置中有accumulation_steps |
|
|
|
current_loss = loss.item() |
|
|
|
total_loss += current_loss |
|
|
|
for key in total_loss_details: |
|
|
|
if key in loss_details: |
|
|
|
total_loss_details[key] += loss_details[key].item() |
|
|
|
|
|
|
|
# 新增:达到累积步数时执行反向传播 |
|
|
|
if (step + 1) % self.config.train.accumulation_steps == 0: |
|
|
|
# 反向传播 |
|
|
|
self.scheduler.optimizer.zero_grad() # 清空梯度 |
|
|
|
accumulated_loss.backward() # 反向传播 |
|
|
|
self.scheduler.optimizer.step() # 更新参数 |
|
|
|
self.scheduler.step(accumulated_loss,epoch) |
|
|
|
accumulated_loss = 0.0 |
|
|
|
|
|
|
|
# 记录日志保持不变 ... |
|
|
|
|
|
|
|
except Exception as loss_e: |
|
|
|
logger.error(f"Error in step {step}: {loss_e}") |
|
|
|
continue |
|
|
|
|
|
|
|
# --- 内存管理 --- |
|
|
|
del loss |
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
# 新增:处理最后未达到累积步数的剩余loss |
|
|
|
if accumulated_loss != 0: |
|
|
|
# 反向传播 |
|
|
|
self.scheduler.optimizer.zero_grad() # 清空梯度 |
|
|
|
accumulated_loss.backward() # 反向传播 |
|
|
|
self.scheduler.optimizer.step() # 更新参数 |
|
|
|
self.scheduler.step(accumulated_loss,epoch) |
|
|
|
|
|
|
|
# 计算并记录epoch损失 |
|
|
|
logger.info(f'Train Epoch: {epoch:4d}]\t' |
|
|
|
f'Loss: {total_loss:.6f}') |
|
|
|
logger.info(f"Loss Details: {total_loss_details}") |
|
|
|
return total_loss # 返回平均损失而非累计值 |
|
|
|
|
|
|
|
def train_epoch_stage3(self, epoch: int) -> float: |
|
|
|
# --- 1. 检查输入数据 --- |
|
|
|
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) |
|
|
|
# 并且 SDF 值总是在最后一列 |
|
|
@ -521,6 +628,7 @@ class Trainer: |
|
|
|
) |
|
|
|
|
|
|
|
#logger.print_tensor_stats("psdf",psdf) |
|
|
|
#logger.print_tensor_stats("nonmnfld_pnts",nonmnfld_pnts) |
|
|
|
#logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) |
|
|
|
|
|
|
|
# --- 计算损失 --- |
|
|
@ -752,31 +860,27 @@ class Trainer: |
|
|
|
start_epoch = self._load_checkpoint(args.resume_checkpoint_path) |
|
|
|
logger.info(f"Loaded model from {args.resume_checkpoint_path}") |
|
|
|
|
|
|
|
# stage1 |
|
|
|
self.model.encoder.freeze_stage1() |
|
|
|
for epoch in range(start_epoch, self.config.train.num_epochs + 1): |
|
|
|
for epoch in range(start_epoch, self.config.train.num_epochs1 + 1): |
|
|
|
# 训练一个epoch |
|
|
|
train_loss = self.train_epoch_stage1(epoch) |
|
|
|
#train_loss = self.train_epoch_stage2(epoch) |
|
|
|
#train_loss = self.train_epoch(epoch) |
|
|
|
|
|
|
|
# 验证 |
|
|
|
''' |
|
|
|
if epoch % self.config.train.val_freq == 0: |
|
|
|
val_loss = self.validate(epoch) |
|
|
|
logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}') |
|
|
|
|
|
|
|
# 保存最佳模型 |
|
|
|
if val_loss < best_val_loss: |
|
|
|
best_val_loss = val_loss |
|
|
|
self._save_model(epoch, val_loss) |
|
|
|
logger.info(f'New best model saved at epoch {epoch} with val loss {val_loss:.6f}') |
|
|
|
else: |
|
|
|
logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}') |
|
|
|
''' |
|
|
|
|
|
|
|
# 保存检查点 |
|
|
|
if epoch % self.config.train.save_freq == 0: |
|
|
|
self._save_checkpoint(epoch, train_loss) |
|
|
|
logger.info(f'Checkpoint saved at epoch {epoch}') |
|
|
|
|
|
|
|
# stage2 freeze_stage2 |
|
|
|
|
|
|
|
|
|
|
|
self.train_stage2(self.config.train.num_epochs2) |
|
|
|
epoch = self.config.train.num_epochs1+self.config.train.num_epochs2 |
|
|
|
logger.info(f'Checkpoint saved at epoch {epoch}') |
|
|
|
self._save_checkpoint(epoch, 0.0) |
|
|
|
|
|
|
|
self.model.encoder.unfreeze() |
|
|
|
# 训练完成 |
|
|
|
total_time = time.time() - start_time |
|
|
@ -848,7 +952,7 @@ class Trainer: |
|
|
|
def _load_checkpoint(self, checkpoint_path): |
|
|
|
"""从检查点恢复训练状态""" |
|
|
|
try: |
|
|
|
checkpoint = torch.load(checkpoint_path) |
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) |
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
self.scheduler.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
|
return checkpoint['epoch'] + 1 |
|
|
|