|
|
@ -29,49 +29,161 @@ class BRepSDFDataset(Dataset): |
|
|
|
raise ValueError(f"SDF directory not found: {self.sdf_dir}") |
|
|
|
|
|
|
|
# 加载数据列表 |
|
|
|
self.data_list = self._load_data_list() |
|
|
|
|
|
|
|
self.brep_data_list = self._load_data_list(self.brep_dir) |
|
|
|
self.sdf_data_list = self._load_data_list(self.sdf_dir) |
|
|
|
|
|
|
|
# 检查数据集是否为空 |
|
|
|
if len(self.data_list) == 0: |
|
|
|
raise ValueError(f"No valid data found in {split} set") |
|
|
|
if len(self.brep_data_list) == 0 : |
|
|
|
raise ValueError(f"No valid brep data found in {split} set") |
|
|
|
if len(self.sdf_data_list) == 0: |
|
|
|
raise ValueError(f"No valid sdf data found in {split} set") |
|
|
|
|
|
|
|
logger.info(f"Loaded {split} dataset with {len(self.data_list)} samples") |
|
|
|
logger.info(f"Loaded {split} dataset with {len(self.brep_data_list)} samples") |
|
|
|
|
|
|
|
def _load_data_list(self): |
|
|
|
# data_dir 为 self.brep_dir or sdf_dir |
|
|
|
def _load_data_list(self, data_dir): |
|
|
|
data_list = [] |
|
|
|
for sample_dir in os.listdir(self.brep_dir): |
|
|
|
sample_path = os.path.join(self.brep_dir, sample_dir) |
|
|
|
if os.path.isdir(sample_path): |
|
|
|
data_list.append(sample_path) |
|
|
|
for sample_file in os.listdir(data_dir): |
|
|
|
path = os.path.join(data_dir, sample_file) |
|
|
|
data_list.append(path) |
|
|
|
|
|
|
|
#logger.info(data_list) |
|
|
|
return data_list |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return len(self.data_list) |
|
|
|
return len(self.brep_data_list) |
|
|
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
sample_path = self.data_list[idx] |
|
|
|
|
|
|
|
# 解析 .step 文件 |
|
|
|
step_file = os.path.join(sample_path, 'model.step') |
|
|
|
brep_features = self._parse_step_file(step_file) |
|
|
|
|
|
|
|
# 加载 .sdf 文件 |
|
|
|
sdf_file = os.path.join(sample_path, 'sdf.npy') |
|
|
|
sdf = np.load(sdf_file) |
|
|
|
|
|
|
|
# 转换为 torch 张量 |
|
|
|
brep_features = torch.tensor(brep_features, dtype=torch.float32) |
|
|
|
sdf = torch.tensor(sdf, dtype=torch.float32) |
|
|
|
"""获取单个数据样本""" |
|
|
|
brep_path = self.brep_data_list[idx] |
|
|
|
sdf_path = self.sdf_data_list[idx] |
|
|
|
|
|
|
|
return { |
|
|
|
'brep_features': brep_features, |
|
|
|
'sdf': sdf |
|
|
|
} |
|
|
|
try: |
|
|
|
# 获取文件名(不含扩展名)作为sample name |
|
|
|
name = os.path.splitext(os.path.basename(brep_path))[0] |
|
|
|
|
|
|
|
# 加载B-rep和SDF数据 |
|
|
|
brep_data = self._load_brep_file(brep_path) |
|
|
|
sdf_data = self._load_sdf_file(sdf_path) |
|
|
|
|
|
|
|
# 修改返回格式,将sdf_data作为一个键值对添加 |
|
|
|
return { |
|
|
|
'name': name, |
|
|
|
**brep_data, # 解包B-rep特征 |
|
|
|
'sdf': sdf_data # 添加SDF数据作为一个键 |
|
|
|
} |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
logger.error(f"Error loading sample from {brep_path}: {str(e)}") |
|
|
|
logger.error("Data structure:") |
|
|
|
if 'brep_data' in locals(): |
|
|
|
for key, value in brep_data.items(): |
|
|
|
if isinstance(value, np.ndarray): |
|
|
|
logger.error(f" {key}: type={type(value)}, dtype={value.dtype}, shape={value.shape}") |
|
|
|
raise |
|
|
|
|
|
|
|
def _parse_step_file(self, step_file): |
|
|
|
# 解析 .step 文件的逻辑 |
|
|
|
# 返回 B-rep 特征 |
|
|
|
pass |
|
|
|
def _load_brep_file(self, brep_path): |
|
|
|
"""加载B-rep特征文件""" |
|
|
|
try: |
|
|
|
with open(brep_path, 'rb') as f: |
|
|
|
brep_data = pickle.load(f) |
|
|
|
|
|
|
|
features = {} |
|
|
|
|
|
|
|
# 1. 处理几何数据(不等长序列) |
|
|
|
for key in ['surf_wcs', 'surf_ncs', 'edge_wcs', 'edge_ncs']: |
|
|
|
if key in brep_data: |
|
|
|
try: |
|
|
|
features[key] = [ |
|
|
|
torch.from_numpy(np.array(x, dtype=np.float32)) |
|
|
|
for x in brep_data[key] |
|
|
|
] |
|
|
|
except Exception as e: |
|
|
|
logger.error(f"Error converting {key}:") |
|
|
|
logger.error(f" Type: {type(brep_data[key])}") |
|
|
|
if isinstance(brep_data[key], list): |
|
|
|
logger.error(f" List length: {len(brep_data[key])}") |
|
|
|
if len(brep_data[key]) > 0: |
|
|
|
logger.error(f" First element type: {type(brep_data[key][0])}") |
|
|
|
logger.error(f" First element shape: {brep_data[key][0].shape}") |
|
|
|
logger.error(f" First element dtype: {brep_data[key][0].dtype}") |
|
|
|
raise ValueError(f"Failed to convert {key}: {str(e)}") |
|
|
|
|
|
|
|
# 2. 处理固定形状的数据 |
|
|
|
for key in ['corner_wcs', 'corner_unique', 'surf_bbox_wcs', 'edge_bbox_wcs']: |
|
|
|
if key in brep_data: |
|
|
|
try: |
|
|
|
data = np.array(brep_data[key], dtype=np.float32) |
|
|
|
features[key] = torch.from_numpy(data) |
|
|
|
except Exception as e: |
|
|
|
logger.error(f"Error converting {key}:") |
|
|
|
logger.error(f" Type: {type(brep_data[key])}") |
|
|
|
if isinstance(brep_data[key], np.ndarray): |
|
|
|
logger.error(f" Shape: {brep_data[key].shape}") |
|
|
|
logger.error(f" dtype: {brep_data[key].dtype}") |
|
|
|
raise ValueError(f"Failed to convert {key}: {str(e)}") |
|
|
|
|
|
|
|
# 3. 处理邻接矩阵 |
|
|
|
for key in ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj']: |
|
|
|
if key in brep_data: |
|
|
|
try: |
|
|
|
data = np.array(brep_data[key], dtype=np.int32) |
|
|
|
features[key] = torch.from_numpy(data) |
|
|
|
except Exception as e: |
|
|
|
logger.error(f"Error converting {key}:") |
|
|
|
logger.error(f" Type: {type(brep_data[key])}") |
|
|
|
if isinstance(brep_data[key], np.ndarray): |
|
|
|
logger.error(f" Shape: {brep_data[key].shape}") |
|
|
|
logger.error(f" dtype: {brep_data[key].dtype}") |
|
|
|
raise ValueError(f"Failed to convert {key}: {str(e)}") |
|
|
|
|
|
|
|
return features |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
logger.error(f"\nError loading B-rep file: {brep_path}") |
|
|
|
logger.error(f"Error message: {str(e)}") |
|
|
|
|
|
|
|
# 打印完整的数据结构信息 |
|
|
|
if 'brep_data' in locals(): |
|
|
|
logger.error("\nComplete data structure:") |
|
|
|
for key, value in brep_data.items(): |
|
|
|
logger.error(f"\n{key}:") |
|
|
|
logger.error(f" Type: {type(value)}") |
|
|
|
if isinstance(value, np.ndarray): |
|
|
|
logger.error(f" Shape: {value.shape}") |
|
|
|
logger.error(f" dtype: {value.dtype}") |
|
|
|
elif isinstance(value, list): |
|
|
|
logger.error(f" List length: {len(value)}") |
|
|
|
if len(value) > 0: |
|
|
|
logger.error(f" First element type: {type(value[0])}") |
|
|
|
if isinstance(value[0], np.ndarray): |
|
|
|
logger.error(f" First element shape: {value[0].shape}") |
|
|
|
logger.error(f" First element dtype: {value[0].dtype}") |
|
|
|
raise |
|
|
|
|
|
|
|
def _load_sdf_file(self, sdf_path): |
|
|
|
"""加载和处理SDF数据""" |
|
|
|
try: |
|
|
|
# 加载SDF值 |
|
|
|
sdf_data = np.load(sdf_path) |
|
|
|
if 'pos' not in sdf_data or 'neg' not in sdf_data: |
|
|
|
raise ValueError("Missing pos/neg data in SDF file") |
|
|
|
|
|
|
|
sdf_pos = sdf_data['pos'] # (N1, 4) |
|
|
|
sdf_neg = sdf_data['neg'] # (N2, 4) |
|
|
|
|
|
|
|
# 添加数据验证 |
|
|
|
if sdf_pos.shape[1] != 4 or sdf_neg.shape[1] != 4: |
|
|
|
raise ValueError(f"Invalid SDF data shape: pos={sdf_pos.shape}, neg={sdf_neg.shape}") |
|
|
|
|
|
|
|
sdf_np = np.concatenate([sdf_pos, sdf_neg], axis=0) |
|
|
|
return torch.from_numpy(sdf_np.astype(np.float32)) |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
logger.error(f"Error loading SDF from {sdf_path}") |
|
|
|
logger.error(f"Error type: {type(e).__name__}") |
|
|
|
logger.error(f"Error message: {str(e)}") |
|
|
|
raise |
|
|
|
|
|
|
|
def test_dataset(): |
|
|
|
"""测试数据集功能""" |
|
|
@ -88,11 +200,10 @@ def test_dataset(): |
|
|
|
logger.info(f"Split: {split}") |
|
|
|
|
|
|
|
# ... (其余测试代码保持不变) ... |
|
|
|
dataset = BRepSDFDataset(brep_data_dir='/home/wch/myDeepSDF/test_data/pkl', sdf_data_dir='/home/wch/myDeepSDF/test_data/sdf', split='train') |
|
|
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True) |
|
|
|
dataset = BRepSDFDataset(brep_dir='/home/wch/brep2sdf/test_data/pkl', sdf_dir='/home/wch/brep2sdf/test_data/sdf', split='train') |
|
|
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True) |
|
|
|
|
|
|
|
for batch in dataloader: |
|
|
|
print(batch['brep_features'].shape) |
|
|
|
print(batch['sdf'].shape) |
|
|
|
break |
|
|
|
except Exception as e: |
|
|
|