Browse Source

可以三阶段训练,但是 stage2 未使用 batch,会溢出显存

final
mckay 1 month ago
parent
commit
965c0864c3
  1. 4
      brep2sdf/networks/learning_rate.py
  2. 19
      brep2sdf/networks/loss.py
  3. 4
      brep2sdf/networks/network.py
  4. 194
      brep2sdf/train.py

4
brep2sdf/networks/learning_rate.py

@ -42,8 +42,8 @@ class LearningRateScheduler:
self.best_loss = float('inf') self.best_loss = float('inf')
self.patience = 20 self.patience = 20
self.decay_factor = 0.5 self.decay_factor = 0.5
initial_lr = self.lr_schedules[0].get_learning_rate(0) self.initial_lr = self.lr_schedules[0].get_learning_rate(0)
self.lr = initial_lr self.lr = self.initial_lr
self.epochs_since_improvement = 0 self.epochs_since_improvement = 0
except Exception as e: except Exception as e:

19
brep2sdf/networks/loss.py

@ -189,7 +189,7 @@ class LossManager:
gt_sdfs, gt_sdfs,
mnfld_pred, mnfld_pred,
nonmnfld_pred, nonmnfld_pred,
): psdfs):
""" """
计算流型损失的逻辑 计算流型损失的逻辑
@ -220,18 +220,19 @@ class LossManager:
# 计算修正损失 # 计算修正损失
#correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi) #correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi)
psdf_loss = self.psdf_loss(nonmnfld_pred, psdfs)
# 汇总损失 # 汇总损失
loss_details = { loss_details = torch.stack([
"manifold": self.weights["manifold"] * manifold_loss, self.weights["manifold"] * manifold_loss,
"normals": self.weights["normals"] * normals_loss, self.weights["normals"] * normals_loss,
"eikonal": self.weights["eikonal"] * eikonal_loss, self.weights["eikonal"] * eikonal_loss,
"offsurface": self.weights["offsurface"] * offsurface_loss, self.weights["offsurface"] * offsurface_loss,
} self.weights["psdf"] * psdf_loss
])
# 计算总损失 # 计算总损失
total_loss = sum(loss_details.values()) total_loss = loss_details.sum()
return total_loss, loss_details return total_loss, loss_details

4
brep2sdf/networks/network.py

@ -153,7 +153,7 @@ class Net(nn.Module):
h = self.decoder.forward_training_volumes(feature_vectors) # (B, D) h = self.decoder.forward_training_volumes(feature_vectors) # (B, D)
return h return h
@torch.jit.export @torch.jit.ignore
def forward_without_octree(self, query_points,face_indices_mask,operator): def forward_without_octree(self, query_points,face_indices_mask,operator):
""" """
前向传播 前向传播
@ -175,7 +175,7 @@ class Net(nn.Module):
#logger.debug("step combine") #logger.debug("step combine")
return self.process_sdf(f_i, face_indices_mask, operator) return self.process_sdf(f_i, face_indices_mask, operator)
@torch.jit.export @torch.jit.ignore
def forward_training_volumes(self, surf_points, patch_id:int): def forward_training_volumes(self, surf_points, patch_id:int):
""" """
only surf sampled points only surf sampled points

194
brep2sdf/train.py

