You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
402 lines
18 KiB
402 lines
18 KiB
from typing import Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
from brep2sdf.data.utils import process_surf_ncs_with_dynamic_padding
|
|
from brep2sdf.networks.patch_graph import PatchGraph
|
|
from brep2sdf.utils.logger import logger
|
|
|
|
|
|
|
|
|
|
def bbox_intersect_(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor:
|
|
"""判断两个轴对齐包围盒(AABB)是否相交
|
|
|
|
参数:
|
|
bbox1: 形状为 (6,) 的张量,格式 [min_x, min_y, min_z, max_x, max_y, max_z]
|
|
bbox2: 同bbox1格式
|
|
|
|
返回:
|
|
torch.Tensor: 两包围盒是否相交(包括刚好接触的情况)
|
|
"""
|
|
assert bbox1.shape == (6,) and bbox2.shape == (6,), "输入必须是形状为(6,)的张量"
|
|
|
|
# 提取min和max坐标
|
|
min1, max1 = bbox1[:3], bbox1[3:]
|
|
min2, max2 = bbox2[:3], bbox2[3:]
|
|
|
|
# 向量化比较
|
|
return torch.all((max1 >= min2) & (max2 >= min1))
|
|
|
|
def if_points_in_box(points: np.ndarray, bbox: torch.Tensor) -> bool:
|
|
"""判断点是否在AABB包围盒内
|
|
|
|
参数:
|
|
points: 形状为 (N, 3) 的数组,表示N个点的坐标
|
|
bbox: 形状为 (6,) 的张量,表示AABB包围盒的坐标
|
|
|
|
返回:
|
|
bool: 如果所有点都在包围盒内,返回True,否则返回False
|
|
"""
|
|
# 将 points 转换为 torch.Tensor
|
|
points_tensor = torch.tensor(points, dtype=torch.float32, device=bbox.device)
|
|
|
|
# 提取min和max坐标
|
|
min_coords = bbox[:3]
|
|
max_coords = bbox[3:]
|
|
#logger.debug(f"min_coords: {min_coords}, max_coords: {max_coords}")
|
|
# 向量化比较
|
|
return torch.any((points_tensor >= min_coords) & (points_tensor <= max_coords)).item()
|
|
|
|
def bbox_intersect(
|
|
surf_bboxes: torch.Tensor,
|
|
indices: torch.Tensor,
|
|
child_bboxes: torch.Tensor,
|
|
surf_points: torch.Tensor = None
|
|
) -> torch.Tensor:
|
|
'''
|
|
args:
|
|
surf_bboxes: [B, 6] - 表示多个包围盒的张量,每个包围盒由其最小和最大坐标定义。
|
|
indices: 选择bbox, [N], N <= B - 用于选择特定包围盒的索引张量。
|
|
child_bboxes: [8, 6] - 一个包围盒被分成八个子包围盒后的结果。
|
|
surf_points: [B, M_max, 3] - 每个包围盒对应的点云数据(可选)。
|
|
return:
|
|
result_mask: [8, B] - 布尔掩码,表示每个子边界框与所有包围盒是否相交,
|
|
且是否包含至少一个点(如果提供了点云)。
|
|
'''
|
|
# 初始化全为 False 的结果掩码 [8, B]
|
|
B = surf_bboxes.size(0)
|
|
result_mask = torch.zeros((8, B), dtype=torch.bool).to(surf_bboxes.device)
|
|
logger.debug(result_mask.shape)
|
|
logger.debug(indices.shape)
|
|
|
|
# 提取选中的边界框
|
|
selected_bboxes = surf_bboxes[indices] # 形状为 [N, 6]
|
|
min1, max1 = selected_bboxes[:, :3], selected_bboxes[:, 3:] # 形状为 [N, 3]
|
|
min2, max2 = child_bboxes[:, :3], child_bboxes[:, 3:] # 形状为 [8, 3]
|
|
|
|
logger.debug(selected_bboxes.shape)
|
|
# 计算子包围盒与选中包围盒的交集
|
|
intersect_mask = torch.all(
|
|
(max1.unsqueeze(0) >= min2.unsqueeze(1)) & # 形状为 [8, N, 3]
|
|
(max2.unsqueeze(1) >= min1.unsqueeze(0)), # 形状为 [8, N, 3]
|
|
dim=-1
|
|
) # 最终形状为 [8, N]
|
|
|
|
# 更新结果掩码中选中的部分
|
|
result_mask[:, indices] = intersect_mask
|
|
|
|
# 如果提供了点云,进一步检查点是否在子包围盒内
|
|
if surf_points is not None:
|
|
# 提取选中的点云
|
|
selected_points = surf_points[indices] # 形状为 [N, M_max, 3]
|
|
|
|
# 将点云广播到子边界框的维度
|
|
points_expanded = selected_points.unsqueeze(1) # 形状为 [N, 1, M_max, 3]
|
|
min2_expanded = min2.unsqueeze(0).unsqueeze(2) # 形状为 [1, 8, 1, 3]
|
|
max2_expanded = max2.unsqueeze(0).unsqueeze(2) # 形状为 [1, 8, 1, 3]
|
|
|
|
# 判断点是否在子边界框内
|
|
point_in_box_mask = (
|
|
(points_expanded >= min2_expanded) & # 形状为 [N, 8, M_max, 3]
|
|
(points_expanded <= max2_expanded) # 形状为 [N, 8, M_max, 3]
|
|
).all(dim=-1) # 最终形状为 [N, 8, M_max]
|
|
|
|
# 检查每个子边界框是否包含至少一个点
|
|
points_in_boxes_mask = point_in_box_mask.any(dim=-1).permute(1, 0) # 形状为 [8, N]
|
|
|
|
# 合并交集条件和点云条件
|
|
result_mask[:, indices] = result_mask[:, indices] & points_in_boxes_mask
|
|
logger.debug(result_mask.shape)
|
|
return result_mask
|
|
|
|
class OctreeNode(nn.Module):
|
|
def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None,surf_ncs:np.ndarray = None,device=None):
|
|
super().__init__()
|
|
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
# 改为普通张量属性
|
|
self.bbox = bbox.to(self.device) # 显式设备管理
|
|
self.node_bboxes = None
|
|
self.parent_indices = None
|
|
self.child_indices = None
|
|
self.is_leaf_mask = None
|
|
# 面片索引张量
|
|
self.all_face_indices = torch.from_numpy(face_indices).to(self.device)
|
|
self.surf_bbox = surf_bbox.to(self.device) if surf_bbox is not None else None
|
|
self.surf_ncs = process_surf_ncs_with_dynamic_padding(surf_ncs).to(self.device)
|
|
# PatchGraph作为普通属性
|
|
self.patch_graph = patch_graph.to(self.device) if patch_graph is not None else None
|
|
|
|
self.max_depth = max_depth
|
|
self._is_leaf = True
|
|
|
|
@torch.jit.export
|
|
def build_static_tree(self) -> None:
|
|
"""构建静态八叉树结构"""
|
|
# 预计算所有可能的节点数量,确保结果为整数
|
|
total_nodes = int(sum(8**i for i in range(self.max_depth + 1)))
|
|
num_faces = self.all_face_indices.shape[0]
|
|
|
|
# 初始化静态张量,使用整数列表作为形状参数
|
|
self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.device)
|
|
self.parent_indices = torch.full([int(total_nodes)], -1, dtype=torch.long, device=self.device)
|
|
self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.device)
|
|
self.face_indices_mask = torch.zeros([int(total_nodes),num_faces], dtype=torch.bool, device=self.device) # 1 代表有
|
|
self.is_leaf_mask = torch.ones([int(total_nodes)], dtype=torch.bool, device=self.device)
|
|
# 使用队列进行广度优先遍历
|
|
queue = [(0, self.bbox, self.all_face_indices)] # (node_idx, bbox, face_indices)
|
|
current_idx = 0
|
|
|
|
while queue:
|
|
node_idx, bbox, faces = queue.pop(0)
|
|
|
|
#logger.debug(f"Processing node {node_idx} with {len(faces)} faces.")
|
|
self.node_bboxes[node_idx] = bbox
|
|
|
|
# 判断 要不要继续分裂
|
|
if not self._should_split_node(current_idx, faces, total_nodes):
|
|
continue
|
|
|
|
self.is_leaf_mask[node_idx] = 0
|
|
# 计算子节点边界框
|
|
min_coords = bbox[:3]
|
|
max_coords = bbox[3:]
|
|
mid_coords = (min_coords + max_coords) / 2
|
|
|
|
# 生成8个子节点
|
|
child_bboxes = self._generate_child_bboxes(min_coords, mid_coords, max_coords)
|
|
intersect_mask = bbox_intersect(self.surf_bbox, faces, child_bboxes)
|
|
self.face_indices_mask[current_idx + 1:current_idx + 9, :] = intersect_mask
|
|
|
|
# 为每个子节点分配面片
|
|
for i, child_bbox in enumerate(child_bboxes):
|
|
child_idx = child_idx = current_idx + i + 1
|
|
|
|
intersecting_faces = intersect_mask[i].nonzero().flatten()
|
|
#logger.debug(f"Node {child_idx} has {len(intersecting_faces)} intersecting faces.")
|
|
# 更新节点关系
|
|
self.parent_indices[child_idx] = node_idx
|
|
self.child_indices[node_idx, i] = child_idx
|
|
|
|
# 将子节点加入队列
|
|
if len(intersecting_faces) > 0:
|
|
queue.append((child_idx, child_bbox, intersecting_faces.clone().detach()))
|
|
current_idx += 8
|
|
|
|
def _should_split_node(self, current_idx: int,face_indices,max_node:int) -> bool:
|
|
"""判断节点是否需要分裂"""
|
|
# 检查是否达到最大深度
|
|
if current_idx + 8 >= max_node:
|
|
return False
|
|
|
|
# 检查是否为完全图
|
|
#is_clique = self.patch_graph.is_clique(face_indices)
|
|
is_clique = face_indices.shape[0] < 2
|
|
if is_clique:
|
|
#logger.debug(f"Node {current_idx} is a clique. Stopping split.")
|
|
return False
|
|
|
|
return True
|
|
|
|
@torch.jit.export
|
|
def _generate_child_bboxes(self, min_coords: torch.Tensor, mid_coords: torch.Tensor, max_coords: torch.Tensor) -> torch.Tensor:
|
|
# 使用 with torch.no_grad() 减少梯度计算的内存占用
|
|
with torch.no_grad():
|
|
child_bboxes = torch.zeros([8, 6], device=self.device)
|
|
|
|
# 使用向量化操作生成所有子节点边界框
|
|
child_bboxes[0] = torch.cat([min_coords, mid_coords]) # 前下左
|
|
child_bboxes[1] = torch.cat([torch.stack([mid_coords[0], min_coords[1], min_coords[2]]),
|
|
torch.stack([max_coords[0], mid_coords[1], mid_coords[2]])]) # 前下右
|
|
child_bboxes[2] = torch.cat([torch.stack([min_coords[0], mid_coords[1], min_coords[2]]),
|
|
torch.stack([mid_coords[0], max_coords[1], mid_coords[2]])]) # 前上左
|
|
child_bboxes[3] = torch.cat([torch.stack([mid_coords[0], mid_coords[1], min_coords[2]]),
|
|
torch.stack([max_coords[0], max_coords[1], mid_coords[2]])]) # 前上右
|
|
child_bboxes[4] = torch.cat([torch.stack([min_coords[0], min_coords[1], mid_coords[2]]),
|
|
torch.stack([mid_coords[0], mid_coords[1], max_coords[2]])]) # 后下左
|
|
child_bboxes[5] = torch.cat([torch.stack([mid_coords[0], min_coords[1], mid_coords[2]]),
|
|
torch.stack([max_coords[0], mid_coords[1], max_coords[2]])]) # 后下右
|
|
child_bboxes[6] = torch.cat([torch.stack([min_coords[0], mid_coords[1], mid_coords[2]]),
|
|
torch.stack([mid_coords[0], max_coords[1], max_coords[2]])]) # 后上左
|
|
child_bboxes[7] = torch.cat([mid_coords, max_coords]) # 后上右
|
|
|
|
return child_bboxes
|
|
|
|
@torch.jit.export
|
|
def find_leaf(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, bool]:
|
|
"""
|
|
修改后的查找叶子节点方法,返回face indices
|
|
:param query_points: 待查找的点,形状为 (3,)
|
|
:return: (bbox, param_key, face_indices, is_leaf)
|
|
"""
|
|
# 确保输入是单个点
|
|
if query_points.dim() != 1 or query_points.shape[0] != 3:
|
|
raise ValueError(f"query_points 必须是形状为 (3,) 的张量,但得到 {query_points.shape}")
|
|
|
|
current_idx = torch.tensor(0, dtype=torch.long, device=query_points.device)
|
|
max_iterations = 1000 # 防止无限循环
|
|
iteration = 0
|
|
|
|
while iteration < max_iterations:
|
|
# 获取当前节点的叶子状态
|
|
if self.is_leaf_mask[current_idx].item():
|
|
#logger.debug(f"Reached leaf node {current_idx}, with {self.face_indices_mask[current_idx].sum()} faces.")
|
|
if self.face_indices_mask[current_idx].sum() == 0:
|
|
parent_idx = self.parent_indices[current_idx]
|
|
#logger.debug(f"Use parent node {parent_idx}, with {self.face_indices_mask[parent_idx].sum()} faces.")
|
|
if parent_idx == -1:
|
|
# 根节点没有父节点,返回根节点的信息
|
|
#logger.warning(f"Reached root node {current_idx}, with {self.face_indices_mask[current_idx].sum()} faces.")
|
|
return (
|
|
self.node_bboxes[current_idx],
|
|
None, # 新增返回face indices
|
|
False
|
|
)
|
|
return (
|
|
self.node_bboxes[parent_idx],
|
|
self.face_indices_mask[parent_idx], # 新增返回face indices
|
|
False
|
|
)
|
|
return (
|
|
self.node_bboxes[current_idx],
|
|
self.face_indices_mask[current_idx], # 新增返回face indices
|
|
True
|
|
)
|
|
|
|
# 计算子节点索引
|
|
child_idx = self._get_child_indices(query_points.unsqueeze(0),
|
|
self.node_bboxes[current_idx].unsqueeze(0))
|
|
|
|
# 获取下一个要访问的节点
|
|
next_idx = self.child_indices[current_idx, child_idx[0]]
|
|
|
|
# 检查索引是否有效
|
|
if next_idx == -1:
|
|
raise IndexError(f"Invalid child node index: {child_idx[0]}")
|
|
|
|
current_idx = next_idx
|
|
iteration += 1
|
|
|
|
# 如果达到最大迭代次数,返回当前节点的信息
|
|
return self.node_bboxes[current_idx], None,bool(self.is_leaf_mask[current_idx].item())
|
|
|
|
@torch.jit.export
|
|
def _get_child_indices(self, points: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor:
|
|
"""批量计算点所在的子节点索引"""
|
|
mid_points = (bboxes[:, :3] + bboxes[:, 3:]) / 2
|
|
return ((points >= mid_points) << torch.arange(3, device=points.device)).sum(dim=1)
|
|
|
|
def forward(self, query_points):
|
|
with torch.no_grad():
|
|
bboxes, face_indices_mask, csg_trees = [], [], []
|
|
for point in query_points:
|
|
bbox, faces_mask, _ = self.find_leaf(point)
|
|
bboxes.append(bbox)
|
|
face_indices_mask.append(faces_mask)
|
|
# 获取当前节点的CSG树结构
|
|
csg_tree = self.patch_graph.get_csg_tree(faces_mask) if self.patch_graph else None
|
|
csg_trees.append(csg_tree) # 保持原始列表结构
|
|
return (
|
|
torch.stack(bboxes),
|
|
torch.stack(face_indices_mask),
|
|
csg_trees # 直接返回列表,不转换为张量
|
|
)
|
|
|
|
def print_tree(self, max_print_depth: int = None) -> None:
|
|
"""
|
|
使用深度优先遍历 (DFS) 打印树结构,父子关系通过缩进体现。
|
|
|
|
参数:
|
|
max_print_depth (int): 最大打印深度 (None 表示打印全部)
|
|
"""
|
|
def dfs(node_idx: int, depth: int):
|
|
"""
|
|
深度优先遍历辅助函数。
|
|
|
|
参数:
|
|
node_idx (int): 当前节点索引
|
|
depth (int): 当前节点的深度
|
|
"""
|
|
# 如果超过最大打印深度,跳过当前节点及其子节点
|
|
if max_print_depth is not None and depth > max_print_depth:
|
|
return
|
|
|
|
indent = " " * depth # 根据深度生成缩进
|
|
is_leaf = self.is_leaf_mask[node_idx].item() # 判断是否为叶子节点
|
|
bbox = self.node_bboxes[node_idx].cpu().numpy().tolist() # 获取边界框信息
|
|
|
|
# 打印当前节点的基本信息
|
|
node_type = "Leaf" if is_leaf else "Internal"
|
|
log_lines.append(f"{indent}L{depth} [{node_type}] NODE_ID-{node_idx}, BBox: {bbox}")
|
|
if self.face_indices_mask is not None:
|
|
face_indices = self.face_indices_mask[node_idx].nonzero().cpu().numpy().flatten().tolist()
|
|
log_lines.append(f"{indent} Face Indices: {face_indices}")
|
|
# 如果是叶子节点,打印额外信息
|
|
if is_leaf:
|
|
|
|
child_indices = self.child_indices[node_idx].cpu().numpy().tolist()
|
|
log_lines.append(f"{indent} Child Indices: {child_indices}")
|
|
|
|
# 如果不是叶子节点,递归处理子节点
|
|
if not is_leaf:
|
|
for i in range(8): # 遍历所有子节点
|
|
child_idx = self.child_indices[node_idx, i].item()
|
|
if child_idx != -1: # 忽略无效的子节点索引
|
|
dfs(child_idx, depth + 1)
|
|
|
|
# 初始化日志行列表
|
|
log_lines = []
|
|
|
|
# 从根节点开始深度优先遍历
|
|
dfs(0, 0)
|
|
|
|
# 统一输出所有日志
|
|
logger.debug("\n".join(log_lines))
|
|
|
|
def __getstate__(self):
|
|
"""支持pickle序列化"""
|
|
state = {
|
|
'bbox': self.bbox,
|
|
'node_bboxes': self.node_bboxes,
|
|
'parent_indices': self.parent_indices,
|
|
'child_indices': self.child_indices,
|
|
'is_leaf_mask': self.is_leaf_mask,
|
|
'face_indices': self.face_indices,
|
|
'surf_bbox': self.surf_bbox,
|
|
'patch_graph': self.patch_graph,
|
|
'max_depth': self.max_depth,
|
|
'_is_leaf': self._is_leaf
|
|
}
|
|
return state
|
|
|
|
def __setstate__(self, state):
|
|
"""支持pickle反序列化"""
|
|
self.bbox = state['bbox']
|
|
self.node_bboxes = state['node_bboxes']
|
|
self.parent_indices = state['parent_indices']
|
|
self.child_indices = state['child_indices']
|
|
self.is_leaf_mask = state['is_leaf_mask']
|
|
self.face_indices = state['face_indices']
|
|
self.surf_bbox = state['surf_bbox']
|
|
self.patch_graph = state['patch_graph']
|
|
self.max_depth = state['max_depth']
|
|
self._is_leaf = state['_is_leaf']
|
|
|
|
def to(self, device=None, dtype=None, non_blocking=False):
|
|
# 调用父类方法迁移基础参数
|
|
super().to(device, dtype, non_blocking)
|
|
|
|
# 迁移自定义属性
|
|
if self.patch_graph is not None:
|
|
if hasattr(self.patch_graph, 'to'):
|
|
self.patch_graph = self.patch_graph.to(device=device, dtype=dtype)
|
|
else:
|
|
# 手动移动非Module属性
|
|
for attr in ['edge_index', 'edge_type', 'patch_features']:
|
|
tensor = getattr(self.patch_graph, attr, None)
|
|
if tensor is not None:
|
|
setattr(self.patch_graph, attr, tensor.to(device=device, dtype=dtype))
|
|
|
|
return self
|