Browse Source

feat: filter the data if their face_num or edge_num > max_face or max_edge.

final
mckay 6 months ago
parent
commit
0b4df54447
  1. 32
      brep2sdf/data/data.py

32
brep2sdf/data/data.py

@ -12,7 +12,7 @@ from brep2sdf.config.default_config import get_default_config
class BRepSDFDataset(Dataset): class BRepSDFDataset(Dataset):
def __init__(self, brep_dir:str, sdf_dir:str, valid_data_dir:str, split:str='train'): def __init__(self, brep_dir:str, sdf_dir:str, valid_data_dir:str, use_filter: bool=True, split:str='train'):
""" """
初始化数据集 初始化数据集
@ -52,6 +52,9 @@ class BRepSDFDataset(Dataset):
self.brep_data_list = self._load_data_list(self.brep_dir) self.brep_data_list = self._load_data_list(self.brep_dir)
self.sdf_data_list = self._load_data_list(self.sdf_dir) self.sdf_data_list = self._load_data_list(self.sdf_dir)
if use_filter:
self._filter_num_faces_and_num_edges()
# 检查数据集是否为空 # 检查数据集是否为空
if len(self.brep_data_list) == 0 : if len(self.brep_data_list) == 0 :
raise ValueError(f"No valid brep data found in {split} set") raise ValueError(f"No valid brep data found in {split} set")
@ -76,6 +79,20 @@ class BRepSDFDataset(Dataset):
#logger.info(data_list) #logger.info(data_list)
return data_list return data_list
def _filter_num_faces_and_num_edges(self):
'''
Filter the data if their face_num or edge_num > max_face or max_edge.
'''
# Collect indices of elements that satisfy the condition
filtered_indices = [
idx for idx in range(len(self.brep_data_list))
if (self._get_brep_face_and_edge(self.brep_data_list[idx]) <= (self.max_face, self.max_edge))
]
# Use filtered_indices to update brep_data_list and sdf_data_list
self.brep_data_list = [self.brep_data_list[idx] for idx in filtered_indices]
self.sdf_data_list = [self.sdf_data_list[idx] for idx in filtered_indices]
def __len__(self): def __len__(self):
return len(self.brep_data_list) return len(self.brep_data_list)
@ -87,8 +104,7 @@ class BRepSDFDataset(Dataset):
name = os.path.splitext(os.path.basename(brep_path))[0] name = os.path.splitext(os.path.basename(brep_path))[0]
# 加载B-rep和SDF数据 # 加载B-rep和SDF数据
with open(brep_path, 'rb') as f: brep_raw = self._load_brep_file(brep_path)
brep_raw = pickle.load(f)
sdf_data = self._load_sdf_file(sdf_path) sdf_data = self._load_sdf_file(sdf_path)
try: try:
@ -162,7 +178,10 @@ class BRepSDFDataset(Dataset):
logger.error(f"Error loading sample from {brep_path}: {str(e)}") logger.error(f"Error loading sample from {brep_path}: {str(e)}")
logger.error("Data structure:") logger.error("Data structure:")
raise raise
def _load_brep_file(self, brep_path):
with open(brep_path, 'rb') as f:
brep_raw = pickle.load(f)
return brep_raw
def _load_sdf_file(self, sdf_path): def _load_sdf_file(self, sdf_path):
"""加载和处理SDF数据,并进行随机采样""" """加载和处理SDF数据,并进行随机采样"""
@ -223,6 +242,11 @@ class BRepSDFDataset(Dataset):
logger.error(f"Error message: {str(e)}") logger.error(f"Error message: {str(e)}")
raise raise
def _get_brep_face_and_edge(self, brep_path: str) -> tuple[int,int]:
brep: dict = self._load_brep_file(brep_path)
face_edge_adj = brep["faceEdge_adj"]
num_faces, num_edges = face_edge_adj.shape
return num_faces, num_edges
def test_dataset(): def test_dataset():
"""测试数据集功能""" """测试数据集功能"""

Loading…
Cancel
Save