36 lines
1.0 KiB

import torch
import utils.general as utils
import abc
class Sampler(metaclass=abc.ABCMeta):
@abc.abstractmethod
def get_points(self,pc_input):
pass
@staticmethod
def get_sampler(sampler_type):
return utils.get_class("model.sample.{0}".format(sampler_type))
class NormalPerPoint(Sampler):
def __init__(self, global_sigma, local_sigma=0.01):
self.global_sigma = global_sigma
self.local_sigma = local_sigma
def get_points(self, pc_input, local_sigma=None):
batch_size, sample_size, dim = pc_input.shape
if local_sigma is not None:
sample_local = pc_input + (torch.randn_like(pc_input) * local_sigma.unsqueeze(-1))
else:
sample_local = pc_input + (torch.randn_like(pc_input) * self.local_sigma)
sample_global = (torch.rand(batch_size, sample_size // 8, dim, device=pc_input.device) * (self.global_sigma * 2)) - self.global_sigma
sample = torch.cat([sample_local, sample_global], dim=1)
return sample