|
@ -319,16 +319,19 @@ class Trainer: |
|
|
if epoch % 10 == 1 or self.cached_train_data is None: |
|
|
if epoch % 10 == 1 or self.cached_train_data is None: |
|
|
# 计算流形点的掩码和操作符 |
|
|
# 计算流形点的掩码和操作符 |
|
|
# 生成非流形点 |
|
|
# 生成非流形点 |
|
|
_nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals) |
|
|
_psdf_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals,local_sigma=0.001) |
|
|
|
|
|
_nonmnfld_pnts = self.sampler.get_points(_mnfld_pnts, local_sigma=0.01): |
|
|
|
|
|
|
|
|
# 更新缓存 |
|
|
# 更新缓存 |
|
|
self.cached_train_data = { |
|
|
self.cached_train_data = { |
|
|
"nonmnfld_pnts": _nonmnfld_pnts, |
|
|
"nonmnfld_pnts": _nonmnfld_pnts, |
|
|
|
|
|
"psdf_pnts": _psdf_pnts, |
|
|
"psdf": _psdf, |
|
|
"psdf": _psdf, |
|
|
} |
|
|
} |
|
|
else: |
|
|
else: |
|
|
# 从缓存中读取数据 |
|
|
# 从缓存中读取数据 |
|
|
_nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"] |
|
|
_nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"] |
|
|
|
|
|
_psdf_pnts = self.cached_train_data["psdf_pnts"] |
|
|
_psdf = self.cached_train_data["psdf"] |
|
|
_psdf = self.cached_train_data["psdf"] |
|
|
|
|
|
|
|
|
logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf)) |
|
|
logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf)) |
|
@ -350,11 +353,13 @@ class Trainer: |
|
|
|
|
|
|
|
|
# 非流形点使用缓存数据(整个batch共享) |
|
|
# 非流形点使用缓存数据(整个batch共享) |
|
|
nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx] |
|
|
nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx] |
|
|
|
|
|
psdf_pnts = _psdf_pnts[start_idx:end_idx] |
|
|
psdf = _psdf[start_idx:end_idx] |
|
|
psdf = _psdf[start_idx:end_idx] |
|
|
|
|
|
|
|
|
# --- 准备模型输入,启用梯度 --- |
|
|
# --- 准备模型输入,启用梯度 --- |
|
|
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
|
|
|
psdf_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
|
|
|
|
|
|
# --- 前向传播 --- |
|
|
# --- 前向传播 --- |
|
|
mnfld_pred = self.model.forward_background( |
|
|
mnfld_pred = self.model.forward_background( |
|
@ -363,6 +368,9 @@ class Trainer: |
|
|
nonmnfld_pred = self.model.forward_background( |
|
|
nonmnfld_pred = self.model.forward_background( |
|
|
nonmnfld_pnts |
|
|
nonmnfld_pnts |
|
|
) |
|
|
) |
|
|
|
|
|
psdf_pred = self.model.forward_background( |
|
|
|
|
|
psdf_pnts |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -382,6 +390,7 @@ class Trainer: |
|
|
loss, loss_details = self.loss_manager.compute_loss( |
|
|
loss, loss_details = self.loss_manager.compute_loss( |
|
|
mnfld_pnts, |
|
|
mnfld_pnts, |
|
|
nonmnfld_pnts, |
|
|
nonmnfld_pnts, |
|
|
|
|
|
psdf_pnts, |
|
|
normals, # 传递检查过的 normals |
|
|
normals, # 传递检查过的 normals |
|
|
gt_sdf, |
|
|
gt_sdf, |
|
|
mnfld_pred, |
|
|
mnfld_pred, |
|
|