Browse Source

fix: 梯度改在train里面加

main
mckay 4 months ago
parent
commit
de0cf41d59
  1. 26
      brep2sdf/train.py

26
brep2sdf/train.py

@ -113,22 +113,28 @@ class Trainer:
# 清空梯度
self.optimizer.zero_grad()
# 获取数据并移动到设备
surf_ncs = batch['surf_ncs'].to(self.device)
edge_ncs = batch['edge_ncs'].to(self.device)
surf_pos = batch['surf_pos'].to(self.device)
edge_pos = batch['edge_pos'].to(self.device)
vertex_pos = batch['vertex_pos'].to(self.device)
# 获取数据并移动到设备,同时设置梯度
surf_ncs = batch['surf_ncs'].to(self.device).requires_grad_(True)
edge_ncs = batch['edge_ncs'].to(self.device).requires_grad_(True)
surf_pos = batch['surf_pos'].to(self.device).requires_grad_(True)
edge_pos = batch['edge_pos'].to(self.device).requires_grad_(True)
vertex_pos = batch['vertex_pos'].to(self.device).requires_grad_(True)
points = batch['points'].to(self.device).requires_grad_(True)
# 这些不需要梯度
edge_mask = batch['edge_mask'].to(self.device)
points = batch['points'].to(self.device)
gt_sdf = batch['sdf'].to(self.device)
# 前向传播
pred_sdf = self.model(
surf_ncs=surf_ncs, edge_ncs=edge_ncs,
surf_pos=surf_pos, edge_pos=edge_pos,
vertex_pos=vertex_pos, edge_mask=edge_mask,
query_points=points # 只使用点坐标,不包括SDF值
surf_ncs=surf_ncs,
edge_ncs=edge_ncs,
surf_pos=surf_pos,
edge_pos=edge_pos,
vertex_pos=vertex_pos,
edge_mask=edge_mask,
query_points=points
)
# 计算损失

Loading…
Cancel
Save