From 84fc447adb6a11071a65e65fc14727c760848e42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=90=9B=E6=B6=B5?= Date: Tue, 19 Nov 2024 01:17:54 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E7=94=A8config=E6=9B=BF=E4=BB=A3?= =?UTF-8?q?=E7=A1=AC=E7=BC=96=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/config/default_config.py | 4 +- brep2sdf/data/data.py | 41 +++++++++------- brep2sdf/data/utils.py | 80 +++++++++++++++++-------------- brep2sdf/networks/encoder.py | 29 +++++++++-- 4 files changed, 94 insertions(+), 60 deletions(-) diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py index fa473b2..ac31944 100644 --- a/brep2sdf/config/default_config.py +++ b/brep2sdf/config/default_config.py @@ -16,8 +16,8 @@ class ModelConfig: @dataclass class DataConfig: """数据相关配置""" - max_face: int = 70 - max_edge: int = 70 + max_face: int = 64 + max_edge: int = 64 bbox_scaled: float = 1.0 # 数据路径 diff --git a/brep2sdf/data/data.py b/brep2sdf/data/data.py index 05dd07f..8f7798e 100644 --- a/brep2sdf/data/data.py +++ b/brep2sdf/data/data.py @@ -5,6 +5,7 @@ import numpy as np import pickle from brep2sdf.utils.logger import logger from brep2sdf.data.utils import process_brep_data +from brep2sdf.config.default_config import get_default_config @@ -21,14 +22,17 @@ class BRepSDFDataset(Dataset): split: 数据集分割('train', 'val', 'test') """ super().__init__() + # 使用配置文件 + self.config = get_default_config() + self.brep_dir = os.path.join(brep_dir, split) self.sdf_dir = os.path.join(sdf_dir, split) self.split = split - # 设置固定参数 - self.max_face = 70 - self.max_edge = 70 - self.bbox_scaled = 1.0 + # 使用配置文件中的参数替换固定参数 + self.max_face = self.config.data.max_face + self.max_edge = self.config.data.max_edge + self.bbox_scaled = self.config.data.bbox_scaled # 检查目录是否存在 if not os.path.exists(self.brep_dir): @@ -334,21 +338,26 @@ class BRepSDFDataset(Dataset): def test_dataset(): """测试数据集功能""" try: - # 1. 设置测试路径和预期的数据维度 - brep_dir = '/home/wch/brep2sdf/test_data/pkl' - sdf_dir = '/home/wch/brep2sdf/test_data/sdf' - valid_data_dir = "/home/wch/brep2sdf/test_data/result/pkl" + # 获取配置 + config = get_default_config() + brep_dir = config.data.brep_dir + sdf_dir = config.data.sdf_dir + valid_data_dir = config.data.valid_data_dir split = 'train' + max_face = config.data.max_face + max_edge = config.data.max_edge + num_edge_points = config.model.num_edge_points + num_surf_points = config.model.num_surf_points - # 定义预期的数据维度 + # 定义预期的数据维度,使用配置中的参数 expected_shapes = { - 'edge_ncs': (70, 70, 10, 3), # [max_face, max_edge, sample_points, xyz] - 'edge_pos': (70, 70, 6), # [max_face, max_edge, bbox] - 'edge_mask': (70, 70), # [max_face, max_edge] - 'surf_ncs': (70, 100, 3), # [max_face, sample_points, xyz] - 'surf_pos': (70, 6), # [max_face, bbox] - 'vertex_pos': (70, 70, 2, 3), # [max_face, max_edge, 2_points, xyz] - 'sdf': (2097152, 4) # [num_points, xyz+sdf] + 'edge_ncs': (max_face, max_edge, num_edge_points, 3), # [max_face, max_edge, sample_points, xyz] + 'edge_pos': (max_face, max_edge, 6), + 'edge_mask': (max_face, max_edge), + 'surf_ncs': (max_face, num_surf_points, 3), + 'surf_pos': (max_face, 6), + 'vertex_pos': (max_face, max_edge, 2, 3), + 'sdf': (2097152, 4) } logger.info("="*50) diff --git a/brep2sdf/data/utils.py b/brep2sdf/data/utils.py index 354aa83..00f3271 100644 --- a/brep2sdf/data/utils.py +++ b/brep2sdf/data/utils.py @@ -9,6 +9,7 @@ from chamferdist import ChamferDistance from mpl_toolkits.mplot3d.art3d import Poly3DCollection from typing import List, Optional, Tuple, Union from brep2sdf.utils.logger import logger +from brep2sdf.config.default_config import get_default_config from OCC.Core.gp import gp_Pnt, gp_Pnt @@ -1000,25 +1001,25 @@ def process_brep_data(data: dict, max_face: int, max_edge: int, bbox_scaled: flo Args: data (dict): 包含B-rep数据的字典,结构如下: { - 'surf_ncs': np.ndarray, # 面归一化点云 [num_faces, 100, 3] - 'edge_ncs': np.ndarray, # 边归一化点云 [num_edges, 10, 3] + 'surf_ncs': np.ndarray, # 面归一化点云 [num_faces, num_surf_sample_points, 3] + 'edge_ncs': np.ndarray, # 边归一化点云 [num_edges, num_edge_sample_points, 3] 'corner_wcs': np.ndarray, # 顶点坐标 [num_edges, 2, 3] - 每条边的两个端点 'faceEdge_adj': np.ndarray, # 面-边邻接矩阵 [num_faces, num_edges] 'surf_pos': np.ndarray, # 面位置(包围盒) [num_faces, 6] 'edge_pos': np.ndarray, # 边位置(包围盒) [num_edges, 6] } - max_face (int): 最大面数,用于填充 - max_edge (int): 最大边数,用于填充 - bbox_scaled (float): 边界框缩放因子 - aug (bool): 是否使用数据增强 + max_face (int): 最大面数,用于填充, config.data.max_face + max_edge (int): 最大边数,用于填充, config.data.max_edge + bbox_scaled (float): 边界框缩放因子, config.data.bbox_scaled + aug (bool): 是否使用数据增强, config.data.aug data_class (Optional[int]): 数据类别标签 Returns: Tuple[torch.Tensor, ...]: 包含以下张量的元组: - - edge_ncs: 边归一化特征 [max_face, max_edge, 10, 3] + - edge_ncs: 边归一化特征 [max_face, max_edge, num_edge_sample_points, 3] - edge_pos: 边位置 [max_face, max_edge, 6] - edge_mask: 边掩码 [max_face, max_edge] - - surf_ncs: 面归一化特征 [max_face, 100, 3] + - surf_ncs: 面归一化特征 [max_face, num_surf_sample_points, 3] - surf_pos: 面位置 [max_face, 6] - vertex_pos: 顶点位置 [max_face, max_edge, 2, 3] - 每个面的每条边的两个端点 - data_class: (可选) 类别标签 [1] @@ -1038,11 +1039,16 @@ def process_brep_data(data: dict, max_face: int, max_edge: int, bbox_scaled: flo - 填充到最大面数 5. 转换为张量 """ + # 获取配置 + config = get_default_config() + num_surf_points = config.model.num_surf_points # 16 + num_edge_points = config.model.num_edge_points # 4 + # 解包数据 #_, _, surf_ncs, edge_ncs, corner_wcs, _, _, faceEdge_adj, surf_pos, edge_pos, _ = data.values() # 直接获取需要的数据 surf_ncs = data['surf_ncs'] # (num_faces,) -> 每个元素形状 (N, 3) - edge_ncs = data['edge_ncs'] # (num_edges, 100, 3) + edge_ncs = data['edge_ncs'] # (num_edges, num_edge_points, 3) corner_wcs = data['corner_wcs'] # (num_edges, 2, 3) faceEdge_adj = data['faceEdge_adj'] # (num_faces, num_edges) edgeCorner_adj = data['edgeCorner_adj'] # (num_edges, 2) 每条边连接2个顶点 @@ -1079,13 +1085,13 @@ def process_brep_data(data: dict, max_face: int, max_edge: int, bbox_scaled: flo # 特征复制 edge_pos_duplicated = [] # [num_edges_per_face, 6] vertex_pos_duplicated = [] # [num_edges_per_face, 2, 3] - edge_ncs_duplicated = [] # [num_edges_per_face, 10, 3] + edge_ncs_duplicated = [] # [num_edges_per_face, num_edge_points, 3] for adj in faceEdge_adj: # [num_faces, num_edges] edge_indices = np.where(adj)[0] # 获取当前面的边索引 # 复制边的特征 - edge_ncs_duplicated.append(edge_ncs[edge_indices]) # [num_edges_per_face, 10, 3] + edge_ncs_duplicated.append(edge_ncs[edge_indices]) # [num_edges_per_face, num_edge_points, 3] edge_pos_duplicated.append(edge_pos[edge_indices]) # [num_edges_per_face, 6] # 直接获取对应边的顶点对 @@ -1094,7 +1100,7 @@ def process_brep_data(data: dict, max_face: int, max_edge: int, bbox_scaled: flo # 边特征打乱和填充 edge_pos_new = [] # 最终形状: [num_faces, max_edge, 6] - edge_ncs_new = [] # 最终形状: [num_faces, max_edge, 10, 3] + edge_ncs_new = [] # 最终形状: [num_faces, max_edge, num_edge_points, 3] vert_pos_new = [] # 最终形状: [num_faces, max_edge, 6] edge_mask = [] # 最终形状: [num_faces, max_edge] @@ -1105,12 +1111,12 @@ def process_brep_data(data: dict, max_face: int, max_edge: int, bbox_scaled: flo # 同时打乱所有特征 pos = pos[random_indices] # [num_edges_per_face, 6] - ncs = ncs[random_indices] # [num_edges_per_face, 10, 3] + ncs = ncs[random_indices] # [num_edges_per_face, num_edge_points, 3] vert = vert[random_indices] # [num_edges_per_face, 2, 3] # 填充到最大边数 pos, mask = pad_zero(pos, max_edge, return_mask=True) # [max_edge, 6], [max_edge] - ncs = pad_zero(ncs, max_edge) # [max_edge, 10, 3] + ncs = pad_zero(ncs, max_edge) # [max_edge, num_edge_points, 3] vert = pad_zero(vert, max_edge) # [max_edge, 2, 3] edge_pos_new.append(pos) @@ -1119,60 +1125,60 @@ def process_brep_data(data: dict, max_face: int, max_edge: int, bbox_scaled: flo vert_pos_new.append(vert) edge_pos = np.stack(edge_pos_new) # [num_faces, max_edge, 6] - edge_ncs = np.stack(edge_ncs_new) # [num_faces, max_edge, 10, 3] + edge_ncs = np.stack(edge_ncs_new) # [num_faces, max_edge, num_edge_points, 3] edge_mask = np.stack(edge_mask) # [num_faces, max_edge] vertex_pos = np.stack(vert_pos_new) # [num_faces, max_edge, 2, 3] - # 处理edge_ncs (确保形状为 [num_faces, max_edge, 10, 3]) + # 处理edge_ncs (确保形状为 [num_faces, max_edge, num_edge_points, 3]) edge_ncs_list = [] for face_idx in range(len(edge_ncs)): face_edges = [] for edge_idx in range(len(edge_ncs[face_idx])): edge_points = edge_ncs[face_idx][edge_idx] - # 确保每条边有10个点 - if len(edge_points) > 10: - indices = np.random.choice(len(edge_points), 10, replace=False) + # 确保每条边有num_edge_points个点 + if len(edge_points) > num_edge_points: + indices = np.random.choice(len(edge_points), num_edge_points, replace=False) edge_points = edge_points[indices] - elif len(edge_points) < 10: - indices = np.random.choice(len(edge_points), 10-len(edge_points)) + elif len(edge_points) < num_edge_points: + indices = np.random.choice(len(edge_points), num_edge_points-len(edge_points)) edge_points = np.concatenate([edge_points, edge_points[indices]], axis=0) face_edges.append(edge_points) # 填充到最大边数 while len(face_edges) < max_edge: - face_edges.append(np.zeros((10, 3), dtype=np.float32)) + face_edges.append(np.zeros((num_edge_points, 3), dtype=np.float32)) edge_ncs_list.append(np.stack(face_edges)) - edge_ncs = np.stack(edge_ncs_list).astype(np.float32) # [num_faces, max_edge, 10, 3] + edge_ncs = np.stack(edge_ncs_list).astype(np.float32) # [num_faces, max_edge, num_edge_points, 3] - # 处理surf_ncs (确保形状为 [num_faces, 100, 3]) + # 处理surf_ncs (确保形状为 [num_faces, num_surf_points, 3]) surf_ncs_list = [] for points in surf_ncs: - if len(points) > 100: - indices = np.random.choice(len(points), 100, replace=False) + if len(points) > num_surf_points: + indices = np.random.choice(len(points), num_surf_points, replace=False) points = points[indices] - elif len(points) < 100: - indices = np.random.choice(len(points), 100-len(points)) + elif len(points) < num_surf_points: + indices = np.random.choice(len(points), num_surf_points-len(points)) points = np.concatenate([points, points[indices]], axis=0) surf_ncs_list.append(points) - surf_ncs = np.stack(surf_ncs_list).astype(np.float32) # [num_faces, 100, 3] + surf_ncs = np.stack(surf_ncs_list).astype(np.float32) # [num_faces, num_surf_points, 3] # 面特征打乱 random_indices = np.random.permutation(surf_pos.shape[0]) surf_pos = surf_pos[random_indices] # [num_faces, 6] edge_pos = edge_pos[random_indices] # [num_faces, max_edge, 6] - surf_ncs = surf_ncs[random_indices] # [num_faces, 100, 3] - edge_ncs = edge_ncs[random_indices] # [num_faces, max_edge, 10, 3] + surf_ncs = surf_ncs[random_indices] # [num_faces, num_surf_points, 3] + edge_ncs = edge_ncs[random_indices] # [num_faces, max_edge, num_edge_points, 3] edge_mask = edge_mask[random_indices] # [num_faces, max_edge] vertex_pos = vertex_pos[random_indices] # [num_faces, max_edge, 2, 3] # 填充到最大面数 surf_pos, face_mask = pad_zero(surf_pos, max_face, return_mask=True) # [max_face, 6] - surf_ncs = pad_zero(surf_ncs, max_face) # [max_face, 100, 3] + surf_ncs = pad_zero(surf_ncs, max_face) # [max_face, num_surf_points, 3] edge_pos = pad_zero(edge_pos, max_face) # [max_face, max_edge, 6] - edge_ncs = pad_zero(edge_ncs, max_face) # [max_face, max_edge, 10, 3] + edge_ncs = pad_zero(edge_ncs, max_face) # [max_face, max_edge, num_edge_points, 3] vertex_pos = pad_zero(vertex_pos, max_face) # [max_face, max_edge, 2, 3] # 扩展边掩码 - 使用face_mask来创建新的edge_mask @@ -1186,20 +1192,20 @@ def process_brep_data(data: dict, max_face: int, max_edge: int, bbox_scaled: flo # 转换为张量并返回 if data_class is not None: return ( - torch.FloatTensor(edge_ncs), # [max_face, max_edge, 10, 3] + torch.FloatTensor(edge_ncs), # [max_face, max_edge, num_edge_points, 3] torch.FloatTensor(edge_pos), # [max_face, max_edge, 6] torch.BoolTensor(edge_mask), # [max_face, max_edge] - torch.FloatTensor(surf_ncs), # [max_face, 100, 3] + torch.FloatTensor(surf_ncs), # [max_face, num_surf_points, 3] torch.FloatTensor(surf_pos), # [max_face, 6] torch.FloatTensor(vertex_pos), # [max_face, max_edge, 2, 3] torch.LongTensor([data_class+1]) # [1] ) else: return ( - torch.FloatTensor(edge_ncs), # [max_face, max_edge, 10, 3] + torch.FloatTensor(edge_ncs), # [max_face, max_edge, num_edge_points, 3] torch.FloatTensor(edge_pos), # [max_face, max_edge, 6] torch.BoolTensor(edge_mask), # [max_face, max_edge] - torch.FloatTensor(surf_ncs), # [max_face, 100, 3] + torch.FloatTensor(surf_ncs), # [max_face, num_surf_points, 3] torch.FloatTensor(surf_pos), # [max_face, 6] torch.FloatTensor(vertex_pos) # [max_face, max_edge, 2, 3] ) \ No newline at end of file diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index 5a956ce..d4e8726 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -4,6 +4,7 @@ import torch.nn as nn import torch.nn.functional as F from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union +from brep2sdf.config.default_config import get_default_config class ResConvBlock(nn.Module): """残差卷积块""" @@ -120,9 +121,15 @@ class BRepFeatureEmbedder(nn.Module): """B-rep特征嵌入器""" def __init__(self, use_cf: bool = True): super().__init__() + # 获取配置 + self.config = get_default_config() self.embed_dim = 768 self.use_cf = use_cf + # 使用配置中的采样点数 + self.num_surf_points = self.config.model.num_surf_points # 16 + self.num_edge_points = self.config.model.num_edge_points # 4 + layer = nn.TransformerEncoderLayer( d_model=self.embed_dim, nhead=12, @@ -134,18 +141,19 @@ class BRepFeatureEmbedder(nn.Module): layer, num_layers=12, norm=nn.LayerNorm(self.embed_dim), - enable_nested_tensor=False # 添加这个参数 + enable_nested_tensor=False ) + # 修改输入维度以匹配采样点数 self.surfz_embed = nn.Sequential( - nn.Linear(3*16, self.embed_dim), + nn.Linear(3 * self.num_surf_points, self.embed_dim), # 3 * 16 nn.LayerNorm(self.embed_dim), nn.SiLU(), nn.Linear(self.embed_dim, self.embed_dim), ) self.edgez_embed = nn.Sequential( - nn.Linear(3*4, self.embed_dim), + nn.Linear(3 * self.num_edge_points, self.embed_dim), # 3 * 4 nn.LayerNorm(self.embed_dim), nn.SiLU(), nn.Linear(self.embed_dim, self.embed_dim), @@ -173,6 +181,15 @@ class BRepFeatureEmbedder(nn.Module): ) def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, mask=None): + """ + Args: + surf_z: 表面特征 [B, N, num_surf_points*3] + edge_z: 边特征 [B, M, num_edge_points*3] + surf_p: 表面点 [B, N, 6] + edge_p: 边点 [B, M, 6] + vert_p: 顶点点 [B, K, 6] + mask: 注意力掩码 + """ # 特征嵌入 surf_embeds = self.surfz_embed(surf_z) edge_embeds = self.edgez_embed(edge_z) @@ -243,6 +260,8 @@ class BRepToSDF(nn.Module): latent_dim: int = 256 ): super().__init__() + # 获取配置 + self.config = get_default_config() self.embed_dim = embed_dim # 1. 查询点编码器 @@ -271,8 +290,8 @@ class BRepToSDF(nn.Module): def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, query_points, mask=None): """ Args: - surf_z: 表面特征 [B, N, 48] - edge_z: 边特征 [B, M, 12] + surf_z: 表面特征 [B, N, num_surf_points*3] + edge_z: 边特征 [B, M, num_edge_points*3] surf_p: 表面点 [B, N, 6] edge_p: 边点 [B, M, 6] vert_p: 顶点点 [B, K, 6]