|
|
@ -8,58 +8,7 @@ from .octree import OctreeNode |
|
|
|
from brep2sdf.config.default_config import get_default_config |
|
|
|
from brep2sdf.utils.logger import logger |
|
|
|
import numpy as np |
|
|
|
''' |
|
|
|
class Encoder: |
|
|
|
def __init__(self, surf_bbox: torch.Tensor, origin_bbox: torch.Tensor, max_depth: int, feature_dim:int = 64): |
|
|
|
""" |
|
|
|
初始化表面八叉树管理器 |
|
|
|
|
|
|
|
参数: |
|
|
|
surf_bbox: 表面包围盒的世界坐标,形状为 (num_edges, 6), dtype=float32 |
|
|
|
origin_bbox: 原点包围盒的世界坐标,形状为 (6), dtype=float32 |
|
|
|
max_depth: 八叉树的最大深度 |
|
|
|
""" |
|
|
|
self.max_depth = max_depth |
|
|
|
# 初始化根节点,包含所有面索引。NOTE: 实际上需要判断 aabb 包含问题,但是这里默认 origin_bbox 由这些 face 计算,所以不再重复判断 |
|
|
|
num_faces = surf_bbox.shape[0] |
|
|
|
|
|
|
|
#print(f"surf_bbox: {surf_bbox.shape}") |
|
|
|
#print(f"origin_bbox: {origin_bbox.shape}") |
|
|
|
self.root = OctreeNode( |
|
|
|
bbox=origin_bbox, |
|
|
|
face_indices=np.arange(num_faces), # 初始包含所有面 |
|
|
|
max_depth=self.max_depth, |
|
|
|
feature_dim=feature_dim, |
|
|
|
surf_bbox=surf_bbox |
|
|
|
) |
|
|
|
#print(surf_bbox) |
|
|
|
logger.info("starting octree conduction") |
|
|
|
self.root.conduct_tree() |
|
|
|
logger.info("complete octree conduction") |
|
|
|
#self.root.print_tree(0) |
|
|
|
|
|
|
|
def get_feature_vector(self, query_point): |
|
|
|
return self.root.get_feature_vector(query_point) |
|
|
|
|
|
|
|
def forward(self, query_points): |
|
|
|
""" |
|
|
|
前向传播,处理批量查询点 |
|
|
|
|
|
|
|
参数: |
|
|
|
query_points: 查询点的位置坐标,形状为(batch_size, 3) |
|
|
|
返回: |
|
|
|
feature_vectors: 查询点的特征向量,形状为(batch_size, feature_dim) |
|
|
|
""" |
|
|
|
batch_size = query_points.shape[0] |
|
|
|
feature_vectors = [] |
|
|
|
for i in range(batch_size): |
|
|
|
feature_vector = self.get_feature_vector(query_points[i]) |
|
|
|
feature_vectors.append(feature_vector) |
|
|
|
return torch.stack(feature_vectors, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
''' |
|
|
|
class Encoder(nn.Module): |
|
|
|
def __init__(self, octree: OctreeNode, feature_dim: int = 32): |
|
|
|
""" |
|
|
@ -74,49 +23,50 @@ class Encoder(nn.Module): |
|
|
|
self.feature_dim = feature_dim |
|
|
|
|
|
|
|
# 初始化叶子节点参数 |
|
|
|
self.register_buffer('num_parameters', torch.tensor(0, dtype=torch.long)) |
|
|
|
self._leaf_parameters = nn.ParameterList() # 使用 ParameterList 存储参数 |
|
|
|
self.param_key_to_index: Dict[str, int] = {} # 字典映射:param_key -> index |
|
|
|
# 为所有叶子节点注册可学习参数 |
|
|
|
# 不再需要 param_key_to_index 字典 |
|
|
|
self._init_parameters() |
|
|
|
|
|
|
|
def _init_parameters(self): |
|
|
|
"""为所有叶子节点初始化特征参数""" |
|
|
|
stack = [(self.octree, "root")] # (当前节点, 当前路径) |
|
|
|
stack = [(self.octree, 0)] # (当前节点, 参数索引) |
|
|
|
param_index = 0 # 参数索引计数器 |
|
|
|
|
|
|
|
while stack: |
|
|
|
node, path = stack.pop() |
|
|
|
node, param_index = stack.pop() |
|
|
|
|
|
|
|
if node._is_leaf: |
|
|
|
# 如果是叶子节点,初始化参数 |
|
|
|
param_name = f"leaf_{path}" |
|
|
|
self._leaf_parameters.append(nn.Parameter(torch.randn(8, self.feature_dim))) # 8个顶点的特征 |
|
|
|
self.param_key_to_index[param_name] = param_index # 记录索引 |
|
|
|
node.set_param_key(param_name) # 为节点存储参数键 |
|
|
|
node.set_param_key(param_index) # 直接使用索引作为键 |
|
|
|
param_index += 1 |
|
|
|
else: |
|
|
|
# 如果不是叶子节点,继续遍历子节点 |
|
|
|
for i, child in enumerate(node.child_nodes): |
|
|
|
if child is not None: |
|
|
|
stack.append((child, f"{path}_{i}")) |
|
|
|
stack.append((child, param_index)) |
|
|
|
|
|
|
|
def get_leaf_parameter(self, param_key: str) -> torch.Tensor: |
|
|
|
# 保存参数总数 |
|
|
|
self.num_parameters.fill_(param_index) |
|
|
|
|
|
|
|
@torch.jit.export |
|
|
|
def get_leaf_parameter(self, param_index: torch.Tensor) -> torch.Tensor: |
|
|
|
""" |
|
|
|
获取叶子节点的特征参数 |
|
|
|
:param param_key: 叶子节点的参数键 |
|
|
|
:param param_index: 叶子节点的参数索引 |
|
|
|
:return: 对应的参数 |
|
|
|
""" |
|
|
|
if param_key not in self.param_key_to_index: |
|
|
|
raise KeyError(f"Invalid param_key: {param_key}") |
|
|
|
|
|
|
|
target_index = self.param_key_to_index[param_key] |
|
|
|
index = param_index.item() # 转换为Python整数 |
|
|
|
if index < 0 or index >= self.num_parameters.item(): |
|
|
|
raise IndexError(f"Parameter index {index} out of range") |
|
|
|
|
|
|
|
# 使用枚举代替动态索引 |
|
|
|
for index, param in enumerate(self._leaf_parameters): |
|
|
|
if index == target_index: |
|
|
|
# 使用列表推导代替直接索引 |
|
|
|
for i, param in enumerate(self._leaf_parameters): |
|
|
|
if i == index: |
|
|
|
return param |
|
|
|
|
|
|
|
raise IndexError(f"Index {target_index} not found in ParameterList") |
|
|
|
raise IndexError(f"Parameter index {index} not found") |
|
|
|
|
|
|
|
def forward(self, query_points: torch.Tensor) -> torch.Tensor: |
|
|
|
""" |
|
|
@ -132,11 +82,10 @@ class Encoder(nn.Module): |
|
|
|
|
|
|
|
for i in range(batch_size): |
|
|
|
# 1. 在八叉树中查找包含该点的叶子节点 |
|
|
|
bbox, param_key, _ = self.octree.find_leaf(query_points[i]) |
|
|
|
#logger.debug(leaf_node.param_key) |
|
|
|
bbox, param_index, _ = self.octree.find_leaf(query_points[i]) |
|
|
|
|
|
|
|
# 2. 获取该节点的特征参数 |
|
|
|
node_features = self.get_leaf_parameter(param_key) |
|
|
|
node_features = self.get_leaf_parameter(param_index) |
|
|
|
|
|
|
|
# 3. 使用三线性插值计算特征 |
|
|
|
# (这里需要实现你的插值逻辑) |
|
|
|