|
@ -39,7 +39,7 @@ parser.add_argument( |
|
|
help='强制重新进行数据预处理,忽略缓存或已有结果' |
|
|
help='强制重新进行数据预处理,忽略缓存或已有结果' |
|
|
) |
|
|
) |
|
|
parser.add_argument( |
|
|
parser.add_argument( |
|
|
'--resume-checkpoint-path', |
|
|
'--resume-checkpoint-path', '-r', |
|
|
type=str, |
|
|
type=str, |
|
|
default=None, |
|
|
default=None, |
|
|
help='从指定的checkpoint文件继续训练' |
|
|
help='从指定的checkpoint文件继续训练' |
|
@ -99,6 +99,7 @@ class Trainer: |
|
|
) |
|
|
) |
|
|
# 合并表面点数据和采样点数据 |
|
|
# 合并表面点数据和采样点数据 |
|
|
self.sdf_data = torch.cat([surface_sdf_data, sampled_sdf_data], dim=0) |
|
|
self.sdf_data = torch.cat([surface_sdf_data, sampled_sdf_data], dim=0) |
|
|
|
|
|
self.train_surf_ncs = sampled_sdf_data |
|
|
else: |
|
|
else: |
|
|
self.sdf_data = surface_sdf_data |
|
|
self.sdf_data = surface_sdf_data |
|
|
print_data_distribution(self.sdf_data) |
|
|
print_data_distribution(self.sdf_data) |
|
@ -403,7 +404,7 @@ class Trainer: |
|
|
logger.debug(f"Patch [{patch_id:2d}] | train pnt number {points.shape[0]}") |
|
|
logger.debug(f"Patch [{patch_id:2d}] | train pnt number {points.shape[0]}") |
|
|
|
|
|
|
|
|
mnfld_pnts = points[:, 0:3] |
|
|
mnfld_pnts = points[:, 0:3] |
|
|
gt_sdf = torch.zeros(mnfld_pnts.shape[0], device=self.device) |
|
|
gt_sdf = points[:, -1] |
|
|
|
|
|
|
|
|
normals = points[:, 3:6] |
|
|
normals = points[:, 3:6] |
|
|
|
|
|
|
|
|