| 
						
						
							
								
							
						
						
					 | 
					@ -29,25 +29,29 @@ def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    return torch.all((max1 >= min2) & (max2 >= min1)) | 
					 | 
					 | 
					    return torch.all((max1 >= min2) & (max2 >= min1)) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					class OctreeNode(nn.Module): | 
					 | 
					 | 
					class OctreeNode(nn.Module): | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None): | 
					 | 
					 | 
					    def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None,device=None): | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        super().__init__() | 
					 | 
					 | 
					        super().__init__() | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 静态张量存储节点信息 | 
					 | 
					 | 
					        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.register_buffer('bbox', bbox)  # 当前节点的边界框 | 
					 | 
					 | 
					        # 改为普通张量属性 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.register_buffer('node_bboxes', None)  # 所有节点的边界框 | 
					 | 
					 | 
					        self.bbox = bbox.to(self.device)  # 显式设备管理 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.register_buffer('parent_indices', None)  # 父节点索引 | 
					 | 
					 | 
					        self.node_bboxes = None   | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.register_buffer('child_indices', None)  # 子节点索引 | 
					 | 
					 | 
					        self.parent_indices = None  | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.register_buffer('is_leaf_mask', None)  # 叶子节点标记 | 
					 | 
					 | 
					        self.child_indices = None  | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.register_buffer('face_indices', torch.from_numpy(face_indices).to(bbox.device))  # 面片索引张量 | 
					 | 
					 | 
					        self.is_leaf_mask = None   | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.register_buffer('surf_bbox', surf_bbox)  # 面片边界框 | 
					 | 
					 | 
					        # 面片索引张量 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					         | 
					 | 
					 | 
					        self.face_indices = torch.from_numpy(face_indices).to(self.device) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.surf_bbox = surf_bbox.to(self.device) if surf_bbox is not None else None | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # PatchGraph作为普通属性 | 
					 | 
					 | 
					        # PatchGraph作为普通属性 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.patch_graph = patch_graph  # 不再使用register_buffer | 
					 | 
					 | 
					        self.patch_graph = patch_graph.to(self.device) if patch_graph is not None else None | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.max_depth = max_depth | 
					 | 
					 | 
					        self.max_depth = max_depth | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 将param_key改为张量 | 
					 | 
					 | 
					        # 参数键改为普通张量 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.register_buffer('param_key', torch.tensor(-1, dtype=torch.long)) | 
					 | 
					 | 
					        self.param_key = torch.tensor(-1, dtype=torch.long, device=self.device) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        self._is_leaf = True | 
					 | 
					 | 
					        self._is_leaf = True | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 删除所有register_buffer调用 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    @torch.jit.export | 
					 | 
					 | 
					    @torch.jit.export | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def set_param_key(self, k: int) -> None: | 
					 | 
					 | 
					    def set_param_key(self, k: int) -> None: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        """设置参数键值 | 
					 | 
					 | 
					        """设置参数键值 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -64,11 +68,10 @@ class OctreeNode(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        total_nodes = int(sum(8**i for i in range(self.max_depth + 1))) | 
					 | 
					 | 
					        total_nodes = int(sum(8**i for i in range(self.max_depth + 1))) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 初始化静态张量,使用整数列表作为形状参数 | 
					 | 
					 | 
					        # 初始化静态张量,使用整数列表作为形状参数 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.bbox.device) | 
					 | 
					 | 
					        self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.device) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.parent_indices = torch.full([int(total_nodes)], -1, dtype=torch.long, device=self.bbox.device) | 
					 | 
					 | 
					        self.parent_indices = torch.full([int(total_nodes)], -1, dtype=torch.long, device=self.device) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.bbox.device) | 
					 | 
					 | 
					        self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.device) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.is_leaf_mask = torch.zeros([int(total_nodes)], dtype=torch.bool, device=self.bbox.device) | 
					 | 
					 | 
					        self.is_leaf_mask = torch.zeros([int(total_nodes)], dtype=torch.bool, device=self.device) | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        # 使用队列进行广度优先遍历 | 
					 | 
					 | 
					        # 使用队列进行广度优先遍历 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        queue = [(0, self.bbox, self.face_indices)]  # (node_idx, bbox, face_indices) | 
					 | 
					 | 
					        queue = [(0, self.bbox, self.face_indices)]  # (node_idx, bbox, face_indices) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        current_idx = 0 | 
					 | 
					 | 
					        current_idx = 0 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -108,7 +111,7 @@ class OctreeNode(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                 | 
					 | 
					 | 
					                 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                # 将子节点加入队列 | 
					 | 
					 | 
					                # 将子节点加入队列 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                if intersecting_faces: | 
					 | 
					 | 
					                if intersecting_faces: | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    queue.append((child_idx, child_bbox, torch.tensor(intersecting_faces, device=self.bbox.device))) | 
					 | 
					 | 
					                    queue.append((child_idx, child_bbox, torch.tensor(intersecting_faces, device=self.device))) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def _should_split_node(self, current_depth: int) -> bool: | 
					 | 
					 | 
					    def _should_split_node(self, current_depth: int) -> bool: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        """判断节点是否需要分裂""" | 
					 | 
					 | 
					        """判断节点是否需要分裂""" | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -127,7 +130,7 @@ class OctreeNode(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def _generate_child_bboxes(self, min_coords: torch.Tensor, mid_coords: torch.Tensor, max_coords: torch.Tensor) -> torch.Tensor: | 
					 | 
					 | 
					    def _generate_child_bboxes(self, min_coords: torch.Tensor, mid_coords: torch.Tensor, max_coords: torch.Tensor) -> torch.Tensor: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 使用 with torch.no_grad() 减少梯度计算的内存占用 | 
					 | 
					 | 
					        # 使用 with torch.no_grad() 减少梯度计算的内存占用 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        with torch.no_grad(): | 
					 | 
					 | 
					        with torch.no_grad(): | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            child_bboxes = torch.zeros([8, 6], device=self.bbox.device) | 
					 | 
					 | 
					            child_bboxes = torch.zeros([8, 6], device=self.device) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 使用向量化操作生成所有子节点边界框 | 
					 | 
					 | 
					            # 使用向量化操作生成所有子节点边界框 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            child_bboxes[0] = torch.cat([min_coords, mid_coords])  # 前下左 | 
					 | 
					 | 
					            child_bboxes[0] = torch.cat([min_coords, mid_coords])  # 前下左 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -199,6 +202,7 @@ class OctreeNode(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                bboxes.append(bbox) | 
					 | 
					 | 
					                bboxes.append(bbox) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            param_indices = torch.stack(param_indices) | 
					 | 
					 | 
					            param_indices = torch.stack(param_indices) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            bboxes = torch.stack(bboxes) | 
					 | 
					 | 
					            bboxes = torch.stack(bboxes) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            # 添加检查代码 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            return param_indices, bboxes | 
					 | 
					 | 
					            return param_indices, bboxes | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None: | 
					 | 
					 | 
					    def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None: | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -258,4 +262,21 @@ class OctreeNode(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.patch_graph = state['patch_graph'] | 
					 | 
					 | 
					        self.patch_graph = state['patch_graph'] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.max_depth = state['max_depth'] | 
					 | 
					 | 
					        self.max_depth = state['max_depth'] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.param_key = state['param_key'] | 
					 | 
					 | 
					        self.param_key = state['param_key'] | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self._is_leaf = state['_is_leaf'] | 
					 | 
					 | 
					        self._is_leaf = state['_is_leaf'] | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    def to(self, device=None, dtype=None, non_blocking=False): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 调用父类方法迁移基础参数 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        super().to(device, dtype, non_blocking) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 迁移自定义属性 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        if self.patch_graph is not None: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            if hasattr(self.patch_graph, 'to'): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                self.patch_graph = self.patch_graph.to(device=device, dtype=dtype) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                # 手动移动非Module属性 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                for attr in ['edge_index', 'edge_type', 'patch_features']: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    tensor = getattr(self.patch_graph, attr, None) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    if tensor is not None: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        setattr(self.patch_graph, attr, tensor.to(device=device, dtype=dtype)) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        return self |