Browse Source

fix: zero_grad mv afterwards

main
mckay 3 months ago
parent
commit
6580fffdbf
  1. 7
      brep2sdf/train.py

7
brep2sdf/train.py

@ -111,8 +111,7 @@ class Trainer:
total_loss = 0
for batch_idx, batch in enumerate(self.train_loader):
# 清空梯度
self.optimizer.zero_grad()
# 获取数据并移动到设备,同时设置梯度
# 获取数据并移动到设备,同时保留计算图
@ -129,6 +128,10 @@ class Trainer:
# 这些不需要梯度
edge_mask = batch['edge_mask'].to(self.device)
gt_sdf = batch['sdf'].to(self.device)
# 前向传播前清空梯度
self.model.zero_grad() # 清空模型梯度
self.optimizer.zero_grad() # 清空优化器梯度
# 前向传播
pred_sdf = self.model(

Loading…
Cancel
Save