|
@ -315,7 +315,7 @@ class Trainer: |
|
|
start_idx = batch_idx * batch_size |
|
|
start_idx = batch_idx * batch_size |
|
|
end_idx = min((batch_idx + 1) * batch_size, num_points) |
|
|
end_idx = min((batch_idx + 1) * batch_size, num_points) |
|
|
mnfld_pnts = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点 |
|
|
mnfld_pnts = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点 |
|
|
nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0)).squeeze(0) # 生成非流形点 |
|
|
|
|
|
gt_sdf = self.sdf_data[start_idx:end_idx, -1].clone().detach() # 取最后一列作为SDF真值 |
|
|
gt_sdf = self.sdf_data[start_idx:end_idx, -1].clone().detach() # 取最后一列作为SDF真值 |
|
|
normals = None |
|
|
normals = None |
|
|
if args.use_normal: |
|
|
if args.use_normal: |
|
@ -323,7 +323,11 @@ class Trainer: |
|
|
logger.error(f"Epoch {epoch}: --use-normal is specified, but sdf_data has only {self.sdf_data.shape[1]} columns.") |
|
|
logger.error(f"Epoch {epoch}: --use-normal is specified, but sdf_data has only {self.sdf_data.shape[1]} columns.") |
|
|
return float('inf') |
|
|
return float('inf') |
|
|
normals = self.sdf_data[start_idx:end_idx, 3:6].clone().detach() # 取中间3列作为法线 |
|
|
normals = self.sdf_data[start_idx:end_idx, 3:6].clone().detach() # 取中间3列作为法线 |
|
|
|
|
|
nonmnfld_pnts,psdf = self.sampler.get_norm_points(mnfld_pnts,normals) # 生成非流形点 |
|
|
|
|
|
logger.debug((mnfld_pnts,nonmnfld_pnts,psdf)) |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0)).squeeze(0) # 生成非流形点 |
|
|
# 执行检查 |
|
|
# 执行检查 |
|
|
if self.debug_mode: |
|
|
if self.debug_mode: |
|
|
if check_tensor(mnfld_pnts, "Input Points", epoch, step): return float('inf') |
|
|
if check_tensor(mnfld_pnts, "Input Points", epoch, step): return float('inf') |
|
@ -372,7 +376,8 @@ class Trainer: |
|
|
normals, # 传递检查过的 normals |
|
|
normals, # 传递检查过的 normals |
|
|
gt_sdf, |
|
|
gt_sdf, |
|
|
mnfld_pred, |
|
|
mnfld_pred, |
|
|
nonmnfld_pred |
|
|
nonmnfld_pred, |
|
|
|
|
|
psdf |
|
|
) |
|
|
) |
|
|
else: |
|
|
else: |
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) |
|
|
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) |
|
|