|
|
@ -1,82 +1,77 @@ |
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple, List |
|
|
|
from typing import Tuple, List, cast, Dict, Any, Tuple |
|
|
|
|
|
|
|
import torch |
|
|
|
import torch.nn as nn |
|
|
|
import torch.nn.functional as F |
|
|
|
import numpy as np |
|
|
|
import pickle |
|
|
|
|
|
|
|
from brep2sdf.utils.logger import logger |
|
|
|
|
|
|
|
def bbox_intersect(bbox1: np.ndarray, bbox2: np.ndarray) -> bool: |
|
|
|
def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> bool: |
|
|
|
"""判断两个轴对齐包围盒(AABB)是否相交 |
|
|
|
|
|
|
|
参数: |
|
|
|
bbox1: 形状为 (6,) 的数组,格式 [min_x, min_y, min_z, max_x, max_y, max_z] |
|
|
|
bbox1: 形状为 (6,) 的张量,格式 [min_x, min_y, min_z, max_x, max_y, max_z] |
|
|
|
bbox2: 同bbox1格式 |
|
|
|
|
|
|
|
返回: |
|
|
|
bool: 两包围盒是否相交(包括刚好接触的情况) |
|
|
|
""" |
|
|
|
assert bbox1.shape == (6,) and bbox2.shape == (6,), "输入必须是形状为(6,)的数组" |
|
|
|
assert bbox1.shape == (6,) and bbox2.shape == (6,), "输入必须是形状为(6,)的张量" |
|
|
|
|
|
|
|
# 提取min和max坐标 |
|
|
|
min1, max1 = bbox1[:3], bbox1[3:] |
|
|
|
min2, max2 = bbox2[:3], bbox2[3:] |
|
|
|
|
|
|
|
# 向量化比较 |
|
|
|
return np.all((max1 >= min2) & (max2 >= min1)) |
|
|
|
return torch.all((max1 >= min2) & (max2 >= min1)) |
|
|
|
|
|
|
|
|
|
|
|
class OctreeNode: |
|
|
|
feature_dim = None |
|
|
|
class OctreeNode(nn.Module): |
|
|
|
device=None |
|
|
|
surf_bbox = None |
|
|
|
|
|
|
|
def __init__(self, bbox: np.ndarray, face_indices: np.ndarray, max_depth: int = 5, feature_dim: int = None, surf_bbox: np.ndarray = None): |
|
|
|
""" |
|
|
|
初始化八叉树节点。 |
|
|
|
:param bbox: 节点的边界框,格式为 [min_x, min_y, min_z, max_x, max_y, max_z] (形状为 (6,)) |
|
|
|
:param face_indices: 当前节点包含的面索引数组 |
|
|
|
:param max_depth: 八叉树的最大深度 |
|
|
|
:param feature_dim: 特征维度(仅在叶子节点时使用) |
|
|
|
:param surf_bbox: 面的包围盒数组,形状为 (N, 6) |
|
|
|
""" |
|
|
|
def __init__(self, bbox: torch.Tensor,face_indices: np.ndarray, max_depth: int = 5, surf_bbox:torch.Tensor = None): |
|
|
|
super().__init__() |
|
|
|
self.bbox = bbox # 节点的边界框 |
|
|
|
self.max_depth: int = max_depth # 最大深度,当这个为0时,表示已经到达最大深度,不可再分子节点 |
|
|
|
self.children: List['OctreeNode'] = [] # 子节点列表 |
|
|
|
self.child_nodes: torch.nn.ModuleList = torch.nn.ModuleList() # 子节点列表 |
|
|
|
self.face_indices = face_indices |
|
|
|
self.param_key = "" |
|
|
|
#self.patch_feature_volume = None # 补丁特征体积,only leaf has |
|
|
|
self._is_leaf = True |
|
|
|
#print(f"box shape: {self.bbox.shape}") |
|
|
|
|
|
|
|
if feature_dim is not None: |
|
|
|
OctreeNode.feature_dim = feature_dim |
|
|
|
if surf_bbox is not None: |
|
|
|
if not isinstance(surf_bbox, np.ndarray): |
|
|
|
raise TypeError(f"surf_bbox 必须是 numpy.ndarray 类型,但得到 {type(surf_bbox)}") |
|
|
|
if surf_bbox.ndim != 2 or surf_bbox.shape[1] != 6: |
|
|
|
raise ValueError(f"surf_bbox 应为二维数组且形状为 (N,6),但得到 {surf_bbox.shape}") |
|
|
|
OctreeNode.surf_bbox = surf_bbox # NOTE: 只在根节点时创建 |
|
|
|
if surf_bbox is not None: |
|
|
|
if not isinstance(surf_bbox, torch.Tensor): |
|
|
|
raise TypeError( |
|
|
|
f"surf_bbox 必须是 torch.Tensor 类型,但得到 {type(surf_bbox)}" |
|
|
|
) |
|
|
|
if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6: |
|
|
|
raise ValueError( |
|
|
|
f"surf_bbox 应为二维张量且形状为 (N,6),但得到 {surf_bbox.shape}" |
|
|
|
) |
|
|
|
OctreeNode.surf_bbox = surf_bbox # NOTE: 只在根节点时创建 |
|
|
|
OctreeNode.device = bbox.device |
|
|
|
|
|
|
|
def is_leaf(self): |
|
|
|
# Check if self.children is None before calling len() |
|
|
|
# Check if self.child——nodes is None before calling len() |
|
|
|
return self._is_leaf |
|
|
|
|
|
|
|
def set_param_key(self, k): |
|
|
|
self.param_key = k |
|
|
|
|
|
|
|
def conduct_tree(self): |
|
|
|
""" |
|
|
|
构建八叉树:如果达到最大深度或当前节点包含的面数小于等于2,则停止划分。 |
|
|
|
""" |
|
|
|
if self.max_depth <= 0 or len(self.face_indices) <= 2: |
|
|
|
if self.max_depth <= 0 or self.face_indices.shape[0] <= 2: |
|
|
|
# 达到最大深度 or 一个单元格至多只有两个面 |
|
|
|
#self.patch_feature_volume = np.random.randn(8, OctreeNode.feature_dim) |
|
|
|
return |
|
|
|
self.subdivide() |
|
|
|
|
|
|
|
|
|
|
|
def subdivide(self): |
|
|
|
""" |
|
|
|
将当前节点划分为8个子节点,并分配相交的面。 |
|
|
|
""" |
|
|
|
|
|
|
|
#min_x, min_y, min_z, max_x, max_y, max_z = self.bbox |
|
|
|
# 使用索引操作替代解包 |
|
|
|
min_coords = self.bbox[:3] # [min_x, min_y, min_z] |
|
|
|
max_coords = self.bbox[3:] # [max_x, max_y, max_z] |
|
|
|
|
|
|
@ -84,46 +79,58 @@ class OctreeNode: |
|
|
|
mid_coords = (min_coords + max_coords) / 2 |
|
|
|
|
|
|
|
# 提取 min_x, min_y, min_z, mid_x, mid_y, mid_z |
|
|
|
min_x, min_y, min_z = min_coords |
|
|
|
mid_x, mid_y, mid_z = mid_coords |
|
|
|
max_x, max_y, max_z = max_coords |
|
|
|
min_x, min_y, min_z = min_coords[0], min_coords[1], min_coords[2] |
|
|
|
mid_x, mid_y, mid_z = mid_coords[0], mid_coords[1], mid_coords[2] |
|
|
|
max_x, max_y, max_z = max_coords[0], max_coords[1], max_coords[2] |
|
|
|
|
|
|
|
# 生成 8 个子包围盒 |
|
|
|
child_bboxes = np.array([ |
|
|
|
[*min_coords, *mid_coords], # 前下左 |
|
|
|
[mid_x, min_y, min_z, max_x, mid_y, mid_z], # 前下右 |
|
|
|
[min_x, mid_y, min_z, mid_x, max_y, mid_z], # 前上左 |
|
|
|
[mid_x, mid_y, min_z, max_x, max_y, mid_z], # 前上右 |
|
|
|
[min_x, min_y, mid_z, mid_x, mid_y, max_z], # 后下左 |
|
|
|
[mid_x, min_y, mid_z, max_x, mid_y, max_z], # 后下右 |
|
|
|
[min_x, mid_y, mid_z, mid_x, max_y, max_z], # 后上左 |
|
|
|
[mid_x, mid_y, mid_z, max_x, max_y, max_z] # 后上右 |
|
|
|
child_bboxes = torch.stack([ |
|
|
|
torch.cat([min_coords, mid_coords]), # 前下左 |
|
|
|
torch.cat([torch.tensor([mid_x, min_y, min_z], device=self.bbox.device), |
|
|
|
torch.tensor([max_x, mid_y, mid_z], device=self.bbox.device)]), # 前下右 |
|
|
|
torch.cat([torch.tensor([min_x, mid_y, min_z], device=self.bbox.device), |
|
|
|
torch.tensor([mid_x, max_y, mid_z], device=self.bbox.device)]), # 前上左 |
|
|
|
torch.cat([torch.tensor([mid_x, mid_y, min_z], device=self.bbox.device), |
|
|
|
torch.tensor([max_x, max_y, mid_z], device=self.bbox.device)]), # 前上右 |
|
|
|
torch.cat([torch.tensor([min_x, min_y, mid_z], device=self.bbox.device), |
|
|
|
torch.tensor([mid_x, mid_y, max_z], device=self.bbox.device)]), # 后下左 |
|
|
|
torch.cat([torch.tensor([mid_x, min_y, mid_z], device=self.bbox.device), |
|
|
|
torch.tensor([max_x, mid_y, max_z], device=self.bbox.device)]), # 后下右 |
|
|
|
torch.cat([torch.tensor([min_x, mid_y, mid_z], device=self.bbox.device), |
|
|
|
torch.tensor([mid_x, max_y, max_z], device=self.bbox.device)]), # 后上左 |
|
|
|
torch.cat([torch.tensor([mid_x, mid_y, mid_z], device=self.bbox.device), |
|
|
|
torch.tensor([max_x, max_y, max_z], device=self.bbox.device)]) # 后上右 |
|
|
|
]) |
|
|
|
|
|
|
|
# 为每个子包围盒创建子节点,并分配相交的面 |
|
|
|
self.children = [] |
|
|
|
for bbox in child_bboxes: |
|
|
|
# 找到与子包围盒相交的面 |
|
|
|
intersecting_faces = [ |
|
|
|
face_idx for face_idx in self.face_indices |
|
|
|
if bbox_intersect(bbox, OctreeNode.surf_bbox[face_idx]) |
|
|
|
] |
|
|
|
intersecting_faces = [] |
|
|
|
for face_idx in self.face_indices: |
|
|
|
face_bbox = OctreeNode.surf_bbox[face_idx] |
|
|
|
if bbox_intersect(bbox, face_bbox): |
|
|
|
intersecting_faces.append(face_idx) |
|
|
|
#print(f"{bbox}: {intersecting_faces}") |
|
|
|
|
|
|
|
child_node = OctreeNode( |
|
|
|
bbox=bbox, |
|
|
|
face_indices=np.array(intersecting_faces), |
|
|
|
max_depth=self.max_depth - 1 |
|
|
|
) |
|
|
|
child_node.conduct_tree() |
|
|
|
self.children.append(child_node) |
|
|
|
self.child_nodes.append(child_node) |
|
|
|
|
|
|
|
self._is_leaf = False |
|
|
|
|
|
|
|
def get_child_index(self, query_point: np.ndarray) -> int: |
|
|
|
def get_child_index(self, query_point: torch.Tensor) -> int: |
|
|
|
""" |
|
|
|
计算点所在子节点的索引。 |
|
|
|
计算点所在子节点的索引 |
|
|
|
:param query_point: 待检查的点,格式为 (x, y, z) |
|
|
|
:return: 子节点的索引,范围从 0 到 7 |
|
|
|
""" |
|
|
|
# 确保 query_point 和 bbox 在同一设备上 |
|
|
|
query_point = query_point.to(self.bbox.device) |
|
|
|
|
|
|
|
# 提取 bbox 的最小和最大坐标 |
|
|
|
min_coords = self.bbox[:3] # [min_x, min_y, min_z] |
|
|
|
max_coords = self.bbox[3:] # [max_x, max_y, max_z] |
|
|
@ -132,11 +139,11 @@ class OctreeNode: |
|
|
|
mid_coords = (min_coords + max_coords) / 2 |
|
|
|
|
|
|
|
# 使用布尔比较结果计算索引 |
|
|
|
index = ((query_point >= mid_coords) << np.arange(3)).sum() |
|
|
|
index = ((query_point >= mid_coords) << torch.arange(3, device=self.bbox.device)).sum() |
|
|
|
|
|
|
|
return index |
|
|
|
return index.item() |
|
|
|
|
|
|
|
def find_leaf(self, query_point: np.ndarray) -> np.ndarray: |
|
|
|
def find_leaf(self, query_point: torch.Tensor) -> Tuple[torch.Tensor, str, bool]: |
|
|
|
""" |
|
|
|
查找包含给定点的叶子节点,并返回其信息(以元组形式) |
|
|
|
:param query_point: 待查找的点 |
|
|
@ -145,22 +152,66 @@ class OctreeNode: |
|
|
|
# 如果当前节点是叶子节点,返回其信息 |
|
|
|
if self._is_leaf: |
|
|
|
#logger.info(f"{self.bbox}, {self.param_key}, {True}") |
|
|
|
return self.face_indices |
|
|
|
return (self.bbox, self.param_key, True) |
|
|
|
|
|
|
|
# 计算查询点所在的子节点索引 |
|
|
|
index = self.get_child_index(query_point) |
|
|
|
try: |
|
|
|
# 直接访问子节点,不进行显式检查 |
|
|
|
return self.children[index].find_leaf(query_point) |
|
|
|
except IndexError as e: |
|
|
|
# 记录错误日志并重新抛出异常 |
|
|
|
logger.error( |
|
|
|
f"Error accessing child node: {e}. " |
|
|
|
f"Query point: {query_point.tolist()}, " |
|
|
|
f"Node bbox: {self.bbox.tolist()}, " |
|
|
|
f"Depth info: {self.max_depth}" |
|
|
|
) |
|
|
|
raise e |
|
|
|
|
|
|
|
# 遍历子节点列表,找到对应的子节点 |
|
|
|
for i, child_node in enumerate(self.child_nodes): |
|
|
|
if i == index and child_node is not None: |
|
|
|
# 递归调用子节点的 find_leaf 方法 |
|
|
|
result = child_node.find_leaf(query_point) |
|
|
|
|
|
|
|
# 确保返回值是一个元组 |
|
|
|
assert isinstance(result, tuple), f"Unexpected return type: {type(result)}" |
|
|
|
return result |
|
|
|
|
|
|
|
# 如果找不到有效的子节点,抛出异常 |
|
|
|
raise IndexError(f"Invalid child node index: {index}") |
|
|
|
''' |
|
|
|
try: |
|
|
|
# 直接访问子节点,不进行显式检查 |
|
|
|
return self.child_nodes[index].find_leaf(query_point) |
|
|
|
except IndexError as e: |
|
|
|
# 记录错误日志并重新抛出异常 |
|
|
|
logger.error( |
|
|
|
f"Error accessing child node: {e}. " |
|
|
|
f"Query point: {query_point.cpu().numpy().tolist()}, " |
|
|
|
f"Node bbox: {self.bbox.cpu().numpy().tolist()}, " |
|
|
|
f"Depth info: {self.max_depth}" |
|
|
|
) |
|
|
|
raise e |
|
|
|
''' |
|
|
|
|
|
|
|
''' |
|
|
|
def get_feature_vector(self, query_point:torch.Tensor): |
|
|
|
""" |
|
|
|
预测给定点的 SDF 值 |
|
|
|
:param point: 待预测的点,格式为 (x, y, z) |
|
|
|
:return: 预测的 SDF 值 |
|
|
|
""" |
|
|
|
# 将点转换为 numpy 数组 |
|
|
|
|
|
|
|
# 从根节点开始递归查找包含该点的叶子节点 |
|
|
|
if self._is_leaf: |
|
|
|
return self.trilinear_interpolation(query_point) |
|
|
|
else: |
|
|
|
index = self.get_child_index(query_point) |
|
|
|
try: |
|
|
|
# 直接访问子节点,不进行显式检查 |
|
|
|
return self.child_nodes[index].get_feature_vector(query_point) |
|
|
|
except IndexError as e: |
|
|
|
# 记录错误日志并重新抛出异常 |
|
|
|
logger.error( |
|
|
|
f"Error accessing child node: {e}. " |
|
|
|
f"Query point: {query_point.cpu().numpy().tolist()}, " |
|
|
|
f"Node bbox: {self.bbox.cpu().numpy().tolist()}, " |
|
|
|
f"Depth info: {self.max_depth}" |
|
|
|
) |
|
|
|
raise e |
|
|
|
''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -178,236 +229,37 @@ class OctreeNode: |
|
|
|
# 打印当前节点信息 |
|
|
|
indent = " " * depth |
|
|
|
node_type = "Leaf" if self._is_leaf else "Internal" |
|
|
|
print(f"{indent}L{depth} [{node_type}] BBox: {self.bbox.tolist()}") |
|
|
|
print(f"{indent}L{depth} [{node_type}] BBox: {self.bbox.cpu().numpy().tolist()}") |
|
|
|
|
|
|
|
# 打印面片信息(如果有) |
|
|
|
if self.face_indices is not None: |
|
|
|
print(f"{indent} Face indices: {self.face_indices.tolist()}") |
|
|
|
print(f"{indent} len children: {len(self.children)}") |
|
|
|
print(f"{indent} len child_nodes: {len(self.child_nodes)}") |
|
|
|
|
|
|
|
# 递归打印子节点 |
|
|
|
for i, child in enumerate(self.children): |
|
|
|
for i, child in enumerate(self.child_nodes): |
|
|
|
print(f"{indent} Child {i}:") |
|
|
|
child.print_tree(depth + 1, max_print_depth) |
|
|
|
|
|
|
|
# 保存 |
|
|
|
|
|
|
|
def save_tree_to_file(self, file_path: str): |
|
|
|
""" |
|
|
|
将八叉树保存到文件 |
|
|
|
:param file_path: 要保存的文件路径 |
|
|
|
""" |
|
|
|
|
|
|
|
# 获取完整状态字典 |
|
|
|
state = self.state_dict() |
|
|
|
|
|
|
|
# 添加类级别的静态变量 |
|
|
|
state['feature_dim'] = OctreeNode.feature_dim |
|
|
|
state['surf_bbox'] = OctreeNode.surf_bbox |
|
|
|
|
|
|
|
# 保存到文件 |
|
|
|
with open(file_path, 'wb') as f: |
|
|
|
pickle.dump(state, f) |
|
|
|
def __getstate__(self): |
|
|
|
"""支持pickle序列化""" |
|
|
|
return self._serialize_node(self) |
|
|
|
|
|
|
|
print(f"八叉树已成功保存到 {file_path}") |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def load_tree_from_file(cls, file_path: str) -> 'OctreeNode': |
|
|
|
""" |
|
|
|
从文件加载八叉树 |
|
|
|
:param file_path: 要加载的文件路径 |
|
|
|
:return: 恢复的八叉树根节点 |
|
|
|
""" |
|
|
|
def __setstate__(self, state): |
|
|
|
"""支持pickle反序列化""" |
|
|
|
self = self._deserialize_node(state) |
|
|
|
|
|
|
|
|
|
|
|
with open(file_path, 'rb') as f: |
|
|
|
state = pickle.load(f) |
|
|
|
|
|
|
|
# 恢复类级别的静态变量 |
|
|
|
cls.feature_dim = state.pop('feature_dim') |
|
|
|
cls.surf_bbox = state.pop('surf_bbox') |
|
|
|
|
|
|
|
# 创建根节点 |
|
|
|
root = cls( |
|
|
|
bbox=state['bbox'], |
|
|
|
face_indices=state['face_indices'], |
|
|
|
max_depth=state['max_depth'] |
|
|
|
) |
|
|
|
|
|
|
|
# 加载状态 |
|
|
|
root.load_state_dict(state) |
|
|
|
|
|
|
|
print(f"八叉树已从 {file_path} 成功加载") |
|
|
|
return root |
|
|
|
def state_dict(self): |
|
|
|
"""返回节点及其子树的state_dict""" |
|
|
|
state = { |
|
|
|
'bbox': self.bbox, |
|
|
|
'max_depth': self.max_depth, |
|
|
|
'face_indices': self.face_indices, |
|
|
|
'is_leaf': self._is_leaf |
|
|
|
def _serialize_node(self, node): |
|
|
|
return { |
|
|
|
'bbox': node.bbox, |
|
|
|
'is_leaf': node._is_leaf, |
|
|
|
'child_nodes': [self._serialize_node(c) for c in node.child_nodes], |
|
|
|
'param_key': node.param_key |
|
|
|
} |
|
|
|
|
|
|
|
if self._is_leaf: |
|
|
|
pass |
|
|
|
else: |
|
|
|
state['children'] = [child.state_dict() for child in self.children] |
|
|
|
|
|
|
|
return state |
|
|
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
|
"""从state_dict加载节点状态""" |
|
|
|
self.bbox = state_dict['bbox'] |
|
|
|
self.max_depth = state_dict['max_depth'] |
|
|
|
self.face_indices = state_dict['face_indices'] |
|
|
|
self._is_leaf = state_dict['is_leaf'] |
|
|
|
|
|
|
|
if self._is_leaf: |
|
|
|
return |
|
|
|
else: |
|
|
|
self.children = [] |
|
|
|
for child_state in state_dict['children']: |
|
|
|
child = OctreeNode( |
|
|
|
bbox=child_state['bbox'], |
|
|
|
face_indices=child_state['face_indices'], |
|
|
|
max_depth=child_state['max_depth'] |
|
|
|
) |
|
|
|
child.load_state_dict(child_state) |
|
|
|
self.children.append(child) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_octree(): |
|
|
|
# 1. 测试bbox_intersect函数 |
|
|
|
print("测试bbox_intersect函数...") |
|
|
|
bbox1 = np.array([0, 0, 0, 1, 1, 1]) |
|
|
|
bbox2 = np.array([0.5, 0.5, 0.5, 1.5, 1.5, 1.5]) |
|
|
|
assert bbox_intersect(bbox1, bbox2), "相交测试失败" |
|
|
|
|
|
|
|
bbox3 = np.array([2, 2, 2, 3, 3, 3]) |
|
|
|
assert not bbox_intersect(bbox1, bbox3), "不相交测试失败" |
|
|
|
print("bbox_intersect测试通过!\n") |
|
|
|
|
|
|
|
# 2. 创建测试用的面包围盒 |
|
|
|
# 假设有4个面,每个面有一个包围盒 |
|
|
|
surf_bbox = np.array([ |
|
|
|
[0.1, 0.1, 0.1, 0.4, 0.4, 0.4], # 面0 |
|
|
|
[0.6, 0.1, 0.1, 0.9, 0.4, 0.4], # 面1 |
|
|
|
[0.1, 0.6, 0.1, 0.4, 0.9, 0.4], # 面2 |
|
|
|
[0.6, 0.6, 0.1, 0.9, 0.9, 0.4] # 面3 |
|
|
|
]) |
|
|
|
|
|
|
|
# 3. 创建根节点 |
|
|
|
root_bbox = np.array([0, 0, 0, 1, 1, 1]) |
|
|
|
face_indices = np.arange(len(surf_bbox)) # 初始包含所有面 |
|
|
|
root = OctreeNode( |
|
|
|
bbox=root_bbox, |
|
|
|
face_indices=face_indices, |
|
|
|
max_depth=2, |
|
|
|
feature_dim=32, |
|
|
|
surf_bbox=surf_bbox |
|
|
|
) |
|
|
|
|
|
|
|
# 4. 构建八叉树 |
|
|
|
root.conduct_tree() |
|
|
|
|
|
|
|
# 5. 打印树结构(只打印前2层) |
|
|
|
print("八叉树结构:") |
|
|
|
root.print_tree(max_print_depth=2) |
|
|
|
|
|
|
|
# 6. 测试子节点索引计算 |
|
|
|
print("\n测试子节点索引计算...") |
|
|
|
test_points = [ |
|
|
|
([0.25, 0.25, 0.25], "应在前下左子节点"), |
|
|
|
([0.75, 0.25, 0.25], "应在前下右子节点"), |
|
|
|
([0.25, 0.75, 0.25], "应在前上左子节点"), |
|
|
|
([0.75, 0.75, 0.25], "应在前上右子节点") |
|
|
|
] |
|
|
|
|
|
|
|
for point, desc in test_points: |
|
|
|
idx = root.get_child_index(np.array(point)) |
|
|
|
print(f"点 {point} {desc}, 计算得到的索引: {idx}") |
|
|
|
|
|
|
|
# 7. 验证叶子节点特征 |
|
|
|
print("\n验证叶子节点特征:") |
|
|
|
for i, child in enumerate(root.children): |
|
|
|
if child.is_leaf(): |
|
|
|
print(f"子节点 {i} 是叶子节点,") |
|
|
|
else: |
|
|
|
print(f"子节点 {i} 不是叶子节点") |
|
|
|
|
|
|
|
print("\n所有测试完成!") |
|
|
|
|
|
|
|
# ... existing code ... |
|
|
|
|
|
|
|
def test_octree_save_load(): |
|
|
|
print("\n测试八叉树的保存和加载功能...") |
|
|
|
|
|
|
|
# 1. 创建测试数据 |
|
|
|
surf_bbox = np.array([ |
|
|
|
[0.1, 0.1, 0.1, 0.4, 0.4, 0.4], # 面0 |
|
|
|
[0.6, 0.1, 0.1, 0.9, 0.4, 0.4], # 面1 |
|
|
|
[0.1, 0.6, 0.1, 0.4, 0.9, 0.4], # 面2 |
|
|
|
[0.6, 0.6, 0.1, 0.9, 0.9, 0.4] # 面3 |
|
|
|
]) |
|
|
|
|
|
|
|
# 2. 创建原始树 |
|
|
|
root = OctreeNode( |
|
|
|
bbox=np.array([0, 0, 0, 1, 1, 1]), |
|
|
|
face_indices=np.arange(len(surf_bbox)), |
|
|
|
max_depth=2, |
|
|
|
feature_dim=32, |
|
|
|
surf_bbox=surf_bbox |
|
|
|
) |
|
|
|
root.conduct_tree() |
|
|
|
|
|
|
|
# 3. 保存树状态 |
|
|
|
test_file = 'test_octree.pkl' |
|
|
|
root.save_tree_to_file(test_file) |
|
|
|
|
|
|
|
# 4. 从文件加载树 |
|
|
|
new_root = OctreeNode.load_tree_from_file(test_file) |
|
|
|
print("树状态加载成功!") |
|
|
|
|
|
|
|
# 5. 验证加载后的树结构 |
|
|
|
print("\n验证加载后的树结构:") |
|
|
|
|
|
|
|
# 5.1 验证根节点属性 |
|
|
|
assert np.allclose(root.bbox, new_root.bbox), "bbox不匹配" |
|
|
|
assert root.max_depth == new_root.max_depth, "max_depth不匹配" |
|
|
|
assert np.array_equal(root.face_indices, new_root.face_indices), "face_indices不匹配" |
|
|
|
assert root._is_leaf == new_root._is_leaf, "is_leaf不匹配" |
|
|
|
print("根节点属性验证通过!") |
|
|
|
|
|
|
|
# 5.2 验证叶子节点特征 |
|
|
|
if root._is_leaf: |
|
|
|
#assert np.allclose(root.patch_feature_volume, new_root.patch_feature_volume), "特征不匹配" |
|
|
|
print("叶子节点特征验证通过!") |
|
|
|
else: |
|
|
|
# 递归验证子节点 |
|
|
|
for i, (orig_child, new_child) in enumerate(zip(root.children, new_root.children)): |
|
|
|
print(f"\n验证子节点 {i}:") |
|
|
|
assert np.allclose(orig_child.bbox, new_child.bbox), f"子节点{i} bbox不匹配" |
|
|
|
assert orig_child.max_depth == new_child.max_depth, f"子节点{i} max_depth不匹配" |
|
|
|
assert np.array_equal(orig_child.face_indices, new_child.face_indices), f"子节点{i} face_indices不匹配" |
|
|
|
assert orig_child._is_leaf == new_child._is_leaf, f"子节点{i} is_leaf不匹配" |
|
|
|
|
|
|
|
if orig_child._is_leaf: |
|
|
|
#assert np.allclose(orig_child.patch_feature_volume, new_child.patch_feature_volume), f"子节点{i} 特征不匹配" |
|
|
|
print(f"子节点{i} 叶子节点特征验证通过!") |
|
|
|
else: |
|
|
|
print(f"子节点{i} 是非叶子节点,继续验证其子节点...") |
|
|
|
|
|
|
|
# 6. 打印部分树结构对比 |
|
|
|
print("\n原始树结构(前2层):") |
|
|
|
root.print_tree(max_print_depth=2) |
|
|
|
|
|
|
|
print("\n加载后的树结构(前2层):") |
|
|
|
new_root.print_tree(max_print_depth=2) |
|
|
|
|
|
|
|
print("\n八叉树保存和加载测试全部通过!") |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
test_octree() # 运行基本功能测试 |
|
|
|
test_octree_save_load() # 运行保存加载测试 |
|
|
|
def _deserialize_node(self, data): |
|
|
|
node = OctreeNode(data['bbox'], 0) # max_depth会在encoder中重建 |
|
|
|
node._is_leaf = data['is_leaf'] |
|
|
|
node.param_key = data['param_key'] |
|
|
|
node.child_nodes = [self._deserialize_node(c) for c in data['child_nodes']] |
|
|
|
return node |