|
@ -613,7 +613,7 @@ class Trainer: |
|
|
self.model.train() |
|
|
self.model.train() |
|
|
total_loss = 0.0 |
|
|
total_loss = 0.0 |
|
|
step = 0 # 如果你的训练是分批次的,这里应该用批次索引 |
|
|
step = 0 # 如果你的训练是分批次的,这里应该用批次索引 |
|
|
batch_size = 50000 # 设置合适的batch大小 |
|
|
batch_size = 25000 # 设置合适的batch大小 |
|
|
|
|
|
|
|
|
# 数据处理 |
|
|
# 数据处理 |
|
|
# manfld |
|
|
# manfld |
|
@ -627,7 +627,7 @@ class Trainer: |
|
|
_, _mnfld_face_indices_mask, _mnfld_operator = self.root.forward(_mnfld_pnts) |
|
|
_, _mnfld_face_indices_mask, _mnfld_operator = self.root.forward(_mnfld_pnts) |
|
|
|
|
|
|
|
|
# 生成非流形点 |
|
|
# 生成非流形点 |
|
|
_nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals, 0.01) |
|
|
_nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals, 0.1) |
|
|
_, _nonmnfld_face_indices_mask, _nonmnfld_operator = self.root.forward(_nonmnfld_pnts) |
|
|
_, _nonmnfld_face_indices_mask, _nonmnfld_operator = self.root.forward(_nonmnfld_pnts) |
|
|
|
|
|
|
|
|
# 更新缓存 |
|
|
# 更新缓存 |
|
@ -905,7 +905,7 @@ class Trainer: |
|
|
return total_loss # 对于单批次训练,直接返回当前损失 |
|
|
return total_loss # 对于单批次训练,直接返回当前损失 |
|
|
|
|
|
|
|
|
def validate(self, epoch, loss): |
|
|
def validate(self, epoch, loss): |
|
|
if loss < self.best_loss: |
|
|
if epoch > self.config.train.num_epochs3 / 5 and loss < self.best_loss: |
|
|
self.best_loss = loss |
|
|
self.best_loss = loss |
|
|
self._save_checkpoint(-1, loss) # 存 best |
|
|
self._save_checkpoint(-1, loss) # 存 best |
|
|
logger.info(f'Best Epoch: {epoch}\tAverage Loss: {loss:.6f}') |
|
|
logger.info(f'Best Epoch: {epoch}\tAverage Loss: {loss:.6f}') |
|
|