|
|
@ -5,7 +5,6 @@ import torch.nn as nn |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from brep2sdf.networks.patch_graph import PatchGraph |
|
|
|
from brep2sdf.utils.logger import logger |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -63,7 +62,6 @@ class OctreeNode(nn.Module): |
|
|
|
"""构建静态八叉树结构""" |
|
|
|
# 预计算所有可能的节点数量,确保结果为整数 |
|
|
|
total_nodes = int(sum(8**i for i in range(self.max_depth + 1))) |
|
|
|
logger.info(f"总节点数量: {total_nodes}") |
|
|
|
|
|
|
|
# 初始化静态张量,使用整数列表作为形状参数 |
|
|
|
self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.bbox.device) |
|
|
@ -71,7 +69,6 @@ class OctreeNode(nn.Module): |
|
|
|
self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.bbox.device) |
|
|
|
self.is_leaf_mask = torch.zeros([int(total_nodes)], dtype=torch.bool, device=self.bbox.device) |
|
|
|
|
|
|
|
logger.gpu_memory_stats("树初始化后") |
|
|
|
# 使用队列进行广度优先遍历 |
|
|
|
queue = [(0, self.bbox, self.face_indices)] # (node_idx, bbox, face_indices) |
|
|
|
current_idx = 0 |
|
|
@ -193,6 +190,17 @@ class OctreeNode(nn.Module): |
|
|
|
mid_points = (bboxes[:, :3] + bboxes[:, 3:]) / 2 |
|
|
|
return ((points >= mid_points) << torch.arange(3, device=points.device)).sum(dim=1) |
|
|
|
|
|
|
|
def forward(self, query_points): |
|
|
|
with torch.no_grad(): |
|
|
|
param_indices, bboxes = [], [] |
|
|
|
for point in query_points: |
|
|
|
bbox, idx, _ = self.find_leaf(point) |
|
|
|
param_indices.append(idx) |
|
|
|
bboxes.append(bbox) |
|
|
|
param_indices = torch.stack(param_indices) |
|
|
|
bboxes = torch.stack(bboxes) |
|
|
|
return param_indices, bboxes |
|
|
|
|
|
|
|
def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None: |
|
|
|
""" |
|
|
|
递归打印八叉树结构 |
|
|
|