diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 967526e..7d3918c 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -307,7 +307,7 @@ class Trainer: self.model.train() total_loss = 0.0 step = 0 # 如果你的训练是分批次的,这里应该用批次索引 - batch_size = 8192 # 设置合适的batch大小 + batch_size = 8192*16 # 设置合适的batch大小 # 数据处理 # manfld @@ -319,19 +319,16 @@ class Trainer: if epoch % 10 == 1 or self.cached_train_data is None: # 计算流形点的掩码和操作符 # 生成非流形点 - _psdf_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals,local_sigma=0.001) - _nonmnfld_pnts = self.sampler.get_points(_mnfld_pnts, local_sigma=0.01): + _nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals) # 更新缓存 self.cached_train_data = { "nonmnfld_pnts": _nonmnfld_pnts, - "psdf_pnts": _psdf_pnts, "psdf": _psdf, } else: # 从缓存中读取数据 _nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"] - _psdf_pnts = self.cached_train_data["psdf_pnts"] _psdf = self.cached_train_data["psdf"] logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf)) @@ -353,13 +350,11 @@ class Trainer: # 非流形点使用缓存数据(整个batch共享) nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx] - psdf_pnts = _psdf_pnts[start_idx:end_idx] psdf = _psdf[start_idx:end_idx] # --- 准备模型输入,启用梯度 --- mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 - psdf_pnts.requires_grad_(True) # 在检查之后启用梯度 # --- 前向传播 --- mnfld_pred = self.model.forward_background( @@ -368,9 +363,6 @@ class Trainer: nonmnfld_pred = self.model.forward_background( nonmnfld_pnts ) - psdf_pred = self.model.forward_background( - psdf_pnts - ) @@ -390,7 +382,6 @@ class Trainer: loss, loss_details = self.loss_manager.compute_loss( mnfld_pnts, nonmnfld_pnts, - psdf_pnts, normals, # 传递检查过的 normals gt_sdf, mnfld_pred,