Browse Source

八叉树优化:

1. 去掉 parent_indices索引和child_indices,改用 8*i+1,8*i+8来索引child
2.改变叶节点查找,使用层级遍历方式查找更新叶节点的方式。(为了利用torch的并行计算功能)
final
mckay 1 month ago
parent
commit
6b53d8b1bf
  1. 184
      brep2sdf/networks/octree.py

184
brep2sdf/networks/octree.py

@ -119,8 +119,6 @@ class OctreeNode(nn.Module):
# 改为普通张量属性
self.bbox = bbox.to(self.device) # 显式设备管理
self.node_bboxes = None
self.parent_indices = None
self.child_indices = None
self.is_leaf_mask = None
# 面片索引张量
self.all_face_indices = torch.from_numpy(face_indices).to(self.device)
@ -141,24 +139,21 @@ class OctreeNode(nn.Module):
# 初始化静态张量,使用整数列表作为形状参数
self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.device)
self.parent_indices = torch.full([int(total_nodes)], -1, dtype=torch.long, device=self.device)
self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.device)
self.face_indices_mask = torch.zeros([int(total_nodes),num_faces], dtype=torch.bool, device=self.device) # 1 代表有
self.is_leaf_mask = torch.ones([int(total_nodes)], dtype=torch.bool, device=self.device)
self.is_valid_leaf_mask = torch.ones([int(total_nodes)], dtype=torch.bool, device=self.device)
# 使用队列进行广度优先遍历
queue = [(0, self.bbox, self.all_face_indices)] # (node_idx, bbox, face_indices)
current_idx = 0
while queue:
node_idx, bbox, faces = queue.pop(0)
node_idx, bbox, faces_indices = queue.pop(0)
#logger.debug(f"Processing node {node_idx} with {len(faces)} faces.")
self.node_bboxes[node_idx] = bbox
# 判断 要不要继续分裂
if not self._should_split_node(current_idx, faces, total_nodes):
if not self._should_split_node(node_idx, faces_indices, total_nodes):
continue
self.is_leaf_mask[node_idx] = 0
# 计算子节点边界框
min_coords = bbox[:3]
@ -167,28 +162,28 @@ class OctreeNode(nn.Module):
# 生成8个子节点
child_bboxes = self._generate_child_bboxes(min_coords, mid_coords, max_coords)
intersect_mask = bbox_intersect(self.surf_bbox, faces, child_bboxes)
self.face_indices_mask[current_idx + 1:current_idx + 9, :] = intersect_mask
intersect_mask = bbox_intersect(self.surf_bbox, faces_indices, child_bboxes)
self.face_indices_mask[node_idx * 8 + 1:node_idx * 8 + 9, :] = intersect_mask
# 为每个子节点分配面片
for i, child_bbox in enumerate(child_bboxes):
child_idx = child_idx = current_idx + i + 1
child_idx = 8 * node_idx + i + 1
intersecting_faces = intersect_mask[i].nonzero().flatten()
#logger.debug(f"Node {child_idx} has {len(intersecting_faces)} intersecting faces.")
# 更新节点关系
self.parent_indices[child_idx] = node_idx
self.child_indices[node_idx, i] = child_idx
# 将子节点加入队列
if len(intersecting_faces) > 0:
queue.append((child_idx, child_bbox, intersecting_faces.clone().detach()))
current_idx += 8
else:
self.is_valid_leaf_mask[child_idx] = False
#self.is_valid_leaf_mask = self.is_valid_leaf_mask & self.is_leaf_mask
def _should_split_node(self, current_idx: int,face_indices:torch.Tensor,max_node:int) -> bool:
"""判断节点是否需要分裂"""
# 检查是否达到最大深度
if current_idx + 8 >= max_node:
if current_idx * 8 + 9 >= max_node:
return False
# 检查是否为完全图
@ -225,61 +220,64 @@ class OctreeNode(nn.Module):
return child_bboxes
@torch.jit.export
def find_leaf(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, bool]:
def find_leaf_batch(self, query_points: torch.Tensor) -> torch.Tensor:
"""
修改后的查找叶子节点方法返回face indices
:param query_points: 待查找的点形状为 (3,)
:return: (bbox, face_indices, is_leaf)
批量定位点在八叉树中的对应节点索引
Args:
query_points (torch.Tensor): 点的坐标形状为 [N, 3]
Returns:
torch.Tensor: 每个点对应的节点索引形状为 [N]
"""
# 确保输入是单个点
if query_points.dim() != 1 or query_points.shape[0] != 3:
raise ValueError(f"query_points 必须是形状为 (3,) 的张量,但得到 {query_points.shape}")
num_points = query_points.shape[0]
current_idx = torch.tensor(0, dtype=torch.long, device=query_points.device)
max_iterations = 1000 # 防止无限循环
iteration = 0
# 初始化:所有点从根节点开始
node_indices = torch.zeros(num_points, dtype=torch.int64, device=self.device)
active_mask = torch.full((num_points,),
not self.is_leaf_mask[0].item(),
dtype=torch.bool,
device=self.device) # 标记需要继续处理的点
while iteration < max_iterations:
# 获取当前节点的叶子状态
if self.is_leaf_mask[current_idx].item():
#logger.debug(f"Reached leaf node {current_idx}, with {self.face_indices_mask[current_idx].sum()} faces.")
if self.face_indices_mask[current_idx].sum() == 0:
parent_idx = self.parent_indices[current_idx]
#logger.debug(f"Use parent node {parent_idx}, with {self.face_indices_mask[parent_idx].sum()} faces.")
if parent_idx == -1:
# 根节点没有父节点,返回根节点的信息
return (
self.node_bboxes[current_idx],
torch.tensor([], dtype=torch.float32, device=self.node_bboxes.device), # 新增返回face indices
False
)
return (
self.node_bboxes[parent_idx],
self.face_indices_mask[parent_idx], # 新增返回face indices
False
)
return (
self.node_bboxes[current_idx],
self.face_indices_mask[current_idx], # 新增返回face indices
True
)
# 计算子节点索引
child_idx = self._get_child_indices(query_points.unsqueeze(0),
self.node_bboxes[current_idx].unsqueeze(0))
# 获取下一个要访问的节点
next_idx = self.child_indices[current_idx, child_idx[0]]
# 检查索引是否有效
if next_idx == -1:
raise IndexError(f"Invalid child node index: {child_idx[0]}")
current_idx = next_idx
iteration = 0
while active_mask.any() and iteration < self.max_depth:
# 获取当前活动点的索引和对应的节点信息
active_indices = torch.nonzero(active_mask).squeeze(1) # 当前活动点的索引
current_node_indices = node_indices[active_indices] # 当前活动点对应的节点索引
# 获取当前节点的边界框
current_bboxes = self.node_bboxes[current_node_indices] # 形状为 [M, 6],其中 M 是活动点数
# 将边界框分为 min_coords 和 max_coords
min_coords = current_bboxes[:, :3] # 形状为 [M, 3]
max_coords = current_bboxes[:, 3:] # 形状为 [M, 3]
mid_coords = (min_coords + max_coords) / 2 # 中点坐标,形状为 [M, 3]
# 获取当前活动点的坐标
active_points = query_points[active_indices] # 形状为 [M, 3]
# 判断点属于哪个子节点
child_offsets = ((active_points > mid_coords).long() * torch.tensor([1, 2, 4], device=self.device)).sum(dim=1) # 子节点偏移量
child_indices = current_node_indices * 8 + child_offsets + 1 # 子节点索引
# 更新节点索引
node_indices[active_indices] = child_indices
# 更新活动掩码:如果子节点是叶子节点,则不再处理
is_leaf = self.is_leaf_mask[child_indices]
active_mask[active_indices[is_leaf]] = False
# logger.debug(f"iter{iteration}: active_indices{active_indices}")
is_valid = self.is_valid_leaf_mask[child_indices]
#logger.debug(is_valid)
# 如果是无效叶节点,返回父节点
invalid_leaf_mask = is_leaf & (~is_valid)
#logger.debug(torch.nonzero(invalid_leaf_mask).squeeze(1) )
node_indices[active_indices[invalid_leaf_mask]] = (child_indices[invalid_leaf_mask]-1) // 8
#print(active_indices[invalid_leaf_mask])
#logger.debug(child_indices[invalid_leaf_mask] // 8)
iteration += 1
# 如果达到最大迭代次数,返回当前节点的信息
return self.node_bboxes[current_idx], torch.tensor([], dtype=torch.float32, device=self.node_bboxes.device),bool(self.is_leaf_mask[current_idx].item())
return node_indices
@torch.jit.export
def _get_child_indices(self, points: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor:
@ -289,21 +287,30 @@ class OctreeNode(nn.Module):
def forward(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
with torch.no_grad():
bboxes: List[torch.Tensor] = []
face_indices_mask: List[torch.Tensor] = []
operator: List[int] = []
for point in query_points:
bbox, faces_mask, _ = self.find_leaf(point)
bboxes.append(bbox)
face_indices_mask.append(faces_mask)
# 获取当前节点的CSG树结构
#csg_tree = self.patch_graph.get_csg_tree(faces_mask) if self.patch_graph else None
#csg_trees.append(csg_tree) # 保持原始列表结构
operator.append(self.patch_graph.get_operator(faces_mask.nonzero()) if self.patch_graph is not None else -1)
# 使用批量查询方法
#logger.gpu_memory_stats("后")
node_indices = self.find_leaf_batch(query_points) # 点对应的 node idx
#logger.gpu_memory_stats("SDF后")
bboxes = torch.stack([self.node_bboxes[idx] for idx in node_indices])
face_indices_masks = torch.stack([self.face_indices_mask[idx] for idx in node_indices])
#print(face_indices_masks)
# 处理操作符
operators: List[int] = []
if self.patch_graph is not None:
# 批量获取操作符
for i in range(query_points.shape[0]):
faces_mask = face_indices_masks[i]
#print(faces_mask.nonzero())
#print(node_indices[i])
#print(self.is_leaf_mask[node_indices[i]],self.is_valid_leaf_mask[node_indices[i]])
operators.append(self.patch_graph.get_operator(faces_mask.nonzero()))
else:
operators = [-1] * query_points.shape[0]
return (
torch.stack(bboxes),
torch.stack(face_indices_mask),
torch.tensor(operator, dtype=torch.int) # 直接返回列表,不转换为张量
bboxes,
face_indices_masks,
torch.tensor(operators, dtype=torch.int, device=query_points.device)
)
def print_tree(self, max_print_depth: Optional[int] = None) -> None:
@ -327,25 +334,24 @@ class OctreeNode(nn.Module):
indent = " " * depth # 根据深度生成缩进
is_leaf = self.is_leaf_mask[node_idx].item() # 判断是否为叶子节点
is_valid_leaf = self.is_valid_leaf_mask[node_idx].item()
bbox = self.node_bboxes[node_idx].cpu().numpy().tolist() # 获取边界框信息
# 打印当前节点的基本信息
node_type = "Leaf" if is_leaf else "Internal"
log_lines.append(f"{indent}L{depth} [{node_type}] NODE_ID-{node_idx}, BBox: {bbox}")
log_lines.append(f"{indent}L{depth} [{node_type}] NODE_ID-{node_idx}, BBox: {bbox}, VALID:{is_valid_leaf}")
if self.face_indices_mask is not None:
face_indices = self.face_indices_mask[node_idx].nonzero().cpu().numpy().flatten().tolist()
log_lines.append(f"{indent} Face Indices: {face_indices}")
# 如果是叶子节点,打印额外信息
if is_leaf:
child_indices = self.child_indices[node_idx].cpu().numpy().tolist()
log_lines.append(f"{indent} Child Indices: {child_indices}")
pass
# 如果不是叶子节点,递归处理子节点
if not is_leaf:
for i in range(8): # 遍历所有子节点
child_idx = self.child_indices[node_idx, i].item()
if child_idx != -1: # 忽略无效的子节点索引
child_idx = 8 * node_idx + i + 1
if self.is_leaf_mask[child_idx] != -1: # 忽略无效的子节点索引
dfs(child_idx, depth + 1)
# 初始化日志行列表
@ -362,9 +368,8 @@ class OctreeNode(nn.Module):
state = {
'bbox': self.bbox.cpu(), # 转换为CPU张量
'node_bboxes': self.node_bboxes.cpu() if self.node_bboxes is not None else None,
'parent_indices': self.parent_indices.cpu() if self.parent_indices is not None else None,
'child_indices': self.child_indices.cpu() if self.child_indices is not None else None,
'is_leaf_mask': self.is_leaf_mask.cpu() if self.is_leaf_mask is not None else None,
"is_valid_leaf_mask": self.is_valid_leaf_mask.cpu() if self.is_valid_leaf_mask is not None else None,
'all_face_indices': self.all_face_indices.cpu(),
'face_indices_mask':self.face_indices_mask.cpu() if self.face_indices_mask is not None else None,
'surf_bbox': self.surf_bbox.cpu() if self.surf_bbox is not None else None,
@ -389,10 +394,9 @@ class OctreeNode(nn.Module):
)
# 可以在这里设置其他不需要在 __init__ 中处理的属性
self.node_bboxes = state['node_bboxes'].to(self.device) if state['node_bboxes'] is not None else None
self.parent_indices = state['parent_indices'].to(self.device) if state['parent_indices'] is not None else None
self.child_indices = state['child_indices'].to(self.device) if state['child_indices'] is not None else None
self.face_indices_mask = state['face_indices_mask'].to(self.device) if state['face_indices_mask'] is not None else None
self.is_leaf_mask = state['is_leaf_mask'].to(self.device) if state['is_leaf_mask'] is not None else None
self.is_valid_leaf_mask = state['is_valid_leaf_mask'].to(self.device) if state['is_valid_leaf_mask'] is not None else None
def to(self, device=None, dtype=None, non_blocking=False):

Loading…
Cancel
Save