diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index 81e5337..cddc3d1 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -8,58 +8,7 @@ from .octree import OctreeNode from brep2sdf.config.default_config import get_default_config from brep2sdf.utils.logger import logger import numpy as np -''' -class Encoder: - def __init__(self, surf_bbox: torch.Tensor, origin_bbox: torch.Tensor, max_depth: int, feature_dim:int = 64): - """ - 初始化表面八叉树管理器 - - 参数: - surf_bbox: 表面包围盒的世界坐标,形状为 (num_edges, 6), dtype=float32 - origin_bbox: 原点包围盒的世界坐标,形状为 (6), dtype=float32 - max_depth: 八叉树的最大深度 - """ - self.max_depth = max_depth - # 初始化根节点,包含所有面索引。NOTE: 实际上需要判断 aabb 包含问题,但是这里默认 origin_bbox 由这些 face 计算,所以不再重复判断 - num_faces = surf_bbox.shape[0] - - #print(f"surf_bbox: {surf_bbox.shape}") - #print(f"origin_bbox: {origin_bbox.shape}") - self.root = OctreeNode( - bbox=origin_bbox, - face_indices=np.arange(num_faces), # 初始包含所有面 - max_depth=self.max_depth, - feature_dim=feature_dim, - surf_bbox=surf_bbox - ) - #print(surf_bbox) - logger.info("starting octree conduction") - self.root.conduct_tree() - logger.info("complete octree conduction") - #self.root.print_tree(0) - - def get_feature_vector(self, query_point): - return self.root.get_feature_vector(query_point) - - def forward(self, query_points): - """ - 前向传播,处理批量查询点 - - 参数: - query_points: 查询点的位置坐标,形状为(batch_size, 3) - 返回: - feature_vectors: 查询点的特征向量,形状为(batch_size, feature_dim) - """ - batch_size = query_points.shape[0] - feature_vectors = [] - for i in range(batch_size): - feature_vector = self.get_feature_vector(query_points[i]) - feature_vectors.append(feature_vector) - return torch.stack(feature_vectors, dim=0) - - -''' class Encoder(nn.Module): def __init__(self, octree: OctreeNode, feature_dim: int = 32): """ @@ -74,49 +23,50 @@ class Encoder(nn.Module): self.feature_dim = feature_dim # 初始化叶子节点参数 + self.register_buffer('num_parameters', torch.tensor(0, dtype=torch.long)) self._leaf_parameters = nn.ParameterList() # 使用 ParameterList 存储参数 - self.param_key_to_index: Dict[str, int] = {} # 字典映射:param_key -> index - # 为所有叶子节点注册可学习参数 + # 不再需要 param_key_to_index 字典 self._init_parameters() def _init_parameters(self): """为所有叶子节点初始化特征参数""" - stack = [(self.octree, "root")] # (当前节点, 当前路径) + stack = [(self.octree, 0)] # (当前节点, 参数索引) param_index = 0 # 参数索引计数器 while stack: - node, path = stack.pop() + node, param_index = stack.pop() if node._is_leaf: # 如果是叶子节点,初始化参数 - param_name = f"leaf_{path}" self._leaf_parameters.append(nn.Parameter(torch.randn(8, self.feature_dim))) # 8个顶点的特征 - self.param_key_to_index[param_name] = param_index # 记录索引 - node.set_param_key(param_name) # 为节点存储参数键 + node.set_param_key(param_index) # 直接使用索引作为键 param_index += 1 else: # 如果不是叶子节点,继续遍历子节点 for i, child in enumerate(node.child_nodes): if child is not None: - stack.append((child, f"{path}_{i}")) + stack.append((child, param_index)) + + # 保存参数总数 + self.num_parameters.fill_(param_index) - def get_leaf_parameter(self, param_key: str) -> torch.Tensor: + @torch.jit.export + def get_leaf_parameter(self, param_index: torch.Tensor) -> torch.Tensor: """ 获取叶子节点的特征参数 - :param param_key: 叶子节点的参数键 + :param param_index: 叶子节点的参数索引 :return: 对应的参数 """ - if param_key not in self.param_key_to_index: - raise KeyError(f"Invalid param_key: {param_key}") - - target_index = self.param_key_to_index[param_key] - - # 使用枚举代替动态索引 - for index, param in enumerate(self._leaf_parameters): - if index == target_index: + index = param_index.item() # 转换为Python整数 + if index < 0 or index >= self.num_parameters.item(): + raise IndexError(f"Parameter index {index} out of range") + + # 使用列表推导代替直接索引 + for i, param in enumerate(self._leaf_parameters): + if i == index: return param - - raise IndexError(f"Index {target_index} not found in ParameterList") + + raise IndexError(f"Parameter index {index} not found") def forward(self, query_points: torch.Tensor) -> torch.Tensor: """ @@ -132,11 +82,10 @@ class Encoder(nn.Module): for i in range(batch_size): # 1. 在八叉树中查找包含该点的叶子节点 - bbox, param_key, _ = self.octree.find_leaf(query_points[i]) - #logger.debug(leaf_node.param_key) + bbox, param_index, _ = self.octree.find_leaf(query_points[i]) # 2. 获取该节点的特征参数 - node_features = self.get_leaf_parameter(param_key) + node_features = self.get_leaf_parameter(param_index) # 3. 使用三线性插值计算特征 # (这里需要实现你的插值逻辑) diff --git a/brep2sdf/networks/octree.py b/brep2sdf/networks/octree.py index 2d3929f..28ac6d7 100644 --- a/brep2sdf/networks/octree.py +++ b/brep2sdf/networks/octree.py @@ -1,4 +1,4 @@ -from typing import Tuple, List, cast, Dict, Any, Tuple +from typing import Tuple, List, cast, Dict, Any import torch import torch.nn as nn @@ -39,12 +39,18 @@ class OctreeNode(nn.Module): self.register_buffer('surf_bbox', surf_bbox) # 面片边界框 self.max_depth = max_depth - self.param_key = "" + # 将param_key改为张量 + self.register_buffer('param_key', torch.tensor(-1, dtype=torch.long)) self._is_leaf = True @torch.jit.export - def set_param_key(self, k: str) -> None: - self.param_key = k + def set_param_key(self, k: int) -> None: + """设置参数键值 + + 参数: + k: 参数索引值 + """ + self.param_key.fill_(k) @torch.jit.export def build_static_tree(self) -> None: @@ -122,7 +128,7 @@ class OctreeNode(nn.Module): return child_bboxes @torch.jit.export - def find_leaf(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, str, bool]: + def find_leaf(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, bool]: """ 查找包含给定点的叶子节点,并返回其信息 :param query_points: 待查找的点,形状为 (3,)