Browse Source

refactor: 用config替代硬编码

main
王琛涵 4 months ago
parent
commit
84fc447adb
  1. 4
      brep2sdf/config/default_config.py
  2. 41
      brep2sdf/data/data.py
  3. 80
      brep2sdf/data/utils.py
  4. 29
      brep2sdf/networks/encoder.py

4
brep2sdf/config/default_config.py

@ -16,8 +16,8 @@ class ModelConfig:
@dataclass
class DataConfig:
"""数据相关配置"""
max_face: int = 70
max_edge: int = 70
max_face: int = 64
max_edge: int = 64
bbox_scaled: float = 1.0
# 数据路径

41
brep2sdf/data/data.py

@ -5,6 +5,7 @@ import numpy as np
import pickle
from brep2sdf.utils.logger import logger
from brep2sdf.data.utils import process_brep_data
from brep2sdf.config.default_config import get_default_config
@ -21,14 +22,17 @@ class BRepSDFDataset(Dataset):
split: 数据集分割('train', 'val', 'test')
"""
super().__init__()
# 使用配置文件
self.config = get_default_config()
self.brep_dir = os.path.join(brep_dir, split)
self.sdf_dir = os.path.join(sdf_dir, split)
self.split = split
# 设置固定参数
self.max_face = 70
self.max_edge = 70
self.bbox_scaled = 1.0
# 使用配置文件中的参数替换固定参数
self.max_face = self.config.data.max_face
self.max_edge = self.config.data.max_edge
self.bbox_scaled = self.config.data.bbox_scaled
# 检查目录是否存在
if not os.path.exists(self.brep_dir):
@ -334,21 +338,26 @@ class BRepSDFDataset(Dataset):
def test_dataset():
"""测试数据集功能"""
try:
# 1. 设置测试路径和预期的数据维度
brep_dir = '/home/wch/brep2sdf/test_data/pkl'
sdf_dir = '/home/wch/brep2sdf/test_data/sdf'
valid_data_dir = "/home/wch/brep2sdf/test_data/result/pkl"
# 获取配置
config = get_default_config()
brep_dir = config.data.brep_dir
sdf_dir = config.data.sdf_dir
valid_data_dir = config.data.valid_data_dir
split = 'train'
max_face = config.data.max_face
max_edge = config.data.max_edge
num_edge_points = config.model.num_edge_points
num_surf_points = config.model.num_surf_points
# 定义预期的数据维度
# 定义预期的数据维度,使用配置中的参数
expected_shapes = {
'edge_ncs': (70, 70, 10, 3), # [max_face, max_edge, sample_points, xyz]
'edge_pos': (70, 70, 6), # [max_face, max_edge, bbox]
'edge_mask': (70, 70), # [max_face, max_edge]
'surf_ncs': (70, 100, 3), # [max_face, sample_points, xyz]
'surf_pos': (70, 6), # [max_face, bbox]
'vertex_pos': (70, 70, 2, 3), # [max_face, max_edge, 2_points, xyz]
'sdf': (2097152, 4) # [num_points, xyz+sdf]
'edge_ncs': (max_face, max_edge, num_edge_points, 3), # [max_face, max_edge, sample_points, xyz]
'edge_pos': (max_face, max_edge, 6),
'edge_mask': (max_face, max_edge),
'surf_ncs': (max_face, num_surf_points, 3),
'surf_pos': (max_face, 6),
'vertex_pos': (max_face, max_edge, 2, 3),
'sdf': (2097152, 4)
}
logger.info("="*50)

80
brep2sdf/data/utils.py

@ -9,6 +9,7 @@ from chamferdist import ChamferDistance
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from typing import List, Optional, Tuple, Union
from brep2sdf.utils.logger import logger
from brep2sdf.config.default_config import get_default_config
from OCC.Core.gp import gp_Pnt, gp_Pnt
@ -1000,25 +1001,25 @@ def process_brep_data(data: dict, max_face: int, max_edge: int, bbox_scaled: flo
Args:
data (dict): 包含B-rep数据的字典结构如下:
{
'surf_ncs': np.ndarray, # 面归一化点云 [num_faces, 100, 3]
'edge_ncs': np.ndarray, # 边归一化点云 [num_edges, 10, 3]
'surf_ncs': np.ndarray, # 面归一化点云 [num_faces, num_surf_sample_points, 3]
'edge_ncs': np.ndarray, # 边归一化点云 [num_edges, num_edge_sample_points, 3]
'corner_wcs': np.ndarray, # 顶点坐标 [num_edges, 2, 3] - 每条边的两个端点
'faceEdge_adj': np.ndarray, # 面-边邻接矩阵 [num_faces, num_edges]
'surf_pos': np.ndarray, # 面位置(包围盒) [num_faces, 6]
'edge_pos': np.ndarray, # 边位置(包围盒) [num_edges, 6]
}
max_face (int): 最大面数用于填充
max_edge (int): 最大边数用于填充
bbox_scaled (float): 边界框缩放因子
aug (bool): 是否使用数据增强
max_face (int): 最大面数用于填充, config.data.max_face
max_edge (int): 最大边数用于填充, config.data.max_edge
bbox_scaled (float): 边界框缩放因子, config.data.bbox_scaled
aug (bool): 是否使用数据增强, config.data.aug
data_class (Optional[int]): 数据类别标签
Returns:
Tuple[torch.Tensor, ...]: 包含以下张量的元组:
- edge_ncs: 边归一化特征 [max_face, max_edge, 10, 3]
- edge_ncs: 边归一化特征 [max_face, max_edge, num_edge_sample_points, 3]
- edge_pos: 边位置 [max_face, max_edge, 6]
- edge_mask: 边掩码 [max_face, max_edge]
- surf_ncs: 面归一化特征 [max_face, 100, 3]
- surf_ncs: 面归一化特征 [max_face, num_surf_sample_points, 3]
- surf_pos: 面位置 [max_face, 6]
- vertex_pos: 顶点位置 [max_face, max_edge, 2, 3] - 每个面的每条边的两个端点
- data_class: (可选) 类别标签 [1]
@ -1038,11 +1039,16 @@ def process_brep_data(data: dict, max_face: int, max_edge: int, bbox_scaled: flo
- 填充到最大面数
5. 转换为张量
"""
# 获取配置
config = get_default_config()
num_surf_points = config.model.num_surf_points # 16
num_edge_points = config.model.num_edge_points # 4
# 解包数据
#_, _, surf_ncs, edge_ncs, corner_wcs, _, _, faceEdge_adj, surf_pos, edge_pos, _ = data.values()
# 直接获取需要的数据
surf_ncs = data['surf_ncs'] # (num_faces,) -> 每个元素形状 (N, 3)
edge_ncs = data['edge_ncs'] # (num_edges, 100, 3)
edge_ncs = data['edge_ncs'] # (num_edges, num_edge_points, 3)
corner_wcs = data['corner_wcs'] # (num_edges, 2, 3)
faceEdge_adj = data['faceEdge_adj'] # (num_faces, num_edges)
edgeCorner_adj = data['edgeCorner_adj'] # (num_edges, 2) 每条边连接2个顶点
@ -1079,13 +1085,13 @@ def process_brep_data(data: dict, max_face: int, max_edge: int, bbox_scaled: flo
# 特征复制
edge_pos_duplicated = [] # [num_edges_per_face, 6]
vertex_pos_duplicated = [] # [num_edges_per_face, 2, 3]
edge_ncs_duplicated = [] # [num_edges_per_face, 10, 3]
edge_ncs_duplicated = [] # [num_edges_per_face, num_edge_points, 3]
for adj in faceEdge_adj: # [num_faces, num_edges]
edge_indices = np.where(adj)[0] # 获取当前面的边索引
# 复制边的特征
edge_ncs_duplicated.append(edge_ncs[edge_indices]) # [num_edges_per_face, 10, 3]
edge_ncs_duplicated.append(edge_ncs[edge_indices]) # [num_edges_per_face, num_edge_points, 3]
edge_pos_duplicated.append(edge_pos[edge_indices]) # [num_edges_per_face, 6]
# 直接获取对应边的顶点对
@ -1094,7 +1100,7 @@ def process_brep_data(data: dict, max_face: int, max_edge: int, bbox_scaled: flo
# 边特征打乱和填充
edge_pos_new = [] # 最终形状: [num_faces, max_edge, 6]
edge_ncs_new = [] # 最终形状: [num_faces, max_edge, 10, 3]
edge_ncs_new = [] # 最终形状: [num_faces, max_edge, num_edge_points, 3]
vert_pos_new = [] # 最终形状: [num_faces, max_edge, 6]
edge_mask = [] # 最终形状: [num_faces, max_edge]
@ -1105,12 +1111,12 @@ def process_brep_data(data: dict, max_face: int, max_edge: int, bbox_scaled: flo
# 同时打乱所有特征
pos = pos[random_indices] # [num_edges_per_face, 6]
ncs = ncs[random_indices] # [num_edges_per_face, 10, 3]
ncs = ncs[random_indices] # [num_edges_per_face, num_edge_points, 3]
vert = vert[random_indices] # [num_edges_per_face, 2, 3]
# 填充到最大边数
pos, mask = pad_zero(pos, max_edge, return_mask=True) # [max_edge, 6], [max_edge]
ncs = pad_zero(ncs, max_edge) # [max_edge, 10, 3]
ncs = pad_zero(ncs, max_edge) # [max_edge, num_edge_points, 3]
vert = pad_zero(vert, max_edge) # [max_edge, 2, 3]
edge_pos_new.append(pos)
@ -1119,60 +1125,60 @@ def process_brep_data(data: dict, max_face: int, max_edge: int, bbox_scaled: flo
vert_pos_new.append(vert)
edge_pos = np.stack(edge_pos_new) # [num_faces, max_edge, 6]
edge_ncs = np.stack(edge_ncs_new) # [num_faces, max_edge, 10, 3]
edge_ncs = np.stack(edge_ncs_new) # [num_faces, max_edge, num_edge_points, 3]
edge_mask = np.stack(edge_mask) # [num_faces, max_edge]
vertex_pos = np.stack(vert_pos_new) # [num_faces, max_edge, 2, 3]
# 处理edge_ncs (确保形状为 [num_faces, max_edge, 10, 3])
# 处理edge_ncs (确保形状为 [num_faces, max_edge, num_edge_points, 3])
edge_ncs_list = []
for face_idx in range(len(edge_ncs)):
face_edges = []
for edge_idx in range(len(edge_ncs[face_idx])):
edge_points = edge_ncs[face_idx][edge_idx]
# 确保每条边有10个点
if len(edge_points) > 10:
indices = np.random.choice(len(edge_points), 10, replace=False)
# 确保每条边有num_edge_points个点
if len(edge_points) > num_edge_points:
indices = np.random.choice(len(edge_points), num_edge_points, replace=False)
edge_points = edge_points[indices]
elif len(edge_points) < 10:
indices = np.random.choice(len(edge_points), 10-len(edge_points))
elif len(edge_points) < num_edge_points:
indices = np.random.choice(len(edge_points), num_edge_points-len(edge_points))
edge_points = np.concatenate([edge_points, edge_points[indices]], axis=0)
face_edges.append(edge_points)
# 填充到最大边数
while len(face_edges) < max_edge:
face_edges.append(np.zeros((10, 3), dtype=np.float32))
face_edges.append(np.zeros((num_edge_points, 3), dtype=np.float32))
edge_ncs_list.append(np.stack(face_edges))
edge_ncs = np.stack(edge_ncs_list).astype(np.float32) # [num_faces, max_edge, 10, 3]
edge_ncs = np.stack(edge_ncs_list).astype(np.float32) # [num_faces, max_edge, num_edge_points, 3]
# 处理surf_ncs (确保形状为 [num_faces, 100, 3])
# 处理surf_ncs (确保形状为 [num_faces, num_surf_points, 3])
surf_ncs_list = []
for points in surf_ncs:
if len(points) > 100:
indices = np.random.choice(len(points), 100, replace=False)
if len(points) > num_surf_points:
indices = np.random.choice(len(points), num_surf_points, replace=False)
points = points[indices]
elif len(points) < 100:
indices = np.random.choice(len(points), 100-len(points))
elif len(points) < num_surf_points:
indices = np.random.choice(len(points), num_surf_points-len(points))
points = np.concatenate([points, points[indices]], axis=0)
surf_ncs_list.append(points)
surf_ncs = np.stack(surf_ncs_list).astype(np.float32) # [num_faces, 100, 3]
surf_ncs = np.stack(surf_ncs_list).astype(np.float32) # [num_faces, num_surf_points, 3]
# 面特征打乱
random_indices = np.random.permutation(surf_pos.shape[0])
surf_pos = surf_pos[random_indices] # [num_faces, 6]
edge_pos = edge_pos[random_indices] # [num_faces, max_edge, 6]
surf_ncs = surf_ncs[random_indices] # [num_faces, 100, 3]
edge_ncs = edge_ncs[random_indices] # [num_faces, max_edge, 10, 3]
surf_ncs = surf_ncs[random_indices] # [num_faces, num_surf_points, 3]
edge_ncs = edge_ncs[random_indices] # [num_faces, max_edge, num_edge_points, 3]
edge_mask = edge_mask[random_indices] # [num_faces, max_edge]
vertex_pos = vertex_pos[random_indices] # [num_faces, max_edge, 2, 3]
# 填充到最大面数
surf_pos, face_mask = pad_zero(surf_pos, max_face, return_mask=True) # [max_face, 6]
surf_ncs = pad_zero(surf_ncs, max_face) # [max_face, 100, 3]
surf_ncs = pad_zero(surf_ncs, max_face) # [max_face, num_surf_points, 3]
edge_pos = pad_zero(edge_pos, max_face) # [max_face, max_edge, 6]
edge_ncs = pad_zero(edge_ncs, max_face) # [max_face, max_edge, 10, 3]
edge_ncs = pad_zero(edge_ncs, max_face) # [max_face, max_edge, num_edge_points, 3]
vertex_pos = pad_zero(vertex_pos, max_face) # [max_face, max_edge, 2, 3]
# 扩展边掩码 - 使用face_mask来创建新的edge_mask
@ -1186,20 +1192,20 @@ def process_brep_data(data: dict, max_face: int, max_edge: int, bbox_scaled: flo
# 转换为张量并返回
if data_class is not None:
return (
torch.FloatTensor(edge_ncs), # [max_face, max_edge, 10, 3]
torch.FloatTensor(edge_ncs), # [max_face, max_edge, num_edge_points, 3]
torch.FloatTensor(edge_pos), # [max_face, max_edge, 6]
torch.BoolTensor(edge_mask), # [max_face, max_edge]
torch.FloatTensor(surf_ncs), # [max_face, 100, 3]
torch.FloatTensor(surf_ncs), # [max_face, num_surf_points, 3]
torch.FloatTensor(surf_pos), # [max_face, 6]
torch.FloatTensor(vertex_pos), # [max_face, max_edge, 2, 3]
torch.LongTensor([data_class+1]) # [1]
)
else:
return (
torch.FloatTensor(edge_ncs), # [max_face, max_edge, 10, 3]
torch.FloatTensor(edge_ncs), # [max_face, max_edge, num_edge_points, 3]
torch.FloatTensor(edge_pos), # [max_face, max_edge, 6]
torch.BoolTensor(edge_mask), # [max_face, max_edge]
torch.FloatTensor(surf_ncs), # [max_face, 100, 3]
torch.FloatTensor(surf_ncs), # [max_face, num_surf_points, 3]
torch.FloatTensor(surf_pos), # [max_face, 6]
torch.FloatTensor(vertex_pos) # [max_face, max_edge, 2, 3]
)

29
brep2sdf/networks/encoder.py

@ -4,6 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from brep2sdf.config.default_config import get_default_config
class ResConvBlock(nn.Module):
"""残差卷积块"""
@ -120,9 +121,15 @@ class BRepFeatureEmbedder(nn.Module):
"""B-rep特征嵌入器"""
def __init__(self, use_cf: bool = True):
super().__init__()
# 获取配置
self.config = get_default_config()
self.embed_dim = 768
self.use_cf = use_cf
# 使用配置中的采样点数
self.num_surf_points = self.config.model.num_surf_points # 16
self.num_edge_points = self.config.model.num_edge_points # 4
layer = nn.TransformerEncoderLayer(
d_model=self.embed_dim,
nhead=12,
@ -134,18 +141,19 @@ class BRepFeatureEmbedder(nn.Module):
layer,
num_layers=12,
norm=nn.LayerNorm(self.embed_dim),
enable_nested_tensor=False # 添加这个参数
enable_nested_tensor=False
)
# 修改输入维度以匹配采样点数
self.surfz_embed = nn.Sequential(
nn.Linear(3*16, self.embed_dim),
nn.Linear(3 * self.num_surf_points, self.embed_dim), # 3 * 16
nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim),
)
self.edgez_embed = nn.Sequential(
nn.Linear(3*4, self.embed_dim),
nn.Linear(3 * self.num_edge_points, self.embed_dim), # 3 * 4
nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim),
@ -173,6 +181,15 @@ class BRepFeatureEmbedder(nn.Module):
)
def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, mask=None):
"""
Args:
surf_z: 表面特征 [B, N, num_surf_points*3]
edge_z: 边特征 [B, M, num_edge_points*3]
surf_p: 表面点 [B, N, 6]
edge_p: 边点 [B, M, 6]
vert_p: 顶点点 [B, K, 6]
mask: 注意力掩码
"""
# 特征嵌入
surf_embeds = self.surfz_embed(surf_z)
edge_embeds = self.edgez_embed(edge_z)
@ -243,6 +260,8 @@ class BRepToSDF(nn.Module):
latent_dim: int = 256
):
super().__init__()
# 获取配置
self.config = get_default_config()
self.embed_dim = embed_dim
# 1. 查询点编码器
@ -271,8 +290,8 @@ class BRepToSDF(nn.Module):
def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, query_points, mask=None):
"""
Args:
surf_z: 表面特征 [B, N, 48]
edge_z: 边特征 [B, M, 12]
surf_z: 表面特征 [B, N, num_surf_points*3]
edge_z: 边特征 [B, M, num_edge_points*3]
surf_p: 表面点 [B, N, 6]
edge_p: 边点 [B, M, 6]
vert_p: 顶点点 [B, K, 6]

Loading…
Cancel
Save