|
|
@ -579,10 +579,10 @@ class Trainer: |
|
|
|
|
|
|
|
def train_epoch_stage3(self, epoch: int) -> float: |
|
|
|
# --- 1. 检查输入数据 --- |
|
|
|
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) |
|
|
|
# 注意:假设 self.train_surf_ncs 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) |
|
|
|
# 并且 SDF 值总是在最后一列 |
|
|
|
if self.sdf_data is None: |
|
|
|
logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.") |
|
|
|
if self.train_surf_ncs is None: |
|
|
|
logger.error(f"Epoch {epoch}: self.train_surf_ncs is None. Cannot train.") |
|
|
|
return float('inf') |
|
|
|
|
|
|
|
self.model.train() |
|
|
@ -592,9 +592,9 @@ class Trainer: |
|
|
|
|
|
|
|
# 数据处理 |
|
|
|
# manfld |
|
|
|
_mnfld_pnts = self.sdf_data[:, 0:3].clone().detach() # 取前3列作为点 |
|
|
|
_normals = self.sdf_data[:, 3:6].clone().detach() # 取中间3列作为法线 |
|
|
|
_gt_sdf = self.sdf_data[:, -1].clone().detach() # 取最后一列作为SDF真值 |
|
|
|
_mnfld_pnts = self.train_surf_ncs[:, 0:3].clone().detach() # 取前3列作为点 |
|
|
|
_normals = self.train_surf_ncs[:, 3:6].clone().detach() # 取中间3列作为法线 |
|
|
|
_gt_sdf = self.train_surf_ncs[:, -1].clone().detach() # 取最后一列作为SDF真值 |
|
|
|
|
|
|
|
# 检查是否需要重新计算缓存 |
|
|
|
if epoch % 10 == 1 or self.cached_train_data is None: |
|
|
@ -628,7 +628,7 @@ class Trainer: |
|
|
|
|
|
|
|
|
|
|
|
# 将数据分成多个batch |
|
|
|
num_points = self.sdf_data.shape[0] |
|
|
|
num_points = self.train_surf_ncs.shape[0] |
|
|
|
num_batches = (num_points + batch_size - 1) // batch_size |
|
|
|
|
|
|
|
for batch_idx in range(num_batches): |
|
|
@ -894,7 +894,7 @@ class Trainer: |
|
|
|
logger.info(f"Loaded model from {args.resume_checkpoint_path}") |
|
|
|
|
|
|
|
# stage1 |
|
|
|
self.model.freeze_stage2() |
|
|
|
self.model.freeze_stage1() |
|
|
|
for epoch in range(start_epoch, self.config.train.num_epochs1 + 1): |
|
|
|
# 训练一个epoch |
|
|
|
train_loss = self.train_epoch_stage1(epoch) |
|
|
@ -921,8 +921,8 @@ class Trainer: |
|
|
|
|
|
|
|
#stage 3 |
|
|
|
self.scheduler.reset() |
|
|
|
#self.model.freeze_stage2() |
|
|
|
self.model.unfreeze() |
|
|
|
self.model.freeze_stage2() |
|
|
|
#self.model.unfreeze() |
|
|
|
for epoch in range(cur_epoch + 1, max_stage2_epoch + self.config.train.num_epochs3 + 1): |
|
|
|
# 训练一个epoch |
|
|
|
train_loss = self.train_epoch_stage3(epoch) |
|
|
|