From 2b5860a7e7b3f4436bbe0b13e39620972618e83f Mon Sep 17 00:00:00 2001 From: mckay Date: Fri, 22 Nov 2024 00:40:23 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0sdf=20dataloader=E9=9A=8F?= =?UTF-8?q?=E6=9C=BA=E9=87=87=E6=A0=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/data/data.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/brep2sdf/data/data.py b/brep2sdf/data/data.py index 0e1b322..a893de3 100644 --- a/brep2sdf/data/data.py +++ b/brep2sdf/data/data.py @@ -271,7 +271,7 @@ class BRepSDFDataset(Dataset): raise def _load_sdf_file(self, sdf_path): - """加载和处理SDF数据""" + """加载和处理SDF数据,并进行随机采样""" try: # 加载SDF值 sdf_data = np.load(sdf_path) @@ -285,7 +285,35 @@ class BRepSDFDataset(Dataset): if sdf_pos.shape[1] != 4 or sdf_neg.shape[1] != 4: raise ValueError(f"Invalid SDF data shape: pos={sdf_pos.shape}, neg={sdf_neg.shape}") + # 随机采样 + max_points = self.config.data.num_query_points # 例如4096 + + # 确保正负样本均衡 + num_pos = min(max_points // 2, sdf_pos.shape[0]) + num_neg = min(max_points // 2, sdf_neg.shape[0]) + + # 随机采样正样本 + if sdf_pos.shape[0] > num_pos: + pos_indices = np.random.choice(sdf_pos.shape[0], num_pos, replace=False) + sdf_pos = sdf_pos[pos_indices] + + # 随机采样负样本 + if sdf_neg.shape[0] > num_neg: + neg_indices = np.random.choice(sdf_neg.shape[0], num_neg, replace=False) + sdf_neg = sdf_neg[neg_indices] + + # 合并数据 sdf_np = np.concatenate([sdf_pos, sdf_neg], axis=0) + + # 再次随机打乱 + np.random.shuffle(sdf_np) + + # 如果总点数仍然超过最大限制,再次采样 + if sdf_np.shape[0] > max_points: + indices = np.random.choice(sdf_np.shape[0], max_points, replace=False) + sdf_np = sdf_np[indices] + + logger.debug(f"Sampled SDF points: {sdf_np.shape[0]} (max: {max_points})") return torch.from_numpy(sdf_np.astype(np.float32)) except Exception as e: