|
|
@ -1,4 +1,4 @@ |
|
|
|
from typing import Tuple |
|
|
|
from typing import Tuple,List, Optional |
|
|
|
|
|
|
|
import torch |
|
|
|
import torch.nn as nn |
|
|
@ -54,7 +54,7 @@ def bbox_intersect( |
|
|
|
surf_bboxes: torch.Tensor, |
|
|
|
indices: torch.Tensor, |
|
|
|
child_bboxes: torch.Tensor, |
|
|
|
surf_points: torch.Tensor = None |
|
|
|
surf_points: Optional[torch.Tensor]=None |
|
|
|
) -> torch.Tensor: |
|
|
|
''' |
|
|
|
args: |
|
|
@ -69,15 +69,15 @@ def bbox_intersect( |
|
|
|
# 初始化全为 False 的结果掩码 [8, B] |
|
|
|
B = surf_bboxes.size(0) |
|
|
|
result_mask = torch.zeros((8, B), dtype=torch.bool).to(surf_bboxes.device) |
|
|
|
logger.debug(result_mask.shape) |
|
|
|
logger.debug(indices.shape) |
|
|
|
#logger.debug(result_mask.shape) |
|
|
|
#logger.debug(indices.shape) |
|
|
|
|
|
|
|
# 提取选中的边界框 |
|
|
|
selected_bboxes = surf_bboxes[indices] # 形状为 [N, 6] |
|
|
|
min1, max1 = selected_bboxes[:, :3], selected_bboxes[:, 3:] # 形状为 [N, 3] |
|
|
|
min2, max2 = child_bboxes[:, :3], child_bboxes[:, 3:] # 形状为 [8, 3] |
|
|
|
|
|
|
|
logger.debug(selected_bboxes.shape) |
|
|
|
#logger.debug(selected_bboxes.shape) |
|
|
|
# 计算子包围盒与选中包围盒的交集 |
|
|
|
intersect_mask = torch.all( |
|
|
|
(max1.unsqueeze(0) >= min2.unsqueeze(1)) & # 形状为 [8, N, 3] |
|
|
@ -109,11 +109,11 @@ def bbox_intersect( |
|
|
|
|
|
|
|
# 合并交集条件和点云条件 |
|
|
|
result_mask[:, indices] = result_mask[:, indices] & points_in_boxes_mask |
|
|
|
logger.debug(result_mask.shape) |
|
|
|
#logger.debug(result_mask.shape) |
|
|
|
return result_mask |
|
|
|
|
|
|
|
class OctreeNode(nn.Module): |
|
|
|
def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None,surf_ncs:np.ndarray = None,device=None): |
|
|
|
def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: Optional[torch.Tensor]=None, patch_graph: Optional[PatchGraph] = None,surf_ncs:Optional[np.ndarray] = None,device:Optional[torch.device]=None): |
|
|
|
super().__init__() |
|
|
|
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
# 改为普通张量属性 |
|
|
@ -185,7 +185,7 @@ class OctreeNode(nn.Module): |
|
|
|
queue.append((child_idx, child_bbox, intersecting_faces.clone().detach())) |
|
|
|
current_idx += 8 |
|
|
|
|
|
|
|
def _should_split_node(self, current_idx: int,face_indices,max_node:int) -> bool: |
|
|
|
def _should_split_node(self, current_idx: int,face_indices:torch.Tensor,max_node:int) -> bool: |
|
|
|
"""判断节点是否需要分裂""" |
|
|
|
# 检查是否达到最大深度 |
|
|
|
if current_idx + 8 >= max_node: |
|
|
@ -229,7 +229,7 @@ class OctreeNode(nn.Module): |
|
|
|
""" |
|
|
|
修改后的查找叶子节点方法,返回face indices |
|
|
|
:param query_points: 待查找的点,形状为 (3,) |
|
|
|
:return: (bbox, param_key, face_indices, is_leaf) |
|
|
|
:return: (bbox, face_indices, is_leaf) |
|
|
|
""" |
|
|
|
# 确保输入是单个点 |
|
|
|
if query_points.dim() != 1 or query_points.shape[0] != 3: |
|
|
@ -248,10 +248,9 @@ class OctreeNode(nn.Module): |
|
|
|
#logger.debug(f"Use parent node {parent_idx}, with {self.face_indices_mask[parent_idx].sum()} faces.") |
|
|
|
if parent_idx == -1: |
|
|
|
# 根节点没有父节点,返回根节点的信息 |
|
|
|
#logger.warning(f"Reached root node {current_idx}, with {self.face_indices_mask[current_idx].sum()} faces.") |
|
|
|
return ( |
|
|
|
self.node_bboxes[current_idx], |
|
|
|
None, # 新增返回face indices |
|
|
|
torch.tensor([], dtype=torch.float32, device=self.node_bboxes.device), # 新增返回face indices |
|
|
|
False |
|
|
|
) |
|
|
|
return ( |
|
|
@ -280,7 +279,7 @@ class OctreeNode(nn.Module): |
|
|
|
iteration += 1 |
|
|
|
|
|
|
|
# 如果达到最大迭代次数,返回当前节点的信息 |
|
|
|
return self.node_bboxes[current_idx], None,bool(self.is_leaf_mask[current_idx].item()) |
|
|
|
return self.node_bboxes[current_idx], torch.tensor([], dtype=torch.float32, device=self.node_bboxes.device),bool(self.is_leaf_mask[current_idx].item()) |
|
|
|
|
|
|
|
@torch.jit.export |
|
|
|
def _get_child_indices(self, points: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor: |
|
|
@ -288,23 +287,26 @@ class OctreeNode(nn.Module): |
|
|
|
mid_points = (bboxes[:, :3] + bboxes[:, 3:]) / 2 |
|
|
|
return ((points >= mid_points) << torch.arange(3, device=points.device)).sum(dim=1) |
|
|
|
|
|
|
|
def forward(self, query_points): |
|
|
|
def forward(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
|
with torch.no_grad(): |
|
|
|
bboxes, face_indices_mask, csg_trees = [], [], [] |
|
|
|
bboxes: List[torch.Tensor] = [] |
|
|
|
face_indices_mask: List[torch.Tensor] = [] |
|
|
|
operator: List[int] = [] |
|
|
|
for point in query_points: |
|
|
|
bbox, faces_mask, _ = self.find_leaf(point) |
|
|
|
bbox, faces_mask, _ = self.find_leaf(point) |
|
|
|
bboxes.append(bbox) |
|
|
|
face_indices_mask.append(faces_mask) |
|
|
|
# 获取当前节点的CSG树结构 |
|
|
|
csg_tree = self.patch_graph.get_csg_tree(faces_mask) if self.patch_graph else None |
|
|
|
csg_trees.append(csg_tree) # 保持原始列表结构 |
|
|
|
#csg_tree = self.patch_graph.get_csg_tree(faces_mask) if self.patch_graph else None |
|
|
|
#csg_trees.append(csg_tree) # 保持原始列表结构 |
|
|
|
operator.append(self.patch_graph.get_operator(faces_mask.nonzero()) if self.patch_graph is not None else -1) |
|
|
|
return ( |
|
|
|
torch.stack(bboxes), |
|
|
|
torch.stack(face_indices_mask), |
|
|
|
csg_trees # 直接返回列表,不转换为张量 |
|
|
|
torch.tensor(operator, dtype=torch.int) # 直接返回列表,不转换为张量 |
|
|
|
) |
|
|
|
|
|
|
|
def print_tree(self, max_print_depth: int = None) -> None: |
|
|
|
def print_tree(self, max_print_depth: Optional[int] = None) -> None: |
|
|
|
""" |
|
|
|
使用深度优先遍历 (DFS) 打印树结构,父子关系通过缩进体现。 |
|
|
|
|
|
|
@ -353,36 +355,45 @@ class OctreeNode(nn.Module): |
|
|
|
dfs(0, 0) |
|
|
|
|
|
|
|
# 统一输出所有日志 |
|
|
|
logger.debug("\n".join(log_lines)) |
|
|
|
#logger.debug("\n".join(log_lines)) |
|
|
|
|
|
|
|
def __getstate__(self): |
|
|
|
"""支持pickle序列化""" |
|
|
|
state = { |
|
|
|
'bbox': self.bbox, |
|
|
|
'node_bboxes': self.node_bboxes, |
|
|
|
'parent_indices': self.parent_indices, |
|
|
|
'child_indices': self.child_indices, |
|
|
|
'is_leaf_mask': self.is_leaf_mask, |
|
|
|
'face_indices': self.face_indices, |
|
|
|
'surf_bbox': self.surf_bbox, |
|
|
|
'patch_graph': self.patch_graph, |
|
|
|
'bbox': self.bbox.cpu(), # 转换为CPU张量 |
|
|
|
'node_bboxes': self.node_bboxes.cpu() if self.node_bboxes is not None else None, |
|
|
|
'parent_indices': self.parent_indices.cpu() if self.parent_indices is not None else None, |
|
|
|
'child_indices': self.child_indices.cpu() if self.child_indices is not None else None, |
|
|
|
'is_leaf_mask': self.is_leaf_mask.cpu() if self.is_leaf_mask is not None else None, |
|
|
|
'all_face_indices': self.all_face_indices.cpu(), |
|
|
|
'face_indices_mask':self.face_indices_mask.cpu() if self.face_indices_mask is not None else None, |
|
|
|
'surf_bbox': self.surf_bbox.cpu() if self.surf_bbox is not None else None, |
|
|
|
'surf_ncs': self.surf_ncs.cpu() if self.surf_ncs is not None else None, |
|
|
|
'patch_graph': self.patch_graph, # 假设PatchGraph支持序列化 |
|
|
|
'max_depth': self.max_depth, |
|
|
|
'_is_leaf': self._is_leaf |
|
|
|
'device': str(self.device) # 保存设备信息 |
|
|
|
} |
|
|
|
return state |
|
|
|
|
|
|
|
def __setstate__(self, state): |
|
|
|
"""支持pickle反序列化""" |
|
|
|
self.bbox = state['bbox'] |
|
|
|
self.node_bboxes = state['node_bboxes'] |
|
|
|
self.parent_indices = state['parent_indices'] |
|
|
|
self.child_indices = state['child_indices'] |
|
|
|
self.is_leaf_mask = state['is_leaf_mask'] |
|
|
|
self.face_indices = state['face_indices'] |
|
|
|
self.surf_bbox = state['surf_bbox'] |
|
|
|
self.patch_graph = state['patch_graph'] |
|
|
|
self.max_depth = state['max_depth'] |
|
|
|
self._is_leaf = state['_is_leaf'] |
|
|
|
# 手动调用 __init__ 方法 |
|
|
|
self.__init__( |
|
|
|
bbox=state['bbox'], |
|
|
|
face_indices=state['all_face_indices'].cpu().numpy(), |
|
|
|
patch_graph=state['patch_graph'], |
|
|
|
max_depth=state['max_depth'], |
|
|
|
surf_bbox=state['surf_bbox'], |
|
|
|
surf_ncs=state['surf_ncs'], |
|
|
|
device=torch.device(state['device']) |
|
|
|
) |
|
|
|
# 可以在这里设置其他不需要在 __init__ 中处理的属性 |
|
|
|
self.node_bboxes = state['node_bboxes'].to(self.device) if state['node_bboxes'] is not None else None |
|
|
|
self.parent_indices = state['parent_indices'].to(self.device) if state['parent_indices'] is not None else None |
|
|
|
self.child_indices = state['child_indices'].to(self.device) if state['child_indices'] is not None else None |
|
|
|
self.face_indices_mask = state['face_indices_mask'].to(self.device) if state['face_indices_mask'] is not None else None |
|
|
|
self.is_leaf_mask = state['is_leaf_mask'].to(self.device) if state['is_leaf_mask'] is not None else None |
|
|
|
|
|
|
|
|
|
|
|
def to(self, device=None, dtype=None, non_blocking=False): |
|
|
|
# 调用父类方法迁移基础参数 |
|
|
|