diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 543c082..967526e 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -319,16 +319,19 @@ class Trainer: if epoch % 10 == 1 or self.cached_train_data is None: # 计算流形点的掩码和操作符 # 生成非流形点 - _nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals) + _psdf_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals,local_sigma=0.001) + _nonmnfld_pnts = self.sampler.get_points(_mnfld_pnts, local_sigma=0.01): # 更新缓存 self.cached_train_data = { "nonmnfld_pnts": _nonmnfld_pnts, + "psdf_pnts": _psdf_pnts, "psdf": _psdf, } else: # 从缓存中读取数据 _nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"] + _psdf_pnts = self.cached_train_data["psdf_pnts"] _psdf = self.cached_train_data["psdf"] logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf)) @@ -350,11 +353,13 @@ class Trainer: # 非流形点使用缓存数据(整个batch共享) nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx] + psdf_pnts = _psdf_pnts[start_idx:end_idx] psdf = _psdf[start_idx:end_idx] # --- 准备模型输入,启用梯度 --- mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 + psdf_pnts.requires_grad_(True) # 在检查之后启用梯度 # --- 前向传播 --- mnfld_pred = self.model.forward_background( @@ -363,6 +368,9 @@ class Trainer: nonmnfld_pred = self.model.forward_background( nonmnfld_pnts ) + psdf_pred = self.model.forward_background( + psdf_pnts + ) @@ -382,6 +390,7 @@ class Trainer: loss, loss_details = self.loss_manager.compute_loss( mnfld_pnts, nonmnfld_pnts, + psdf_pnts, normals, # 传递检查过的 normals gt_sdf, mnfld_pred,