| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -8,58 +8,7 @@ from .octree import OctreeNode | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from brep2sdf.config.default_config import get_default_config | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from brep2sdf.utils.logger import logger | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import numpy as np | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					''' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					class Encoder: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, surf_bbox: torch.Tensor, origin_bbox: torch.Tensor, max_depth: int, feature_dim:int = 64): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        初始化表面八叉树管理器 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        参数: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_bbox: 表面包围盒的世界坐标,形状为 (num_edges, 6), dtype=float32 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            origin_bbox: 原点包围盒的世界坐标,形状为 (6), dtype=float32 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            max_depth: 八叉树的最大深度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.max_depth = max_depth | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 初始化根节点,包含所有面索引。NOTE: 实际上需要判断 aabb 包含问题,但是这里默认 origin_bbox 由这些 face 计算,所以不再重复判断 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        num_faces = surf_bbox.shape[0] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #print(f"surf_bbox: {surf_bbox.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #print(f"origin_bbox: {origin_bbox.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.root = OctreeNode( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            bbox=origin_bbox, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            face_indices=np.arange(num_faces),  # 初始包含所有面 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            max_depth=self.max_depth, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            feature_dim=feature_dim, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_bbox=surf_bbox | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #print(surf_bbox) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info("starting octree conduction") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.root.conduct_tree() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info("complete octree conduction") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #self.root.print_tree(0)  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def get_feature_vector(self, query_point): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return self.root.get_feature_vector(query_point) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def forward(self, query_points): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        前向传播,处理批量查询点 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        参数: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            query_points: 查询点的位置坐标,形状为(batch_size, 3) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        返回: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            feature_vectors: 查询点的特征向量,形状为(batch_size, feature_dim) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        batch_size = query_points.shape[0] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        feature_vectors = [] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for i in range(batch_size): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            feature_vector = self.get_feature_vector(query_points[i]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            feature_vectors.append(feature_vector) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return torch.stack(feature_vectors, dim=0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					''' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					class Encoder(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, octree: OctreeNode, feature_dim: int = 32): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -74,49 +23,50 @@ class Encoder(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.feature_dim = feature_dim | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 初始化叶子节点参数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.register_buffer('num_parameters', torch.tensor(0, dtype=torch.long)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self._leaf_parameters = nn.ParameterList()  # 使用 ParameterList 存储参数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.param_key_to_index: Dict[str, int] = {}  # 字典映射:param_key -> index | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 为所有叶子节点注册可学习参数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 不再需要 param_key_to_index 字典 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self._init_parameters() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def _init_parameters(self): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """为所有叶子节点初始化特征参数""" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        stack = [(self.octree, "root")]  # (当前节点, 当前路径) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        stack = [(self.octree, 0)]  # (当前节点, 参数索引) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        param_index = 0  # 参数索引计数器 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        while stack: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            node, path = stack.pop() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            node, param_index = stack.pop() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if node._is_leaf: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 如果是叶子节点,初始化参数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                param_name = f"leaf_{path}" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                self._leaf_parameters.append(nn.Parameter(torch.randn(8, self.feature_dim)))  # 8个顶点的特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                self.param_key_to_index[param_name] = param_index  # 记录索引 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                node.set_param_key(param_name)  # 为节点存储参数键 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                node.set_param_key(param_index)  # 直接使用索引作为键 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                param_index += 1 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 如果不是叶子节点,继续遍历子节点 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                for i, child in enumerate(node.child_nodes): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if child is not None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        stack.append((child, f"{path}_{i}")) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        stack.append((child, param_index)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def get_leaf_parameter(self, param_key: str) -> torch.Tensor: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 保存参数总数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.num_parameters.fill_(param_index) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    @torch.jit.export | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def get_leaf_parameter(self, param_index: torch.Tensor) -> torch.Tensor: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        获取叶子节点的特征参数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        :param param_key: 叶子节点的参数键 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        :param param_index: 叶子节点的参数索引 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        :return: 对应的参数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if param_key not in self.param_key_to_index: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise KeyError(f"Invalid param_key: {param_key}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        target_index = self.param_key_to_index[param_key] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        index = param_index.item()  # 转换为Python整数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if index < 0 or index >= self.num_parameters.item(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise IndexError(f"Parameter index {index} out of range") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 使用枚举代替动态索引 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for index, param in enumerate(self._leaf_parameters): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if index == target_index: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 使用列表推导代替直接索引 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for i, param in enumerate(self._leaf_parameters): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if i == index: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                return param | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        raise IndexError(f"Index {target_index} not found in ParameterList") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        raise IndexError(f"Parameter index {index} not found") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def forward(self, query_points: torch.Tensor) -> torch.Tensor: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -132,11 +82,10 @@ class Encoder(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for i in range(batch_size): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 1. 在八叉树中查找包含该点的叶子节点 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            bbox, param_key, _ = self.octree.find_leaf(query_points[i]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            #logger.debug(leaf_node.param_key) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            bbox, param_index, _ = self.octree.find_leaf(query_points[i]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 2. 获取该节点的特征参数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            node_features = self.get_leaf_parameter(param_key) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            node_features = self.get_leaf_parameter(param_index) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 3. 使用三线性插值计算特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # (这里需要实现你的插值逻辑) | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |