diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 280a79b..8b71d7c 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -579,10 +579,10 @@ class Trainer: def train_epoch_stage3(self, epoch: int) -> float: # --- 1. 检查输入数据 --- - # 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) + # 注意:假设 self.train_surf_ncs 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) # 并且 SDF 值总是在最后一列 - if self.sdf_data is None: - logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.") + if self.train_surf_ncs is None: + logger.error(f"Epoch {epoch}: self.train_surf_ncs is None. Cannot train.") return float('inf') self.model.train() @@ -592,9 +592,9 @@ class Trainer: # 数据处理 # manfld - _mnfld_pnts = self.sdf_data[:, 0:3].clone().detach() # 取前3列作为点 - _normals = self.sdf_data[:, 3:6].clone().detach() # 取中间3列作为法线 - _gt_sdf = self.sdf_data[:, -1].clone().detach() # 取最后一列作为SDF真值 + _mnfld_pnts = self.train_surf_ncs[:, 0:3].clone().detach() # 取前3列作为点 + _normals = self.train_surf_ncs[:, 3:6].clone().detach() # 取中间3列作为法线 + _gt_sdf = self.train_surf_ncs[:, -1].clone().detach() # 取最后一列作为SDF真值 # 检查是否需要重新计算缓存 if epoch % 10 == 1 or self.cached_train_data is None: @@ -628,7 +628,7 @@ class Trainer: # 将数据分成多个batch - num_points = self.sdf_data.shape[0] + num_points = self.train_surf_ncs.shape[0] num_batches = (num_points + batch_size - 1) // batch_size for batch_idx in range(num_batches): @@ -894,7 +894,7 @@ class Trainer: logger.info(f"Loaded model from {args.resume_checkpoint_path}") # stage1 - self.model.freeze_stage2() + self.model.freeze_stage1() for epoch in range(start_epoch, self.config.train.num_epochs1 + 1): # 训练一个epoch train_loss = self.train_epoch_stage1(epoch) @@ -921,8 +921,8 @@ class Trainer: #stage 3 self.scheduler.reset() - #self.model.freeze_stage2() - self.model.unfreeze() + self.model.freeze_stage2() + #self.model.unfreeze() for epoch in range(cur_epoch + 1, max_stage2_epoch + self.config.train.num_epochs3 + 1): # 训练一个epoch train_loss = self.train_epoch_stage3(epoch)