|
|
@ -271,7 +271,7 @@ class BRepSDFDataset(Dataset): |
|
|
|
raise |
|
|
|
|
|
|
|
def _load_sdf_file(self, sdf_path): |
|
|
|
"""加载和处理SDF数据""" |
|
|
|
"""加载和处理SDF数据,并进行随机采样""" |
|
|
|
try: |
|
|
|
# 加载SDF值 |
|
|
|
sdf_data = np.load(sdf_path) |
|
|
@ -285,7 +285,35 @@ class BRepSDFDataset(Dataset): |
|
|
|
if sdf_pos.shape[1] != 4 or sdf_neg.shape[1] != 4: |
|
|
|
raise ValueError(f"Invalid SDF data shape: pos={sdf_pos.shape}, neg={sdf_neg.shape}") |
|
|
|
|
|
|
|
# 随机采样 |
|
|
|
max_points = self.config.data.num_query_points # 例如4096 |
|
|
|
|
|
|
|
# 确保正负样本均衡 |
|
|
|
num_pos = min(max_points // 2, sdf_pos.shape[0]) |
|
|
|
num_neg = min(max_points // 2, sdf_neg.shape[0]) |
|
|
|
|
|
|
|
# 随机采样正样本 |
|
|
|
if sdf_pos.shape[0] > num_pos: |
|
|
|
pos_indices = np.random.choice(sdf_pos.shape[0], num_pos, replace=False) |
|
|
|
sdf_pos = sdf_pos[pos_indices] |
|
|
|
|
|
|
|
# 随机采样负样本 |
|
|
|
if sdf_neg.shape[0] > num_neg: |
|
|
|
neg_indices = np.random.choice(sdf_neg.shape[0], num_neg, replace=False) |
|
|
|
sdf_neg = sdf_neg[neg_indices] |
|
|
|
|
|
|
|
# 合并数据 |
|
|
|
sdf_np = np.concatenate([sdf_pos, sdf_neg], axis=0) |
|
|
|
|
|
|
|
# 再次随机打乱 |
|
|
|
np.random.shuffle(sdf_np) |
|
|
|
|
|
|
|
# 如果总点数仍然超过最大限制,再次采样 |
|
|
|
if sdf_np.shape[0] > max_points: |
|
|
|
indices = np.random.choice(sdf_np.shape[0], max_points, replace=False) |
|
|
|
sdf_np = sdf_np[indices] |
|
|
|
|
|
|
|
logger.debug(f"Sampled SDF points: {sdf_np.shape[0]} (max: {max_points})") |
|
|
|
return torch.from_numpy(sdf_np.astype(np.float32)) |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|