You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
243 lines
9.6 KiB
243 lines
9.6 KiB
|
|
|
|
from typing import Tuple, List
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
|
|
from brep2sdf.utils.logger import logger
|
|
|
|
def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> bool:
|
|
"""判断两个轴对齐包围盒(AABB)是否相交
|
|
|
|
参数:
|
|
bbox1: 形状为 (6,) 的张量,格式 [min_x, min_y, min_z, max_x, max_y, max_z]
|
|
bbox2: 同bbox1格式
|
|
|
|
返回:
|
|
bool: 两包围盒是否相交(包括刚好接触的情况)
|
|
"""
|
|
assert bbox1.shape == (6,) and bbox2.shape == (6,), "输入必须是形状为(6,)的张量"
|
|
|
|
# 提取min和max坐标
|
|
min1, max1 = bbox1[:3], bbox1[3:]
|
|
min2, max2 = bbox2[:3], bbox2[3:]
|
|
|
|
# 向量化比较
|
|
return torch.all((max1 >= min2) & (max2 >= min1))
|
|
|
|
class OctreeNode(nn.Module):
|
|
device=None
|
|
surf_bbox = None
|
|
def __init__(self, bbox: torch.Tensor,face_indices: np.ndarray, max_depth: int = 5, surf_bbox:torch.Tensor = None):
|
|
super().__init__()
|
|
self.bbox = bbox # 节点的边界框
|
|
self.max_depth: int = max_depth # 最大深度,当这个为0时,表示已经到达最大深度,不可再分子节点
|
|
self.child_nodes: List['OctreeNode'] = [] # 子节点列表
|
|
self.face_indices = face_indices
|
|
self.param_key = None
|
|
#self.patch_feature_volume = None # 补丁特征体积,only leaf has
|
|
self._is_leaf = True
|
|
#print(f"box shape: {self.bbox.shape}")
|
|
|
|
if surf_bbox is not None:
|
|
if not isinstance(surf_bbox, torch.Tensor):
|
|
raise TypeError(
|
|
f"surf_bbox 必须是 torch.Tensor 类型,但得到 {type(surf_bbox)}"
|
|
)
|
|
if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6:
|
|
raise ValueError(
|
|
f"surf_bbox 应为二维张量且形状为 (N,6),但得到 {surf_bbox.shape}"
|
|
)
|
|
OctreeNode.surf_bbox = surf_bbox # NOTE: 只在根节点时创建
|
|
OctreeNode.device = bbox.device
|
|
|
|
def is_leaf(self):
|
|
# Check if self.child——nodes is None before calling len()
|
|
return self._is_leaf
|
|
|
|
def set_param_key(self, k):
|
|
self.param_key = k
|
|
|
|
def conduct_tree(self):
|
|
if self.max_depth <= 0 or self.face_indices.shape[0] <= 2:
|
|
# 达到最大深度 or 一个单元格至多只有两个面
|
|
return
|
|
self.subdivide()
|
|
|
|
|
|
def subdivide(self):
|
|
|
|
#min_x, min_y, min_z, max_x, max_y, max_z = self.bbox
|
|
# 使用索引操作替代解包
|
|
min_coords = self.bbox[:3] # [min_x, min_y, min_z]
|
|
max_coords = self.bbox[3:] # [max_x, max_y, max_z]
|
|
|
|
# 计算中间点
|
|
mid_coords = (min_coords + max_coords) / 2
|
|
|
|
# 提取 min_x, min_y, min_z, mid_x, mid_y, mid_z
|
|
min_x, min_y, min_z = min_coords[0], min_coords[1], min_coords[2]
|
|
mid_x, mid_y, mid_z = mid_coords[0], mid_coords[1], mid_coords[2]
|
|
max_x, max_y, max_z = max_coords[0], max_coords[1], max_coords[2]
|
|
|
|
# 生成 8 个子包围盒
|
|
child_bboxes = torch.stack([
|
|
torch.cat([min_coords, mid_coords]), # 前下左
|
|
torch.cat([torch.tensor([mid_x, min_y, min_z], device=self.bbox.device),
|
|
torch.tensor([max_x, mid_y, mid_z], device=self.bbox.device)]), # 前下右
|
|
torch.cat([torch.tensor([min_x, mid_y, min_z], device=self.bbox.device),
|
|
torch.tensor([mid_x, max_y, mid_z], device=self.bbox.device)]), # 前上左
|
|
torch.cat([torch.tensor([mid_x, mid_y, min_z], device=self.bbox.device),
|
|
torch.tensor([max_x, max_y, mid_z], device=self.bbox.device)]), # 前上右
|
|
torch.cat([torch.tensor([min_x, min_y, mid_z], device=self.bbox.device),
|
|
torch.tensor([mid_x, mid_y, max_z], device=self.bbox.device)]), # 后下左
|
|
torch.cat([torch.tensor([mid_x, min_y, mid_z], device=self.bbox.device),
|
|
torch.tensor([max_x, mid_y, max_z], device=self.bbox.device)]), # 后下右
|
|
torch.cat([torch.tensor([min_x, mid_y, mid_z], device=self.bbox.device),
|
|
torch.tensor([mid_x, max_y, max_z], device=self.bbox.device)]), # 后上左
|
|
torch.cat([torch.tensor([mid_x, mid_y, mid_z], device=self.bbox.device),
|
|
torch.tensor([max_x, max_y, max_z], device=self.bbox.device)]) # 后上右
|
|
])
|
|
|
|
# 为每个子包围盒创建子节点,并分配相交的面
|
|
self.child_nodes = []
|
|
for bbox in child_bboxes:
|
|
# 找到与子包围盒相交的面
|
|
intersecting_faces = []
|
|
for face_idx in self.face_indices:
|
|
face_bbox = OctreeNode.surf_bbox[face_idx]
|
|
if bbox_intersect(bbox, face_bbox):
|
|
intersecting_faces.append(face_idx)
|
|
#print(f"{bbox}: {intersecting_faces}")
|
|
|
|
child_node = OctreeNode(
|
|
bbox=bbox,
|
|
face_indices=np.array(intersecting_faces),
|
|
max_depth=self.max_depth - 1
|
|
)
|
|
child_node.conduct_tree()
|
|
self.child_nodes.append(child_node)
|
|
|
|
self._is_leaf = False
|
|
|
|
def get_child_index(self, query_point: torch.Tensor) -> int:
|
|
"""
|
|
计算点所在子节点的索引
|
|
:param query_point: 待检查的点,格式为 (x, y, z)
|
|
:return: 子节点的索引,范围从 0 到 7
|
|
"""
|
|
# 确保 query_point 和 bbox 在同一设备上
|
|
query_point = query_point.to(self.bbox.device)
|
|
|
|
# 提取 bbox 的最小和最大坐标
|
|
min_coords = self.bbox[:3] # [min_x, min_y, min_z]
|
|
max_coords = self.bbox[3:] # [max_x, max_y, max_z]
|
|
|
|
# 计算中间点
|
|
mid_coords = (min_coords + max_coords) / 2
|
|
|
|
# 使用布尔比较结果计算索引
|
|
index = ((query_point >= mid_coords) << torch.arange(3, device=self.bbox.device)).sum()
|
|
|
|
return index.unsqueeze(0)
|
|
|
|
def find_leaf(self, query_point:torch.Tensor):
|
|
# 从根节点开始递归查找包含该点的叶子节点
|
|
if self._is_leaf:
|
|
return self
|
|
else:
|
|
index = self.get_child_index(query_point)
|
|
try:
|
|
# 直接访问子节点,不进行显式检查
|
|
return self.child_nodes[index].find_leaf(query_point)
|
|
except IndexError as e:
|
|
# 记录错误日志并重新抛出异常
|
|
logger.error(
|
|
f"Error accessing child node: {e}. "
|
|
f"Query point: {query_point.cpu().numpy().tolist()}, "
|
|
f"Node bbox: {self.bbox.cpu().numpy().tolist()}, "
|
|
f"Depth info: {self.max_depth}"
|
|
)
|
|
raise e
|
|
|
|
def get_feature_vector(self, query_point:torch.Tensor):
|
|
"""
|
|
预测给定点的 SDF 值
|
|
:param point: 待预测的点,格式为 (x, y, z)
|
|
:return: 预测的 SDF 值
|
|
"""
|
|
# 将点转换为 numpy 数组
|
|
|
|
# 从根节点开始递归查找包含该点的叶子节点
|
|
if self._is_leaf:
|
|
return self.trilinear_interpolation(query_point)
|
|
else:
|
|
index = self.get_child_index(query_point)
|
|
try:
|
|
# 直接访问子节点,不进行显式检查
|
|
return self.child_nodes[index].get_feature_vector(query_point)
|
|
except IndexError as e:
|
|
# 记录错误日志并重新抛出异常
|
|
logger.error(
|
|
f"Error accessing child node: {e}. "
|
|
f"Query point: {query_point.cpu().numpy().tolist()}, "
|
|
f"Node bbox: {self.bbox.cpu().numpy().tolist()}, "
|
|
f"Depth info: {self.max_depth}"
|
|
)
|
|
raise e
|
|
|
|
|
|
|
|
|
|
|
|
def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None:
|
|
"""
|
|
递归打印八叉树结构
|
|
|
|
参数:
|
|
depth: 当前深度 (内部使用)
|
|
max_print_depth: 最大打印深度 (None表示打印全部)
|
|
"""
|
|
if max_print_depth is not None and depth > max_print_depth:
|
|
return
|
|
|
|
# 打印当前节点信息
|
|
indent = " " * depth
|
|
node_type = "Leaf" if self._is_leaf else "Internal"
|
|
print(f"{indent}L{depth} [{node_type}] BBox: {self.bbox.cpu().numpy().tolist()}")
|
|
|
|
# 打印面片信息(如果有)
|
|
if self.face_indices is not None:
|
|
print(f"{indent} Face indices: {self.face_indices.tolist()}")
|
|
print(f"{indent} len child_nodes: {len(self.child_nodes)}")
|
|
|
|
# 递归打印子节点
|
|
for i, child in enumerate(self.child_nodes):
|
|
print(f"{indent} Child {i}:")
|
|
child.print_tree(depth + 1, max_print_depth)
|
|
|
|
def __getstate__(self):
|
|
"""支持pickle序列化"""
|
|
return self._serialize_node(self)
|
|
|
|
def __setstate__(self, state):
|
|
"""支持pickle反序列化"""
|
|
self = self._deserialize_node(state)
|
|
|
|
def _serialize_node(self, node):
|
|
return {
|
|
'bbox': node.bbox,
|
|
'is_leaf': node._is_leaf,
|
|
'child_nodes': [self._serialize_node(c) for c in node.child_nodes],
|
|
'param_key': node.param_key
|
|
}
|
|
|
|
def _deserialize_node(self, data):
|
|
node = OctreeNode(data['bbox'], 0) # max_depth会在encoder中重建
|
|
node._is_leaf = data['is_leaf']
|
|
node.param_key = data['param_key']
|
|
node.child_nodes = [self._deserialize_node(c) for c in data['child_nodes']]
|
|
return node
|