From 965c0864c3ef3e680225bf9670d2c1027e931cfe Mon Sep 17 00:00:00 2001 From: mckay Date: Tue, 6 May 2025 12:41:58 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8F=AF=E4=BB=A5=E4=B8=89=E9=98=B6=E6=AE=B5?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=EF=BC=8C=E4=BD=86=E6=98=AF=20stage2=20?= =?UTF-8?q?=E6=9C=AA=E4=BD=BF=E7=94=A8=20batch=EF=BC=8C=E4=BC=9A=E6=BA=A2?= =?UTF-8?q?=E5=87=BA=E6=98=BE=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/networks/learning_rate.py | 4 +- brep2sdf/networks/loss.py | 19 +-- brep2sdf/networks/network.py | 4 +- brep2sdf/train.py | 194 ++++++++++++++++++----------- 4 files changed, 134 insertions(+), 87 deletions(-) diff --git a/brep2sdf/networks/learning_rate.py b/brep2sdf/networks/learning_rate.py index e2af628..6432247 100644 --- a/brep2sdf/networks/learning_rate.py +++ b/brep2sdf/networks/learning_rate.py @@ -42,8 +42,8 @@ class LearningRateScheduler: self.best_loss = float('inf') self.patience = 20 self.decay_factor = 0.5 - initial_lr = self.lr_schedules[0].get_learning_rate(0) - self.lr = initial_lr + self.initial_lr = self.lr_schedules[0].get_learning_rate(0) + self.lr = self.initial_lr self.epochs_since_improvement = 0 except Exception as e: diff --git a/brep2sdf/networks/loss.py b/brep2sdf/networks/loss.py index 70fb526..26d96c2 100644 --- a/brep2sdf/networks/loss.py +++ b/brep2sdf/networks/loss.py @@ -189,7 +189,7 @@ class LossManager: gt_sdfs, mnfld_pred, nonmnfld_pred, - ): + psdfs): """ 计算流型损失的逻辑 @@ -220,18 +220,19 @@ class LossManager: # 计算修正损失 #correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi) - + psdf_loss = self.psdf_loss(nonmnfld_pred, psdfs) # 汇总损失 - loss_details = { - "manifold": self.weights["manifold"] * manifold_loss, - "normals": self.weights["normals"] * normals_loss, - "eikonal": self.weights["eikonal"] * eikonal_loss, - "offsurface": self.weights["offsurface"] * offsurface_loss, - } + loss_details = torch.stack([ + self.weights["manifold"] * manifold_loss, + self.weights["normals"] * normals_loss, + self.weights["eikonal"] * eikonal_loss, + self.weights["offsurface"] * offsurface_loss, + self.weights["psdf"] * psdf_loss + ]) # 计算总损失 - total_loss = sum(loss_details.values()) + total_loss = loss_details.sum() return total_loss, loss_details diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index babdddb..b921940 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -153,7 +153,7 @@ class Net(nn.Module): h = self.decoder.forward_training_volumes(feature_vectors) # (B, D) return h - @torch.jit.export + @torch.jit.ignore def forward_without_octree(self, query_points,face_indices_mask,operator): """ 前向传播 @@ -175,7 +175,7 @@ class Net(nn.Module): #logger.debug("step combine") return self.process_sdf(f_i, face_indices_mask, operator) - @torch.jit.export + @torch.jit.ignore def forward_training_volumes(self, surf_points, patch_id:int): """ only surf sampled points diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 39d0a73..a8c6177 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -350,106 +350,135 @@ class Trainer: return total_loss # 对于单批次训练,直接返回当前损失 def train_stage2(self, num_epoch): + if not args.use_normal: + logger.warning(f"need args.use_normal, skip stage2") + return float('inf') self.model.freeze_stage2() self.cached_train_data = None num_volumes = self.data['surf_bbox_ncs'].shape[0] - surf_bbox=torch.tensor( + surf_bbox = torch.tensor( self.data['surf_bbox_ncs'], dtype=torch.float32, device=self.device ) logger.info(f"Start Stage 2 Training: {num_epoch} epochs") + total_loss = 0.0 + + # 收集所有有效的点云数据和对应的 patch_ids + all_points = [] + valid_patch_ids = [] + nonmnfld_pnts_list, psdf_list = [], [] + for patch_id in range(num_volumes): points = points_in_box(self.train_surf_ncs, surf_bbox[patch_id]) - loss = self.train_stage2_by_volume(num_epoch, patch_id, points) - logger.debug(f"Patch [{patch_id:2d}] | Loss: {loss:.6f}") - total_loss += loss + points = points.to(self.device) + if points.shape[0] == 0: + logger.warning(f"Patch {patch_id} has no valid points.") + continue - return total_loss + nonmnfld_pnts, psdf = self.sampler.get_norm_points(points[:,0:3], points[:,3:6]) # 生成非流形点 + all_points.append(points) + valid_patch_ids.append(patch_id) + nonmnfld_pnts_list.append(nonmnfld_pnts) + psdf_list.append(psdf) + if not all_points: + logger.warning("No valid patches found.") + return 0.0 - def train_stage2_by_volume(self, num_epoch, patch_id, points): - logger.debug(f"Patch [{patch_id:2d}] | train pnt number {points.shape[0]}") - points.to(self.device) - mnfld_pnts = points[:,0:3] - logger.debug(mnfld_pnts) - gt_sdf = torch.zeros(mnfld_pnts.shape[0], device=self.device) - if not args.use_normal: - logger.warning(f"need args.use_normal,skip stage2") - return float('inf') - normals = points[:,3:6] - logger.debug(normals) - nonmnfld_pnts, psdf = self.sampler.get_norm_points(mnfld_pnts, normals) # 生成非流形点 + weights = torch.tensor([points.shape[0] for points in all_points], device=self.device).float() + weights /= weights.sum() - # --- 准备模型输入,启用梯度 --- - mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 - nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 + # 清空梯度 + self.scheduler.optimizer.zero_grad() + # 对每个 patch 进行前向传播并计算损失 + for epoch in range(num_epoch): - # --- 前向传播 --- - mnfld_pred = self.model.forward_training_volumes(mnfld_pnts, patch_id) - nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, patch_id) + losses = [] + loss_detailss = [] + for patch_id, points,nonmnfld_pnts, psdf in zip(valid_patch_ids, all_points, nonmnfld_pnts_list,psdf_list): + logger.debug(f"Patch [{patch_id:2d}] | train pnt number {points.shape[0]}") + + mnfld_pnts = points[:, 0:3] + gt_sdf = torch.zeros(mnfld_pnts.shape[0], device=self.device) + + normals = points[:, 3:6] - # --- 计算损失 --- - loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败 - loss_details = {} - try: - logger.gpu_memory_stats("计算损失前") - loss, loss_details = self.loss_manager.compute_loss( - mnfld_pnts, - nonmnfld_pnts, - normals, # 传递检查过的 normals - gt_sdf, - mnfld_pred, - nonmnfld_pred, - psdf - ) + # --- 准备模型输入,启用梯度 --- + mnfld_pnts.requires_grad_(True) + nonmnfld_pnts.requires_grad_(True) - # --- 4. 检查损失计算结果 --- - if self.debug_mode: - logger.print_tensor_stats("psdf",psdf) - logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) - if check_tensor(loss, "Calculated Loss", epoch, step): - logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") - if loss_details: logger.error(f"Loss Details: {loss_details}") - return float('inf') # 如果损失无效,停止这个epoch + # --- 前向传播 --- + mnfld_pred = self.model.forward_training_volumes(mnfld_pnts, patch_id) + nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, patch_id) - except Exception as loss_e: - logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True) - return float('inf') # 如果计算出错,停止这个epoch - logger.gpu_memory_stats("损失计算后") + # --- 计算损失 --- + loss = torch.tensor(float('nan'), device=self.device) + loss_details = {} + try: + loss, loss_details = self.loss_manager.compute_loss_volume( + mnfld_pnts, + nonmnfld_pnts, + normals, + gt_sdf, + mnfld_pred, + nonmnfld_pred, + psdf + ) - # --- 反向传播和优化 --- - try: - # 反向传播 - self.scheduler.optimizer.zero_grad() # 清空梯度 - loss.backward() # 反向传播 - self.scheduler.optimizer.step() # 更新参数 - self.scheduler.step(loss,epoch) - except Exception as backward_e: - logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) - # 如果你想看是哪个操作导致的,可以启用 anomaly detection - # torch.autograd.set_detect_anomaly(True) # 放在训练开始前 - return float('inf') # 如果反向传播或优化出错,停止这个epoch + # 检查损失计算结果 + if self.debug_mode: + if check_tensor(loss, "Calculated Loss", epoch): + logger.error(f"Epoch {epoch}: Loss calculation resulted in inf/nan.") + if loss_details: logger.error(f"Loss Details: {loss_details}") + return float('inf') + except Exception as loss_e: + logger.error(f"Epoch {epoch}: Error during loss calculation: {loss_e}", exc_info=True) + return float('inf') + # 累积损失 + losses.append(loss) + if epoch % 1 == 0: + loss_detailss.append(loss_details) - torch.cuda.empty_cache() - if epoch % 100 == 0: - # 记录训练进度 (只记录有效的损失) - logger.info(f'Train Epoch: {epoch:4d}]\t' - f'Loss: {loss:.6f}') - if loss_details: logger.info(f"Loss Details: {loss_details}") - return loss # last loss + # 多个损失平均后反向传播 + loss_tensor = torch.stack(losses) + mean_loss = (loss_tensor * weights).sum() + mean_loss.backward() + # 更新参数 + self.scheduler.optimizer.step() + self.scheduler.step(mean_loss, epoch) + # 清空梯度 + self.scheduler.optimizer.zero_grad() + # 清理缓存 + torch.cuda.empty_cache() + + # 如果你想查看详细的损失信息,可以在这里添加日志记录 + if epoch % 1 == 0: + logger.info(f'Train [Stage 2] Epoch: {epoch:4d}\t' + f'Loss: {loss:.6f}') + loss_details_tensor = torch.stack(loss_detailss) # shape: [num_patches, 5] + + # 对每个子项取加权平均(如果需要 weights) + weighted_avg = (loss_details_tensor * weights.view(-1, 1)).sum(dim=0) / weights.sum() + subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"] + logger.info(" ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)])) + + avg_loss = sum(losses) / len(losses) + logger.info(f"Total Loss: {total_loss:.6f} | Avg Loss: {avg_loss:.6f}") + + return avg_loss def train_epoch_stage2_(self, epoch: int): total_loss = 0.0 @@ -874,14 +903,31 @@ class Trainer: logger.info(f'Checkpoint saved at epoch {epoch}') # stage2 freeze_stage2 - - - self.train_stage2(self.config.train.num_epochs2) - epoch = self.config.train.num_epochs1+self.config.train.num_epochs2 - logger.info(f'Checkpoint saved at epoch {epoch}') - self._save_checkpoint(epoch, 0.0) + max_stage2_epoch = self.config.train.num_epochs1+self.config.train.num_epochs2 + if start_epoch < max_stage2_epoch: + self.scheduler.reset() + self.train_stage2(self.config.train.num_epochs2) + cur_epoch = max_stage2_epoch + logger.info(f'Checkpoint saved at epoch {cur_epoch}') + self._save_checkpoint(cur_epoch, 0.0) + else: + logger.info(f"start_epoch:{start_epoch} > {max_stage2_epoch}, skip stage 2 training.") + cur_epoch = start_epoch + #stage 3 self.model.encoder.unfreeze() + self.scheduler.reset() + for epoch in range(cur_epoch, max_stage2_epoch + self.config.train.num_epochs3 + 1): + # 训练一个epoch + train_loss = self.train_epoch_stage3(epoch) + #train_loss = self.train_epoch_stage2(epoch) + #train_loss = self.train_epoch(epoch) + + # 保存检查点 + if epoch % self.config.train.save_freq == 0: + self._save_checkpoint(epoch, train_loss) + logger.info(f'Checkpoint saved at epoch {epoch}') + # 训练完成 total_time = time.time() - start_time