36 lines
1.0 KiB
36 lines
1.0 KiB
import torch
|
|
import utils.general as utils
|
|
import abc
|
|
|
|
|
|
class Sampler(metaclass=abc.ABCMeta):
|
|
|
|
@abc.abstractmethod
|
|
def get_points(self,pc_input):
|
|
pass
|
|
|
|
@staticmethod
|
|
def get_sampler(sampler_type):
|
|
|
|
return utils.get_class("model.sample.{0}".format(sampler_type))
|
|
|
|
|
|
class NormalPerPoint(Sampler):
|
|
|
|
def __init__(self, global_sigma, local_sigma=0.01):
|
|
self.global_sigma = global_sigma
|
|
self.local_sigma = local_sigma
|
|
|
|
def get_points(self, pc_input, local_sigma=None):
|
|
batch_size, sample_size, dim = pc_input.shape
|
|
|
|
if local_sigma is not None:
|
|
sample_local = pc_input + (torch.randn_like(pc_input) * local_sigma.unsqueeze(-1))
|
|
else:
|
|
sample_local = pc_input + (torch.randn_like(pc_input) * self.local_sigma)
|
|
|
|
sample_global = (torch.rand(batch_size, sample_size // 8, dim, device=pc_input.device) * (self.global_sigma * 2)) - self.global_sigma
|
|
|
|
sample = torch.cat([sample_local, sample_global], dim=1)
|
|
|
|
return sample
|