|
|
@ -3,7 +3,7 @@ import torch |
|
|
|
|
|
|
|
class NormalPerPoint(): |
|
|
|
|
|
|
|
def __init__(self, global_sigma, local_sigma=0.01): |
|
|
|
def __init__(self, global_sigma, local_sigma=0.1): |
|
|
|
self.global_sigma = global_sigma |
|
|
|
self.local_sigma = local_sigma |
|
|
|
|
|
|
@ -20,3 +20,24 @@ class NormalPerPoint(): |
|
|
|
sample = torch.cat([sample_local, sample_global], dim=1) |
|
|
|
|
|
|
|
return sample |
|
|
|
|
|
|
|
def get_norm_points(self, pc_input, normals, local_sigma=None): |
|
|
|
""" |
|
|
|
返回沿法线方向偏移的点以及对应的伪 SDF 值(PSDF) |
|
|
|
:param pc_input: 输入点云,形状为 (sample_size, 3) |
|
|
|
:param normals: 点云的法线,形状为 (sample_size, 3) |
|
|
|
:param local_sigma: 局部偏移的标准差 |
|
|
|
:return: 偏移后的点 (sample_size, dim), 伪 SDF 值 |
|
|
|
""" |
|
|
|
sample_size, dim = pc_input.shape |
|
|
|
|
|
|
|
# 生成随机位移值 |
|
|
|
if local_sigma is not None: |
|
|
|
psdf = torch.randn(sample_size, device=pc_input.device) * local_sigma |
|
|
|
else: |
|
|
|
psdf = torch.randn(sample_size, device=pc_input.device) * self.local_sigma |
|
|
|
|
|
|
|
# 沿法线方向偏移 |
|
|
|
sample = pc_input + normals * psdf.unsqueeze(-1) |
|
|
|
|
|
|
|
return sample, psdf |