|
|
@ -1,5 +1,3 @@ |
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple, List, cast, Dict, Any, Tuple |
|
|
|
|
|
|
|
import torch |
|
|
@ -9,7 +7,7 @@ import numpy as np |
|
|
|
|
|
|
|
from brep2sdf.utils.logger import logger |
|
|
|
|
|
|
|
def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> bool: |
|
|
|
def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor: |
|
|
|
"""判断两个轴对齐包围盒(AABB)是否相交 |
|
|
|
|
|
|
|
参数: |
|
|
@ -17,7 +15,7 @@ def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> bool: |
|
|
|
bbox2: 同bbox1格式 |
|
|
|
|
|
|
|
返回: |
|
|
|
bool: 两包围盒是否相交(包括刚好接触的情况) |
|
|
|
torch.Tensor: 两包围盒是否相交(包括刚好接触的情况) |
|
|
|
""" |
|
|
|
assert bbox1.shape == (6,) and bbox2.shape == (6,), "输入必须是形状为(6,)的张量" |
|
|
|
|
|
|
@ -29,191 +27,142 @@ def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> bool: |
|
|
|
return torch.all((max1 >= min2) & (max2 >= min1)) |
|
|
|
|
|
|
|
class OctreeNode(nn.Module): |
|
|
|
device=None |
|
|
|
surf_bbox = None |
|
|
|
def __init__(self, bbox: torch.Tensor,face_indices: np.ndarray, max_depth: int = 5, surf_bbox:torch.Tensor = None): |
|
|
|
def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None): |
|
|
|
super().__init__() |
|
|
|
self.bbox = bbox # 节点的边界框 |
|
|
|
self.max_depth: int = max_depth # 最大深度,当这个为0时,表示已经到达最大深度,不可再分子节点 |
|
|
|
self.child_nodes: torch.nn.ModuleList = torch.nn.ModuleList() # 子节点列表 |
|
|
|
self.face_indices = face_indices |
|
|
|
# 静态张量存储节点信息 |
|
|
|
self.register_buffer('bbox', bbox) # 当前节点的边界框 |
|
|
|
self.register_buffer('node_bboxes', None) # 所有节点的边界框 |
|
|
|
self.register_buffer('parent_indices', None) # 父节点索引 |
|
|
|
self.register_buffer('child_indices', None) # 子节点索引 |
|
|
|
self.register_buffer('is_leaf_mask', None) # 叶子节点标记 |
|
|
|
self.register_buffer('face_indices', torch.from_numpy(face_indices).to(bbox.device)) # 面片索引张量 |
|
|
|
self.register_buffer('surf_bbox', surf_bbox) # 面片边界框 |
|
|
|
|
|
|
|
self.max_depth = max_depth |
|
|
|
self.param_key = "" |
|
|
|
#self.patch_feature_volume = None # 补丁特征体积,only leaf has |
|
|
|
self._is_leaf = True |
|
|
|
#print(f"box shape: {self.bbox.shape}") |
|
|
|
|
|
|
|
if surf_bbox is not None: |
|
|
|
if not isinstance(surf_bbox, torch.Tensor): |
|
|
|
raise TypeError( |
|
|
|
f"surf_bbox 必须是 torch.Tensor 类型,但得到 {type(surf_bbox)}" |
|
|
|
) |
|
|
|
if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6: |
|
|
|
raise ValueError( |
|
|
|
f"surf_bbox 应为二维张量且形状为 (N,6),但得到 {surf_bbox.shape}" |
|
|
|
) |
|
|
|
OctreeNode.surf_bbox = surf_bbox # NOTE: 只在根节点时创建 |
|
|
|
OctreeNode.device = bbox.device |
|
|
|
|
|
|
|
def is_leaf(self): |
|
|
|
# Check if self.child——nodes is None before calling len() |
|
|
|
return self._is_leaf |
|
|
|
|
|
|
|
def set_param_key(self, k): |
|
|
|
self.param_key = k |
|
|
|
|
|
|
|
def conduct_tree(self): |
|
|
|
if self.max_depth <= 0 or self.face_indices.shape[0] <= 2: |
|
|
|
# 达到最大深度 or 一个单元格至多只有两个面 |
|
|
|
return |
|
|
|
self.subdivide() |
|
|
|
|
|
|
|
@torch.jit.export |
|
|
|
def set_param_key(self, k: str) -> None: |
|
|
|
self.param_key = k |
|
|
|
|
|
|
|
def subdivide(self): |
|
|
|
@torch.jit.export |
|
|
|
def build_static_tree(self) -> None: |
|
|
|
"""构建静态八叉树结构""" |
|
|
|
# 预计算所有可能的节点数量,确保结果为整数 |
|
|
|
total_nodes = int(sum(8**i for i in range(self.max_depth + 1))) |
|
|
|
|
|
|
|
# 初始化静态张量,使用整数列表作为形状参数 |
|
|
|
self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.bbox.device) |
|
|
|
self.parent_indices = torch.full([int(total_nodes)], -1, dtype=torch.long, device=self.bbox.device) |
|
|
|
self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.bbox.device) |
|
|
|
self.is_leaf_mask = torch.zeros([int(total_nodes)], dtype=torch.bool, device=self.bbox.device) |
|
|
|
|
|
|
|
# 使用队列进行广度优先遍历 |
|
|
|
queue = [(0, self.bbox, self.face_indices)] # (node_idx, bbox, face_indices) |
|
|
|
current_idx = 0 |
|
|
|
|
|
|
|
while queue: |
|
|
|
node_idx, bbox, faces = queue.pop(0) |
|
|
|
self.node_bboxes[node_idx] = bbox |
|
|
|
|
|
|
|
if faces.shape[0] <= 2 or current_idx >= self.max_depth: |
|
|
|
self.is_leaf_mask[node_idx] = True |
|
|
|
continue |
|
|
|
|
|
|
|
# 计算子节点边界框 |
|
|
|
min_coords = bbox[:3] |
|
|
|
max_coords = bbox[3:] |
|
|
|
mid_coords = (min_coords + max_coords) / 2 |
|
|
|
|
|
|
|
#min_x, min_y, min_z, max_x, max_y, max_z = self.bbox |
|
|
|
# 使用索引操作替代解包 |
|
|
|
min_coords = self.bbox[:3] # [min_x, min_y, min_z] |
|
|
|
max_coords = self.bbox[3:] # [max_x, max_y, max_z] |
|
|
|
# 生成8个子节点 |
|
|
|
child_bboxes = self._generate_child_bboxes(min_coords, mid_coords, max_coords) |
|
|
|
|
|
|
|
# 计算中间点 |
|
|
|
mid_coords = (min_coords + max_coords) / 2 |
|
|
|
# 为每个子节点分配面片 |
|
|
|
for i, child_bbox in enumerate(child_bboxes): |
|
|
|
child_idx = current_idx + 1 |
|
|
|
current_idx += 1 |
|
|
|
|
|
|
|
# 提取 min_x, min_y, min_z, mid_x, mid_y, mid_z |
|
|
|
min_x, min_y, min_z = min_coords[0], min_coords[1], min_coords[2] |
|
|
|
mid_x, mid_y, mid_z = mid_coords[0], mid_coords[1], mid_coords[2] |
|
|
|
max_x, max_y, max_z = max_coords[0], max_coords[1], max_coords[2] |
|
|
|
|
|
|
|
# 生成 8 个子包围盒 |
|
|
|
child_bboxes = torch.stack([ |
|
|
|
torch.cat([min_coords, mid_coords]), # 前下左 |
|
|
|
torch.cat([torch.tensor([mid_x, min_y, min_z], device=self.bbox.device), |
|
|
|
torch.tensor([max_x, mid_y, mid_z], device=self.bbox.device)]), # 前下右 |
|
|
|
torch.cat([torch.tensor([min_x, mid_y, min_z], device=self.bbox.device), |
|
|
|
torch.tensor([mid_x, max_y, mid_z], device=self.bbox.device)]), # 前上左 |
|
|
|
torch.cat([torch.tensor([mid_x, mid_y, min_z], device=self.bbox.device), |
|
|
|
torch.tensor([max_x, max_y, mid_z], device=self.bbox.device)]), # 前上右 |
|
|
|
torch.cat([torch.tensor([min_x, min_y, mid_z], device=self.bbox.device), |
|
|
|
torch.tensor([mid_x, mid_y, max_z], device=self.bbox.device)]), # 后下左 |
|
|
|
torch.cat([torch.tensor([mid_x, min_y, mid_z], device=self.bbox.device), |
|
|
|
torch.tensor([max_x, mid_y, max_z], device=self.bbox.device)]), # 后下右 |
|
|
|
torch.cat([torch.tensor([min_x, mid_y, mid_z], device=self.bbox.device), |
|
|
|
torch.tensor([mid_x, max_y, max_z], device=self.bbox.device)]), # 后上左 |
|
|
|
torch.cat([torch.tensor([mid_x, mid_y, mid_z], device=self.bbox.device), |
|
|
|
torch.tensor([max_x, max_y, max_z], device=self.bbox.device)]) # 后上右 |
|
|
|
]) |
|
|
|
|
|
|
|
# 为每个子包围盒创建子节点,并分配相交的面 |
|
|
|
for bbox in child_bboxes: |
|
|
|
# 找到与子包围盒相交的面 |
|
|
|
intersecting_faces = [] |
|
|
|
for face_idx in self.face_indices: |
|
|
|
face_bbox = OctreeNode.surf_bbox[face_idx] |
|
|
|
if bbox_intersect(bbox, face_bbox): |
|
|
|
for face_idx in faces: |
|
|
|
face_bbox = self.surf_bbox[face_idx] |
|
|
|
if bbox_intersect(child_bbox, face_bbox).item(): |
|
|
|
intersecting_faces.append(face_idx) |
|
|
|
#print(f"{bbox}: {intersecting_faces}") |
|
|
|
|
|
|
|
child_node = OctreeNode( |
|
|
|
bbox=bbox, |
|
|
|
face_indices=np.array(intersecting_faces), |
|
|
|
max_depth=self.max_depth - 1 |
|
|
|
) |
|
|
|
child_node.conduct_tree() |
|
|
|
self.child_nodes.append(child_node) |
|
|
|
|
|
|
|
self._is_leaf = False |
|
|
|
|
|
|
|
def get_child_index(self, query_point: torch.Tensor) -> int: |
|
|
|
# 更新节点关系 |
|
|
|
self.parent_indices[child_idx] = node_idx |
|
|
|
self.child_indices[node_idx, i] = child_idx |
|
|
|
|
|
|
|
# 将子节点加入队列 |
|
|
|
if intersecting_faces: |
|
|
|
queue.append((child_idx, child_bbox, torch.tensor(intersecting_faces, device=self.bbox.device))) |
|
|
|
|
|
|
|
@torch.jit.export |
|
|
|
def _generate_child_bboxes(self, min_coords: torch.Tensor, mid_coords: torch.Tensor, max_coords: torch.Tensor) -> torch.Tensor: |
|
|
|
"""生成8个子节点的边界框""" |
|
|
|
child_bboxes = torch.zeros([8, 6], device=self.bbox.device) |
|
|
|
|
|
|
|
# 使用向量化操作生成所有子节点边界框 |
|
|
|
child_bboxes[0] = torch.cat([min_coords, mid_coords]) # 前下左 |
|
|
|
child_bboxes[1] = torch.cat([torch.stack([mid_coords[0], min_coords[1], min_coords[2]]), |
|
|
|
torch.stack([max_coords[0], mid_coords[1], mid_coords[2]])]) # 前下右 |
|
|
|
child_bboxes[2] = torch.cat([torch.stack([min_coords[0], mid_coords[1], min_coords[2]]), |
|
|
|
torch.stack([mid_coords[0], max_coords[1], mid_coords[2]])]) # 前上左 |
|
|
|
child_bboxes[3] = torch.cat([torch.stack([mid_coords[0], mid_coords[1], min_coords[2]]), |
|
|
|
torch.stack([max_coords[0], max_coords[1], mid_coords[2]])]) # 前上右 |
|
|
|
child_bboxes[4] = torch.cat([torch.stack([min_coords[0], min_coords[1], mid_coords[2]]), |
|
|
|
torch.stack([mid_coords[0], mid_coords[1], max_coords[2]])]) # 后下左 |
|
|
|
child_bboxes[5] = torch.cat([torch.stack([mid_coords[0], min_coords[1], mid_coords[2]]), |
|
|
|
torch.stack([max_coords[0], mid_coords[1], max_coords[2]])]) # 后下右 |
|
|
|
child_bboxes[6] = torch.cat([torch.stack([min_coords[0], mid_coords[1], mid_coords[2]]), |
|
|
|
torch.stack([mid_coords[0], max_coords[1], max_coords[2]])]) # 后上左 |
|
|
|
child_bboxes[7] = torch.cat([mid_coords, max_coords]) # 后上右 |
|
|
|
|
|
|
|
return child_bboxes |
|
|
|
|
|
|
|
@torch.jit.export |
|
|
|
def find_leaf(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, str, bool]: |
|
|
|
""" |
|
|
|
计算点所在子节点的索引 |
|
|
|
:param query_point: 待检查的点,格式为 (x, y, z) |
|
|
|
:return: 子节点的索引,范围从 0 到 7 |
|
|
|
查找包含给定点的叶子节点,并返回其信息 |
|
|
|
:param query_points: 待查找的点,形状为 (3,) |
|
|
|
:return: 包含叶子节点信息的元组 (bbox, param_key, is_leaf) |
|
|
|
""" |
|
|
|
# 确保 query_point 和 bbox 在同一设备上 |
|
|
|
query_point = query_point.to(self.bbox.device) |
|
|
|
# 确保输入是单个点 |
|
|
|
if query_points.dim() != 1 or query_points.shape[0] != 3: |
|
|
|
raise ValueError(f"query_points 必须是形状为 (3,) 的张量,但得到 {query_points.shape}") |
|
|
|
|
|
|
|
# 提取 bbox 的最小和最大坐标 |
|
|
|
min_coords = self.bbox[:3] # [min_x, min_y, min_z] |
|
|
|
max_coords = self.bbox[3:] # [max_x, max_y, max_z] |
|
|
|
current_idx = torch.tensor(0, dtype=torch.long, device=query_points.device) |
|
|
|
max_iterations = 1000 # 防止无限循环 |
|
|
|
iteration = 0 |
|
|
|
|
|
|
|
# 计算中间点 |
|
|
|
mid_coords = (min_coords + max_coords) / 2 |
|
|
|
while iteration < max_iterations: |
|
|
|
# 获取当前节点的叶子状态 |
|
|
|
if self.is_leaf_mask[current_idx].item(): |
|
|
|
return self.node_bboxes[current_idx], self.param_key, True |
|
|
|
|
|
|
|
# 使用布尔比较结果计算索引 |
|
|
|
index = ((query_point >= mid_coords) << torch.arange(3, device=self.bbox.device)).sum() |
|
|
|
# 计算子节点索引 |
|
|
|
child_idx = self._get_child_indices(query_points.unsqueeze(0), |
|
|
|
self.node_bboxes[current_idx].unsqueeze(0)) |
|
|
|
|
|
|
|
return index.item() |
|
|
|
# 获取下一个要访问的节点 |
|
|
|
next_idx = self.child_indices[current_idx, child_idx[0]] |
|
|
|
|
|
|
|
def find_leaf(self, query_point: torch.Tensor) -> Tuple[torch.Tensor, str, bool]: |
|
|
|
""" |
|
|
|
查找包含给定点的叶子节点,并返回其信息(以元组形式) |
|
|
|
:param query_point: 待查找的点 |
|
|
|
:return: 包含叶子节点信息的元组 (bbox, param_key, is_leaf) |
|
|
|
""" |
|
|
|
# 如果当前节点是叶子节点,返回其信息 |
|
|
|
if self._is_leaf: |
|
|
|
#logger.info(f"{self.bbox}, {self.param_key}, {True}") |
|
|
|
return (self.bbox, self.param_key, True) |
|
|
|
|
|
|
|
# 计算查询点所在的子节点索引 |
|
|
|
index = self.get_child_index(query_point) |
|
|
|
|
|
|
|
# 遍历子节点列表,找到对应的子节点 |
|
|
|
for i, child_node in enumerate(self.child_nodes): |
|
|
|
if i == index and child_node is not None: |
|
|
|
# 递归调用子节点的 find_leaf 方法 |
|
|
|
result = child_node.find_leaf(query_point) |
|
|
|
|
|
|
|
# 确保返回值是一个元组 |
|
|
|
assert isinstance(result, tuple), f"Unexpected return type: {type(result)}" |
|
|
|
return result |
|
|
|
|
|
|
|
# 如果找不到有效的子节点,抛出异常 |
|
|
|
raise IndexError(f"Invalid child node index: {index}") |
|
|
|
''' |
|
|
|
try: |
|
|
|
# 直接访问子节点,不进行显式检查 |
|
|
|
return self.child_nodes[index].find_leaf(query_point) |
|
|
|
except IndexError as e: |
|
|
|
# 记录错误日志并重新抛出异常 |
|
|
|
logger.error( |
|
|
|
f"Error accessing child node: {e}. " |
|
|
|
f"Query point: {query_point.cpu().numpy().tolist()}, " |
|
|
|
f"Node bbox: {self.bbox.cpu().numpy().tolist()}, " |
|
|
|
f"Depth info: {self.max_depth}" |
|
|
|
) |
|
|
|
raise e |
|
|
|
''' |
|
|
|
|
|
|
|
''' |
|
|
|
def get_feature_vector(self, query_point:torch.Tensor): |
|
|
|
""" |
|
|
|
预测给定点的 SDF 值 |
|
|
|
:param point: 待预测的点,格式为 (x, y, z) |
|
|
|
:return: 预测的 SDF 值 |
|
|
|
""" |
|
|
|
# 将点转换为 numpy 数组 |
|
|
|
|
|
|
|
# 从根节点开始递归查找包含该点的叶子节点 |
|
|
|
if self._is_leaf: |
|
|
|
return self.trilinear_interpolation(query_point) |
|
|
|
else: |
|
|
|
index = self.get_child_index(query_point) |
|
|
|
try: |
|
|
|
# 直接访问子节点,不进行显式检查 |
|
|
|
return self.child_nodes[index].get_feature_vector(query_point) |
|
|
|
except IndexError as e: |
|
|
|
# 记录错误日志并重新抛出异常 |
|
|
|
logger.error( |
|
|
|
f"Error accessing child node: {e}. " |
|
|
|
f"Query point: {query_point.cpu().numpy().tolist()}, " |
|
|
|
f"Node bbox: {self.bbox.cpu().numpy().tolist()}, " |
|
|
|
f"Depth info: {self.max_depth}" |
|
|
|
) |
|
|
|
raise e |
|
|
|
''' |
|
|
|
# 检查索引是否有效 |
|
|
|
if next_idx == -1: |
|
|
|
raise IndexError(f"Invalid child node index: {child_idx[0]}") |
|
|
|
|
|
|
|
current_idx = next_idx |
|
|
|
iteration += 1 |
|
|
|
|
|
|
|
# 如果达到最大迭代次数,返回当前节点的信息 |
|
|
|
return self.node_bboxes[current_idx], self.param_key, bool(self.is_leaf_mask[current_idx].item()) |
|
|
|
|
|
|
|
@torch.jit.export |
|
|
|
def _get_child_indices(self, points: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor: |
|
|
|
"""批量计算点所在的子节点索引""" |
|
|
|
mid_points = (bboxes[:, :3] + bboxes[:, 3:]) / 2 |
|
|
|
return ((points >= mid_points) << torch.arange(3, device=points.device)).sum(dim=1) |
|
|
|
|
|
|
|
def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None: |
|
|
|
""" |
|
|
@ -233,33 +182,41 @@ class OctreeNode(nn.Module): |
|
|
|
|
|
|
|
# 打印面片信息(如果有) |
|
|
|
if self.face_indices is not None: |
|
|
|
print(f"{indent} Face indices: {self.face_indices.tolist()}") |
|
|
|
print(f"{indent} len child_nodes: {len(self.child_nodes)}") |
|
|
|
print(f"{indent} Face indices: {self.face_indices.cpu().numpy().tolist()}") |
|
|
|
print(f"{indent} Child indices: {self.child_indices.cpu().numpy().tolist()}") |
|
|
|
|
|
|
|
# 递归打印子节点 |
|
|
|
for i, child in enumerate(self.child_nodes): |
|
|
|
print(f"{indent} Child {i}:") |
|
|
|
child.print_tree(depth + 1, max_print_depth) |
|
|
|
# 打印子节点信息 |
|
|
|
if self.child_indices is not None: |
|
|
|
for i in range(8): |
|
|
|
child_idx = self.child_indices[0, i].item() |
|
|
|
if child_idx != -1: |
|
|
|
print(f"{indent} Child {i}: Node {child_idx}") |
|
|
|
|
|
|
|
def __getstate__(self): |
|
|
|
"""支持pickle序列化""" |
|
|
|
return self._serialize_node(self) |
|
|
|
state = { |
|
|
|
'bbox': self.bbox, |
|
|
|
'node_bboxes': self.node_bboxes, |
|
|
|
'parent_indices': self.parent_indices, |
|
|
|
'child_indices': self.child_indices, |
|
|
|
'is_leaf_mask': self.is_leaf_mask, |
|
|
|
'face_indices': self.face_indices, |
|
|
|
'surf_bbox': self.surf_bbox, |
|
|
|
'max_depth': self.max_depth, |
|
|
|
'param_key': self.param_key, |
|
|
|
'_is_leaf': self._is_leaf |
|
|
|
} |
|
|
|
return state |
|
|
|
|
|
|
|
def __setstate__(self, state): |
|
|
|
"""支持pickle反序列化""" |
|
|
|
self = self._deserialize_node(state) |
|
|
|
|
|
|
|
def _serialize_node(self, node): |
|
|
|
return { |
|
|
|
'bbox': node.bbox, |
|
|
|
'is_leaf': node._is_leaf, |
|
|
|
'child_nodes': [self._serialize_node(c) for c in node.child_nodes], |
|
|
|
'param_key': node.param_key |
|
|
|
} |
|
|
|
|
|
|
|
def _deserialize_node(self, data): |
|
|
|
node = OctreeNode(data['bbox'], 0) # max_depth会在encoder中重建 |
|
|
|
node._is_leaf = data['is_leaf'] |
|
|
|
node.param_key = data['param_key'] |
|
|
|
node.child_nodes = [self._deserialize_node(c) for c in data['child_nodes']] |
|
|
|
return node |
|
|
|
self.bbox = state['bbox'] |
|
|
|
self.node_bboxes = state['node_bboxes'] |
|
|
|
self.parent_indices = state['parent_indices'] |
|
|
|
self.child_indices = state['child_indices'] |
|
|
|
self.is_leaf_mask = state['is_leaf_mask'] |
|
|
|
self.face_indices = state['face_indices'] |
|
|
|
self.surf_bbox = state['surf_bbox'] |
|
|
|
self.max_depth = state['max_depth'] |
|
|
|
self.param_key = state['param_key'] |
|
|
|
self._is_leaf = state['_is_leaf'] |