| 
						
						
							
								
							
						
						
					 | 
					@ -69,24 +69,37 @@ class OctreeNode: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def subdivide(self): | 
					 | 
					 | 
					    def subdivide(self): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        min_x, min_y, min_z, max_x, max_y, max_z = self.bbox | 
					 | 
					 | 
					        #min_x, min_y, min_z, max_x, max_y, max_z = self.bbox | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 使用索引操作替代解包 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        min_coords = self.bbox[:3]  # [min_x, min_y, min_z] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        max_coords = self.bbox[3:]  # [max_x, max_y, max_z] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 计算中间点 | 
					 | 
					 | 
					        # 计算中间点 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        mid_x = (min_x + max_x) / 2 | 
					 | 
					 | 
					        mid_coords = (min_coords + max_coords) / 2 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        mid_y = (min_y + max_y) / 2 | 
					 | 
					 | 
					         | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        mid_z = (min_z + max_z) / 2 | 
					 | 
					 | 
					        # 提取 min_x, min_y, min_z, mid_x, mid_y, mid_z | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        min_x, min_y, min_z = min_coords[0], min_coords[1], min_coords[2] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        mid_x, mid_y, mid_z = mid_coords[0], mid_coords[1], mid_coords[2] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        max_x, max_y, max_z = max_coords[0], max_coords[1], max_coords[2] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 生成 8 个子包围盒 | 
					 | 
					 | 
					        # 生成 8 个子包围盒 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        child_bboxes = torch.tensor([ | 
					 | 
					 | 
					        child_bboxes = torch.stack([ | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            [min_x, min_y, min_z, mid_x, mid_y, mid_z],  # 前下左 | 
					 | 
					 | 
					            torch.cat([min_coords, mid_coords]),  # 前下左 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            [mid_x, min_y, min_z, max_x, mid_y, mid_z],  # 前下右 | 
					 | 
					 | 
					            torch.cat([torch.tensor([mid_x, min_y, min_z], device=self.bbox.device),  | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            [min_x, mid_y, min_z, mid_x, max_y, mid_z],  # 前上左 | 
					 | 
					 | 
					                       torch.tensor([max_x, mid_y, mid_z], device=self.bbox.device)]),  # 前下右 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            [mid_x, mid_y, min_z, max_x, max_y, mid_z],  # 前上右 | 
					 | 
					 | 
					            torch.cat([torch.tensor([min_x, mid_y, min_z], device=self.bbox.device),  | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            [min_x, min_y, mid_z, mid_x, mid_y, max_z],  # 后下左 | 
					 | 
					 | 
					                       torch.tensor([mid_x, max_y, mid_z], device=self.bbox.device)]),  # 前上左 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            [mid_x, min_y, mid_z, max_x, mid_y, max_z],  # 后下右 | 
					 | 
					 | 
					            torch.cat([torch.tensor([mid_x, mid_y, min_z], device=self.bbox.device),  | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            [min_x, mid_y, mid_z, mid_x, max_y, max_z],  # 后上左 | 
					 | 
					 | 
					                       torch.tensor([max_x, max_y, mid_z], device=self.bbox.device)]),  # 前上右 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            [mid_x, mid_y, mid_z, max_x, max_y, max_z]   # 后上右 | 
					 | 
					 | 
					            torch.cat([torch.tensor([min_x, min_y, mid_z], device=self.bbox.device),  | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        ], dtype=torch.float32, device=OctreeNode.device) | 
					 | 
					 | 
					                       torch.tensor([mid_x, mid_y, max_z], device=self.bbox.device)]),  # 后下左 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            torch.cat([torch.tensor([mid_x, min_y, mid_z], device=self.bbox.device),  | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                       torch.tensor([max_x, mid_y, max_z], device=self.bbox.device)]),  # 后下右 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            torch.cat([torch.tensor([min_x, mid_y, mid_z], device=self.bbox.device),  | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                       torch.tensor([mid_x, max_y, max_z], device=self.bbox.device)]),  # 后上左 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            torch.cat([torch.tensor([mid_x, mid_y, mid_z], device=self.bbox.device),  | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                       torch.tensor([max_x, max_y, max_z], device=self.bbox.device)])   # 后上右 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        ]) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 为每个子包围盒创建子节点,并分配相交的面 | 
					 | 
					 | 
					        # 为每个子包围盒创建子节点,并分配相交的面 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.children = [] | 
					 | 
					 | 
					        self.children = [] | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -98,7 +111,7 @@ class OctreeNode: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                if bbox_intersect(bbox, face_bbox): | 
					 | 
					 | 
					                if bbox_intersect(bbox, face_bbox): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    intersecting_faces.append(face_idx) | 
					 | 
					 | 
					                    intersecting_faces.append(face_idx) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            #print(f"{bbox}: {intersecting_faces}") | 
					 | 
					 | 
					            #print(f"{bbox}: {intersecting_faces}") | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            if intersecting_faces: | 
					 | 
					 | 
					             | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					            child_node = OctreeNode( | 
					 | 
					 | 
					            child_node = OctreeNode( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                bbox=bbox, | 
					 | 
					 | 
					                bbox=bbox, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                face_indices=np.array(intersecting_faces), | 
					 | 
					 | 
					                face_indices=np.array(intersecting_faces), | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -112,29 +125,23 @@ class OctreeNode: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def get_child_index(self, query_point: torch.Tensor) -> int: | 
					 | 
					 | 
					    def get_child_index(self, query_point: torch.Tensor) -> int: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        """ | 
					 | 
					 | 
					        """ | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        计算点所在子节点的索引 | 
					 | 
					 | 
					        计算点所在子节点的索引 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        :param point: 待检查的点,格式为 (x, y, z) | 
					 | 
					 | 
					        :param query_point: 待检查的点,格式为 (x, y, z) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        :return: 子节点的索引,范围从 0 到 7 | 
					 | 
					 | 
					        :return: 子节点的索引,范围从 0 到 7 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        """ | 
					 | 
					 | 
					        """ | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        #print(query_point) | 
					 | 
					 | 
					        # 确保 query_point 和 bbox 在同一设备上 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        x, y, z = query_point | 
					 | 
					 | 
					        query_point = query_point.to(self.bbox.device) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        #logger.info(f"query_point: {query_point}") | 
					 | 
					 | 
					
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        #logger.info(f"box: {self.bbox}") | 
					 | 
					 | 
					        # 提取 bbox 的最小和最大坐标 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        min_x, min_y, min_z, max_x, max_y, max_z = self.bbox | 
					 | 
					 | 
					        min_coords = self.bbox[:3]  # [min_x, min_y, min_z] | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        max_coords = self.bbox[3:]  # [max_x, max_y, max_z] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 计算中间点 | 
					 | 
					 | 
					        # 计算中间点 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        mid_x = (min_x + max_x) / 2 | 
					 | 
					 | 
					        mid_coords = (min_coords + max_coords) / 2 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        mid_y = (min_y + max_y) / 2 | 
					 | 
					 | 
					
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        mid_z = (min_z + max_z) / 2 | 
					 | 
					 | 
					        # 使用布尔比较结果计算索引 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					
 | 
					 | 
					 | 
					        index = ((query_point >= mid_coords) << torch.arange(3, device=self.bbox.device)).sum() | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        index = 0 | 
					 | 
					 | 
					
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        if x >= mid_x:  # 修正变量名 | 
					 | 
					 | 
					        return index.unsqueeze(0) | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					            index += 1 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        if y >= mid_y:  # 修正变量名 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            index += 2 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        if z >= mid_z:  # 修正变量名 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            index += 4 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        #logger.info(f"index: {index}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        return index | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def get_feature_vector(self, query_point:torch.Tensor): | 
					 | 
					 | 
					    def get_feature_vector(self, query_point:torch.Tensor): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        """ | 
					 | 
					 | 
					        """ | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -150,58 +157,59 @@ class OctreeNode: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        else: | 
					 | 
					 | 
					        else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            index = self.get_child_index(query_point) | 
					 | 
					 | 
					            index = self.get_child_index(query_point) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            try: | 
					 | 
					 | 
					            try: | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					                if index < 0 or index >= len(self.children): | 
					 | 
					 | 
					                # 直接访问子节点,不进行显式检查 | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					                    raise IndexError( | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                        f"Child index {index} out of range (0-{len(self.children)-1}) " | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                        f"for query point {query_point.cpu().numpy().tolist()}. " | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                        f"Node bbox: {self.bbox.cpu().numpy().tolist()}" | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                        f"dept info: {self.max_depth}" | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    ) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					                return self.children[index].get_feature_vector(query_point) | 
					 | 
					 | 
					                return self.children[index].get_feature_vector(query_point) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            except IndexError as e: | 
					 | 
					 | 
					            except IndexError as e: | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					                logger.error(str(e)) | 
					 | 
					 | 
					                # 记录错误日志并重新抛出异常 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.error( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    f"Error accessing child node: {e}. " | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    f"Query point: {query_point.cpu().numpy().tolist()}, " | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    f"Node bbox: {self.bbox.cpu().numpy().tolist()}, " | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    f"Depth info: {self.max_depth}" | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                raise e | 
					 | 
					 | 
					                raise e | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def trilinear_interpolation(self, query_point:torch.Tensor) -> torch.Tensor: | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        """ | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        使用三线性插值从补丁特征体积中获取查询点的特征向量。 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        :param query_point: 查询点的位置坐标 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        :return: 插值后的特征向量 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        """ | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        """三线性插值""" | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def trilinear_interpolation(self, query_point: torch.Tensor) -> torch.Tensor: | 
					 | 
					 | 
					    def trilinear_interpolation(self, query_point: torch.Tensor) -> torch.Tensor: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        """ | 
					 | 
					 | 
					        """ | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        实现三线性插值 | 
					 | 
					 | 
					        实现三线性插值 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        :param query_point: 待插值的点,格式为 (x, y, z) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        :return: 插值结果,形状为 (D,) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        """ | 
					 | 
					 | 
					        """ | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 获取包围盒的边界 | 
					 | 
					 | 
					        # 确保 query_point 和 bbox 在同一设备上 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        min_x, min_y, min_z, max_x, max_y, max_z = self.bbox | 
					 | 
					 | 
					        #query_point = query_point.to(self.bbox.device) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 获取包围盒的最小和最大坐标 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        min_coords = self.bbox[:3]  # [min_x, min_y, min_z] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        max_coords = self.bbox[3:]  # [max_x, max_y, max_z] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 计算归一化坐标 | 
					 | 
					 | 
					        # 计算归一化坐标 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        x = (query_point[0] - min_x) / (max_x - min_x) | 
					 | 
					 | 
					        normalized_coords = (query_point - min_coords) / (max_coords - min_coords + 1e-8)  # 防止除零错误 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        y = (query_point[1] - min_y) / (max_y - min_y) | 
					 | 
					 | 
					        x, y, z = normalized_coords.unbind(dim=-1) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        z = (query_point[2] - min_z) / (max_z - min_z) | 
					 | 
					 | 
					         | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 使用torch.stack避免Python标量转换 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        wx = torch.stack([1 - x, x], dim=-1)  # 保持自动微分 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        wy = torch.stack([1 - y, y], dim=-1) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        wz = torch.stack([1 - z, z], dim=-1) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 获取8个顶点的特征向量 | 
					 | 
					 | 
					        # 获取8个顶点的特征向量 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        c000 = self.patch_feature_volume[0] | 
					 | 
					 | 
					        c = self.patch_feature_volume  # 形状为 (8, D) | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					        c100 = self.patch_feature_volume[1] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        c010 = self.patch_feature_volume[2] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        c110 = self.patch_feature_volume[3] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        c001 = self.patch_feature_volume[4] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        c101 = self.patch_feature_volume[5] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        c011 = self.patch_feature_volume[6] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        c111 = self.patch_feature_volume[7] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 执行三线性插值 | 
					 | 
					 | 
					        # 执行三线性插值 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        c00 = c000 * (1 - x) + c100 * x | 
					 | 
					 | 
					        # 先对 x 轴插值 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        c01 = c001 * (1 - x) + c101 * x | 
					 | 
					 | 
					        c00 = c[0] * wx[0] + c[1] * wx[1] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        c10 = c010 * (1 - x) + c110 * x | 
					 | 
					 | 
					        c01 = c[2] * wx[0] + c[3] * wx[1] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        c11 = c011 * (1 - x) + c111 * x | 
					 | 
					 | 
					        c10 = c[4] * wx[0] + c[5] * wx[1] | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        c11 = c[6] * wx[0] + c[7] * wx[1] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        c0 = c00 * (1 - y) + c10 * y | 
					 | 
					 | 
					        # 再对 y 轴插值 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        c1 = c01 * (1 - y) + c11 * y | 
					 | 
					 | 
					        c0 = c00 * wy[0] + c10 * wy[1] | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        c1 = c01 * wy[0] + c11 * wy[1] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        return c0 * (1 - z) + c1 * z | 
					 | 
					 | 
					        # 最后对 z 轴插值 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        result = c0 * wz[0] + c1 * wz[1] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        return result | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None: | 
					 | 
					 | 
					    def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None: | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -229,3 +237,40 @@ class OctreeNode: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        for i, child in enumerate(self.children): | 
					 | 
					 | 
					        for i, child in enumerate(self.children): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            print(f"{indent}  Child {i}:") | 
					 | 
					 | 
					            print(f"{indent}  Child {i}:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            child.print_tree(depth + 1, max_print_depth) | 
					 | 
					 | 
					            child.print_tree(depth + 1, max_print_depth) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					# 保存 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    def state_dict(self): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        """返回节点及其子树的state_dict""" | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        state = { | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            'bbox': self.bbox, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            'max_depth': self.max_depth, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            'face_indices': self.face_indices, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            'is_leaf': self._is_leaf | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        } | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        if self._is_leaf: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            state['patch_feature_volume'] = self.patch_feature_volume | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            state['children'] = [child.state_dict() for child in self.children] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        return state | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    def load_state_dict(self, state_dict): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        """从state_dict加载节点状态""" | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.bbox = state_dict['bbox'] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.max_depth = state_dict['max_depth'] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.face_indices = state_dict['face_indices'] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self._is_leaf = state_dict['is_leaf'] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        if self._is_leaf: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            self.patch_feature_volume = nn.Parameter(state_dict['patch_feature_volume']) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            self.children = [] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            for child_state in state_dict['children']: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                child = OctreeNode( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    bbox=child_state['bbox'], | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    face_indices=child_state['face_indices'], | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    max_depth=child_state['max_depth'] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                child.load_state_dict(child_state) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                self.children.append(child) |