| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -119,8 +119,6 @@ class OctreeNode(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 改为普通张量属性 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.bbox = bbox.to(self.device)  # 显式设备管理 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.node_bboxes = None   | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.parent_indices = None  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.child_indices = None  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.is_leaf_mask = None   | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 面片索引张量 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.all_face_indices = torch.from_numpy(face_indices).to(self.device) | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -141,24 +139,21 @@ class OctreeNode(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 初始化静态张量,使用整数列表作为形状参数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.parent_indices = torch.full([int(total_nodes)], -1, dtype=torch.long, device=self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.face_indices_mask = torch.zeros([int(total_nodes),num_faces], dtype=torch.bool, device=self.device) # 1 代表有 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.is_leaf_mask = torch.ones([int(total_nodes)], dtype=torch.bool, device=self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.is_valid_leaf_mask = torch.ones([int(total_nodes)], dtype=torch.bool, device=self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 使用队列进行广度优先遍历 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        queue = [(0, self.bbox, self.all_face_indices)]  # (node_idx, bbox, face_indices) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        current_idx = 0 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        while queue: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            node_idx, bbox, faces = queue.pop(0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            node_idx, bbox, faces_indices = queue.pop(0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            #logger.debug(f"Processing node {node_idx} with {len(faces)} faces.") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.node_bboxes[node_idx] = bbox | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 判断 要不要继续分裂 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if not self._should_split_node(current_idx, faces, total_nodes): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if not self._should_split_node(node_idx, faces_indices, total_nodes): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                continue | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.is_leaf_mask[node_idx] = 0 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 计算子节点边界框 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            min_coords = bbox[:3] | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -167,28 +162,28 @@ class OctreeNode(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 生成8个子节点 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            child_bboxes = self._generate_child_bboxes(min_coords, mid_coords, max_coords) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            intersect_mask = bbox_intersect(self.surf_bbox, faces, child_bboxes) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.face_indices_mask[current_idx + 1:current_idx + 9, :] = intersect_mask | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            intersect_mask = bbox_intersect(self.surf_bbox, faces_indices, child_bboxes) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.face_indices_mask[node_idx * 8 + 1:node_idx * 8 + 9, :] = intersect_mask | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 为每个子节点分配面片 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for i, child_bbox in enumerate(child_bboxes): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                child_idx = child_idx = current_idx + i + 1 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                child_idx = 8 * node_idx + i + 1 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                intersecting_faces = intersect_mask[i].nonzero().flatten() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                #logger.debug(f"Node {child_idx} has {len(intersecting_faces)} intersecting faces.") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 更新节点关系 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                self.parent_indices[child_idx] = node_idx | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                self.child_indices[node_idx, i] = child_idx | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 将子节点加入队列 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if len(intersecting_faces) > 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    queue.append((child_idx, child_bbox, intersecting_faces.clone().detach())) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            current_idx += 8 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    self.is_valid_leaf_mask[child_idx] = False | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #self.is_valid_leaf_mask = self.is_valid_leaf_mask & self.is_leaf_mask | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def _should_split_node(self, current_idx: int,face_indices:torch.Tensor,max_node:int) -> bool: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """判断节点是否需要分裂""" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 检查是否达到最大深度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if current_idx + 8 >= max_node: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if current_idx * 8 + 9 >= max_node: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            return False | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 检查是否为完全图 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -225,61 +220,64 @@ class OctreeNode(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            return child_bboxes | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    @torch.jit.export | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def find_leaf(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, bool]: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def find_leaf_batch(self, query_points: torch.Tensor) -> torch.Tensor: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        修改后的查找叶子节点方法,返回face indices | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        :param query_points: 待查找的点,形状为 (3,) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        :return: (bbox, face_indices, is_leaf) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 确保输入是单个点 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if query_points.dim() != 1 or query_points.shape[0] != 3: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise ValueError(f"query_points 必须是形状为 (3,) 的张量,但得到 {query_points.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        current_idx = torch.tensor(0, dtype=torch.long, device=query_points.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        max_iterations = 1000  # 防止无限循环 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        iteration = 0 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        while iteration < max_iterations: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 获取当前节点的叶子状态 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if self.is_leaf_mask[current_idx].item(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                #logger.debug(f"Reached leaf node {current_idx}, with {self.face_indices_mask[current_idx].sum()} faces.") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if self.face_indices_mask[current_idx].sum() == 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    parent_idx = self.parent_indices[current_idx] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    #logger.debug(f"Use parent node {parent_idx}, with {self.face_indices_mask[parent_idx].sum()} faces.") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if parent_idx == -1: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        # 根节点没有父节点,返回根节点的信息 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        return ( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            self.node_bboxes[current_idx], | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            torch.tensor([], dtype=torch.float32, device=self.node_bboxes.device),  # 新增返回face indices | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            False | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    return ( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        self.node_bboxes[parent_idx], | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        self.face_indices_mask[parent_idx],  # 新增返回face indices | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        False | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                return ( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    self.node_bboxes[current_idx], | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    self.face_indices_mask[current_idx],  # 新增返回face indices | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    True | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        批量定位点在八叉树中的对应节点索引。 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 计算子节点索引 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            child_idx = self._get_child_indices(query_points.unsqueeze(0),  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                              self.node_bboxes[current_idx].unsqueeze(0)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        Args: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            query_points (torch.Tensor): 点的坐标,形状为 [N, 3]。 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 获取下一个要访问的节点 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            next_idx = self.child_indices[current_idx, child_idx[0]] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        Returns: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            torch.Tensor: 每个点对应的节点索引,形状为 [N]。 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        num_points = query_points.shape[0] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 检查索引是否有效 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if next_idx == -1: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                raise IndexError(f"Invalid child node index: {child_idx[0]}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 初始化:所有点从根节点开始 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        node_indices = torch.zeros(num_points, dtype=torch.int64, device=self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        active_mask = torch.full((num_points,),  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                               not self.is_leaf_mask[0].item(),  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                               dtype=torch.bool,  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                               device=self.device)  # 标记需要继续处理的点 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            current_idx = next_idx | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        iteration = 0 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        while active_mask.any() and iteration < self.max_depth: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 获取当前活动点的索引和对应的节点信息 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            active_indices = torch.nonzero(active_mask).squeeze(1)  # 当前活动点的索引 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            current_node_indices = node_indices[active_indices]  # 当前活动点对应的节点索引 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 获取当前节点的边界框 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            current_bboxes = self.node_bboxes[current_node_indices]  # 形状为 [M, 6],其中 M 是活动点数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 将边界框分为 min_coords 和 max_coords | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            min_coords = current_bboxes[:, :3]  # 形状为 [M, 3] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            max_coords = current_bboxes[:, 3:]  # 形状为 [M, 3] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            mid_coords = (min_coords + max_coords) / 2  # 中点坐标,形状为 [M, 3] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 获取当前活动点的坐标 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            active_points = query_points[active_indices]  # 形状为 [M, 3] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 判断点属于哪个子节点 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            child_offsets = ((active_points > mid_coords).long() * torch.tensor([1, 2, 4], device=self.device)).sum(dim=1)  # 子节点偏移量 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            child_indices = current_node_indices * 8 + child_offsets + 1  # 子节点索引 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 更新节点索引 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            node_indices[active_indices] = child_indices | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 更新活动掩码:如果子节点是叶子节点,则不再处理 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            is_leaf = self.is_leaf_mask[child_indices] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            active_mask[active_indices[is_leaf]] = False | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # logger.debug(f"iter{iteration}: active_indices{active_indices}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            is_valid = self.is_valid_leaf_mask[child_indices] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            #logger.debug(is_valid) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 如果是无效叶节点,返回父节点 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            invalid_leaf_mask = is_leaf & (~is_valid) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            #logger.debug(torch.nonzero(invalid_leaf_mask).squeeze(1) ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            node_indices[active_indices[invalid_leaf_mask]] = (child_indices[invalid_leaf_mask]-1) // 8 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            #print(active_indices[invalid_leaf_mask]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            #logger.debug(child_indices[invalid_leaf_mask] // 8) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            iteration += 1 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 如果达到最大迭代次数,返回当前节点的信息 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return self.node_bboxes[current_idx], torch.tensor([], dtype=torch.float32, device=self.node_bboxes.device),bool(self.is_leaf_mask[current_idx].item()) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return node_indices | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    @torch.jit.export | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def _get_child_indices(self, points: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor: | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -289,21 +287,30 @@ class OctreeNode(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def forward(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        with torch.no_grad(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            bboxes: List[torch.Tensor] = [] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            face_indices_mask: List[torch.Tensor] = [] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            operator: List[int] = [] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for point in query_points: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                bbox, faces_mask, _ = self.find_leaf(point) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                bboxes.append(bbox) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                face_indices_mask.append(faces_mask) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 获取当前节点的CSG树结构 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                #csg_tree = self.patch_graph.get_csg_tree(faces_mask) if self.patch_graph else None | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                #csg_trees.append(csg_tree)  # 保持原始列表结构 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                operator.append(self.patch_graph.get_operator(faces_mask.nonzero()) if self.patch_graph is not None else -1) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 使用批量查询方法 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            #logger.gpu_memory_stats("后") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            node_indices = self.find_leaf_batch(query_points) # 点对应的 node idx | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            #logger.gpu_memory_stats("SDF后") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            bboxes = torch.stack([self.node_bboxes[idx] for idx in node_indices]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            face_indices_masks = torch.stack([self.face_indices_mask[idx] for idx in node_indices]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            #print(face_indices_masks) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 处理操作符 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            operators: List[int] = [] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if self.patch_graph is not None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 批量获取操作符 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                for i in range(query_points.shape[0]): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    faces_mask = face_indices_masks[i] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    #print(faces_mask.nonzero()) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    #print(node_indices[i]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    #print(self.is_leaf_mask[node_indices[i]],self.is_valid_leaf_mask[node_indices[i]]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    operators.append(self.patch_graph.get_operator(faces_mask.nonzero())) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                operators = [-1] * query_points.shape[0] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            return ( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                torch.stack(bboxes), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                torch.stack(face_indices_mask), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                torch.tensor(operator, dtype=torch.int)  # 直接返回列表,不转换为张量 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                bboxes, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                face_indices_masks, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                torch.tensor(operators, dtype=torch.int, device=query_points.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def print_tree(self, max_print_depth: Optional[int] = None) -> None: | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -327,25 +334,24 @@ class OctreeNode(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            indent = "  " * depth  # 根据深度生成缩进 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            is_leaf = self.is_leaf_mask[node_idx].item()  # 判断是否为叶子节点 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            is_valid_leaf = self.is_valid_leaf_mask[node_idx].item() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            bbox = self.node_bboxes[node_idx].cpu().numpy().tolist()  # 获取边界框信息 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 打印当前节点的基本信息 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            node_type = "Leaf" if is_leaf else "Internal" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            log_lines.append(f"{indent}L{depth} [{node_type}] NODE_ID-{node_idx}, BBox: {bbox}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            log_lines.append(f"{indent}L{depth} [{node_type}] NODE_ID-{node_idx}, BBox: {bbox}, VALID:{is_valid_leaf}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if self.face_indices_mask is not None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                face_indices = self.face_indices_mask[node_idx].nonzero().cpu().numpy().flatten().tolist() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                log_lines.append(f"{indent}  Face Indices: {face_indices}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 如果是叶子节点,打印额外信息 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if is_leaf: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                child_indices = self.child_indices[node_idx].cpu().numpy().tolist() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                log_lines.append(f"{indent}  Child Indices: {child_indices}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                pass | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 如果不是叶子节点,递归处理子节点 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if not is_leaf: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                for i in range(8):  # 遍历所有子节点 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    child_idx = self.child_indices[node_idx, i].item() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if child_idx != -1:  # 忽略无效的子节点索引 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    child_idx = 8 * node_idx + i + 1 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if self.is_leaf_mask[child_idx] != -1:  # 忽略无效的子节点索引 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        dfs(child_idx, depth + 1) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 初始化日志行列表 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -362,9 +368,8 @@ class OctreeNode(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        state = { | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'bbox': self.bbox.cpu(),  # 转换为CPU张量 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'node_bboxes': self.node_bboxes.cpu() if self.node_bboxes is not None else None, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'parent_indices': self.parent_indices.cpu() if self.parent_indices is not None else None, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'child_indices': self.child_indices.cpu() if self.child_indices is not None else None, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'is_leaf_mask': self.is_leaf_mask.cpu() if self.is_leaf_mask is not None else None, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            "is_valid_leaf_mask": self.is_valid_leaf_mask.cpu() if self.is_valid_leaf_mask is not None else None, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'all_face_indices': self.all_face_indices.cpu(), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'face_indices_mask':self.face_indices_mask.cpu() if self.face_indices_mask is not None else None, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'surf_bbox': self.surf_bbox.cpu() if self.surf_bbox is not None else None, | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -389,10 +394,9 @@ class OctreeNode(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 可以在这里设置其他不需要在 __init__ 中处理的属性 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.node_bboxes = state['node_bboxes'].to(self.device) if state['node_bboxes'] is not None else None | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.parent_indices = state['parent_indices'].to(self.device) if state['parent_indices'] is not None else None | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.child_indices = state['child_indices'].to(self.device) if state['child_indices'] is not None else None | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.face_indices_mask = state['face_indices_mask'].to(self.device) if state['face_indices_mask'] is not None else None | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.is_leaf_mask = state['is_leaf_mask'].to(self.device) if state['is_leaf_mask'] is not None else None | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.is_valid_leaf_mask = state['is_valid_leaf_mask'].to(self.device) if state['is_valid_leaf_mask'] is not None else None | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def to(self, device=None, dtype=None, non_blocking=False): | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |