From 2b5860a7e7b3f4436bbe0b13e39620972618e83f Mon Sep 17 00:00:00 2001
From: mckay <wchpub@163.com>
Date: Fri, 22 Nov 2024 00:40:23 +0800
Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0sdf=20dataloader=E9=9A=8F?=
 =?UTF-8?q?=E6=9C=BA=E9=87=87=E6=A0=B7?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 brep2sdf/data/data.py | 30 +++++++++++++++++++++++++++++-
 1 file changed, 29 insertions(+), 1 deletion(-)

diff --git a/brep2sdf/data/data.py b/brep2sdf/data/data.py
index 0e1b322..a893de3 100644
--- a/brep2sdf/data/data.py
+++ b/brep2sdf/data/data.py
@@ -271,7 +271,7 @@ class BRepSDFDataset(Dataset):
             raise
 
     def _load_sdf_file(self, sdf_path):
-        """加载和处理SDF数据"""
+        """加载和处理SDF数据,并进行随机采样"""
         try:
             # 加载SDF值
             sdf_data = np.load(sdf_path)
@@ -285,7 +285,35 @@ class BRepSDFDataset(Dataset):
             if sdf_pos.shape[1] != 4 or sdf_neg.shape[1] != 4:
                 raise ValueError(f"Invalid SDF data shape: pos={sdf_pos.shape}, neg={sdf_neg.shape}")
             
+            # 随机采样
+            max_points = self.config.data.num_query_points  # 例如4096
+            
+            # 确保正负样本均衡
+            num_pos = min(max_points // 2, sdf_pos.shape[0])
+            num_neg = min(max_points // 2, sdf_neg.shape[0])
+            
+            # 随机采样正样本
+            if sdf_pos.shape[0] > num_pos:
+                pos_indices = np.random.choice(sdf_pos.shape[0], num_pos, replace=False)
+                sdf_pos = sdf_pos[pos_indices]
+                
+            # 随机采样负样本
+            if sdf_neg.shape[0] > num_neg:
+                neg_indices = np.random.choice(sdf_neg.shape[0], num_neg, replace=False)
+                sdf_neg = sdf_neg[neg_indices]
+            
+            # 合并数据
             sdf_np = np.concatenate([sdf_pos, sdf_neg], axis=0)
+            
+            # 再次随机打乱
+            np.random.shuffle(sdf_np)
+            
+            # 如果总点数仍然超过最大限制,再次采样
+            if sdf_np.shape[0] > max_points:
+                indices = np.random.choice(sdf_np.shape[0], max_points, replace=False)
+                sdf_np = sdf_np[indices]
+                
+            logger.debug(f"Sampled SDF points: {sdf_np.shape[0]} (max: {max_points})")
             return torch.from_numpy(sdf_np.astype(np.float32))
             
         except Exception as e: