|
|
@ -102,8 +102,10 @@ class Trainer: |
|
|
|
self.train_surf_ncs = sampled_sdf_data |
|
|
|
else: |
|
|
|
self.sdf_data = surface_sdf_data |
|
|
|
print_data_distribution(self.sdf_data) |
|
|
|
logger.print_tensor_stats("sdfd_data",self.sdf_data) |
|
|
|
logger.debug(self.sdf_data.shape) |
|
|
|
logger.print_tensor_stats("train_surf_ncs",self.train_surf_ncs[:,0:3]) |
|
|
|
logger.debug(self.train_surf_ncs.shape) |
|
|
|
logger.gpu_memory_stats("SDF数据准备后") |
|
|
|
# 初始化数据集 |
|
|
|
#self.brep_data = load_brep_file(self.config.data.pkl_path) |
|
|
@ -229,22 +231,9 @@ class Trainer: |
|
|
|
_gt_sdf = self.train_surf_ncs[:, -1].clone().detach() # 取最后一列作为SDF真值 |
|
|
|
|
|
|
|
# 检查是否需要重新计算缓存 |
|
|
|
if epoch % 10 == 1 or self.cached_train_data is None: |
|
|
|
# 计算流形点的掩码和操作符 |
|
|
|
# 生成非流形点 |
|
|
|
_nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals) |
|
|
|
|
|
|
|
# 更新缓存 |
|
|
|
self.cached_train_data = { |
|
|
|
"nonmnfld_pnts": _nonmnfld_pnts, |
|
|
|
"psdf": _psdf, |
|
|
|
} |
|
|
|
else: |
|
|
|
# 从缓存中读取数据 |
|
|
|
_nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"] |
|
|
|
_psdf = self.cached_train_data["psdf"] |
|
|
|
|
|
|
|
|
|
|
|
logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf)) |
|
|
|
logger.debug((_mnfld_pnts)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -260,22 +249,14 @@ class Trainer: |
|
|
|
mnfld_pnts = _mnfld_pnts[start_idx:end_idx] # 流形点 |
|
|
|
gt_sdf = _gt_sdf[start_idx:end_idx] # SDF真值 |
|
|
|
normals = _normals[start_idx:end_idx] if args.use_normal else None # 法线 |
|
|
|
|
|
|
|
# 非流形点使用缓存数据(整个batch共享) |
|
|
|
nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx] |
|
|
|
psdf = _psdf[start_idx:end_idx] |
|
|
|
|
|
|
|
# --- 准备模型输入,启用梯度 --- |
|
|
|
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
|
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
|
|
|
|
|
# --- 前向传播 --- |
|
|
|
mnfld_pred = self.model.forward_background( |
|
|
|
mnfld_pnts |
|
|
|
) |
|
|
|
nonmnfld_pred = self.model.forward_background( |
|
|
|
nonmnfld_pnts |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -292,14 +273,11 @@ class Trainer: |
|
|
|
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") |
|
|
|
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") |
|
|
|
logger.gpu_memory_stats("计算损失前") |
|
|
|
loss, loss_details = self.loss_manager.compute_loss( |
|
|
|
loss, loss_details = self.loss_manager.compute_loss_stage1( |
|
|
|
mnfld_pnts, |
|
|
|
nonmnfld_pnts, |
|
|
|
normals, # 传递检查过的 normals |
|
|
|
gt_sdf, |
|
|
|
mnfld_pred, |
|
|
|
nonmnfld_pred, |
|
|
|
psdf |
|
|
|
) |
|
|
|
else: |
|
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) |
|
|
@ -307,7 +285,6 @@ class Trainer: |
|
|
|
# --- 4. 检查损失计算结果 --- |
|
|
|
if self.debug_mode: |
|
|
|
logger.print_tensor_stats("psdf",psdf) |
|
|
|
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) |
|
|
|
if check_tensor(loss, "Calculated Loss", epoch, step): |
|
|
|
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") |
|
|
|
if loss_details: logger.error(f"Loss Details: {loss_details}") |
|
|
@ -917,7 +894,7 @@ class Trainer: |
|
|
|
logger.info(f"Loaded model from {args.resume_checkpoint_path}") |
|
|
|
|
|
|
|
# stage1 |
|
|
|
self.model.encoder.freeze_stage1() |
|
|
|
self.model.freeze_stage2() |
|
|
|
for epoch in range(start_epoch, self.config.train.num_epochs1 + 1): |
|
|
|
# 训练一个epoch |
|
|
|
train_loss = self.train_epoch_stage1(epoch) |
|
|
@ -929,6 +906,7 @@ class Trainer: |
|
|
|
self._save_checkpoint(epoch, train_loss) |
|
|
|
logger.info(f'Checkpoint saved at epoch {epoch}') |
|
|
|
|
|
|
|
start_epoch=max(start_epoch, self.config.train.num_epochs1) |
|
|
|
# stage2 freeze_stage2 |
|
|
|
max_stage2_epoch = self.config.train.num_epochs1+self.config.train.num_epochs2 |
|
|
|
if start_epoch < max_stage2_epoch: |
|
|
@ -942,9 +920,10 @@ class Trainer: |
|
|
|
cur_epoch = start_epoch |
|
|
|
|
|
|
|
#stage 3 |
|
|
|
self.model.encoder.unfreeze() |
|
|
|
self.scheduler.reset() |
|
|
|
for epoch in range(cur_epoch, max_stage2_epoch + self.config.train.num_epochs3 + 1): |
|
|
|
#self.model.freeze_stage2() |
|
|
|
self.model.unfreeze() |
|
|
|
for epoch in range(cur_epoch + 1, max_stage2_epoch + self.config.train.num_epochs3 + 1): |
|
|
|
# 训练一个epoch |
|
|
|
train_loss = self.train_epoch_stage3(epoch) |
|
|
|
#train_loss = self.train_epoch_stage2(epoch) |
|
|
|