diff --git a/brep2sdf/data/data.py b/brep2sdf/data/data.py index ac80975..d970cdf 100644 --- a/brep2sdf/data/data.py +++ b/brep2sdf/data/data.py @@ -12,7 +12,7 @@ from brep2sdf.config.default_config import get_default_config class BRepSDFDataset(Dataset): - def __init__(self, brep_dir:str, sdf_dir:str, valid_data_dir:str, split:str='train'): + def __init__(self, brep_dir:str, sdf_dir:str, valid_data_dir:str, use_filter: bool=True, split:str='train'): """ 初始化数据集 @@ -52,6 +52,9 @@ class BRepSDFDataset(Dataset): self.brep_data_list = self._load_data_list(self.brep_dir) self.sdf_data_list = self._load_data_list(self.sdf_dir) + if use_filter: + self._filter_num_faces_and_num_edges() + # 检查数据集是否为空 if len(self.brep_data_list) == 0 : raise ValueError(f"No valid brep data found in {split} set") @@ -76,6 +79,20 @@ class BRepSDFDataset(Dataset): #logger.info(data_list) return data_list + def _filter_num_faces_and_num_edges(self): + ''' + Filter the data if their face_num or edge_num > max_face or max_edge. + ''' + # Collect indices of elements that satisfy the condition + filtered_indices = [ + idx for idx in range(len(self.brep_data_list)) + if (self._get_brep_face_and_edge(self.brep_data_list[idx]) <= (self.max_face, self.max_edge)) + ] + + # Use filtered_indices to update brep_data_list and sdf_data_list + self.brep_data_list = [self.brep_data_list[idx] for idx in filtered_indices] + self.sdf_data_list = [self.sdf_data_list[idx] for idx in filtered_indices] + def __len__(self): return len(self.brep_data_list) @@ -87,8 +104,7 @@ class BRepSDFDataset(Dataset): name = os.path.splitext(os.path.basename(brep_path))[0] # 加载B-rep和SDF数据 - with open(brep_path, 'rb') as f: - brep_raw = pickle.load(f) + brep_raw = self._load_brep_file(brep_path) sdf_data = self._load_sdf_file(sdf_path) try: @@ -162,7 +178,10 @@ class BRepSDFDataset(Dataset): logger.error(f"Error loading sample from {brep_path}: {str(e)}") logger.error("Data structure:") raise - + def _load_brep_file(self, brep_path): + with open(brep_path, 'rb') as f: + brep_raw = pickle.load(f) + return brep_raw def _load_sdf_file(self, sdf_path): """加载和处理SDF数据,并进行随机采样""" @@ -222,7 +241,12 @@ class BRepSDFDataset(Dataset): logger.error(f"Error type: {type(e).__name__}") logger.error(f"Error message: {str(e)}") raise - + + def _get_brep_face_and_edge(self, brep_path: str) -> tuple[int,int]: + brep: dict = self._load_brep_file(brep_path) + face_edge_adj = brep["faceEdge_adj"] + num_faces, num_edges = face_edge_adj.shape + return num_faces, num_edges def test_dataset(): """测试数据集功能"""