diff --git a/brep2sdf/train.py b/brep2sdf/train.py index ee098a1..0023c8f 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -113,22 +113,28 @@ class Trainer: # 清空梯度 self.optimizer.zero_grad() - # 获取数据并移动到设备 - surf_ncs = batch['surf_ncs'].to(self.device) - edge_ncs = batch['edge_ncs'].to(self.device) - surf_pos = batch['surf_pos'].to(self.device) - edge_pos = batch['edge_pos'].to(self.device) - vertex_pos = batch['vertex_pos'].to(self.device) + # 获取数据并移动到设备,同时设置梯度 + surf_ncs = batch['surf_ncs'].to(self.device).requires_grad_(True) + edge_ncs = batch['edge_ncs'].to(self.device).requires_grad_(True) + surf_pos = batch['surf_pos'].to(self.device).requires_grad_(True) + edge_pos = batch['edge_pos'].to(self.device).requires_grad_(True) + vertex_pos = batch['vertex_pos'].to(self.device).requires_grad_(True) + points = batch['points'].to(self.device).requires_grad_(True) + + # 这些不需要梯度 edge_mask = batch['edge_mask'].to(self.device) points = batch['points'].to(self.device) gt_sdf = batch['sdf'].to(self.device) # 前向传播 pred_sdf = self.model( - surf_ncs=surf_ncs, edge_ncs=edge_ncs, - surf_pos=surf_pos, edge_pos=edge_pos, - vertex_pos=vertex_pos, edge_mask=edge_mask, - query_points=points # 只使用点坐标,不包括SDF值 + surf_ncs=surf_ncs, + edge_ncs=edge_ncs, + surf_pos=surf_pos, + edge_pos=edge_pos, + vertex_pos=vertex_pos, + edge_mask=edge_mask, + query_points=points ) # 计算损失