Browse Source

train psdf回退

final
mckay 1 month ago
parent
commit
037f1236e1
  1. 13
      brep2sdf/train.py

13
brep2sdf/train.py

@ -307,7 +307,7 @@ class Trainer:
self.model.train() self.model.train()
total_loss = 0.0 total_loss = 0.0
step = 0 # 如果你的训练是分批次的,这里应该用批次索引 step = 0 # 如果你的训练是分批次的,这里应该用批次索引
batch_size = 8192 # 设置合适的batch大小 batch_size = 8192*16 # 设置合适的batch大小
# 数据处理 # 数据处理
# manfld # manfld
@ -319,19 +319,16 @@ class Trainer:
if epoch % 10 == 1 or self.cached_train_data is None: if epoch % 10 == 1 or self.cached_train_data is None:
# 计算流形点的掩码和操作符 # 计算流形点的掩码和操作符
# 生成非流形点 # 生成非流形点
_psdf_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals,local_sigma=0.001) _nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals)
_nonmnfld_pnts = self.sampler.get_points(_mnfld_pnts, local_sigma=0.01):
# 更新缓存 # 更新缓存
self.cached_train_data = { self.cached_train_data = {
"nonmnfld_pnts": _nonmnfld_pnts, "nonmnfld_pnts": _nonmnfld_pnts,
"psdf_pnts": _psdf_pnts,
"psdf": _psdf, "psdf": _psdf,
} }
else: else:
# 从缓存中读取数据 # 从缓存中读取数据
_nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"] _nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"]
_psdf_pnts = self.cached_train_data["psdf_pnts"]
_psdf = self.cached_train_data["psdf"] _psdf = self.cached_train_data["psdf"]
logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf)) logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf))
@ -353,13 +350,11 @@ class Trainer:
# 非流形点使用缓存数据(整个batch共享) # 非流形点使用缓存数据(整个batch共享)
nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx] nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx]
psdf_pnts = _psdf_pnts[start_idx:end_idx]
psdf = _psdf[start_idx:end_idx] psdf = _psdf[start_idx:end_idx]
# --- 准备模型输入,启用梯度 --- # --- 准备模型输入,启用梯度 ---
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
psdf_pnts.requires_grad_(True) # 在检查之后启用梯度
# --- 前向传播 --- # --- 前向传播 ---
mnfld_pred = self.model.forward_background( mnfld_pred = self.model.forward_background(
@ -368,9 +363,6 @@ class Trainer:
nonmnfld_pred = self.model.forward_background( nonmnfld_pred = self.model.forward_background(
nonmnfld_pnts nonmnfld_pnts
) )
psdf_pred = self.model.forward_background(
psdf_pnts
)
@ -390,7 +382,6 @@ class Trainer:
loss, loss_details = self.loss_manager.compute_loss( loss, loss_details = self.loss_manager.compute_loss(
mnfld_pnts, mnfld_pnts,
nonmnfld_pnts, nonmnfld_pnts,
psdf_pnts,
normals, # 传递检查过的 normals normals, # 传递检查过的 normals
gt_sdf, gt_sdf,
mnfld_pred, mnfld_pred,

Loading…
Cancel
Save