|
|
@ -12,7 +12,7 @@ from brep2sdf.config.default_config import get_default_config |
|
|
|
|
|
|
|
|
|
|
|
class BRepSDFDataset(Dataset): |
|
|
|
def __init__(self, brep_dir:str, sdf_dir:str, valid_data_dir:str, split:str='train'): |
|
|
|
def __init__(self, brep_dir:str, sdf_dir:str, valid_data_dir:str, use_filter: bool=True, split:str='train'): |
|
|
|
""" |
|
|
|
初始化数据集 |
|
|
|
|
|
|
@ -52,6 +52,9 @@ class BRepSDFDataset(Dataset): |
|
|
|
self.brep_data_list = self._load_data_list(self.brep_dir) |
|
|
|
self.sdf_data_list = self._load_data_list(self.sdf_dir) |
|
|
|
|
|
|
|
if use_filter: |
|
|
|
self._filter_num_faces_and_num_edges() |
|
|
|
|
|
|
|
# 检查数据集是否为空 |
|
|
|
if len(self.brep_data_list) == 0 : |
|
|
|
raise ValueError(f"No valid brep data found in {split} set") |
|
|
@ -76,6 +79,20 @@ class BRepSDFDataset(Dataset): |
|
|
|
#logger.info(data_list) |
|
|
|
return data_list |
|
|
|
|
|
|
|
def _filter_num_faces_and_num_edges(self): |
|
|
|
''' |
|
|
|
Filter the data if their face_num or edge_num > max_face or max_edge. |
|
|
|
''' |
|
|
|
# Collect indices of elements that satisfy the condition |
|
|
|
filtered_indices = [ |
|
|
|
idx for idx in range(len(self.brep_data_list)) |
|
|
|
if (self._get_brep_face_and_edge(self.brep_data_list[idx]) <= (self.max_face, self.max_edge)) |
|
|
|
] |
|
|
|
|
|
|
|
# Use filtered_indices to update brep_data_list and sdf_data_list |
|
|
|
self.brep_data_list = [self.brep_data_list[idx] for idx in filtered_indices] |
|
|
|
self.sdf_data_list = [self.sdf_data_list[idx] for idx in filtered_indices] |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return len(self.brep_data_list) |
|
|
|
|
|
|
@ -87,8 +104,7 @@ class BRepSDFDataset(Dataset): |
|
|
|
name = os.path.splitext(os.path.basename(brep_path))[0] |
|
|
|
|
|
|
|
# 加载B-rep和SDF数据 |
|
|
|
with open(brep_path, 'rb') as f: |
|
|
|
brep_raw = pickle.load(f) |
|
|
|
brep_raw = self._load_brep_file(brep_path) |
|
|
|
sdf_data = self._load_sdf_file(sdf_path) |
|
|
|
|
|
|
|
try: |
|
|
@ -162,7 +178,10 @@ class BRepSDFDataset(Dataset): |
|
|
|
logger.error(f"Error loading sample from {brep_path}: {str(e)}") |
|
|
|
logger.error("Data structure:") |
|
|
|
raise |
|
|
|
|
|
|
|
def _load_brep_file(self, brep_path): |
|
|
|
with open(brep_path, 'rb') as f: |
|
|
|
brep_raw = pickle.load(f) |
|
|
|
return brep_raw |
|
|
|
|
|
|
|
def _load_sdf_file(self, sdf_path): |
|
|
|
"""加载和处理SDF数据,并进行随机采样""" |
|
|
@ -222,7 +241,12 @@ class BRepSDFDataset(Dataset): |
|
|
|
logger.error(f"Error type: {type(e).__name__}") |
|
|
|
logger.error(f"Error message: {str(e)}") |
|
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
def _get_brep_face_and_edge(self, brep_path: str) -> tuple[int,int]: |
|
|
|
brep: dict = self._load_brep_file(brep_path) |
|
|
|
face_edge_adj = brep["faceEdge_adj"] |
|
|
|
num_faces, num_edges = face_edge_adj.shape |
|
|
|
return num_faces, num_edges |
|
|
|
|
|
|
|
def test_dataset(): |
|
|
|
"""测试数据集功能""" |
|
|
|