Browse Source

可以二阶段训练,但是二阶段并行能力很差

final
mckay 2 months ago
parent
commit
9689c0314e
  1. 15
      brep2sdf/networks/encoder.py
  2. 53
      brep2sdf/networks/loss.py
  3. 318
      brep2sdf/train.py

15
brep2sdf/networks/encoder.py

@ -109,7 +109,7 @@ class Encoder(nn.Module):
background_features = self.background.forward(query_points) # (B, D) background_features = self.background.forward(query_points) # (B, D)
return background_features return background_features
@torch.jit.export @torch.jit.ignore
def forward_training_volumes(self, surf_points: torch.Tensor, patch_id: int) -> torch.Tensor: def forward_training_volumes(self, surf_points: torch.Tensor, patch_id: int) -> torch.Tensor:
""" """
处理表面采样点的特征提取 处理表面采样点的特征提取
@ -119,18 +119,9 @@ class Encoder(nn.Module):
特征张量 (S, D) 特征张量 (S, D)
""" """
# 获取 patch 特征 # 获取 patch 特征
patch_features = torch.zeros(surf_points.shape[0], self.feature_dim, device=surf_points.device) patch_features = self.feature_volumes[patch_id].forward(surf_points)
for idx, volume in enumerate(self.feature_volumes):
if idx == patch_id:
patch_features = volume.forward(surf_points)
# 获取背景场特征 return patch_features
background_features = self.background.forward(surf_points)
# 叠加 patch 和背景场特征
combined_features = 0.7 * patch_features + 0.3 * background_features
return combined_features
def _optimized_trilinear(self, points, bboxes, features)-> torch.Tensor: def _optimized_trilinear(self, points, bboxes, features)-> torch.Tensor:
"""优化后的向量化三线性插值""" """优化后的向量化三线性插值"""

53
brep2sdf/networks/loss.py

@ -182,6 +182,59 @@ class LossManager:
return total_loss, loss_details return total_loss, loss_details
def compute_loss_volume(self,
mnfld_pnts,
nonmnfld_pnts,
normals,
gt_sdfs,
mnfld_pred,
nonmnfld_pred,
):
"""
计算流型损失的逻辑
:param outputs: 模型的输出
:return: 计算得到的流型损失值
"""
# 强制类型转换确保一致性
normals = normals.to(torch.float32)
mnfld_pred = mnfld_pred.to(torch.float32)
gt_sdfs = gt_sdfs.to(torch.float32)
# 计算流形损失
manifold_loss = self.position_loss(mnfld_pred, gt_sdfs)
# 计算法线损失
normals_loss = self.normals_loss(normals, mnfld_pnts, mnfld_pred)
#logger.gpu_memory_stats("计算法线损失后")
# 计算Eikonal损失
eikonal_loss = self.eikonal_loss(nonmnfld_pnts, nonmnfld_pred)
# 计算离表面损失
offsurface_loss = self.offsurface_loss(nonmnfld_pnts, nonmnfld_pred)
# 计算一致性损失
#onsistency_loss = self.consistency_loss(mnfld_pnts, mnfld_pred, all_fi)
# 计算修正损失
#correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi)
# 汇总损失
loss_details = {
"manifold": self.weights["manifold"] * manifold_loss,
"normals": self.weights["normals"] * normals_loss,
"eikonal": self.weights["eikonal"] * eikonal_loss,
"offsurface": self.weights["offsurface"] * offsurface_loss,
}
# 计算总损失
total_loss = sum(loss_details.values())
return total_loss, loss_details
def _compute_loss(self, mnfld_pnts, normals, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last): def _compute_loss(self, mnfld_pnts, normals, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last):
""" """
计算流型损失的逻辑 计算流型损失的逻辑

318
brep2sdf/train.py

