From de0cf41d598d6fb589aa73ca3e2fc501a747ff31 Mon Sep 17 00:00:00 2001 From: mckay Date: Sat, 30 Nov 2024 18:46:58 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=A2=AF=E5=BA=A6=E6=94=B9=E5=9C=A8trai?= =?UTF-8?q?n=E9=87=8C=E9=9D=A2=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/train.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/brep2sdf/train.py b/brep2sdf/train.py index ee098a1..0023c8f 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -113,22 +113,28 @@ class Trainer: # 清空梯度 self.optimizer.zero_grad() - # 获取数据并移动到设备 - surf_ncs = batch['surf_ncs'].to(self.device) - edge_ncs = batch['edge_ncs'].to(self.device) - surf_pos = batch['surf_pos'].to(self.device) - edge_pos = batch['edge_pos'].to(self.device) - vertex_pos = batch['vertex_pos'].to(self.device) + # 获取数据并移动到设备,同时设置梯度 + surf_ncs = batch['surf_ncs'].to(self.device).requires_grad_(True) + edge_ncs = batch['edge_ncs'].to(self.device).requires_grad_(True) + surf_pos = batch['surf_pos'].to(self.device).requires_grad_(True) + edge_pos = batch['edge_pos'].to(self.device).requires_grad_(True) + vertex_pos = batch['vertex_pos'].to(self.device).requires_grad_(True) + points = batch['points'].to(self.device).requires_grad_(True) + + # 这些不需要梯度 edge_mask = batch['edge_mask'].to(self.device) points = batch['points'].to(self.device) gt_sdf = batch['sdf'].to(self.device) # 前向传播 pred_sdf = self.model( - surf_ncs=surf_ncs, edge_ncs=edge_ncs, - surf_pos=surf_pos, edge_pos=edge_pos, - vertex_pos=vertex_pos, edge_mask=edge_mask, - query_points=points # 只使用点坐标,不包括SDF值 + surf_ncs=surf_ncs, + edge_ncs=edge_ncs, + surf_pos=surf_pos, + edge_pos=edge_pos, + vertex_pos=vertex_pos, + edge_mask=edge_mask, + query_points=points ) # 计算损失