From 1297c3c28eb2ca1fa41e7cd88b32169d71106e3a Mon Sep 17 00:00:00 2001 From: mckay Date: Wed, 7 May 2025 15:44:20 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8F=AA=E7=94=A8=E5=BF=99=E7=9A=84=E8=AE=AD?= =?UTF-8?q?=E7=BB=83stage1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/train.py | 43 +++++++++++-------------------------------- 1 file changed, 11 insertions(+), 32 deletions(-) diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 7133e92..280a79b 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -102,8 +102,10 @@ class Trainer: self.train_surf_ncs = sampled_sdf_data else: self.sdf_data = surface_sdf_data - print_data_distribution(self.sdf_data) + logger.print_tensor_stats("sdfd_data",self.sdf_data) logger.debug(self.sdf_data.shape) + logger.print_tensor_stats("train_surf_ncs",self.train_surf_ncs[:,0:3]) + logger.debug(self.train_surf_ncs.shape) logger.gpu_memory_stats("SDF数据准备后") # 初始化数据集 #self.brep_data = load_brep_file(self.config.data.pkl_path) @@ -229,22 +231,9 @@ class Trainer: _gt_sdf = self.train_surf_ncs[:, -1].clone().detach() # 取最后一列作为SDF真值 # 检查是否需要重新计算缓存 - if epoch % 10 == 1 or self.cached_train_data is None: - # 计算流形点的掩码和操作符 - # 生成非流形点 - _nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals) - - # 更新缓存 - self.cached_train_data = { - "nonmnfld_pnts": _nonmnfld_pnts, - "psdf": _psdf, - } - else: - # 从缓存中读取数据 - _nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"] - _psdf = self.cached_train_data["psdf"] + - logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf)) + logger.debug((_mnfld_pnts)) @@ -260,22 +249,14 @@ class Trainer: mnfld_pnts = _mnfld_pnts[start_idx:end_idx] # 流形点 gt_sdf = _gt_sdf[start_idx:end_idx] # SDF真值 normals = _normals[start_idx:end_idx] if args.use_normal else None # 法线 - - # 非流形点使用缓存数据(整个batch共享) - nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx] - psdf = _psdf[start_idx:end_idx] # --- 准备模型输入,启用梯度 --- mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 - nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 # --- 前向传播 --- mnfld_pred = self.model.forward_background( mnfld_pnts ) - nonmnfld_pred = self.model.forward_background( - nonmnfld_pnts - ) @@ -292,14 +273,11 @@ class Trainer: #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") logger.gpu_memory_stats("计算损失前") - loss, loss_details = self.loss_manager.compute_loss( + loss, loss_details = self.loss_manager.compute_loss_stage1( mnfld_pnts, - nonmnfld_pnts, normals, # 传递检查过的 normals gt_sdf, mnfld_pred, - nonmnfld_pred, - psdf ) else: loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) @@ -307,7 +285,6 @@ class Trainer: # --- 4. 检查损失计算结果 --- if self.debug_mode: logger.print_tensor_stats("psdf",psdf) - logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) if check_tensor(loss, "Calculated Loss", epoch, step): logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") if loss_details: logger.error(f"Loss Details: {loss_details}") @@ -917,7 +894,7 @@ class Trainer: logger.info(f"Loaded model from {args.resume_checkpoint_path}") # stage1 - self.model.encoder.freeze_stage1() + self.model.freeze_stage2() for epoch in range(start_epoch, self.config.train.num_epochs1 + 1): # 训练一个epoch train_loss = self.train_epoch_stage1(epoch) @@ -929,6 +906,7 @@ class Trainer: self._save_checkpoint(epoch, train_loss) logger.info(f'Checkpoint saved at epoch {epoch}') + start_epoch=max(start_epoch, self.config.train.num_epochs1) # stage2 freeze_stage2 max_stage2_epoch = self.config.train.num_epochs1+self.config.train.num_epochs2 if start_epoch < max_stage2_epoch: @@ -942,9 +920,10 @@ class Trainer: cur_epoch = start_epoch #stage 3 - self.model.encoder.unfreeze() self.scheduler.reset() - for epoch in range(cur_epoch, max_stage2_epoch + self.config.train.num_epochs3 + 1): + #self.model.freeze_stage2() + self.model.unfreeze() + for epoch in range(cur_epoch + 1, max_stage2_epoch + self.config.train.num_epochs3 + 1): # 训练一个epoch train_loss = self.train_epoch_stage3(epoch) #train_loss = self.train_epoch_stage2(epoch)