|
|
@ -113,22 +113,28 @@ class Trainer: |
|
|
|
# 清空梯度 |
|
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
|
|
# 获取数据并移动到设备 |
|
|
|
surf_ncs = batch['surf_ncs'].to(self.device) |
|
|
|
edge_ncs = batch['edge_ncs'].to(self.device) |
|
|
|
surf_pos = batch['surf_pos'].to(self.device) |
|
|
|
edge_pos = batch['edge_pos'].to(self.device) |
|
|
|
vertex_pos = batch['vertex_pos'].to(self.device) |
|
|
|
# 获取数据并移动到设备,同时设置梯度 |
|
|
|
surf_ncs = batch['surf_ncs'].to(self.device).requires_grad_(True) |
|
|
|
edge_ncs = batch['edge_ncs'].to(self.device).requires_grad_(True) |
|
|
|
surf_pos = batch['surf_pos'].to(self.device).requires_grad_(True) |
|
|
|
edge_pos = batch['edge_pos'].to(self.device).requires_grad_(True) |
|
|
|
vertex_pos = batch['vertex_pos'].to(self.device).requires_grad_(True) |
|
|
|
points = batch['points'].to(self.device).requires_grad_(True) |
|
|
|
|
|
|
|
# 这些不需要梯度 |
|
|
|
edge_mask = batch['edge_mask'].to(self.device) |
|
|
|
points = batch['points'].to(self.device) |
|
|
|
gt_sdf = batch['sdf'].to(self.device) |
|
|
|
|
|
|
|
# 前向传播 |
|
|
|
pred_sdf = self.model( |
|
|
|
surf_ncs=surf_ncs, edge_ncs=edge_ncs, |
|
|
|
surf_pos=surf_pos, edge_pos=edge_pos, |
|
|
|
vertex_pos=vertex_pos, edge_mask=edge_mask, |
|
|
|
query_points=points # 只使用点坐标,不包括SDF值 |
|
|
|
surf_ncs=surf_ncs, |
|
|
|
edge_ncs=edge_ncs, |
|
|
|
surf_pos=surf_pos, |
|
|
|
edge_pos=edge_pos, |
|
|
|
vertex_pos=vertex_pos, |
|
|
|
edge_mask=edge_mask, |
|
|
|
query_points=points |
|
|
|
) |
|
|
|
|
|
|
|
# 计算损失 |
|
|
|