|
|
@ -73,34 +73,51 @@ class Encoder(nn.Module): |
|
|
|
self.octree = octree |
|
|
|
self.feature_dim = feature_dim |
|
|
|
|
|
|
|
# 初始化叶子节点参数 |
|
|
|
self._leaf_parameters = nn.ParameterList() # 使用 ParameterList 存储参数 |
|
|
|
self.param_key_to_index: Dict[str, int] = {} # 字典映射:param_key -> index |
|
|
|
# 为所有叶子节点注册可学习参数 |
|
|
|
self._init_parameters() |
|
|
|
|
|
|
|
def _init_parameters(self): |
|
|
|
"""为所有叶子节点初始化特征参数""" |
|
|
|
# 使用字典保存所有参数,避免动态属性 |
|
|
|
self._leaf_parameters = nn.ParameterDict() |
|
|
|
# 使用栈模拟递归遍历(避免递归) |
|
|
|
stack = [(self.octree, "root")] # (当前节点, 当前路径) |
|
|
|
param_index = 0 # 参数索引计数器 |
|
|
|
|
|
|
|
while stack: |
|
|
|
node, path = stack.pop() |
|
|
|
|
|
|
|
# 递归遍历树结构 |
|
|
|
def _register_params(node, path=""): |
|
|
|
#logger.debug(node.is_leaf()) |
|
|
|
if node.is_leaf(): |
|
|
|
# 如果是叶子节点,初始化参数 |
|
|
|
param_name = f"leaf_{path}" |
|
|
|
self._leaf_parameters[param_name] = nn.Parameter( |
|
|
|
torch.randn(8, self.feature_dim) # 8个顶点的特征 |
|
|
|
) |
|
|
|
self._leaf_parameters.append(nn.Parameter(torch.randn(8, self.feature_dim))) # 8个顶点的特征 |
|
|
|
self.param_key_to_index[param_name] = param_index # 记录索引 |
|
|
|
node.set_param_key(param_name) # 为节点存储参数键 |
|
|
|
#logger.debug(param_name) |
|
|
|
#logger.debug(node.param_key) |
|
|
|
param_index += 1 |
|
|
|
else: |
|
|
|
# 如果不是叶子节点,继续遍历子节点 |
|
|
|
for i, child in enumerate(node.child_nodes): |
|
|
|
_register_params(child, f"{path}_{i}") |
|
|
|
if child is not None: |
|
|
|
stack.append((child, f"{path}_{i}")) |
|
|
|
|
|
|
|
def get_leaf_parameter(self, param_key: str) -> torch.Tensor: |
|
|
|
""" |
|
|
|
获取叶子节点的特征参数 |
|
|
|
:param param_key: 叶子节点的参数键 |
|
|
|
:return: 对应的参数 |
|
|
|
""" |
|
|
|
if param_key not in self.param_key_to_index: |
|
|
|
raise KeyError(f"Invalid param_key: {param_key}") |
|
|
|
|
|
|
|
target_index = self.param_key_to_index[param_key] |
|
|
|
|
|
|
|
_register_params(self.octree, "root") |
|
|
|
# 使用枚举代替动态索引 |
|
|
|
for index, param in enumerate(self._leaf_parameters): |
|
|
|
if index == target_index: |
|
|
|
return param |
|
|
|
|
|
|
|
def get_leaf_parameter(self, node): |
|
|
|
"""获取叶子节点的特征参数""" |
|
|
|
return self._leaf_parameters[node.param_key] |
|
|
|
raise IndexError(f"Index {target_index} not found in ParameterList") |
|
|
|
|
|
|
|
def forward(self, query_points: torch.Tensor) -> torch.Tensor: |
|
|
|
""" |
|
|
@ -116,12 +133,11 @@ class Encoder(nn.Module): |
|
|
|
|
|
|
|
for i in range(batch_size): |
|
|
|
# 1. 在八叉树中查找包含该点的叶子节点 |
|
|
|
leaf_node = self.octree.find_leaf(query_points[i]) |
|
|
|
bbox, param_key, _ = self.octree.find_leaf(query_points[i]) |
|
|
|
#logger.debug(leaf_node.param_key) |
|
|
|
|
|
|
|
# 2. 获取该节点的特征参数 |
|
|
|
bbox = leaf_node.bbox |
|
|
|
node_features = self.get_leaf_parameter(leaf_node) |
|
|
|
node_features = self.get_leaf_parameter(param_key) |
|
|
|
|
|
|
|
# 3. 使用三线性插值计算特征 |
|
|
|
# (这里需要实现你的插值逻辑) |
|
|
|