From 6c274cd293ccfb0db431c34fd6013c54c0563306 Mon Sep 17 00:00:00 2001 From: mckay Date: Wed, 7 May 2025 15:52:44 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=BF=E6=8D=A2stage3=20=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/train.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 280a79b..8b71d7c 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -579,10 +579,10 @@ class Trainer: def train_epoch_stage3(self, epoch: int) -> float: # --- 1. 检查输入数据 --- - # 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) + # 注意:假设 self.train_surf_ncs 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) # 并且 SDF 值总是在最后一列 - if self.sdf_data is None: - logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.") + if self.train_surf_ncs is None: + logger.error(f"Epoch {epoch}: self.train_surf_ncs is None. Cannot train.") return float('inf') self.model.train() @@ -592,9 +592,9 @@ class Trainer: # 数据处理 # manfld - _mnfld_pnts = self.sdf_data[:, 0:3].clone().detach() # 取前3列作为点 - _normals = self.sdf_data[:, 3:6].clone().detach() # 取中间3列作为法线 - _gt_sdf = self.sdf_data[:, -1].clone().detach() # 取最后一列作为SDF真值 + _mnfld_pnts = self.train_surf_ncs[:, 0:3].clone().detach() # 取前3列作为点 + _normals = self.train_surf_ncs[:, 3:6].clone().detach() # 取中间3列作为法线 + _gt_sdf = self.train_surf_ncs[:, -1].clone().detach() # 取最后一列作为SDF真值 # 检查是否需要重新计算缓存 if epoch % 10 == 1 or self.cached_train_data is None: @@ -628,7 +628,7 @@ class Trainer: # 将数据分成多个batch - num_points = self.sdf_data.shape[0] + num_points = self.train_surf_ncs.shape[0] num_batches = (num_points + batch_size - 1) // batch_size for batch_idx in range(num_batches): @@ -894,7 +894,7 @@ class Trainer: logger.info(f"Loaded model from {args.resume_checkpoint_path}") # stage1 - self.model.freeze_stage2() + self.model.freeze_stage1() for epoch in range(start_epoch, self.config.train.num_epochs1 + 1): # 训练一个epoch train_loss = self.train_epoch_stage1(epoch) @@ -921,8 +921,8 @@ class Trainer: #stage 3 self.scheduler.reset() - #self.model.freeze_stage2() - self.model.unfreeze() + self.model.freeze_stage2() + #self.model.unfreeze() for epoch in range(cur_epoch + 1, max_stage2_epoch + self.config.train.num_epochs3 + 1): # 训练一个epoch train_loss = self.train_epoch_stage3(epoch)