Browse Source

fix: del collate and load_brep

main
mckay 4 months ago
parent
commit
2e124b7193
  1. 169
      brep2sdf/data/data.py

169
brep2sdf/data/data.py

@ -121,7 +121,9 @@ class BRepSDFDataset(Dataset):
# 解包处理后的特征
edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos = brep_features
sdf_points = sdf_data[:, :3]
sdf_values = sdf_data[:, 3:]
# 构建返回字典
return {
'name': name,
@ -131,8 +133,8 @@ class BRepSDFDataset(Dataset):
'surf_ncs': surf_ncs, # [max_face, 100, 3]
'surf_pos': surf_pos, # [max_face, 6]
'vertex_pos': vertex_pos, # [max_face, max_edge, 6]
'points': sdf_data[:, :3], # [num_queries, 3] 所有点的xyz坐标
'sdf': sdf_data[:, 3:] # [num_queries, 1] 所有点的sdf值
'points': sdf_points, # [num_queries, 3] 所有点的xyz坐标
'sdf': sdf_values # [num_queries, 1] 所有点的sdf值
}
except Exception as e:
@ -161,115 +163,6 @@ class BRepSDFDataset(Dataset):
raise
def _load_brep_file(self, brep_path):
"""加载B-rep特征文件"""
try:
# 1. 加载原始数据
with open(brep_path, 'rb') as f:
raw_data = pickle.load(f)
brep_data = {}
# 2. 处理几何数据(不等长序列)
geom_keys = ['surf_wcs', 'surf_ncs', 'edge_wcs', 'edge_ncs']
for key in geom_keys:
if key in raw_data:
try:
# 确保数据是列表
if not isinstance(raw_data[key], list):
raise ValueError(f"{key} is not a list")
# 转换每个元素为张量
tensors = []
for i, x in enumerate(raw_data[key]):
try:
# 先转换为numpy数组
arr = np.array(x, dtype=np.float32)
# 再转换为张量
tensor = torch.from_numpy(arr)
tensors.append(tensor)
except Exception as e:
logger.error(f"Error converting {key}[{i}]:")
logger.error(f" Data type: {type(x)}")
if isinstance(x, np.ndarray):
logger.error(f" Shape: {x.shape}")
logger.error(f" dtype: {x.dtype}")
raise ValueError(f"Failed to convert {key}[{i}]: {str(e)}")
brep_data[key] = tensors
except Exception as e:
logger.error(f"Error processing {key}:")
logger.error(f" Raw data type: {type(raw_data[key])}")
raise ValueError(f"Failed to process {key}: {str(e)}")
# 3. 处理固定形状的数据
fixed_keys = ['corner_wcs', 'corner_unique', 'surf_bbox_wcs', 'edge_bbox_wcs']
for key in fixed_keys:
if key in raw_data:
try:
# 直接从原始数据转换
arr = np.array(raw_data[key], dtype=np.float32)
brep_data[key] = torch.from_numpy(arr)
except Exception as e:
logger.error(f"Error converting fixed shape data {key}:")
logger.error(f" Raw data type: {type(raw_data[key])}")
if isinstance(raw_data[key], np.ndarray):
logger.error(f" Shape: {raw_data[key].shape}")
logger.error(f" dtype: {raw_data[key].dtype}")
raise ValueError(f"Failed to convert {key}: {str(e)}")
# 4. 处理邻接矩阵
adj_keys = ['edgeFace_adj', 'edgeCorner_adj', 'faceEdge_adj']
for key in adj_keys:
if key in raw_data:
try:
# 转换为整型数组
arr = np.array(raw_data[key], dtype=np.int32)
brep_data[key] = torch.from_numpy(arr)
except Exception as e:
logger.error(f"Error converting adjacency matrix {key}:")
logger.error(f" Raw data type: {type(raw_data[key])}")
if isinstance(raw_data[key], np.ndarray):
logger.error(f" Shape: {raw_data[key].shape}")
logger.error(f" dtype: {raw_data[key].dtype}")
raise ValueError(f"Failed to convert {key}: {str(e)}")
# 5. 验证必要的键是否存在
required_keys = {'surf_wcs', 'edge_wcs', 'corner_wcs'}
missing_keys = required_keys - set(brep_data.keys())
if missing_keys:
raise ValueError(f"Missing required keys: {missing_keys}")
# 6. 使用process_brep_data处理数据
try:
features = process_brep_data(
data=brep_data,
max_face=self.max_face,
max_edge=self.max_edge,
bbox_scaled=self.bbox_scaled
)
return features
except Exception as e:
logger.error("Error in process_brep_data:")
logger.error(f" Error message: {str(e)}")
# 打印数据形状信息
logger.error("\nInput data shapes:")
for key, value in brep_data.items():
if isinstance(value, list):
shapes = [t.shape for t in value]
logger.error(f" {key}: list of tensors with shapes {shapes}")
elif isinstance(value, torch.Tensor):
logger.error(f" {key}: tensor of shape {value.shape}")
raise
except Exception as e:
logger.error(f"\nError loading B-rep file: {brep_path}")
logger.error(f"Error message: {str(e)}")
raise
def _load_sdf_file(self, sdf_path):
"""加载和处理SDF数据,并进行随机采样"""
try:
@ -321,52 +214,7 @@ class BRepSDFDataset(Dataset):
logger.error(f"Error type: {type(e).__name__}")
logger.error(f"Error message: {str(e)}")
raise
@staticmethod
def collate_fn(batch):
"""自定义批处理函数"""
# 收集所有样本的名称
names = [item['name'] for item in batch]
# 处理固定大小的张量数据
tensor_keys = ['edge_ncs', 'edge_pos', 'edge_mask',
'surf_ncs', 'surf_pos', 'vertex_pos']
tensors = {
key: torch.stack([item[key] for item in batch])
for key in tensor_keys
}
# 处理变长的SDF数据
sdf_data = [item['sdf'] for item in batch]
max_sdf_len = max(sdf.size(0) for sdf in sdf_data)
# 填充SDF数据
padded_sdfs = []
sdf_masks = []
for sdf in sdf_data:
pad_len = max_sdf_len - sdf.size(0)
if pad_len > 0:
padding = torch.zeros(pad_len, sdf.size(1),
dtype=sdf.dtype, device=sdf.device)
padded_sdf = torch.cat([sdf, padding], dim=0)
mask = torch.cat([
torch.ones(sdf.size(0), dtype=torch.bool),
torch.zeros(pad_len, dtype=torch.bool)
])
else:
padded_sdf = sdf
mask = torch.ones(sdf.size(0), dtype=torch.bool)
padded_sdfs.append(padded_sdf)
sdf_masks.append(mask)
# 合并所有数据
batch_data = {
'name': names,
'sdf': torch.stack(padded_sdfs),
'sdf_mask': torch.stack(sdf_masks),
**tensors
}
return batch_data
def test_dataset():
"""测试数据集功能"""
@ -381,6 +229,7 @@ def test_dataset():
max_edge = config.data.max_edge
num_edge_points = config.model.num_edge_points
num_surf_points = config.model.num_surf_points
num_query_points = config.data.num_query_points
# 定义预期的数据维度,使用配置中的参数
expected_shapes = {
@ -390,7 +239,8 @@ def test_dataset():
'surf_ncs': (max_face, num_surf_points, 3),
'surf_pos': (max_face, 6),
'vertex_pos': (max_face, max_edge, 2, 3),
'sdf': (2097152, 4)
'points': (num_query_points, 3),
'sdf': (num_query_points, 1)
}
logger.info("="*50)
@ -424,6 +274,7 @@ def test_dataset():
logger.info(f" Expected: {expected_shape}")
logger.info(f" Match: {shape_match}")
logger.info(f" dtype: {value.dtype}")
logger.info(f" grad: {value.requires_grad}")
if not shape_match:
logger.warning(f" Shape mismatch for {key}!")

Loading…
Cancel
Save