From 6580fffdbf5342d1c5d53cc3b79e5bbc3575ac74 Mon Sep 17 00:00:00 2001 From: mckay Date: Fri, 13 Dec 2024 23:07:45 +0800 Subject: [PATCH] fix: zero_grad mv afterwards --- brep2sdf/train.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 62ee8dd..2a65832 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -111,8 +111,7 @@ class Trainer: total_loss = 0 for batch_idx, batch in enumerate(self.train_loader): - # 清空梯度 - self.optimizer.zero_grad() + # 获取数据并移动到设备,同时设置梯度 # 获取数据并移动到设备,同时保留计算图 @@ -129,6 +128,10 @@ class Trainer: # 这些不需要梯度 edge_mask = batch['edge_mask'].to(self.device) gt_sdf = batch['sdf'].to(self.device) + + # 前向传播前清空梯度 + self.model.zero_grad() # 清空模型梯度 + self.optimizer.zero_grad() # 清空优化器梯度 # 前向传播 pred_sdf = self.model(