Browse Source

替换stage3 训练数据

final
mckay 1 month ago
parent
commit
6c274cd293
  1. 20
      brep2sdf/train.py

20
brep2sdf/train.py

@ -579,10 +579,10 @@ class Trainer:
def train_epoch_stage3(self, epoch: int) -> float:
# --- 1. 检查输入数据 ---
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列)
# 注意:假设 self.train_surf_ncs 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列)
# 并且 SDF 值总是在最后一列
if self.sdf_data is None:
logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.")
if self.train_surf_ncs is None:
logger.error(f"Epoch {epoch}: self.train_surf_ncs is None. Cannot train.")
return float('inf')
self.model.train()
@ -592,9 +592,9 @@ class Trainer:
# 数据处理
# manfld
_mnfld_pnts = self.sdf_data[:, 0:3].clone().detach() # 取前3列作为点
_normals = self.sdf_data[:, 3:6].clone().detach() # 取中间3列作为法线
_gt_sdf = self.sdf_data[:, -1].clone().detach() # 取最后一列作为SDF真值
_mnfld_pnts = self.train_surf_ncs[:, 0:3].clone().detach() # 取前3列作为点
_normals = self.train_surf_ncs[:, 3:6].clone().detach() # 取中间3列作为法线
_gt_sdf = self.train_surf_ncs[:, -1].clone().detach() # 取最后一列作为SDF真值
# 检查是否需要重新计算缓存
if epoch % 10 == 1 or self.cached_train_data is None:
@ -628,7 +628,7 @@ class Trainer:
# 将数据分成多个batch
num_points = self.sdf_data.shape[0]
num_points = self.train_surf_ncs.shape[0]
num_batches = (num_points + batch_size - 1) // batch_size
for batch_idx in range(num_batches):
@ -894,7 +894,7 @@ class Trainer:
logger.info(f"Loaded model from {args.resume_checkpoint_path}")
# stage1
self.model.freeze_stage2()
self.model.freeze_stage1()
for epoch in range(start_epoch, self.config.train.num_epochs1 + 1):
# 训练一个epoch
train_loss = self.train_epoch_stage1(epoch)
@ -921,8 +921,8 @@ class Trainer:
#stage 3
self.scheduler.reset()
#self.model.freeze_stage2()
self.model.unfreeze()
self.model.freeze_stage2()
#self.model.unfreeze()
for epoch in range(cur_epoch + 1, max_stage2_epoch + self.config.train.num_epochs3 + 1):
# 训练一个epoch
train_loss = self.train_epoch_stage3(epoch)

Loading…
Cancel
Save