| 
						
						
							
								
							
						
						
					 | 
					@ -5,7 +5,6 @@ import torch.nn as nn | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					import numpy as np | 
					 | 
					 | 
					import numpy as np | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from brep2sdf.networks.patch_graph import PatchGraph | 
					 | 
					 | 
					from brep2sdf.networks.patch_graph import PatchGraph | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from brep2sdf.utils.logger import logger | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -63,7 +62,6 @@ class OctreeNode(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        """构建静态八叉树结构""" | 
					 | 
					 | 
					        """构建静态八叉树结构""" | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 预计算所有可能的节点数量,确保结果为整数 | 
					 | 
					 | 
					        # 预计算所有可能的节点数量,确保结果为整数 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        total_nodes = int(sum(8**i for i in range(self.max_depth + 1))) | 
					 | 
					 | 
					        total_nodes = int(sum(8**i for i in range(self.max_depth + 1))) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        logger.info(f"总节点数量: {total_nodes}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 初始化静态张量,使用整数列表作为形状参数 | 
					 | 
					 | 
					        # 初始化静态张量,使用整数列表作为形状参数 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.bbox.device) | 
					 | 
					 | 
					        self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.bbox.device) | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -71,7 +69,6 @@ class OctreeNode(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.bbox.device) | 
					 | 
					 | 
					        self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.bbox.device) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.is_leaf_mask = torch.zeros([int(total_nodes)], dtype=torch.bool, device=self.bbox.device) | 
					 | 
					 | 
					        self.is_leaf_mask = torch.zeros([int(total_nodes)], dtype=torch.bool, device=self.bbox.device) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        logger.gpu_memory_stats("树初始化后") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 使用队列进行广度优先遍历 | 
					 | 
					 | 
					        # 使用队列进行广度优先遍历 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        queue = [(0, self.bbox, self.face_indices)]  # (node_idx, bbox, face_indices) | 
					 | 
					 | 
					        queue = [(0, self.bbox, self.face_indices)]  # (node_idx, bbox, face_indices) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        current_idx = 0 | 
					 | 
					 | 
					        current_idx = 0 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -193,6 +190,17 @@ class OctreeNode(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        mid_points = (bboxes[:, :3] + bboxes[:, 3:]) / 2 | 
					 | 
					 | 
					        mid_points = (bboxes[:, :3] + bboxes[:, 3:]) / 2 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        return ((points >= mid_points) << torch.arange(3, device=points.device)).sum(dim=1) | 
					 | 
					 | 
					        return ((points >= mid_points) << torch.arange(3, device=points.device)).sum(dim=1) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    def forward(self, query_points): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        with torch.no_grad(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            param_indices, bboxes = [], [] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            for point in query_points: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                bbox, idx, _ = self.find_leaf(point) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                param_indices.append(idx) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                bboxes.append(bbox) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            param_indices = torch.stack(param_indices) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            bboxes = torch.stack(bboxes) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            return param_indices, bboxes | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None: | 
					 | 
					 | 
					    def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        """ | 
					 | 
					 | 
					        """ | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        递归打印八叉树结构 | 
					 | 
					 | 
					        递归打印八叉树结构 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					
  |