|
|
@ -3,7 +3,7 @@ import torch |
|
|
|
|
|
|
|
class NormalPerPoint(): |
|
|
|
|
|
|
|
def __init__(self, global_sigma, local_sigma=0.1): |
|
|
|
def __init__(self, global_sigma, local_sigma=0.5): |
|
|
|
self.global_sigma = global_sigma |
|
|
|
self.local_sigma = local_sigma |
|
|
|
|
|
|
@ -31,11 +31,11 @@ class NormalPerPoint(): |
|
|
|
""" |
|
|
|
sample_size, dim = pc_input.shape |
|
|
|
|
|
|
|
# 生成随机位移值 |
|
|
|
# 生成随机位移值(确保有正有负) |
|
|
|
if local_sigma is not None: |
|
|
|
psdf = torch.randn(sample_size, device=pc_input.device) * local_sigma |
|
|
|
psdf = (torch.rand(sample_size, device=pc_input.device) * 2 - 1) * local_sigma |
|
|
|
else: |
|
|
|
psdf = torch.randn(sample_size, device=pc_input.device) * self.local_sigma |
|
|
|
psdf = (torch.rand(sample_size, device=pc_input.device) * 2 - 1) * self.local_sigma |
|
|
|
|
|
|
|
# 沿法线方向偏移 |
|
|
|
sample = pc_input + normals * psdf.unsqueeze(-1) |
|
|
|