Browse Source

增加sdf dataloader随机采样

main
mckay 4 months ago
parent
commit
2b5860a7e7
  1. 30
      brep2sdf/data/data.py

30
brep2sdf/data/data.py

@ -271,7 +271,7 @@ class BRepSDFDataset(Dataset):
raise
def _load_sdf_file(self, sdf_path):
"""加载和处理SDF数据"""
"""加载和处理SDF数据,并进行随机采样"""
try:
# 加载SDF值
sdf_data = np.load(sdf_path)
@ -285,7 +285,35 @@ class BRepSDFDataset(Dataset):
if sdf_pos.shape[1] != 4 or sdf_neg.shape[1] != 4:
raise ValueError(f"Invalid SDF data shape: pos={sdf_pos.shape}, neg={sdf_neg.shape}")
# 随机采样
max_points = self.config.data.num_query_points # 例如4096
# 确保正负样本均衡
num_pos = min(max_points // 2, sdf_pos.shape[0])
num_neg = min(max_points // 2, sdf_neg.shape[0])
# 随机采样正样本
if sdf_pos.shape[0] > num_pos:
pos_indices = np.random.choice(sdf_pos.shape[0], num_pos, replace=False)
sdf_pos = sdf_pos[pos_indices]
# 随机采样负样本
if sdf_neg.shape[0] > num_neg:
neg_indices = np.random.choice(sdf_neg.shape[0], num_neg, replace=False)
sdf_neg = sdf_neg[neg_indices]
# 合并数据
sdf_np = np.concatenate([sdf_pos, sdf_neg], axis=0)
# 再次随机打乱
np.random.shuffle(sdf_np)
# 如果总点数仍然超过最大限制,再次采样
if sdf_np.shape[0] > max_points:
indices = np.random.choice(sdf_np.shape[0], max_points, replace=False)
sdf_np = sdf_np[indices]
logger.debug(f"Sampled SDF points: {sdf_np.shape[0]} (max: {max_points})")
return torch.from_numpy(sdf_np.astype(np.float32))
except Exception as e:

Loading…
Cancel
Save