@ -5,9 +5,11 @@ import os
import numpy as np import numpy as np
import argparse import argparse
from brep2sdf.config.default_config import get_default_config from brep2sdf.config.default_config import get_default_config
from brep2sdf.data.data import load_brep_file,prepare_sdf_data, print_data_distribution, check_tensor from brep2sdf.data.data import load_brep_file,prepare_sdf_data, print_data_distribution, check_tensor
from brep2sdf.data.pre_process_by_mesh import process_single_step from brep2sdf.data.pre_process_by_mesh import process_single_step
from brep2sdf.data.utils import points_in_box
from brep2sdf.networks.network import Net from brep2sdf.networks.network import Net
from brep2sdf.networks.octree import OctreeNode from brep2sdf.networks.octree import OctreeNode
from brep2sdf.networks.loss import LossManager from brep2sdf.networks.loss import LossManager
@ -203,97 +205,7 @@ class Trainer:
# # 返回合并后的边界框 # # 返回合并后的边界框
# return torch.cat([global_min, global_max]) # return torch.cat([global_min, global_max])
# return [-0.5,] # 这个是错误的 # return [-0.5,] # 这个是错误的
def train_epoch_stage1_(self, epoch: int):
total_loss = 0.0
total_loss_details = {
"manifold": 0.0,
"normals": 0.0,
"eikonal": 0.0,
"offsurface": 0.0
}
accumulated_loss = 0.0 # 新增:用于累积多个step的loss
# 新增:在每个epoch开始时清零梯度
self.optimizer.zero_grad()
for step, surf_points in enumerate(self.data['surf_ncs']):
mnfld_points = torch.tensor(surf_points, device=self.device)
nonmnfld_pnts = self.sampler.get_points(mnfld_points.unsqueeze(0)).squeeze(0) # 生成非流形点
gt_sdf = torch.zeros(mnfld_points.shape[0], device=self.device)
normals = None
if args.use_normal:
normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device)
logger.debug(normals)
# --- 准备模型输入,启用梯度 ---
mnfld_points.requires_grad_(True) # 在检查之后启用梯度
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
# --- 前向传播 ---
self.optimizer.zero_grad()
mnfld_pred = self.model.forward_training_volumes(mnfld_points, step)
nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step)
if self.debug_mode:
# --- 检查前向传播的输出 ---
logger.print_tensor_stats("mnfld_pred",mnfld_pred)
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred)
logger.gpu_memory_stats("前向传播后")
# --- 计算损失 ---
try:
if args.use_normal:
loss, loss_details = self.loss_manager.compute_loss(
mnfld_points,
nonmnfld_pnts,
normals,
gt_sdf,
mnfld_pred,
nonmnfld_pred
)
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
loss_details = {} # 确保变量初始化
# 修改:累积loss而不是立即backward
accumulated_loss += loss / self.config.train.accumulation_steps # 假设配置中有accumulation_steps
current_loss = loss.item()
total_loss += current_loss
for key in total_loss_details:
if key in loss_details:
total_loss_details[key] += loss_details[key].item()
# 新增:达到累积步数时执行反向传播
if (step + 1) % self.config.train.accumulation_steps == 0:
# 反向传播
self.scheduler.optimizer.zero_grad() # 清空梯度
loss.backward() # 反向传播
self.scheduler.optimizer.step() # 更新参数
self.scheduler.step(accumulated_loss,epoch)
# 记录日志保持不变 ...
except Exception as loss_e:
logger.error(f"Error in step {step}: {loss_e}")
continue
# --- 内存管理 ---
del loss
torch.cuda.empty_cache()
# 新增:处理最后未达到累积步数的剩余loss
if accumulated_loss != 0:
# 反向传播
self.scheduler.optimizer.zero_grad() # 清空梯度
loss.backward() # 反向传播
self.scheduler.optimizer.step() # 更新参数
self.scheduler.step(accumulated_loss,epoch)
# 计算并记录epoch损失
logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {total_loss:.6f}')
logger.info(f"Loss Details: {total_loss_details}")
return total_loss # 返回平均损失而非累计值
def train_epoch_stage1(self, epoch: int) -> float: def train_epoch_stage1(self, epoch: int) -> float:
@ -437,7 +349,202 @@ class Trainer:
return total_loss # 对于单批次训练,直接返回当前损失 return total_loss # 对于单批次训练,直接返回当前损失
def train_epoch_stage2(self, epoch: int) -> float: def train_stage2(self, num_epoch):
self.model.freeze_stage2()
self.cached_train_data = None
num_volumes = self.data['surf_bbox_ncs'].shape[0]
surf_bbox=torch.tensor(
self.data['surf_bbox_ncs'],
dtype=torch.float32,
device=self.device
)
logger.info(f"Start Stage 2 Training: {num_epoch} epochs")
total_loss = 0.0
for patch_id in range(num_volumes):
points = points_in_box(self.train_surf_ncs, surf_bbox[patch_id])
loss = self.train_stage2_by_volume(num_epoch, patch_id, points)
logger.debug(f"Patch [{patch_id:2d}] | Loss: {loss:.6f}")
total_loss += loss
return total_loss
def train_stage2_by_volume(self, num_epoch, patch_id, points):
logger.debug(f"Patch [{patch_id:2d}] | train pnt number {points.shape[0]}")
points.to(self.device)
mnfld_pnts = points[:,0:3]
logger.debug(mnfld_pnts)
gt_sdf = torch.zeros(mnfld_pnts.shape[0], device=self.device)
if not args.use_normal:
logger.warning(f"need args.use_normal,skip stage2")
return float('inf')
normals = points[:,3:6]
logger.debug(normals)
nonmnfld_pnts, psdf = self.sampler.get_norm_points(mnfld_pnts, normals) # 生成非流形点
# --- 准备模型输入,启用梯度 ---
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
for epoch in range(num_epoch):
# --- 前向传播 ---
mnfld_pred = self.model.forward_training_volumes(mnfld_pnts, patch_id)
nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, patch_id)
# --- 计算损失 ---
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败
loss_details = {}
try:
logger.gpu_memory_stats("计算损失前")
loss, loss_details = self.loss_manager.compute_loss(
mnfld_pnts,
nonmnfld_pnts,
normals, # 传递检查过的 normals
gt_sdf,
mnfld_pred,
nonmnfld_pred,
psdf
)
# --- 4. 检查损失计算结果 ---
if self.debug_mode:
logger.print_tensor_stats("psdf",psdf)
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred)
if check_tensor(loss, "Calculated Loss", epoch, step):
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.")
if loss_details: logger.error(f"Loss Details: {loss_details}")
return float('inf') # 如果损失无效,停止这个epoch
except Exception as loss_e:
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True)
return float('inf') # 如果计算出错,停止这个epoch
logger.gpu_memory_stats("损失计算后")
# --- 反向传播和优化 ---
try:
# 反向传播
self.scheduler.optimizer.zero_grad() # 清空梯度
loss.backward() # 反向传播
self.scheduler.optimizer.step() # 更新参数
self.scheduler.step(loss,epoch)
except Exception as backward_e:
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True)
# 如果你想看是哪个操作导致的,可以启用 anomaly detection
# torch.autograd.set_detect_anomaly(True) # 放在训练开始前
return float('inf') # 如果反向传播或优化出错,停止这个epoch
torch.cuda.empty_cache()
if epoch % 100 == 0:
# 记录训练进度 (只记录有效的损失)
logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {loss:.6f}')
if loss_details: logger.info(f"Loss Details: {loss_details}")
return loss # last loss
def train_epoch_stage2_(self, epoch: int):
total_loss = 0.0
total_loss_details = {
"manifold": 0.0,
"normals": 0.0,
"eikonal": 0.0,
"offsurface": 0.0
}
accumulated_loss = 0.0 # 新增:用于累积多个step的loss
# 新增:在每个epoch开始时清零梯度
self.scheduler.optimizer.zero_grad()
for step, surf_points in enumerate(self.data['surf_ncs']):
mnfld_points = torch.tensor(surf_points, device=self.device)
nonmnfld_pnts = self.sampler.get_points(mnfld_points.unsqueeze(0)).squeeze(0) # 生成非流形点
gt_sdf = torch.zeros(mnfld_points.shape[0], device=self.device)
normals = None
if args.use_normal:
normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device)
logger.debug(normals)
# --- 准备模型输入,启用梯度 ---
mnfld_points.requires_grad_(True) # 在检查之后启用梯度
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
# --- 前向传播 ---
self.scheduler.optimizer.zero_grad()
mnfld_pred = self.model.forward_training_volumes(mnfld_points, step)
nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step)
if self.debug_mode:
# --- 检查前向传播的输出 ---
logger.print_tensor_stats("mnfld_pred",mnfld_pred)
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred)
logger.gpu_memory_stats("前向传播后")
# --- 计算损失 ---
try:
if args.use_normal:
loss, loss_details = self.loss_manager.compute_loss_volume(
mnfld_points,
nonmnfld_pnts,
normals,
gt_sdf,
mnfld_pred,
nonmnfld_pred
)
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
loss_details = {} # 确保变量初始化
# 修改:累积loss而不是立即backward
accumulated_loss += loss / self.config.train.accumulation_steps # 假设配置中有accumulation_steps
current_loss = loss.item()
total_loss += current_loss
for key in total_loss_details:
if key in loss_details:
total_loss_details[key] += loss_details[key].item()
# 新增:达到累积步数时执行反向传播
if (step + 1) % self.config.train.accumulation_steps == 0:
# 反向传播
self.scheduler.optimizer.zero_grad() # 清空梯度
accumulated_loss.backward() # 反向传播
self.scheduler.optimizer.step() # 更新参数
self.scheduler.step(accumulated_loss,epoch)
accumulated_loss = 0.0
# 记录日志保持不变 ...
except Exception as loss_e:
logger.error(f"Error in step {step}: {loss_e}")
continue
# --- 内存管理 ---
del loss
torch.cuda.empty_cache()
# 新增:处理最后未达到累积步数的剩余loss
if accumulated_loss != 0:
# 反向传播
self.scheduler.optimizer.zero_grad() # 清空梯度
accumulated_loss.backward() # 反向传播
self.scheduler.optimizer.step() # 更新参数
self.scheduler.step(accumulated_loss,epoch)
# 计算并记录epoch损失
logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {total_loss:.6f}')
logger.info(f"Loss Details: {total_loss_details}")
return total_loss # 返回平均损失而非累计值
def train_epoch_stage3(self, epoch: int) -> float:
# --- 1. 检查输入数据 --- # --- 1. 检查输入数据 ---
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) # 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列)
# 并且 SDF 值总是在最后一列 # 并且 SDF 值总是在最后一列
@ -521,6 +628,7 @@ class Trainer:
) )
#logger.print_tensor_stats("psdf",psdf) #logger.print_tensor_stats("psdf",psdf)
#logger.print_tensor_stats("nonmnfld_pnts",nonmnfld_pnts)
#logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) #logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred)
# --- 计算损失 --- # --- 计算损失 ---
@ -752,31 +860,27 @@ class Trainer:
start_epoch = self._load_checkpoint(args.resume_checkpoint_path) start_epoch = self._load_checkpoint(args.resume_checkpoint_path)
logger.info(f"Loaded model from {args.resume_checkpoint_path}") logger.info(f"Loaded model from {args.resume_checkpoint_path}")
# stage1
self.model.encoder.freeze_stage1() self.model.encoder.freeze_stage1()
for epoch in range(start_epoch, self.config.train.num_epochs + 1): for epoch in range(start_epoch, self.config.train.num_epochs1 + 1):
# 训练一个epoch # 训练一个epoch
train_loss = self.train_epoch_stage1(epoch) train_loss = self.train_epoch_stage1(epoch)
#train_loss = self.train_epoch_stage2(epoch) #train_loss = self.train_epoch_stage2(epoch)
#train_loss = self.train_epoch(epoch) #train_loss = self.train_epoch(epoch)
# 验证
'''
if epoch % self.config.train.val_freq == 0:
val_loss = self.validate(epoch)
logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}')
# 保存最佳模型
if val_loss < best_val_loss:
best_val_loss = val_loss
self._save_model(epoch, val_loss)
logger.info(f'New best model saved at epoch {epoch} with val loss {val_loss:.6f}')
else:
logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}')
'''
# 保存检查点 # 保存检查点
if epoch % self.config.train.save_freq == 0: if epoch % self.config.train.save_freq == 0:
self._save_checkpoint(epoch, train_loss) self._save_checkpoint(epoch, train_loss)
logger.info(f'Checkpoint saved at epoch {epoch}') logger.info(f'Checkpoint saved at epoch {epoch}')
# stage2 freeze_stage2
self.train_stage2(self.config.train.num_epochs2)
epoch = self.config.train.num_epochs1+self.config.train.num_epochs2
logger.info(f'Checkpoint saved at epoch {epoch}')
self._save_checkpoint(epoch, 0.0)
self.model.encoder.unfreeze() self.model.encoder.unfreeze()
# 训练完成 # 训练完成
total_time = time.time() - start_time total_time = time.time() - start_time
@ -848,7 +952,7 @@ class Trainer:
def _load_checkpoint(self, checkpoint_path): def _load_checkpoint(self, checkpoint_path):
"""从检查点恢复训练状态""" """从检查点恢复训练状态"""
try: try:
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
self.model.load_state_dict(checkpoint['model_state_dict']) self.model.load_state_dict(checkpoint['model_state_dict'])
self.scheduler.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch'] + 1 return checkpoint['epoch'] + 1

Loading…
Cancel
Save