7 changed files with 66 additions and 19 deletions
@ -0,0 +1,22 @@ |
|||
import torch |
|||
|
|||
|
|||
class NormalPerPoint(): |
|||
|
|||
def __init__(self, global_sigma, local_sigma=0.01): |
|||
self.global_sigma = global_sigma |
|||
self.local_sigma = local_sigma |
|||
|
|||
def get_points(self, pc_input, local_sigma=None): |
|||
batch_size, sample_size, dim = pc_input.shape |
|||
|
|||
if local_sigma is not None: |
|||
sample_local = pc_input + (torch.randn_like(pc_input) * local_sigma.unsqueeze(-1)) |
|||
else: |
|||
sample_local = pc_input + (torch.randn_like(pc_input) * self.local_sigma) |
|||
|
|||
sample_global = (torch.rand(batch_size, sample_size // 8, dim, device=pc_input.device) * (self.global_sigma * 2)) - self.global_sigma |
|||
|
|||
sample = torch.cat([sample_local, sample_global], dim=1) |
|||
|
|||
return sample |
Loading…
Reference in new issue