Browse Source

训练sample修改

final
mckay 2 months ago
parent
commit
4ef6d1638f
  1. 8
      brep2sdf/networks/sample.py

8
brep2sdf/networks/sample.py

@ -3,7 +3,7 @@ import torch
class NormalPerPoint(): class NormalPerPoint():
def __init__(self, global_sigma, local_sigma=0.1): def __init__(self, global_sigma, local_sigma=0.5):
self.global_sigma = global_sigma self.global_sigma = global_sigma
self.local_sigma = local_sigma self.local_sigma = local_sigma
@ -31,11 +31,11 @@ class NormalPerPoint():
""" """
sample_size, dim = pc_input.shape sample_size, dim = pc_input.shape
# 生成随机位移值 # 生成随机位移值(确保有正有负)
if local_sigma is not None: if local_sigma is not None:
psdf = torch.randn(sample_size, device=pc_input.device) * local_sigma psdf = (torch.rand(sample_size, device=pc_input.device) * 2 - 1) * local_sigma
else: else:
psdf = torch.randn(sample_size, device=pc_input.device) * self.local_sigma psdf = (torch.rand(sample_size, device=pc_input.device) * 2 - 1) * self.local_sigma
# 沿法线方向偏移 # 沿法线方向偏移
sample = pc_input + normals * psdf.unsqueeze(-1) sample = pc_input + normals * psdf.unsqueeze(-1)

Loading…
Cancel
Save