From 9689c0314e34c0178f2193035ee55e5e378ca56d Mon Sep 17 00:00:00 2001 From: mckay Date: Mon, 5 May 2025 19:22:58 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8F=AF=E4=BB=A5=E4=BA=8C=E9=98=B6=E6=AE=B5?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=EF=BC=8C=E4=BD=86=E6=98=AF=E4=BA=8C=E9=98=B6?= =?UTF-8?q?=E6=AE=B5=E5=B9=B6=E8=A1=8C=E8=83=BD=E5=8A=9B=E5=BE=88=E5=B7=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/networks/encoder.py | 15 +- brep2sdf/networks/loss.py | 53 ++++++ brep2sdf/train.py | 322 +++++++++++++++++++++++------------ 3 files changed, 269 insertions(+), 121 deletions(-) diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index 2d0c074..0ede43f 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -109,7 +109,7 @@ class Encoder(nn.Module): background_features = self.background.forward(query_points) # (B, D) return background_features - @torch.jit.export + @torch.jit.ignore def forward_training_volumes(self, surf_points: torch.Tensor, patch_id: int) -> torch.Tensor: """ 处理表面采样点的特征提取 @@ -119,18 +119,9 @@ class Encoder(nn.Module): 特征张量 (S, D) """ # 获取 patch 特征 - patch_features = torch.zeros(surf_points.shape[0], self.feature_dim, device=surf_points.device) - for idx, volume in enumerate(self.feature_volumes): - if idx == patch_id: - patch_features = volume.forward(surf_points) + patch_features = self.feature_volumes[patch_id].forward(surf_points) - # 获取背景场特征 - background_features = self.background.forward(surf_points) - - # 叠加 patch 和背景场特征 - combined_features = 0.7 * patch_features + 0.3 * background_features - - return combined_features + return patch_features def _optimized_trilinear(self, points, bboxes, features)-> torch.Tensor: """优化后的向量化三线性插值""" diff --git a/brep2sdf/networks/loss.py b/brep2sdf/networks/loss.py index 1f282e1..70fb526 100644 --- a/brep2sdf/networks/loss.py +++ b/brep2sdf/networks/loss.py @@ -182,6 +182,59 @@ class LossManager: return total_loss, loss_details + def compute_loss_volume(self, + mnfld_pnts, + nonmnfld_pnts, + normals, + gt_sdfs, + mnfld_pred, + nonmnfld_pred, + ): + """ + 计算流型损失的逻辑 + + :param outputs: 模型的输出 + :return: 计算得到的流型损失值 + """ + # 强制类型转换确保一致性 + normals = normals.to(torch.float32) + mnfld_pred = mnfld_pred.to(torch.float32) + gt_sdfs = gt_sdfs.to(torch.float32) + + # 计算流形损失 + manifold_loss = self.position_loss(mnfld_pred, gt_sdfs) + + # 计算法线损失 + normals_loss = self.normals_loss(normals, mnfld_pnts, mnfld_pred) + #logger.gpu_memory_stats("计算法线损失后") + + # 计算Eikonal损失 + eikonal_loss = self.eikonal_loss(nonmnfld_pnts, nonmnfld_pred) + + # 计算离表面损失 + offsurface_loss = self.offsurface_loss(nonmnfld_pnts, nonmnfld_pred) + + # 计算一致性损失 + #onsistency_loss = self.consistency_loss(mnfld_pnts, mnfld_pred, all_fi) + + # 计算修正损失 + #correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi) + + + + # 汇总损失 + loss_details = { + "manifold": self.weights["manifold"] * manifold_loss, + "normals": self.weights["normals"] * normals_loss, + "eikonal": self.weights["eikonal"] * eikonal_loss, + "offsurface": self.weights["offsurface"] * offsurface_loss, + } + + # 计算总损失 + total_loss = sum(loss_details.values()) + + return total_loss, loss_details + def _compute_loss(self, mnfld_pnts, normals, mnfld_pred_all, nonmnfld_pnts, nonmnfld_pred_all, n_batchsize, n_branch, n_patch_batch, n_patch_last): """ 计算流型损失的逻辑 diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 7d3918c..39d0a73 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -5,9 +5,11 @@ import os import numpy as np import argparse + from brep2sdf.config.default_config import get_default_config from brep2sdf.data.data import load_brep_file,prepare_sdf_data, print_data_distribution, check_tensor from brep2sdf.data.pre_process_by_mesh import process_single_step +from brep2sdf.data.utils import points_in_box from brep2sdf.networks.network import Net from brep2sdf.networks.octree import OctreeNode from brep2sdf.networks.loss import LossManager @@ -203,97 +205,7 @@ class Trainer: # # 返回合并后的边界框 # return torch.cat([global_min, global_max]) # return [-0.5,] # 这个是错误的 - def train_epoch_stage1_(self, epoch: int): - total_loss = 0.0 - total_loss_details = { - "manifold": 0.0, - "normals": 0.0, - "eikonal": 0.0, - "offsurface": 0.0 - } - accumulated_loss = 0.0 # 新增:用于累积多个step的loss - - # 新增:在每个epoch开始时清零梯度 - self.optimizer.zero_grad() - - for step, surf_points in enumerate(self.data['surf_ncs']): - mnfld_points = torch.tensor(surf_points, device=self.device) - nonmnfld_pnts = self.sampler.get_points(mnfld_points.unsqueeze(0)).squeeze(0) # 生成非流形点 - gt_sdf = torch.zeros(mnfld_points.shape[0], device=self.device) - normals = None - if args.use_normal: - normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device) - logger.debug(normals) - - # --- 准备模型输入,启用梯度 --- - mnfld_points.requires_grad_(True) # 在检查之后启用梯度 - nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 - - # --- 前向传播 --- - self.optimizer.zero_grad() - mnfld_pred = self.model.forward_training_volumes(mnfld_points, step) - nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step) - - if self.debug_mode: - # --- 检查前向传播的输出 --- - logger.print_tensor_stats("mnfld_pred",mnfld_pred) - logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) - logger.gpu_memory_stats("前向传播后") - - # --- 计算损失 --- - try: - if args.use_normal: - loss, loss_details = self.loss_manager.compute_loss( - mnfld_points, - nonmnfld_pnts, - normals, - gt_sdf, - mnfld_pred, - nonmnfld_pred - ) - else: - loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) - loss_details = {} # 确保变量初始化 - - # 修改:累积loss而不是立即backward - accumulated_loss += loss / self.config.train.accumulation_steps # 假设配置中有accumulation_steps - current_loss = loss.item() - total_loss += current_loss - for key in total_loss_details: - if key in loss_details: - total_loss_details[key] += loss_details[key].item() - - # 新增:达到累积步数时执行反向传播 - if (step + 1) % self.config.train.accumulation_steps == 0: - # 反向传播 - self.scheduler.optimizer.zero_grad() # 清空梯度 - loss.backward() # 反向传播 - self.scheduler.optimizer.step() # 更新参数 - self.scheduler.step(accumulated_loss,epoch) - - # 记录日志保持不变 ... - - except Exception as loss_e: - logger.error(f"Error in step {step}: {loss_e}") - continue - - # --- 内存管理 --- - del loss - torch.cuda.empty_cache() - - # 新增:处理最后未达到累积步数的剩余loss - if accumulated_loss != 0: - # 反向传播 - self.scheduler.optimizer.zero_grad() # 清空梯度 - loss.backward() # 反向传播 - self.scheduler.optimizer.step() # 更新参数 - self.scheduler.step(accumulated_loss,epoch) - - # 计算并记录epoch损失 - logger.info(f'Train Epoch: {epoch:4d}]\t' - f'Loss: {total_loss:.6f}') - logger.info(f"Loss Details: {total_loss_details}") - return total_loss # 返回平均损失而非累计值 + def train_epoch_stage1(self, epoch: int) -> float: @@ -437,7 +349,202 @@ class Trainer: return total_loss # 对于单批次训练,直接返回当前损失 - def train_epoch_stage2(self, epoch: int) -> float: + def train_stage2(self, num_epoch): + self.model.freeze_stage2() + self.cached_train_data = None + + num_volumes = self.data['surf_bbox_ncs'].shape[0] + surf_bbox=torch.tensor( + self.data['surf_bbox_ncs'], + dtype=torch.float32, + device=self.device + ) + logger.info(f"Start Stage 2 Training: {num_epoch} epochs") + total_loss = 0.0 + for patch_id in range(num_volumes): + points = points_in_box(self.train_surf_ncs, surf_bbox[patch_id]) + loss = self.train_stage2_by_volume(num_epoch, patch_id, points) + logger.debug(f"Patch [{patch_id:2d}] | Loss: {loss:.6f}") + total_loss += loss + + return total_loss + + + def train_stage2_by_volume(self, num_epoch, patch_id, points): + logger.debug(f"Patch [{patch_id:2d}] | train pnt number {points.shape[0]}") + points.to(self.device) + mnfld_pnts = points[:,0:3] + logger.debug(mnfld_pnts) + gt_sdf = torch.zeros(mnfld_pnts.shape[0], device=self.device) + if not args.use_normal: + logger.warning(f"need args.use_normal,skip stage2") + return float('inf') + normals = points[:,3:6] + logger.debug(normals) + nonmnfld_pnts, psdf = self.sampler.get_norm_points(mnfld_pnts, normals) # 生成非流形点 + + # --- 准备模型输入,启用梯度 --- + mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 + nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 + + for epoch in range(num_epoch): + # --- 前向传播 --- + mnfld_pred = self.model.forward_training_volumes(mnfld_pnts, patch_id) + nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, patch_id) + + + # --- 计算损失 --- + loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败 + loss_details = {} + try: + logger.gpu_memory_stats("计算损失前") + loss, loss_details = self.loss_manager.compute_loss( + mnfld_pnts, + nonmnfld_pnts, + normals, # 传递检查过的 normals + gt_sdf, + mnfld_pred, + nonmnfld_pred, + psdf + ) + + # --- 4. 检查损失计算结果 --- + if self.debug_mode: + logger.print_tensor_stats("psdf",psdf) + logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) + if check_tensor(loss, "Calculated Loss", epoch, step): + logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") + if loss_details: logger.error(f"Loss Details: {loss_details}") + return float('inf') # 如果损失无效,停止这个epoch + + except Exception as loss_e: + logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True) + return float('inf') # 如果计算出错,停止这个epoch + logger.gpu_memory_stats("损失计算后") + + # --- 反向传播和优化 --- + try: + # 反向传播 + self.scheduler.optimizer.zero_grad() # 清空梯度 + loss.backward() # 反向传播 + self.scheduler.optimizer.step() # 更新参数 + self.scheduler.step(loss,epoch) + except Exception as backward_e: + logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) + # 如果你想看是哪个操作导致的,可以启用 anomaly detection + # torch.autograd.set_detect_anomaly(True) # 放在训练开始前 + return float('inf') # 如果反向传播或优化出错,停止这个epoch + + + + torch.cuda.empty_cache() + + if epoch % 100 == 0: + # 记录训练进度 (只记录有效的损失) + logger.info(f'Train Epoch: {epoch:4d}]\t' + f'Loss: {loss:.6f}') + if loss_details: logger.info(f"Loss Details: {loss_details}") + + return loss # last loss + + + + + + def train_epoch_stage2_(self, epoch: int): + total_loss = 0.0 + total_loss_details = { + "manifold": 0.0, + "normals": 0.0, + "eikonal": 0.0, + "offsurface": 0.0 + } + accumulated_loss = 0.0 # 新增:用于累积多个step的loss + + # 新增:在每个epoch开始时清零梯度 + self.scheduler.optimizer.zero_grad() + + for step, surf_points in enumerate(self.data['surf_ncs']): + mnfld_points = torch.tensor(surf_points, device=self.device) + nonmnfld_pnts = self.sampler.get_points(mnfld_points.unsqueeze(0)).squeeze(0) # 生成非流形点 + gt_sdf = torch.zeros(mnfld_points.shape[0], device=self.device) + normals = None + if args.use_normal: + normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device) + logger.debug(normals) + + # --- 准备模型输入,启用梯度 --- + mnfld_points.requires_grad_(True) # 在检查之后启用梯度 + nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 + + # --- 前向传播 --- + self.scheduler.optimizer.zero_grad() + mnfld_pred = self.model.forward_training_volumes(mnfld_points, step) + nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step) + + if self.debug_mode: + # --- 检查前向传播的输出 --- + logger.print_tensor_stats("mnfld_pred",mnfld_pred) + logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) + logger.gpu_memory_stats("前向传播后") + + # --- 计算损失 --- + try: + if args.use_normal: + loss, loss_details = self.loss_manager.compute_loss_volume( + mnfld_points, + nonmnfld_pnts, + normals, + gt_sdf, + mnfld_pred, + nonmnfld_pred + ) + else: + loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) + loss_details = {} # 确保变量初始化 + + # 修改:累积loss而不是立即backward + accumulated_loss += loss / self.config.train.accumulation_steps # 假设配置中有accumulation_steps + current_loss = loss.item() + total_loss += current_loss + for key in total_loss_details: + if key in loss_details: + total_loss_details[key] += loss_details[key].item() + + # 新增:达到累积步数时执行反向传播 + if (step + 1) % self.config.train.accumulation_steps == 0: + # 反向传播 + self.scheduler.optimizer.zero_grad() # 清空梯度 + accumulated_loss.backward() # 反向传播 + self.scheduler.optimizer.step() # 更新参数 + self.scheduler.step(accumulated_loss,epoch) + accumulated_loss = 0.0 + + # 记录日志保持不变 ... + + except Exception as loss_e: + logger.error(f"Error in step {step}: {loss_e}") + continue + + # --- 内存管理 --- + del loss + torch.cuda.empty_cache() + + # 新增:处理最后未达到累积步数的剩余loss + if accumulated_loss != 0: + # 反向传播 + self.scheduler.optimizer.zero_grad() # 清空梯度 + accumulated_loss.backward() # 反向传播 + self.scheduler.optimizer.step() # 更新参数 + self.scheduler.step(accumulated_loss,epoch) + + # 计算并记录epoch损失 + logger.info(f'Train Epoch: {epoch:4d}]\t' + f'Loss: {total_loss:.6f}') + logger.info(f"Loss Details: {total_loss_details}") + return total_loss # 返回平均损失而非累计值 + + def train_epoch_stage3(self, epoch: int) -> float: # --- 1. 检查输入数据 --- # 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) # 并且 SDF 值总是在最后一列 @@ -521,6 +628,7 @@ class Trainer: ) #logger.print_tensor_stats("psdf",psdf) + #logger.print_tensor_stats("nonmnfld_pnts",nonmnfld_pnts) #logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) # --- 计算损失 --- @@ -752,31 +860,27 @@ class Trainer: start_epoch = self._load_checkpoint(args.resume_checkpoint_path) logger.info(f"Loaded model from {args.resume_checkpoint_path}") + # stage1 self.model.encoder.freeze_stage1() - for epoch in range(start_epoch, self.config.train.num_epochs + 1): + for epoch in range(start_epoch, self.config.train.num_epochs1 + 1): # 训练一个epoch train_loss = self.train_epoch_stage1(epoch) #train_loss = self.train_epoch_stage2(epoch) #train_loss = self.train_epoch(epoch) - - # 验证 - ''' - if epoch % self.config.train.val_freq == 0: - val_loss = self.validate(epoch) - logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}') - - # 保存最佳模型 - if val_loss < best_val_loss: - best_val_loss = val_loss - self._save_model(epoch, val_loss) - logger.info(f'New best model saved at epoch {epoch} with val loss {val_loss:.6f}') - else: - logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}') - ''' + # 保存检查点 if epoch % self.config.train.save_freq == 0: self._save_checkpoint(epoch, train_loss) logger.info(f'Checkpoint saved at epoch {epoch}') + + # stage2 freeze_stage2 + + + self.train_stage2(self.config.train.num_epochs2) + epoch = self.config.train.num_epochs1+self.config.train.num_epochs2 + logger.info(f'Checkpoint saved at epoch {epoch}') + self._save_checkpoint(epoch, 0.0) + self.model.encoder.unfreeze() # 训练完成 total_time = time.time() - start_time @@ -848,7 +952,7 @@ class Trainer: def _load_checkpoint(self, checkpoint_path): """从检查点恢复训练状态""" try: - checkpoint = torch.load(checkpoint_path) + checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) self.model.load_state_dict(checkpoint['model_state_dict']) self.scheduler.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) return checkpoint['epoch'] + 1