|
|
@ -118,21 +118,23 @@ class Trainer: |
|
|
|
surf_pos = batch['surf_pos'].to(self.device) |
|
|
|
edge_pos = batch['edge_pos'].to(self.device) |
|
|
|
vertex_pos = batch['vertex_pos'].to(self.device) |
|
|
|
sdf = batch['sdf'].to(self.device) |
|
|
|
edge_mask = batch['edge_mask'].to(self.device) |
|
|
|
points = batch['points'].to(self.device) |
|
|
|
gt_sdf = batch['sdf'].to(self.device) |
|
|
|
|
|
|
|
# 前向传播 |
|
|
|
pred_sdf = self.model( |
|
|
|
surf_ncs, edge_ncs, |
|
|
|
surf_pos, edge_pos, |
|
|
|
vertex_pos, |
|
|
|
sdf[:, :3] # 只使用点坐标,不包括SDF值 |
|
|
|
surf_ncs=surf_ncs, edge_ncs=edge_ncs, |
|
|
|
surf_pos=surf_pos, edge_pos=edge_pos, |
|
|
|
vertex_pos=vertex_pos, edge_mask=edge_mask, |
|
|
|
query_points=points # 只使用点坐标,不包括SDF值 |
|
|
|
) |
|
|
|
|
|
|
|
# 计算损失 |
|
|
|
loss = sdf_loss( |
|
|
|
pred_sdf, |
|
|
|
sdf[:, 3], # 使用SDF值 |
|
|
|
sdf[:, :3], # 使用点坐标 |
|
|
|
gt_sdf, |
|
|
|
points, |
|
|
|
grad_weight=self.config.train.grad_weight |
|
|
|
) |
|
|
|
|
|
|
@ -175,21 +177,23 @@ class Trainer: |
|
|
|
surf_pos = batch['surf_pos'].to(self.device) |
|
|
|
edge_pos = batch['edge_pos'].to(self.device) |
|
|
|
vertex_pos = batch['vertex_pos'].to(self.device) |
|
|
|
sdf = batch['sdf'].to(self.device) |
|
|
|
edge_mask = batch['edge_mask'].to(self.device) |
|
|
|
points = batch['points'].to(self.device) |
|
|
|
gt_sdf = batch['sdf'].to(self.device) |
|
|
|
|
|
|
|
# 前向传播 |
|
|
|
pred_sdf = self.model( |
|
|
|
surf_ncs, edge_ncs, |
|
|
|
surf_pos, edge_pos, |
|
|
|
vertex_pos, |
|
|
|
sdf[:, :3] |
|
|
|
surf_ncs=surf_ncs, edge_ncs=edge_ncs, |
|
|
|
surf_pos=surf_pos, edge_pos=edge_pos, |
|
|
|
vertex_pos=vertex_pos, edge_mask=edge_mask, |
|
|
|
query_points=points |
|
|
|
) |
|
|
|
|
|
|
|
# 计算损失 |
|
|
|
loss = sdf_loss( |
|
|
|
pred_sdf, |
|
|
|
sdf[:, 3], |
|
|
|
sdf[:, :3], |
|
|
|
gt_sdf, |
|
|
|
points, |
|
|
|
grad_weight=self.config.train.grad_weight |
|
|
|
) |
|
|
|
|
|
|
|