|
@ -1,6 +1,6 @@ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple, List |
|
|
from typing import Tuple, List, cast, Dict, Any, Tuple |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn as nn |
|
@ -35,9 +35,9 @@ class OctreeNode(nn.Module): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
self.bbox = bbox # 节点的边界框 |
|
|
self.bbox = bbox # 节点的边界框 |
|
|
self.max_depth: int = max_depth # 最大深度,当这个为0时,表示已经到达最大深度,不可再分子节点 |
|
|
self.max_depth: int = max_depth # 最大深度,当这个为0时,表示已经到达最大深度,不可再分子节点 |
|
|
self.child_nodes: List['OctreeNode'] = [] # 子节点列表 |
|
|
self.child_nodes: torch.nn.ModuleList = torch.nn.ModuleList() # 子节点列表 |
|
|
self.face_indices = face_indices |
|
|
self.face_indices = face_indices |
|
|
self.param_key = None |
|
|
self.param_key = "" |
|
|
#self.patch_feature_volume = None # 补丁特征体积,only leaf has |
|
|
#self.patch_feature_volume = None # 补丁特征体积,only leaf has |
|
|
self._is_leaf = True |
|
|
self._is_leaf = True |
|
|
#print(f"box shape: {self.bbox.shape}") |
|
|
#print(f"box shape: {self.bbox.shape}") |
|
@ -103,7 +103,6 @@ class OctreeNode(nn.Module): |
|
|
]) |
|
|
]) |
|
|
|
|
|
|
|
|
# 为每个子包围盒创建子节点,并分配相交的面 |
|
|
# 为每个子包围盒创建子节点,并分配相交的面 |
|
|
self.child_nodes = [] |
|
|
|
|
|
for bbox in child_bboxes: |
|
|
for bbox in child_bboxes: |
|
|
# 找到与子包围盒相交的面 |
|
|
# 找到与子包围盒相交的面 |
|
|
intersecting_faces = [] |
|
|
intersecting_faces = [] |
|
@ -142,14 +141,35 @@ class OctreeNode(nn.Module): |
|
|
# 使用布尔比较结果计算索引 |
|
|
# 使用布尔比较结果计算索引 |
|
|
index = ((query_point >= mid_coords) << torch.arange(3, device=self.bbox.device)).sum() |
|
|
index = ((query_point >= mid_coords) << torch.arange(3, device=self.bbox.device)).sum() |
|
|
|
|
|
|
|
|
return index.unsqueeze(0) |
|
|
return index.item() |
|
|
|
|
|
|
|
|
def find_leaf(self, query_point:torch.Tensor): |
|
|
def find_leaf(self, query_point: torch.Tensor) -> Tuple[torch.Tensor, str, bool]: |
|
|
# 从根节点开始递归查找包含该点的叶子节点 |
|
|
""" |
|
|
|
|
|
查找包含给定点的叶子节点,并返回其信息(以元组形式) |
|
|
|
|
|
:param query_point: 待查找的点 |
|
|
|
|
|
:return: 包含叶子节点信息的元组 (bbox, param_key, is_leaf) |
|
|
|
|
|
""" |
|
|
|
|
|
# 如果当前节点是叶子节点,返回其信息 |
|
|
if self._is_leaf: |
|
|
if self._is_leaf: |
|
|
return self |
|
|
#logger.info(f"{self.bbox}, {self.param_key}, {True}") |
|
|
else: |
|
|
return (self.bbox, self.param_key, True) |
|
|
index = self.get_child_index(query_point) |
|
|
|
|
|
|
|
|
# 计算查询点所在的子节点索引 |
|
|
|
|
|
index = self.get_child_index(query_point) |
|
|
|
|
|
|
|
|
|
|
|
# 遍历子节点列表,找到对应的子节点 |
|
|
|
|
|
for i, child_node in enumerate(self.child_nodes): |
|
|
|
|
|
if i == index and child_node is not None: |
|
|
|
|
|
# 递归调用子节点的 find_leaf 方法 |
|
|
|
|
|
result = child_node.find_leaf(query_point) |
|
|
|
|
|
|
|
|
|
|
|
# 确保返回值是一个元组 |
|
|
|
|
|
assert isinstance(result, tuple), f"Unexpected return type: {type(result)}" |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
# 如果找不到有效的子节点,抛出异常 |
|
|
|
|
|
raise IndexError(f"Invalid child node index: {index}") |
|
|
|
|
|
''' |
|
|
try: |
|
|
try: |
|
|
# 直接访问子节点,不进行显式检查 |
|
|
# 直接访问子节点,不进行显式检查 |
|
|
return self.child_nodes[index].find_leaf(query_point) |
|
|
return self.child_nodes[index].find_leaf(query_point) |
|
@ -162,7 +182,9 @@ class OctreeNode(nn.Module): |
|
|
f"Depth info: {self.max_depth}" |
|
|
f"Depth info: {self.max_depth}" |
|
|
) |
|
|
) |
|
|
raise e |
|
|
raise e |
|
|
|
|
|
''' |
|
|
|
|
|
|
|
|
|
|
|
''' |
|
|
def get_feature_vector(self, query_point:torch.Tensor): |
|
|
def get_feature_vector(self, query_point:torch.Tensor): |
|
|
""" |
|
|
""" |
|
|
预测给定点的 SDF 值 |
|
|
预测给定点的 SDF 值 |
|
@ -188,7 +210,7 @@ class OctreeNode(nn.Module): |
|
|
f"Depth info: {self.max_depth}" |
|
|
f"Depth info: {self.max_depth}" |
|
|
) |
|
|
) |
|
|
raise e |
|
|
raise e |
|
|
|
|
|
''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|