Browse Source

优化八叉树encoder,使用long直接索引 特征向量

final
mckay 2 months ago
parent
commit
e20686242f
  1. 93
      brep2sdf/networks/encoder.py
  2. 16
      brep2sdf/networks/octree.py

93
brep2sdf/networks/encoder.py

@ -8,58 +8,7 @@ from .octree import OctreeNode
from brep2sdf.config.default_config import get_default_config
from brep2sdf.utils.logger import logger
import numpy as np
'''
class Encoder:
def __init__(self, surf_bbox: torch.Tensor, origin_bbox: torch.Tensor, max_depth: int, feature_dim:int = 64):
"""
初始化表面八叉树管理器
参数:
surf_bbox: 表面包围盒的世界坐标形状为 (num_edges, 6), dtype=float32
origin_bbox: 原点包围盒的世界坐标形状为 (6), dtype=float32
max_depth: 八叉树的最大深度
"""
self.max_depth = max_depth
# 初始化根节点,包含所有面索引。NOTE: 实际上需要判断 aabb 包含问题,但是这里默认 origin_bbox 由这些 face 计算,所以不再重复判断
num_faces = surf_bbox.shape[0]
#print(f"surf_bbox: {surf_bbox.shape}")
#print(f"origin_bbox: {origin_bbox.shape}")
self.root = OctreeNode(
bbox=origin_bbox,
face_indices=np.arange(num_faces), # 初始包含所有面
max_depth=self.max_depth,
feature_dim=feature_dim,
surf_bbox=surf_bbox
)
#print(surf_bbox)
logger.info("starting octree conduction")
self.root.conduct_tree()
logger.info("complete octree conduction")
#self.root.print_tree(0)
def get_feature_vector(self, query_point):
return self.root.get_feature_vector(query_point)
def forward(self, query_points):
"""
前向传播处理批量查询点
参数:
query_points: 查询点的位置坐标形状为(batch_size, 3)
返回:
feature_vectors: 查询点的特征向量形状为(batch_size, feature_dim)
"""
batch_size = query_points.shape[0]
feature_vectors = []
for i in range(batch_size):
feature_vector = self.get_feature_vector(query_points[i])
feature_vectors.append(feature_vector)
return torch.stack(feature_vectors, dim=0)
'''
class Encoder(nn.Module):
def __init__(self, octree: OctreeNode, feature_dim: int = 32):
"""
@ -74,49 +23,50 @@ class Encoder(nn.Module):
self.feature_dim = feature_dim
# 初始化叶子节点参数
self.register_buffer('num_parameters', torch.tensor(0, dtype=torch.long))
self._leaf_parameters = nn.ParameterList() # 使用 ParameterList 存储参数
self.param_key_to_index: Dict[str, int] = {} # 字典映射:param_key -> index
# 为所有叶子节点注册可学习参数
# 不再需要 param_key_to_index 字典
self._init_parameters()
def _init_parameters(self):
"""为所有叶子节点初始化特征参数"""
stack = [(self.octree, "root")] # (当前节点, 当前路径)
stack = [(self.octree, 0)] # (当前节点, 参数索引)
param_index = 0 # 参数索引计数器
while stack:
node, path = stack.pop()
node, param_index = stack.pop()
if node._is_leaf:
# 如果是叶子节点,初始化参数
param_name = f"leaf_{path}"
self._leaf_parameters.append(nn.Parameter(torch.randn(8, self.feature_dim))) # 8个顶点的特征
self.param_key_to_index[param_name] = param_index # 记录索引
node.set_param_key(param_name) # 为节点存储参数键
node.set_param_key(param_index) # 直接使用索引作为键
param_index += 1
else:
# 如果不是叶子节点,继续遍历子节点
for i, child in enumerate(node.child_nodes):
if child is not None:
stack.append((child, f"{path}_{i}"))
stack.append((child, param_index))
def get_leaf_parameter(self, param_key: str) -> torch.Tensor:
# 保存参数总数
self.num_parameters.fill_(param_index)
@torch.jit.export
def get_leaf_parameter(self, param_index: torch.Tensor) -> torch.Tensor:
"""
获取叶子节点的特征参数
:param param_key: 叶子节点的参数键
:param param_index: 叶子节点的参数索引
:return: 对应的参数
"""
if param_key not in self.param_key_to_index:
raise KeyError(f"Invalid param_key: {param_key}")
target_index = self.param_key_to_index[param_key]
index = param_index.item() # 转换为Python整数
if index < 0 or index >= self.num_parameters.item():
raise IndexError(f"Parameter index {index} out of range")
# 使用枚举代替动态索引
for index, param in enumerate(self._leaf_parameters):
if index == target_index:
# 使用列表推导代替直接索引
for i, param in enumerate(self._leaf_parameters):
if i == index:
return param
raise IndexError(f"Index {target_index} not found in ParameterList")
raise IndexError(f"Parameter index {index} not found")
def forward(self, query_points: torch.Tensor) -> torch.Tensor:
"""
@ -132,11 +82,10 @@ class Encoder(nn.Module):
for i in range(batch_size):
# 1. 在八叉树中查找包含该点的叶子节点
bbox, param_key, _ = self.octree.find_leaf(query_points[i])
#logger.debug(leaf_node.param_key)
bbox, param_index, _ = self.octree.find_leaf(query_points[i])
# 2. 获取该节点的特征参数
node_features = self.get_leaf_parameter(param_key)
node_features = self.get_leaf_parameter(param_index)
# 3. 使用三线性插值计算特征
# (这里需要实现你的插值逻辑)

16
brep2sdf/networks/octree.py

@ -1,4 +1,4 @@
from typing import Tuple, List, cast, Dict, Any, Tuple
from typing import Tuple, List, cast, Dict, Any
import torch
import torch.nn as nn
@ -39,12 +39,18 @@ class OctreeNode(nn.Module):
self.register_buffer('surf_bbox', surf_bbox) # 面片边界框
self.max_depth = max_depth
self.param_key = ""
# 将param_key改为张量
self.register_buffer('param_key', torch.tensor(-1, dtype=torch.long))
self._is_leaf = True
@torch.jit.export
def set_param_key(self, k: str) -> None:
self.param_key = k
def set_param_key(self, k: int) -> None:
"""设置参数键值
参数:
k: 参数索引值
"""
self.param_key.fill_(k)
@torch.jit.export
def build_static_tree(self) -> None:
@ -122,7 +128,7 @@ class OctreeNode(nn.Module):
return child_bboxes
@torch.jit.export
def find_leaf(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, str, bool]:
def find_leaf(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, bool]:
"""
查找包含给定点的叶子节点并返回其信息
:param query_points: 待查找的点形状为 (3,)

Loading…
Cancel
Save