|
@ -307,7 +307,7 @@ class Trainer: |
|
|
self.model.train() |
|
|
self.model.train() |
|
|
total_loss = 0.0 |
|
|
total_loss = 0.0 |
|
|
step = 0 # 如果你的训练是分批次的,这里应该用批次索引 |
|
|
step = 0 # 如果你的训练是分批次的,这里应该用批次索引 |
|
|
batch_size = 8192 # 设置合适的batch大小 |
|
|
batch_size = 8192*16 # 设置合适的batch大小 |
|
|
|
|
|
|
|
|
# 数据处理 |
|
|
# 数据处理 |
|
|
# manfld |
|
|
# manfld |
|
@ -319,19 +319,16 @@ class Trainer: |
|
|
if epoch % 10 == 1 or self.cached_train_data is None: |
|
|
if epoch % 10 == 1 or self.cached_train_data is None: |
|
|
# 计算流形点的掩码和操作符 |
|
|
# 计算流形点的掩码和操作符 |
|
|
# 生成非流形点 |
|
|
# 生成非流形点 |
|
|
_psdf_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals,local_sigma=0.001) |
|
|
_nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals) |
|
|
_nonmnfld_pnts = self.sampler.get_points(_mnfld_pnts, local_sigma=0.01): |
|
|
|
|
|
|
|
|
|
|
|
# 更新缓存 |
|
|
# 更新缓存 |
|
|
self.cached_train_data = { |
|
|
self.cached_train_data = { |
|
|
"nonmnfld_pnts": _nonmnfld_pnts, |
|
|
"nonmnfld_pnts": _nonmnfld_pnts, |
|
|
"psdf_pnts": _psdf_pnts, |
|
|
|
|
|
"psdf": _psdf, |
|
|
"psdf": _psdf, |
|
|
} |
|
|
} |
|
|
else: |
|
|
else: |
|
|
# 从缓存中读取数据 |
|
|
# 从缓存中读取数据 |
|
|
_nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"] |
|
|
_nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"] |
|
|
_psdf_pnts = self.cached_train_data["psdf_pnts"] |
|
|
|
|
|
_psdf = self.cached_train_data["psdf"] |
|
|
_psdf = self.cached_train_data["psdf"] |
|
|
|
|
|
|
|
|
logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf)) |
|
|
logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf)) |
|
@ -353,13 +350,11 @@ class Trainer: |
|
|
|
|
|
|
|
|
# 非流形点使用缓存数据(整个batch共享) |
|
|
# 非流形点使用缓存数据(整个batch共享) |
|
|
nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx] |
|
|
nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx] |
|
|
psdf_pnts = _psdf_pnts[start_idx:end_idx] |
|
|
|
|
|
psdf = _psdf[start_idx:end_idx] |
|
|
psdf = _psdf[start_idx:end_idx] |
|
|
|
|
|
|
|
|
# --- 准备模型输入,启用梯度 --- |
|
|
# --- 准备模型输入,启用梯度 --- |
|
|
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
psdf_pnts.requires_grad_(True) # 在检查之后启用梯度 |
|
|
|
|
|
|
|
|
|
|
|
# --- 前向传播 --- |
|
|
# --- 前向传播 --- |
|
|
mnfld_pred = self.model.forward_background( |
|
|
mnfld_pred = self.model.forward_background( |
|
@ -368,9 +363,6 @@ class Trainer: |
|
|
nonmnfld_pred = self.model.forward_background( |
|
|
nonmnfld_pred = self.model.forward_background( |
|
|
nonmnfld_pnts |
|
|
nonmnfld_pnts |
|
|
) |
|
|
) |
|
|
psdf_pred = self.model.forward_background( |
|
|
|
|
|
psdf_pnts |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -390,7 +382,6 @@ class Trainer: |
|
|
loss, loss_details = self.loss_manager.compute_loss( |
|
|
loss, loss_details = self.loss_manager.compute_loss( |
|
|
mnfld_pnts, |
|
|
mnfld_pnts, |
|
|
nonmnfld_pnts, |
|
|
nonmnfld_pnts, |
|
|
psdf_pnts, |
|
|
|
|
|
normals, # 传递检查过的 normals |
|
|
normals, # 传递检查过的 normals |
|
|
gt_sdf, |
|
|
gt_sdf, |
|
|
mnfld_pred, |
|
|
mnfld_pred, |
|
|