You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

243 lines
9.6 KiB

from typing import Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from brep2sdf.utils.logger import logger
def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> bool:
"""判断两个轴对齐包围盒(AABB)是否相交
参数:
bbox1: 形状为 (6,) 的张量,格式 [min_x, min_y, min_z, max_x, max_y, max_z]
bbox2: 同bbox1格式
返回:
bool: 两包围盒是否相交(包括刚好接触的情况)
"""
assert bbox1.shape == (6,) and bbox2.shape == (6,), "输入必须是形状为(6,)的张量"
# 提取min和max坐标
min1, max1 = bbox1[:3], bbox1[3:]
min2, max2 = bbox2[:3], bbox2[3:]
# 向量化比较
return torch.all((max1 >= min2) & (max2 >= min1))
class OctreeNode(nn.Module):
device=None
surf_bbox = None
def __init__(self, bbox: torch.Tensor,face_indices: np.ndarray, max_depth: int = 5, surf_bbox:torch.Tensor = None):
super().__init__()
self.bbox = bbox # 节点的边界框
self.max_depth: int = max_depth # 最大深度,当这个为0时,表示已经到达最大深度,不可再分子节点
self.child_nodes: List['OctreeNode'] = [] # 子节点列表
self.face_indices = face_indices
self.param_key = None
#self.patch_feature_volume = None # 补丁特征体积,only leaf has
self._is_leaf = True
#print(f"box shape: {self.bbox.shape}")
if surf_bbox is not None:
if not isinstance(surf_bbox, torch.Tensor):
raise TypeError(
f"surf_bbox 必须是 torch.Tensor 类型,但得到 {type(surf_bbox)}"
)
if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6:
raise ValueError(
f"surf_bbox 应为二维张量且形状为 (N,6),但得到 {surf_bbox.shape}"
)
OctreeNode.surf_bbox = surf_bbox # NOTE: 只在根节点时创建
OctreeNode.device = bbox.device
def is_leaf(self):
# Check if self.child——nodes is None before calling len()
return self._is_leaf
def set_param_key(self, k):
self.param_key = k
def conduct_tree(self):
if self.max_depth <= 0 or self.face_indices.shape[0] <= 2:
# 达到最大深度 or 一个单元格至多只有两个面
return
self.subdivide()
def subdivide(self):
#min_x, min_y, min_z, max_x, max_y, max_z = self.bbox
# 使用索引操作替代解包
min_coords = self.bbox[:3] # [min_x, min_y, min_z]
max_coords = self.bbox[3:] # [max_x, max_y, max_z]
# 计算中间点
mid_coords = (min_coords + max_coords) / 2
# 提取 min_x, min_y, min_z, mid_x, mid_y, mid_z
min_x, min_y, min_z = min_coords[0], min_coords[1], min_coords[2]
mid_x, mid_y, mid_z = mid_coords[0], mid_coords[1], mid_coords[2]
max_x, max_y, max_z = max_coords[0], max_coords[1], max_coords[2]
# 生成 8 个子包围盒
child_bboxes = torch.stack([
torch.cat([min_coords, mid_coords]), # 前下左
torch.cat([torch.tensor([mid_x, min_y, min_z], device=self.bbox.device),
torch.tensor([max_x, mid_y, mid_z], device=self.bbox.device)]), # 前下右
torch.cat([torch.tensor([min_x, mid_y, min_z], device=self.bbox.device),
torch.tensor([mid_x, max_y, mid_z], device=self.bbox.device)]), # 前上左
torch.cat([torch.tensor([mid_x, mid_y, min_z], device=self.bbox.device),
torch.tensor([max_x, max_y, mid_z], device=self.bbox.device)]), # 前上右
torch.cat([torch.tensor([min_x, min_y, mid_z], device=self.bbox.device),
torch.tensor([mid_x, mid_y, max_z], device=self.bbox.device)]), # 后下左
torch.cat([torch.tensor([mid_x, min_y, mid_z], device=self.bbox.device),
torch.tensor([max_x, mid_y, max_z], device=self.bbox.device)]), # 后下右
torch.cat([torch.tensor([min_x, mid_y, mid_z], device=self.bbox.device),
torch.tensor([mid_x, max_y, max_z], device=self.bbox.device)]), # 后上左
torch.cat([torch.tensor([mid_x, mid_y, mid_z], device=self.bbox.device),
torch.tensor([max_x, max_y, max_z], device=self.bbox.device)]) # 后上右
])
# 为每个子包围盒创建子节点,并分配相交的面
self.child_nodes = []
for bbox in child_bboxes:
# 找到与子包围盒相交的面
intersecting_faces = []
for face_idx in self.face_indices:
face_bbox = OctreeNode.surf_bbox[face_idx]
if bbox_intersect(bbox, face_bbox):
intersecting_faces.append(face_idx)
#print(f"{bbox}: {intersecting_faces}")
child_node = OctreeNode(
bbox=bbox,
face_indices=np.array(intersecting_faces),
max_depth=self.max_depth - 1
)
child_node.conduct_tree()
self.child_nodes.append(child_node)
self._is_leaf = False
def get_child_index(self, query_point: torch.Tensor) -> int:
"""
计算点所在子节点的索引
:param query_point: 待检查的点,格式为 (x, y, z)
:return: 子节点的索引,范围从 0 到 7
"""
# 确保 query_point 和 bbox 在同一设备上
query_point = query_point.to(self.bbox.device)
# 提取 bbox 的最小和最大坐标
min_coords = self.bbox[:3] # [min_x, min_y, min_z]
max_coords = self.bbox[3:] # [max_x, max_y, max_z]
# 计算中间点
mid_coords = (min_coords + max_coords) / 2
# 使用布尔比较结果计算索引
index = ((query_point >= mid_coords) << torch.arange(3, device=self.bbox.device)).sum()
return index.unsqueeze(0)
def find_leaf(self, query_point:torch.Tensor):
# 从根节点开始递归查找包含该点的叶子节点
if self._is_leaf:
return self
else:
index = self.get_child_index(query_point)
try:
# 直接访问子节点,不进行显式检查
return self.child_nodes[index].find_leaf(query_point)
except IndexError as e:
# 记录错误日志并重新抛出异常
logger.error(
f"Error accessing child node: {e}. "
f"Query point: {query_point.cpu().numpy().tolist()}, "
f"Node bbox: {self.bbox.cpu().numpy().tolist()}, "
f"Depth info: {self.max_depth}"
)
raise e
def get_feature_vector(self, query_point:torch.Tensor):
"""
预测给定点的 SDF 值
:param point: 待预测的点,格式为 (x, y, z)
:return: 预测的 SDF 值
"""
# 将点转换为 numpy 数组
# 从根节点开始递归查找包含该点的叶子节点
if self._is_leaf:
return self.trilinear_interpolation(query_point)
else:
index = self.get_child_index(query_point)
try:
# 直接访问子节点,不进行显式检查
return self.child_nodes[index].get_feature_vector(query_point)
except IndexError as e:
# 记录错误日志并重新抛出异常
logger.error(
f"Error accessing child node: {e}. "
f"Query point: {query_point.cpu().numpy().tolist()}, "
f"Node bbox: {self.bbox.cpu().numpy().tolist()}, "
f"Depth info: {self.max_depth}"
)
raise e
def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None:
"""
递归打印八叉树结构
参数:
depth: 当前深度 (内部使用)
max_print_depth: 最大打印深度 (None表示打印全部)
"""
if max_print_depth is not None and depth > max_print_depth:
return
# 打印当前节点信息
indent = " " * depth
node_type = "Leaf" if self._is_leaf else "Internal"
print(f"{indent}L{depth} [{node_type}] BBox: {self.bbox.cpu().numpy().tolist()}")
# 打印面片信息(如果有)
if self.face_indices is not None:
print(f"{indent} Face indices: {self.face_indices.tolist()}")
print(f"{indent} len child_nodes: {len(self.child_nodes)}")
# 递归打印子节点
for i, child in enumerate(self.child_nodes):
print(f"{indent} Child {i}:")
child.print_tree(depth + 1, max_print_depth)
def __getstate__(self):
"""支持pickle序列化"""
return self._serialize_node(self)
def __setstate__(self, state):
"""支持pickle反序列化"""
self = self._deserialize_node(state)
def _serialize_node(self, node):
return {
'bbox': node.bbox,
'is_leaf': node._is_leaf,
'child_nodes': [self._serialize_node(c) for c in node.child_nodes],
'param_key': node.param_key
}
def _deserialize_node(self, data):
node = OctreeNode(data['bbox'], 0) # max_depth会在encoder中重建
node._is_leaf = data['is_leaf']
node.param_key = data['param_key']
node.child_nodes = [self._deserialize_node(c) for c in data['child_nodes']]
return node