|
@ -3,10 +3,8 @@ import torch |
|
|
from torch.utils.data import Dataset |
|
|
from torch.utils.data import Dataset |
|
|
import numpy as np |
|
|
import numpy as np |
|
|
import pickle |
|
|
import pickle |
|
|
from brep2sdf.utils.logger import setup_logger |
|
|
from brep2sdf.utils.logger import logger |
|
|
from .utils import process_brep_data |
|
|
from brep2sdf.data.utils import process_brep_data |
|
|
# 设置日志记录器 |
|
|
|
|
|
logger = setup_logger('dataset') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -22,10 +20,16 @@ class BRepSDFDataset(Dataset): |
|
|
sdf_dir: npz文件目录 |
|
|
sdf_dir: npz文件目录 |
|
|
split: 数据集分割('train', 'val', 'test') |
|
|
split: 数据集分割('train', 'val', 'test') |
|
|
""" |
|
|
""" |
|
|
|
|
|
super().__init__() |
|
|
self.brep_dir = os.path.join(brep_dir, split) |
|
|
self.brep_dir = os.path.join(brep_dir, split) |
|
|
self.sdf_dir = os.path.join(sdf_dir, split) |
|
|
self.sdf_dir = os.path.join(sdf_dir, split) |
|
|
self.split = split |
|
|
self.split = split |
|
|
|
|
|
|
|
|
|
|
|
# 设置固定参数 |
|
|
|
|
|
self.max_face = 70 |
|
|
|
|
|
self.max_edge = 70 |
|
|
|
|
|
self.bbox_scaled = 1.0 |
|
|
|
|
|
|
|
|
# 检查目录是否存在 |
|
|
# 检查目录是否存在 |
|
|
if not os.path.exists(self.brep_dir): |
|
|
if not os.path.exists(self.brep_dir): |
|
|
raise ValueError(f"B-rep directory not found: {self.brep_dir}") |
|
|
raise ValueError(f"B-rep directory not found: {self.brep_dir}") |
|
@ -59,112 +63,187 @@ class BRepSDFDataset(Dataset): |
|
|
|
|
|
|
|
|
def __getitem__(self, idx): |
|
|
def __getitem__(self, idx): |
|
|
"""获取单个数据样本""" |
|
|
"""获取单个数据样本""" |
|
|
brep_path = self.brep_data_list[idx] |
|
|
|
|
|
sdf_path = self.sdf_data_list[idx] |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
try: |
|
|
# 获取文件名(不含扩展名)作为sample name |
|
|
brep_path = self.brep_data_list[idx] |
|
|
|
|
|
sdf_path = self.sdf_data_list[idx] |
|
|
name = os.path.splitext(os.path.basename(brep_path))[0] |
|
|
name = os.path.splitext(os.path.basename(brep_path))[0] |
|
|
|
|
|
|
|
|
# 加载B-rep和SDF数据 |
|
|
# 加载B-rep和SDF数据 |
|
|
brep_data = self._load_brep_file(brep_path) |
|
|
with open(brep_path, 'rb') as f: |
|
|
|
|
|
brep_raw = pickle.load(f) |
|
|
sdf_data = self._load_sdf_file(sdf_path) |
|
|
sdf_data = self._load_sdf_file(sdf_path) |
|
|
|
|
|
|
|
|
# 修改返回格式,将sdf_data作为一个键值对添加 |
|
|
try: |
|
|
return { |
|
|
# 处理B-rep数据 |
|
|
'name': name, |
|
|
brep_features = process_brep_data( |
|
|
**brep_data, # 解包B-rep特征 |
|
|
data=brep_raw, |
|
|
'sdf': sdf_data # 添加SDF数据作为一个键 |
|
|
max_face=self.max_face, |
|
|
} |
|
|
max_edge=self.max_edge, |
|
|
|
|
|
bbox_scaled=self.bbox_scaled |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
# 检查返回值的类型和数量 |
|
|
|
|
|
if not isinstance(brep_features, tuple): |
|
|
|
|
|
logger.error(f"process_brep_data returned {type(brep_features)}, expected tuple") |
|
|
|
|
|
raise ValueError("Invalid return type from process_brep_data") |
|
|
|
|
|
|
|
|
|
|
|
if len(brep_features) != 6: |
|
|
|
|
|
logger.error(f"Expected 6 features, got {len(brep_features)}") |
|
|
|
|
|
logger.error("Features returned:") |
|
|
|
|
|
for i, feat in enumerate(brep_features): |
|
|
|
|
|
if isinstance(feat, torch.Tensor): |
|
|
|
|
|
logger.error(f" {i}: Tensor of shape {feat.shape}") |
|
|
|
|
|
else: |
|
|
|
|
|
logger.error(f" {i}: {type(feat)}") |
|
|
|
|
|
raise ValueError(f"Incorrect number of features: {len(brep_features)}") |
|
|
|
|
|
|
|
|
|
|
|
# 解包处理后的特征 |
|
|
|
|
|
edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos = brep_features |
|
|
|
|
|
|
|
|
|
|
|
# 构建返回字典 |
|
|
|
|
|
return { |
|
|
|
|
|
'name': name, |
|
|
|
|
|
'edge_ncs': edge_ncs, # [max_face, max_edge, 10, 3] |
|
|
|
|
|
'edge_pos': edge_pos, # [max_face, max_edge, 6] |
|
|
|
|
|
'edge_mask': edge_mask, # [max_face, max_edge] |
|
|
|
|
|
'surf_ncs': surf_ncs, # [max_face, 100, 3] |
|
|
|
|
|
'surf_pos': surf_pos, # [max_face, 6] |
|
|
|
|
|
'vertex_pos': vertex_pos, # [max_face, max_edge, 6] |
|
|
|
|
|
'sdf': sdf_data # [N, 4] |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
logger.error(f"\nError processing B-rep data for file: {brep_path}") |
|
|
|
|
|
logger.error(f"Error type: {type(e).__name__}") |
|
|
|
|
|
logger.error(f"Error message: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
# 打印原始数据的结构 |
|
|
|
|
|
logger.error("\nRaw data structure:") |
|
|
|
|
|
for key, value in brep_raw.items(): |
|
|
|
|
|
if isinstance(value, list): |
|
|
|
|
|
logger.error(f" {key}: list of length {len(value)}") |
|
|
|
|
|
if value: |
|
|
|
|
|
logger.error(f" First element type: {type(value[0])}") |
|
|
|
|
|
if hasattr(value[0], 'shape'): |
|
|
|
|
|
logger.error(f" First element shape: {value[0].shape}") |
|
|
|
|
|
elif hasattr(value, 'shape'): |
|
|
|
|
|
logger.error(f" {key}: shape {value.shape}") |
|
|
|
|
|
else: |
|
|
|
|
|
logger.error(f" {key}: {type(value)}") |
|
|
|
|
|
raise |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading sample from {brep_path}: {str(e)}") |
|
|
logger.error(f"Error loading sample from {brep_path}: {str(e)}") |
|
|
logger.error("Data structure:") |
|
|
logger.error("Data structure:") |
|
|
if 'brep_data' in locals(): |
|
|
|
|
|
for key, value in brep_data.items(): |
|
|
|
|
|
if isinstance(value, np.ndarray): |
|
|
|
|
|
logger.error(f" {key}: type={type(value)}, dtype={value.dtype}, shape={value.shape}") |
|
|
|
|
|
raise |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_brep_file(self, brep_path): |
|
|
def _load_brep_file(self, brep_path): |
|
|
"""加载B-rep特征文件""" |
|
|
"""加载B-rep特征文件""" |
|
|
try: |
|
|
try: |
|
|
|
|
|
# 1. 加载原始数据 |
|
|
with open(brep_path, 'rb') as f: |
|
|
with open(brep_path, 'rb') as f: |
|
|
brep_data = pickle.load(f) |
|
|
raw_data = pickle.load(f) |
|
|
|
|
|
|
|
|
brep_data = {} |
|
|
brep_data = {} |
|
|
|
|
|
|
|
|
# 1. 处理几何数据(不等长序列) |
|
|
# 2. 处理几何数据(不等长序列) |
|
|
for key in ['surf_wcs', 'surf_ncs', 'edge_wcs', 'edge_ncs']: |
|
|
geom_keys = ['surf_wcs', 'surf_ncs', 'edge_wcs', 'edge_ncs'] |
|
|
if key in brep_data: |
|
|
for key in geom_keys: |
|
|
|
|
|
if key in raw_data: |
|
|
try: |
|
|
try: |
|
|
brep_data[key] = [ |
|
|
# 确保数据是列表 |
|
|
torch.from_numpy(np.array(x, dtype=np.float32)) |
|
|
if not isinstance(raw_data[key], list): |
|
|
for x in brep_data[key] |
|
|
raise ValueError(f"{key} is not a list") |
|
|
] |
|
|
|
|
|
|
|
|
# 转换每个元素为张量 |
|
|
|
|
|
tensors = [] |
|
|
|
|
|
for i, x in enumerate(raw_data[key]): |
|
|
|
|
|
try: |
|
|
|
|
|
# 先转换为numpy数组 |
|
|
|
|
|
arr = np.array(x, dtype=np.float32) |
|
|
|
|
|
# 再转换为张量 |
|
|
|
|
|
tensor = torch.from_numpy(arr) |
|
|
|
|
|
tensors.append(tensor) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
logger.error(f"Error converting {key}[{i}]:") |
|
|
|
|
|
logger.error(f" Data type: {type(x)}") |
|
|
|
|
|
if isinstance(x, np.ndarray): |
|
|
|
|
|
logger.error(f" Shape: {x.shape}") |
|
|
|
|
|
logger.error(f" dtype: {x.dtype}") |
|
|
|
|
|
raise ValueError(f"Failed to convert {key}[{i}]: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
brep_data[key] = tensors |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
logger.error(f"Error converting {key}:") |
|
|
logger.error(f"Error processing {key}:") |
|
|
logger.error(f" Type: {type(brep_data[key])}") |
|
|
logger.error(f" Raw data type: {type(raw_data[key])}") |
|
|
if isinstance(brep_data[key], list): |
|
|
raise ValueError(f"Failed to process {key}: {str(e)}") |
|
|
logger.error(f" List length: {len(brep_data[key])}") |
|
|
|
|
|
if len(brep_data[key]) > 0: |
|
|
|
|
|
logger.error(f" First element type: {type(brep_data[key][0])}") |
|
|
|
|
|
logger.error(f" First element shape: {brep_data[key][0].shape}") |
|
|
|
|
|
logger.error(f" First element dtype: {brep_data[key][0].dtype}") |
|
|
|
|
|
raise ValueError(f"Failed to convert {key}: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
# 2. 处理固定形状的数据 |
|
|
# 3. 处理固定形状的数据 |
|
|
for key in ['corner_wcs', 'corner_unique', 'surf_bbox_wcs', 'edge_bbox_wcs']: |
|
|
fixed_keys = ['corner_wcs', 'corner_unique', 'surf_bbox_wcs', 'edge_bbox_wcs'] |
|
|
if key in brep_data: |
|
|
for key in fixed_keys: |
|
|
|
|
|
if key in raw_data: |
|
|
try: |
|
|
try: |
|
|
data = np.array(brep_data[key], dtype=np.float32) |
|
|
# 直接从原始数据转换 |
|
|
brep_data[key] = torch.from_numpy(data) |
|
|
arr = np.array(raw_data[key], dtype=np.float32) |
|
|
|
|
|
brep_data[key] = torch.from_numpy(arr) |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
logger.error(f"Error converting {key}:") |
|
|
logger.error(f"Error converting fixed shape data {key}:") |
|
|
logger.error(f" Type: {type(brep_data[key])}") |
|
|
logger.error(f" Raw data type: {type(raw_data[key])}") |
|
|
if isinstance(brep_data[key], np.ndarray): |
|
|
if isinstance(raw_data[key], np.ndarray): |
|
|
logger.error(f" Shape: {brep_data[key].shape}") |
|
|
logger.error(f" Shape: {raw_data[key].shape}") |
|
|
logger.error(f" dtype: {brep_data[key].dtype}") |
|
|
logger.error(f" dtype: {raw_data[key].dtype}") |
|
|
raise ValueError(f"Failed to convert {key}: {str(e)}") |
|
|
raise ValueError(f"Failed to convert {key}: {str(e)}") |
|
|
|
|
|
|
|
|
# 3. 处理邻接矩阵 |
|
|
# 4. 处理邻接矩阵 |
|
|
for key in ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj']: |
|
|
adj_keys = ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj'] |
|
|
if key in brep_data: |
|
|
for key in adj_keys: |
|
|
|
|
|
if key in raw_data: |
|
|
try: |
|
|
try: |
|
|
data = np.array(brep_data[key], dtype=np.int32) |
|
|
# 转换为整型数组 |
|
|
brep_data[key] = torch.from_numpy(data) |
|
|
arr = np.array(raw_data[key], dtype=np.int32) |
|
|
|
|
|
brep_data[key] = torch.from_numpy(arr) |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
logger.error(f"Error converting {key}:") |
|
|
logger.error(f"Error converting adjacency matrix {key}:") |
|
|
logger.error(f" Type: {type(brep_data[key])}") |
|
|
logger.error(f" Raw data type: {type(raw_data[key])}") |
|
|
if isinstance(brep_data[key], np.ndarray): |
|
|
if isinstance(raw_data[key], np.ndarray): |
|
|
logger.error(f" Shape: {brep_data[key].shape}") |
|
|
logger.error(f" Shape: {raw_data[key].shape}") |
|
|
logger.error(f" dtype: {brep_data[key].dtype}") |
|
|
logger.error(f" dtype: {raw_data[key].dtype}") |
|
|
raise ValueError(f"Failed to convert {key}: {str(e)}") |
|
|
raise ValueError(f"Failed to convert {key}: {str(e)}") |
|
|
feature_embedder = process_brep_data(brep_data, 70,70,1,) |
|
|
|
|
|
|
|
|
|
|
|
return feature_embedder |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 5. 验证必要的键是否存在 |
|
|
|
|
|
required_keys = {'surf_wcs', 'edge_wcs', 'corner_wcs'} |
|
|
|
|
|
missing_keys = required_keys - set(brep_data.keys()) |
|
|
|
|
|
if missing_keys: |
|
|
|
|
|
raise ValueError(f"Missing required keys: {missing_keys}") |
|
|
|
|
|
|
|
|
|
|
|
# 6. 使用process_brep_data处理数据 |
|
|
|
|
|
try: |
|
|
|
|
|
features = process_brep_data( |
|
|
|
|
|
data=brep_data, |
|
|
|
|
|
max_face=self.max_face, |
|
|
|
|
|
max_edge=self.max_edge, |
|
|
|
|
|
bbox_scaled=self.bbox_scaled |
|
|
|
|
|
) |
|
|
|
|
|
return features |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
logger.error("Error in process_brep_data:") |
|
|
|
|
|
logger.error(f" Error message: {str(e)}") |
|
|
|
|
|
# 打印数据形状信息 |
|
|
|
|
|
logger.error("\nInput data shapes:") |
|
|
|
|
|
for key, value in brep_data.items(): |
|
|
|
|
|
if isinstance(value, list): |
|
|
|
|
|
shapes = [t.shape for t in value] |
|
|
|
|
|
logger.error(f" {key}: list of tensors with shapes {shapes}") |
|
|
|
|
|
elif isinstance(value, torch.Tensor): |
|
|
|
|
|
logger.error(f" {key}: tensor of shape {value.shape}") |
|
|
|
|
|
raise |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
logger.error(f"\nError loading B-rep file: {brep_path}") |
|
|
logger.error(f"\nError loading B-rep file: {brep_path}") |
|
|
logger.error(f"Error message: {str(e)}") |
|
|
logger.error(f"Error message: {str(e)}") |
|
|
|
|
|
|
|
|
# 打印完整的数据结构信息 |
|
|
|
|
|
if 'brep_data' in locals(): |
|
|
|
|
|
logger.error("\nComplete data structure:") |
|
|
|
|
|
for key, value in brep_data.items(): |
|
|
|
|
|
logger.error(f"\n{key}:") |
|
|
|
|
|
logger.error(f" Type: {type(value)}") |
|
|
|
|
|
if isinstance(value, np.ndarray): |
|
|
|
|
|
logger.error(f" Shape: {value.shape}") |
|
|
|
|
|
logger.error(f" dtype: {value.dtype}") |
|
|
|
|
|
elif isinstance(value, list): |
|
|
|
|
|
logger.error(f" List length: {len(value)}") |
|
|
|
|
|
if len(value) > 0: |
|
|
|
|
|
logger.error(f" First element type: {type(value[0])}") |
|
|
|
|
|
if isinstance(value[0], np.ndarray): |
|
|
|
|
|
logger.error(f" First element shape: {value[0].shape}") |
|
|
|
|
|
logger.error(f" First element dtype: {value[0].dtype}") |
|
|
|
|
|
raise |
|
|
raise |
|
|
|
|
|
|
|
|
def _load_sdf_file(self, sdf_path): |
|
|
def _load_sdf_file(self, sdf_path): |
|
@ -190,6 +269,52 @@ class BRepSDFDataset(Dataset): |
|
|
logger.error(f"Error type: {type(e).__name__}") |
|
|
logger.error(f"Error type: {type(e).__name__}") |
|
|
logger.error(f"Error message: {str(e)}") |
|
|
logger.error(f"Error message: {str(e)}") |
|
|
raise |
|
|
raise |
|
|
|
|
|
@staticmethod |
|
|
|
|
|
def collate_fn(batch): |
|
|
|
|
|
"""自定义批处理函数""" |
|
|
|
|
|
# 收集所有样本的名称 |
|
|
|
|
|
names = [item['name'] for item in batch] |
|
|
|
|
|
|
|
|
|
|
|
# 处理固定大小的张量数据 |
|
|
|
|
|
tensor_keys = ['edge_ncs', 'edge_pos', 'edge_mask', |
|
|
|
|
|
'surf_ncs', 'surf_pos', 'vertex_pos'] |
|
|
|
|
|
tensors = { |
|
|
|
|
|
key: torch.stack([item[key] for item in batch]) |
|
|
|
|
|
for key in tensor_keys |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
# 处理变长的SDF数据 |
|
|
|
|
|
sdf_data = [item['sdf'] for item in batch] |
|
|
|
|
|
max_sdf_len = max(sdf.size(0) for sdf in sdf_data) |
|
|
|
|
|
|
|
|
|
|
|
# 填充SDF数据 |
|
|
|
|
|
padded_sdfs = [] |
|
|
|
|
|
sdf_masks = [] |
|
|
|
|
|
for sdf in sdf_data: |
|
|
|
|
|
pad_len = max_sdf_len - sdf.size(0) |
|
|
|
|
|
if pad_len > 0: |
|
|
|
|
|
padding = torch.zeros(pad_len, sdf.size(1), |
|
|
|
|
|
dtype=sdf.dtype, device=sdf.device) |
|
|
|
|
|
padded_sdf = torch.cat([sdf, padding], dim=0) |
|
|
|
|
|
mask = torch.cat([ |
|
|
|
|
|
torch.ones(sdf.size(0), dtype=torch.bool), |
|
|
|
|
|
torch.zeros(pad_len, dtype=torch.bool) |
|
|
|
|
|
]) |
|
|
|
|
|
else: |
|
|
|
|
|
padded_sdf = sdf |
|
|
|
|
|
mask = torch.ones(sdf.size(0), dtype=torch.bool) |
|
|
|
|
|
padded_sdfs.append(padded_sdf) |
|
|
|
|
|
sdf_masks.append(mask) |
|
|
|
|
|
|
|
|
|
|
|
# 合并所有数据 |
|
|
|
|
|
batch_data = { |
|
|
|
|
|
'name': names, |
|
|
|
|
|
'sdf': torch.stack(padded_sdfs), |
|
|
|
|
|
'sdf_mask': torch.stack(sdf_masks), |
|
|
|
|
|
**tensors |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
return batch_data |
|
|
|
|
|
|
|
|
def test_dataset(): |
|
|
def test_dataset(): |
|
|
"""测试数据集功能""" |
|
|
"""测试数据集功能""" |
|
|