@ -350,106 +350,135 @@ class Trainer:
return total_loss # 对于单批次训练,直接返回当前损失 return total_loss # 对于单批次训练,直接返回当前损失
def train_stage2(self, num_epoch): def train_stage2(self, num_epoch):
if not args.use_normal:
logger.warning(f"need args.use_normal, skip stage2")
return float('inf')
self.model.freeze_stage2() self.model.freeze_stage2()
self.cached_train_data = None self.cached_train_data = None
num_volumes = self.data['surf_bbox_ncs'].shape[0] num_volumes = self.data['surf_bbox_ncs'].shape[0]
surf_bbox=torch.tensor( surf_bbox = torch.tensor(
self.data['surf_bbox_ncs'], self.data['surf_bbox_ncs'],
dtype=torch.float32, dtype=torch.float32,
device=self.device device=self.device
) )
logger.info(f"Start Stage 2 Training: {num_epoch} epochs") logger.info(f"Start Stage 2 Training: {num_epoch} epochs")
total_loss = 0.0 total_loss = 0.0
# 收集所有有效的点云数据和对应的 patch_ids
all_points = []
valid_patch_ids = []
nonmnfld_pnts_list, psdf_list = [], []
for patch_id in range(num_volumes): for patch_id in range(num_volumes):
points = points_in_box(self.train_surf_ncs, surf_bbox[patch_id]) points = points_in_box(self.train_surf_ncs, surf_bbox[patch_id])
loss = self.train_stage2_by_volume(num_epoch, patch_id, points) points = points.to(self.device)
logger.debug(f"Patch [{patch_id:2d}] | Loss: {loss:.6f}") if points.shape[0] == 0:
total_loss += loss logger.warning(f"Patch {patch_id} has no valid points.")
continue
return total_loss nonmnfld_pnts, psdf = self.sampler.get_norm_points(points[:,0:3], points[:,3:6]) # 生成非流形点
all_points.append(points)
valid_patch_ids.append(patch_id)
nonmnfld_pnts_list.append(nonmnfld_pnts)
psdf_list.append(psdf)
if not all_points:
logger.warning("No valid patches found.")
return 0.0
def train_stage2_by_volume(self, num_epoch, patch_id, points): weights = torch.tensor([points.shape[0] for points in all_points], device=self.device).float()
logger.debug(f"Patch [{patch_id:2d}] | train pnt number {points.shape[0]}") weights /= weights.sum()
points.to(self.device)
mnfld_pnts = points[:,0:3]
logger.debug(mnfld_pnts)
gt_sdf = torch.zeros(mnfld_pnts.shape[0], device=self.device)
if not args.use_normal:
logger.warning(f"need args.use_normal,skip stage2")
return float('inf')
normals = points[:,3:6]
logger.debug(normals)
nonmnfld_pnts, psdf = self.sampler.get_norm_points(mnfld_pnts, normals) # 生成非流形点
# --- 准备模型输入,启用梯度 --- # 清空梯度
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 self.scheduler.optimizer.zero_grad()
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
# 对每个 patch 进行前向传播并计算损失
for epoch in range(num_epoch): for epoch in range(num_epoch):
# --- 前向传播 --- losses = []
mnfld_pred = self.model.forward_training_volumes(mnfld_pnts, patch_id) loss_detailss = []
nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, patch_id) for patch_id, points,nonmnfld_pnts, psdf in zip(valid_patch_ids, all_points, nonmnfld_pnts_list,psdf_list):
logger.debug(f"Patch [{patch_id:2d}] | train pnt number {points.shape[0]}")
mnfld_pnts = points[:, 0:3]
gt_sdf = torch.zeros(mnfld_pnts.shape[0], device=self.device)
normals = points[:, 3:6]
# --- 计算损失 --- # --- 准备模型输入,启用梯度 ---
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败 mnfld_pnts.requires_grad_(True)
loss_details = {} nonmnfld_pnts.requires_grad_(True)
try:
logger.gpu_memory_stats("计算损失前")
loss, loss_details = self.loss_manager.compute_loss(
mnfld_pnts,
nonmnfld_pnts,
normals, # 传递检查过的 normals
gt_sdf,
mnfld_pred,
nonmnfld_pred,
psdf
)
# --- 4. 检查损失计算结果 --- # --- 前向传播 ---
if self.debug_mode: mnfld_pred = self.model.forward_training_volumes(mnfld_pnts, patch_id)
logger.print_tensor_stats("psdf",psdf) nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, patch_id)
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred)
if check_tensor(loss, "Calculated Loss", epoch, step):
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.")
if loss_details: logger.error(f"Loss Details: {loss_details}")
return float('inf') # 如果损失无效,停止这个epoch
except Exception as loss_e: # --- 计算损失 ---
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True) loss = torch.tensor(float('nan'), device=self.device)
return float('inf') # 如果计算出错,停止这个epoch loss_details = {}
logger.gpu_memory_stats("损失计算后") try:
loss, loss_details = self.loss_manager.compute_loss_volume(
mnfld_pnts,
nonmnfld_pnts,
normals,
gt_sdf,
mnfld_pred,
nonmnfld_pred,
psdf
)
# --- 反向传播和优化 --- # 检查损失计算结果
try: if self.debug_mode:
# 反向传播 if check_tensor(loss, "Calculated Loss", epoch):
self.scheduler.optimizer.zero_grad() # 清空梯度 logger.error(f"Epoch {epoch}: Loss calculation resulted in inf/nan.")
loss.backward() # 反向传播 if loss_details: logger.error(f"Loss Details: {loss_details}")
self.scheduler.optimizer.step() # 更新参数 return float('inf')
self.scheduler.step(loss,epoch)
except Exception as backward_e:
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True)
# 如果你想看是哪个操作导致的,可以启用 anomaly detection
# torch.autograd.set_detect_anomaly(True) # 放在训练开始前
return float('inf') # 如果反向传播或优化出错,停止这个epoch
except Exception as loss_e:
logger.error(f"Epoch {epoch}: Error during loss calculation: {loss_e}", exc_info=True)
return float('inf')
# 累积损失
losses.append(loss)
if epoch % 1 == 0:
loss_detailss.append(loss_details)
torch.cuda.empty_cache()
if epoch % 100 == 0:
# 记录训练进度 (只记录有效的损失)
logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {loss:.6f}')
if loss_details: logger.info(f"Loss Details: {loss_details}")
return loss # last loss # 多个损失平均后反向传播
loss_tensor = torch.stack(losses)
mean_loss = (loss_tensor * weights).sum()
mean_loss.backward()
# 更新参数
self.scheduler.optimizer.step()
self.scheduler.step(mean_loss, epoch)
# 清空梯度
self.scheduler.optimizer.zero_grad()
# 清理缓存
torch.cuda.empty_cache()
# 如果你想查看详细的损失信息,可以在这里添加日志记录
if epoch % 1 == 0:
logger.info(f'Train [Stage 2] Epoch: {epoch:4d}\t'
f'Loss: {loss:.6f}')
loss_details_tensor = torch.stack(loss_detailss) # shape: [num_patches, 5]
# 对每个子项取加权平均(如果需要 weights)
weighted_avg = (loss_details_tensor * weights.view(-1, 1)).sum(dim=0) / weights.sum()
subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"]
logger.info(" ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)]))
avg_loss = sum(losses) / len(losses)
logger.info(f"Total Loss: {total_loss:.6f} | Avg Loss: {avg_loss:.6f}")
return avg_loss
def train_epoch_stage2_(self, epoch: int): def train_epoch_stage2_(self, epoch: int):
total_loss = 0.0 total_loss = 0.0
@ -874,14 +903,31 @@ class Trainer:
logger.info(f'Checkpoint saved at epoch {epoch}') logger.info(f'Checkpoint saved at epoch {epoch}')
# stage2 freeze_stage2 # stage2 freeze_stage2
max_stage2_epoch = self.config.train.num_epochs1+self.config.train.num_epochs2
if start_epoch < max_stage2_epoch:
self.train_stage2(self.config.train.num_epochs2) self.scheduler.reset()
epoch = self.config.train.num_epochs1+self.config.train.num_epochs2 self.train_stage2(self.config.train.num_epochs2)
logger.info(f'Checkpoint saved at epoch {epoch}') cur_epoch = max_stage2_epoch
self._save_checkpoint(epoch, 0.0) logger.info(f'Checkpoint saved at epoch {cur_epoch}')
self._save_checkpoint(cur_epoch, 0.0)
else:
logger.info(f"start_epoch:{start_epoch} > {max_stage2_epoch}, skip stage 2 training.")
cur_epoch = start_epoch
#stage 3
self.model.encoder.unfreeze() self.model.encoder.unfreeze()
self.scheduler.reset()
for epoch in range(cur_epoch, max_stage2_epoch + self.config.train.num_epochs3 + 1):
# 训练一个epoch
train_loss = self.train_epoch_stage3(epoch)
#train_loss = self.train_epoch_stage2(epoch)
#train_loss = self.train_epoch(epoch)
# 保存检查点
if epoch % self.config.train.save_freq == 0:
self._save_checkpoint(epoch, train_loss)
logger.info(f'Checkpoint saved at epoch {epoch}')
# 训练完成 # 训练完成
total_time = time.time() - start_time total_time = time.time() - start_time

Loading…
Cancel
Save