7 changed files with 66 additions and 19 deletions
			
			
		@ -0,0 +1,22 @@ | 
				
			|||
import torch | 
				
			|||
 | 
				
			|||
 | 
				
			|||
class NormalPerPoint(): | 
				
			|||
 | 
				
			|||
    def __init__(self, global_sigma, local_sigma=0.01): | 
				
			|||
        self.global_sigma = global_sigma | 
				
			|||
        self.local_sigma = local_sigma | 
				
			|||
 | 
				
			|||
    def get_points(self, pc_input, local_sigma=None): | 
				
			|||
        batch_size, sample_size, dim = pc_input.shape | 
				
			|||
 | 
				
			|||
        if local_sigma is not None: | 
				
			|||
            sample_local = pc_input + (torch.randn_like(pc_input) * local_sigma.unsqueeze(-1)) | 
				
			|||
        else: | 
				
			|||
            sample_local = pc_input + (torch.randn_like(pc_input) * self.local_sigma) | 
				
			|||
 | 
				
			|||
        sample_global = (torch.rand(batch_size, sample_size // 8, dim, device=pc_input.device) * (self.global_sigma * 2)) - self.global_sigma | 
				
			|||
 | 
				
			|||
        sample = torch.cat([sample_local, sample_global], dim=1) | 
				
			|||
 | 
				
			|||
        return sample | 
				
			|||
					Loading…
					
					
				
		Reference in new issue