diff --git a/brep2sdf/networks/sample.py b/brep2sdf/networks/sample.py index bcb16ae..2e12f74 100644 --- a/brep2sdf/networks/sample.py +++ b/brep2sdf/networks/sample.py @@ -3,7 +3,7 @@ import torch class NormalPerPoint(): - def __init__(self, global_sigma, local_sigma=0.1): + def __init__(self, global_sigma, local_sigma=0.5): self.global_sigma = global_sigma self.local_sigma = local_sigma @@ -31,11 +31,11 @@ class NormalPerPoint(): """ sample_size, dim = pc_input.shape - # 生成随机位移值 + # 生成随机位移值(确保有正有负) if local_sigma is not None: - psdf = torch.randn(sample_size, device=pc_input.device) * local_sigma + psdf = (torch.rand(sample_size, device=pc_input.device) * 2 - 1) * local_sigma else: - psdf = torch.randn(sample_size, device=pc_input.device) * self.local_sigma + psdf = (torch.rand(sample_size, device=pc_input.device) * 2 - 1) * self.local_sigma # 沿法线方向偏移 sample = pc_input + normals * psdf.unsqueeze(-1)