From 6b53d8b1bf988ef5fd87daf1885126fd19e697e0 Mon Sep 17 00:00:00 2001 From: mckay Date: Sun, 27 Apr 2025 20:36:19 +0800 Subject: [PATCH] =?UTF-8?q?=E5=85=AB=E5=8F=89=E6=A0=91=E4=BC=98=E5=8C=96?= =?UTF-8?q?=EF=BC=9A=201.=20=E5=8E=BB=E6=8E=89=20parent=5Findices=E7=B4=A2?= =?UTF-8?q?=E5=BC=95=E5=92=8Cchild=5Findices=EF=BC=8C=E6=94=B9=E7=94=A8=20?= =?UTF-8?q?8*i+1,8*i+8=E6=9D=A5=E7=B4=A2=E5=BC=95child=202.=E6=94=B9?= =?UTF-8?q?=E5=8F=98=E5=8F=B6=E8=8A=82=E7=82=B9=E6=9F=A5=E6=89=BE=EF=BC=8C?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E5=B1=82=E7=BA=A7=E9=81=8D=E5=8E=86=E6=96=B9?= =?UTF-8?q?=E5=BC=8F=E6=9F=A5=E6=89=BE=E6=9B=B4=E6=96=B0=E5=8F=B6=E8=8A=82?= =?UTF-8?q?=E7=82=B9=E7=9A=84=E6=96=B9=E5=BC=8F=E3=80=82=EF=BC=88=E4=B8=BA?= =?UTF-8?q?=E4=BA=86=E5=88=A9=E7=94=A8torch=E7=9A=84=E5=B9=B6=E8=A1=8C?= =?UTF-8?q?=E8=AE=A1=E7=AE=97=E5=8A=9F=E8=83=BD=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/networks/octree.py | 182 ++++++++++++++++++------------------ 1 file changed, 93 insertions(+), 89 deletions(-) diff --git a/brep2sdf/networks/octree.py b/brep2sdf/networks/octree.py index 1726914..a9de420 100644 --- a/brep2sdf/networks/octree.py +++ b/brep2sdf/networks/octree.py @@ -119,8 +119,6 @@ class OctreeNode(nn.Module): # 改为普通张量属性 self.bbox = bbox.to(self.device) # 显式设备管理 self.node_bboxes = None - self.parent_indices = None - self.child_indices = None self.is_leaf_mask = None # 面片索引张量 self.all_face_indices = torch.from_numpy(face_indices).to(self.device) @@ -141,24 +139,21 @@ class OctreeNode(nn.Module): # 初始化静态张量,使用整数列表作为形状参数 self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.device) - self.parent_indices = torch.full([int(total_nodes)], -1, dtype=torch.long, device=self.device) - self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.device) self.face_indices_mask = torch.zeros([int(total_nodes),num_faces], dtype=torch.bool, device=self.device) # 1 代表有 self.is_leaf_mask = torch.ones([int(total_nodes)], dtype=torch.bool, device=self.device) + self.is_valid_leaf_mask = torch.ones([int(total_nodes)], dtype=torch.bool, device=self.device) # 使用队列进行广度优先遍历 queue = [(0, self.bbox, self.all_face_indices)] # (node_idx, bbox, face_indices) - current_idx = 0 while queue: - node_idx, bbox, faces = queue.pop(0) + node_idx, bbox, faces_indices = queue.pop(0) #logger.debug(f"Processing node {node_idx} with {len(faces)} faces.") self.node_bboxes[node_idx] = bbox # 判断 要不要继续分裂 - if not self._should_split_node(current_idx, faces, total_nodes): + if not self._should_split_node(node_idx, faces_indices, total_nodes): continue - self.is_leaf_mask[node_idx] = 0 # 计算子节点边界框 min_coords = bbox[:3] @@ -167,28 +162,28 @@ class OctreeNode(nn.Module): # 生成8个子节点 child_bboxes = self._generate_child_bboxes(min_coords, mid_coords, max_coords) - intersect_mask = bbox_intersect(self.surf_bbox, faces, child_bboxes) - self.face_indices_mask[current_idx + 1:current_idx + 9, :] = intersect_mask + intersect_mask = bbox_intersect(self.surf_bbox, faces_indices, child_bboxes) + self.face_indices_mask[node_idx * 8 + 1:node_idx * 8 + 9, :] = intersect_mask # 为每个子节点分配面片 for i, child_bbox in enumerate(child_bboxes): - child_idx = child_idx = current_idx + i + 1 + child_idx = 8 * node_idx + i + 1 intersecting_faces = intersect_mask[i].nonzero().flatten() #logger.debug(f"Node {child_idx} has {len(intersecting_faces)} intersecting faces.") - # 更新节点关系 - self.parent_indices[child_idx] = node_idx - self.child_indices[node_idx, i] = child_idx - + # 将子节点加入队列 if len(intersecting_faces) > 0: queue.append((child_idx, child_bbox, intersecting_faces.clone().detach())) - current_idx += 8 + else: + self.is_valid_leaf_mask[child_idx] = False + + #self.is_valid_leaf_mask = self.is_valid_leaf_mask & self.is_leaf_mask def _should_split_node(self, current_idx: int,face_indices:torch.Tensor,max_node:int) -> bool: """判断节点是否需要分裂""" # 检查是否达到最大深度 - if current_idx + 8 >= max_node: + if current_idx * 8 + 9 >= max_node: return False # 检查是否为完全图 @@ -225,61 +220,64 @@ class OctreeNode(nn.Module): return child_bboxes @torch.jit.export - def find_leaf(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, bool]: - """ - 修改后的查找叶子节点方法,返回face indices - :param query_points: 待查找的点,形状为 (3,) - :return: (bbox, face_indices, is_leaf) + def find_leaf_batch(self, query_points: torch.Tensor) -> torch.Tensor: """ - # 确保输入是单个点 - if query_points.dim() != 1 or query_points.shape[0] != 3: - raise ValueError(f"query_points 必须是形状为 (3,) 的张量,但得到 {query_points.shape}") + 批量定位点在八叉树中的对应节点索引。 + + Args: + query_points (torch.Tensor): 点的坐标,形状为 [N, 3]。 - current_idx = torch.tensor(0, dtype=torch.long, device=query_points.device) - max_iterations = 1000 # 防止无限循环 - iteration = 0 + Returns: + torch.Tensor: 每个点对应的节点索引,形状为 [N]。 + """ + num_points = query_points.shape[0] + + # 初始化:所有点从根节点开始 + node_indices = torch.zeros(num_points, dtype=torch.int64, device=self.device) + active_mask = torch.full((num_points,), + not self.is_leaf_mask[0].item(), + dtype=torch.bool, + device=self.device) # 标记需要继续处理的点 - while iteration < max_iterations: - # 获取当前节点的叶子状态 - if self.is_leaf_mask[current_idx].item(): - #logger.debug(f"Reached leaf node {current_idx}, with {self.face_indices_mask[current_idx].sum()} faces.") - if self.face_indices_mask[current_idx].sum() == 0: - parent_idx = self.parent_indices[current_idx] - #logger.debug(f"Use parent node {parent_idx}, with {self.face_indices_mask[parent_idx].sum()} faces.") - if parent_idx == -1: - # 根节点没有父节点,返回根节点的信息 - return ( - self.node_bboxes[current_idx], - torch.tensor([], dtype=torch.float32, device=self.node_bboxes.device), # 新增返回face indices - False - ) - return ( - self.node_bboxes[parent_idx], - self.face_indices_mask[parent_idx], # 新增返回face indices - False - ) - return ( - self.node_bboxes[current_idx], - self.face_indices_mask[current_idx], # 新增返回face indices - True - ) + iteration = 0 + while active_mask.any() and iteration < self.max_depth: + # 获取当前活动点的索引和对应的节点信息 + active_indices = torch.nonzero(active_mask).squeeze(1) # 当前活动点的索引 + current_node_indices = node_indices[active_indices] # 当前活动点对应的节点索引 - # 计算子节点索引 - child_idx = self._get_child_indices(query_points.unsqueeze(0), - self.node_bboxes[current_idx].unsqueeze(0)) + # 获取当前节点的边界框 + current_bboxes = self.node_bboxes[current_node_indices] # 形状为 [M, 6],其中 M 是活动点数 - # 获取下一个要访问的节点 - next_idx = self.child_indices[current_idx, child_idx[0]] + # 将边界框分为 min_coords 和 max_coords + min_coords = current_bboxes[:, :3] # 形状为 [M, 3] + max_coords = current_bboxes[:, 3:] # 形状为 [M, 3] + mid_coords = (min_coords + max_coords) / 2 # 中点坐标,形状为 [M, 3] - # 检查索引是否有效 - if next_idx == -1: - raise IndexError(f"Invalid child node index: {child_idx[0]}") - - current_idx = next_idx - iteration += 1 + # 获取当前活动点的坐标 + active_points = query_points[active_indices] # 形状为 [M, 3] + + # 判断点属于哪个子节点 + child_offsets = ((active_points > mid_coords).long() * torch.tensor([1, 2, 4], device=self.device)).sum(dim=1) # 子节点偏移量 + child_indices = current_node_indices * 8 + child_offsets + 1 # 子节点索引 + + # 更新节点索引 + node_indices[active_indices] = child_indices - # 如果达到最大迭代次数,返回当前节点的信息 - return self.node_bboxes[current_idx], torch.tensor([], dtype=torch.float32, device=self.node_bboxes.device),bool(self.is_leaf_mask[current_idx].item()) + # 更新活动掩码:如果子节点是叶子节点,则不再处理 + is_leaf = self.is_leaf_mask[child_indices] + active_mask[active_indices[is_leaf]] = False + # logger.debug(f"iter{iteration}: active_indices{active_indices}") + + is_valid = self.is_valid_leaf_mask[child_indices] + #logger.debug(is_valid) + # 如果是无效叶节点,返回父节点 + invalid_leaf_mask = is_leaf & (~is_valid) + #logger.debug(torch.nonzero(invalid_leaf_mask).squeeze(1) ) + node_indices[active_indices[invalid_leaf_mask]] = (child_indices[invalid_leaf_mask]-1) // 8 + #print(active_indices[invalid_leaf_mask]) + #logger.debug(child_indices[invalid_leaf_mask] // 8) + iteration += 1 + return node_indices @torch.jit.export def _get_child_indices(self, points: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor: @@ -289,21 +287,30 @@ class OctreeNode(nn.Module): def forward(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: with torch.no_grad(): - bboxes: List[torch.Tensor] = [] - face_indices_mask: List[torch.Tensor] = [] - operator: List[int] = [] - for point in query_points: - bbox, faces_mask, _ = self.find_leaf(point) - bboxes.append(bbox) - face_indices_mask.append(faces_mask) - # 获取当前节点的CSG树结构 - #csg_tree = self.patch_graph.get_csg_tree(faces_mask) if self.patch_graph else None - #csg_trees.append(csg_tree) # 保持原始列表结构 - operator.append(self.patch_graph.get_operator(faces_mask.nonzero()) if self.patch_graph is not None else -1) + # 使用批量查询方法 + #logger.gpu_memory_stats("后") + node_indices = self.find_leaf_batch(query_points) # 点对应的 node idx + #logger.gpu_memory_stats("SDF后") + bboxes = torch.stack([self.node_bboxes[idx] for idx in node_indices]) + face_indices_masks = torch.stack([self.face_indices_mask[idx] for idx in node_indices]) + #print(face_indices_masks) + # 处理操作符 + operators: List[int] = [] + if self.patch_graph is not None: + # 批量获取操作符 + for i in range(query_points.shape[0]): + faces_mask = face_indices_masks[i] + #print(faces_mask.nonzero()) + #print(node_indices[i]) + #print(self.is_leaf_mask[node_indices[i]],self.is_valid_leaf_mask[node_indices[i]]) + operators.append(self.patch_graph.get_operator(faces_mask.nonzero())) + else: + operators = [-1] * query_points.shape[0] + return ( - torch.stack(bboxes), - torch.stack(face_indices_mask), - torch.tensor(operator, dtype=torch.int) # 直接返回列表,不转换为张量 + bboxes, + face_indices_masks, + torch.tensor(operators, dtype=torch.int, device=query_points.device) ) def print_tree(self, max_print_depth: Optional[int] = None) -> None: @@ -327,25 +334,24 @@ class OctreeNode(nn.Module): indent = " " * depth # 根据深度生成缩进 is_leaf = self.is_leaf_mask[node_idx].item() # 判断是否为叶子节点 + is_valid_leaf = self.is_valid_leaf_mask[node_idx].item() bbox = self.node_bboxes[node_idx].cpu().numpy().tolist() # 获取边界框信息 # 打印当前节点的基本信息 node_type = "Leaf" if is_leaf else "Internal" - log_lines.append(f"{indent}L{depth} [{node_type}] NODE_ID-{node_idx}, BBox: {bbox}") + log_lines.append(f"{indent}L{depth} [{node_type}] NODE_ID-{node_idx}, BBox: {bbox}, VALID:{is_valid_leaf}") if self.face_indices_mask is not None: face_indices = self.face_indices_mask[node_idx].nonzero().cpu().numpy().flatten().tolist() log_lines.append(f"{indent} Face Indices: {face_indices}") # 如果是叶子节点,打印额外信息 if is_leaf: - - child_indices = self.child_indices[node_idx].cpu().numpy().tolist() - log_lines.append(f"{indent} Child Indices: {child_indices}") + pass # 如果不是叶子节点,递归处理子节点 if not is_leaf: for i in range(8): # 遍历所有子节点 - child_idx = self.child_indices[node_idx, i].item() - if child_idx != -1: # 忽略无效的子节点索引 + child_idx = 8 * node_idx + i + 1 + if self.is_leaf_mask[child_idx] != -1: # 忽略无效的子节点索引 dfs(child_idx, depth + 1) # 初始化日志行列表 @@ -362,9 +368,8 @@ class OctreeNode(nn.Module): state = { 'bbox': self.bbox.cpu(), # 转换为CPU张量 'node_bboxes': self.node_bboxes.cpu() if self.node_bboxes is not None else None, - 'parent_indices': self.parent_indices.cpu() if self.parent_indices is not None else None, - 'child_indices': self.child_indices.cpu() if self.child_indices is not None else None, 'is_leaf_mask': self.is_leaf_mask.cpu() if self.is_leaf_mask is not None else None, + "is_valid_leaf_mask": self.is_valid_leaf_mask.cpu() if self.is_valid_leaf_mask is not None else None, 'all_face_indices': self.all_face_indices.cpu(), 'face_indices_mask':self.face_indices_mask.cpu() if self.face_indices_mask is not None else None, 'surf_bbox': self.surf_bbox.cpu() if self.surf_bbox is not None else None, @@ -389,10 +394,9 @@ class OctreeNode(nn.Module): ) # 可以在这里设置其他不需要在 __init__ 中处理的属性 self.node_bboxes = state['node_bboxes'].to(self.device) if state['node_bboxes'] is not None else None - self.parent_indices = state['parent_indices'].to(self.device) if state['parent_indices'] is not None else None - self.child_indices = state['child_indices'].to(self.device) if state['child_indices'] is not None else None self.face_indices_mask = state['face_indices_mask'].to(self.device) if state['face_indices_mask'] is not None else None self.is_leaf_mask = state['is_leaf_mask'].to(self.device) if state['is_leaf_mask'] is not None else None + self.is_valid_leaf_mask = state['is_valid_leaf_mask'].to(self.device) if state['is_valid_leaf_mask'] is not None else None def to(self, device=None, dtype=None, non_blocking=False):