|
|
@ -69,24 +69,37 @@ class OctreeNode: |
|
|
|
|
|
|
|
def subdivide(self): |
|
|
|
|
|
|
|
min_x, min_y, min_z, max_x, max_y, max_z = self.bbox |
|
|
|
#min_x, min_y, min_z, max_x, max_y, max_z = self.bbox |
|
|
|
# 使用索引操作替代解包 |
|
|
|
min_coords = self.bbox[:3] # [min_x, min_y, min_z] |
|
|
|
max_coords = self.bbox[3:] # [max_x, max_y, max_z] |
|
|
|
|
|
|
|
# 计算中间点 |
|
|
|
mid_x = (min_x + max_x) / 2 |
|
|
|
mid_y = (min_y + max_y) / 2 |
|
|
|
mid_z = (min_z + max_z) / 2 |
|
|
|
mid_coords = (min_coords + max_coords) / 2 |
|
|
|
|
|
|
|
# 提取 min_x, min_y, min_z, mid_x, mid_y, mid_z |
|
|
|
min_x, min_y, min_z = min_coords[0], min_coords[1], min_coords[2] |
|
|
|
mid_x, mid_y, mid_z = mid_coords[0], mid_coords[1], mid_coords[2] |
|
|
|
max_x, max_y, max_z = max_coords[0], max_coords[1], max_coords[2] |
|
|
|
|
|
|
|
# 生成 8 个子包围盒 |
|
|
|
child_bboxes = torch.tensor([ |
|
|
|
[min_x, min_y, min_z, mid_x, mid_y, mid_z], # 前下左 |
|
|
|
[mid_x, min_y, min_z, max_x, mid_y, mid_z], # 前下右 |
|
|
|
[min_x, mid_y, min_z, mid_x, max_y, mid_z], # 前上左 |
|
|
|
[mid_x, mid_y, min_z, max_x, max_y, mid_z], # 前上右 |
|
|
|
[min_x, min_y, mid_z, mid_x, mid_y, max_z], # 后下左 |
|
|
|
[mid_x, min_y, mid_z, max_x, mid_y, max_z], # 后下右 |
|
|
|
[min_x, mid_y, mid_z, mid_x, max_y, max_z], # 后上左 |
|
|
|
[mid_x, mid_y, mid_z, max_x, max_y, max_z] # 后上右 |
|
|
|
], dtype=torch.float32, device=OctreeNode.device) |
|
|
|
child_bboxes = torch.stack([ |
|
|
|
torch.cat([min_coords, mid_coords]), # 前下左 |
|
|
|
torch.cat([torch.tensor([mid_x, min_y, min_z], device=self.bbox.device), |
|
|
|
torch.tensor([max_x, mid_y, mid_z], device=self.bbox.device)]), # 前下右 |
|
|
|
torch.cat([torch.tensor([min_x, mid_y, min_z], device=self.bbox.device), |
|
|
|
torch.tensor([mid_x, max_y, mid_z], device=self.bbox.device)]), # 前上左 |
|
|
|
torch.cat([torch.tensor([mid_x, mid_y, min_z], device=self.bbox.device), |
|
|
|
torch.tensor([max_x, max_y, mid_z], device=self.bbox.device)]), # 前上右 |
|
|
|
torch.cat([torch.tensor([min_x, min_y, mid_z], device=self.bbox.device), |
|
|
|
torch.tensor([mid_x, mid_y, max_z], device=self.bbox.device)]), # 后下左 |
|
|
|
torch.cat([torch.tensor([mid_x, min_y, mid_z], device=self.bbox.device), |
|
|
|
torch.tensor([max_x, mid_y, max_z], device=self.bbox.device)]), # 后下右 |
|
|
|
torch.cat([torch.tensor([min_x, mid_y, mid_z], device=self.bbox.device), |
|
|
|
torch.tensor([mid_x, max_y, max_z], device=self.bbox.device)]), # 后上左 |
|
|
|
torch.cat([torch.tensor([mid_x, mid_y, mid_z], device=self.bbox.device), |
|
|
|
torch.tensor([max_x, max_y, max_z], device=self.bbox.device)]) # 后上右 |
|
|
|
]) |
|
|
|
|
|
|
|
# 为每个子包围盒创建子节点,并分配相交的面 |
|
|
|
self.children = [] |
|
|
@ -98,7 +111,7 @@ class OctreeNode: |
|
|
|
if bbox_intersect(bbox, face_bbox): |
|
|
|
intersecting_faces.append(face_idx) |
|
|
|
#print(f"{bbox}: {intersecting_faces}") |
|
|
|
if intersecting_faces: |
|
|
|
|
|
|
|
child_node = OctreeNode( |
|
|
|
bbox=bbox, |
|
|
|
face_indices=np.array(intersecting_faces), |
|
|
@ -112,29 +125,23 @@ class OctreeNode: |
|
|
|
def get_child_index(self, query_point: torch.Tensor) -> int: |
|
|
|
""" |
|
|
|
计算点所在子节点的索引 |
|
|
|
:param point: 待检查的点,格式为 (x, y, z) |
|
|
|
:param query_point: 待检查的点,格式为 (x, y, z) |
|
|
|
:return: 子节点的索引,范围从 0 到 7 |
|
|
|
""" |
|
|
|
#print(query_point) |
|
|
|
x, y, z = query_point |
|
|
|
#logger.info(f"query_point: {query_point}") |
|
|
|
#logger.info(f"box: {self.bbox}") |
|
|
|
min_x, min_y, min_z, max_x, max_y, max_z = self.bbox |
|
|
|
# 确保 query_point 和 bbox 在同一设备上 |
|
|
|
query_point = query_point.to(self.bbox.device) |
|
|
|
|
|
|
|
# 提取 bbox 的最小和最大坐标 |
|
|
|
min_coords = self.bbox[:3] # [min_x, min_y, min_z] |
|
|
|
max_coords = self.bbox[3:] # [max_x, max_y, max_z] |
|
|
|
|
|
|
|
# 计算中间点 |
|
|
|
mid_x = (min_x + max_x) / 2 |
|
|
|
mid_y = (min_y + max_y) / 2 |
|
|
|
mid_z = (min_z + max_z) / 2 |
|
|
|
|
|
|
|
index = 0 |
|
|
|
if x >= mid_x: # 修正变量名 |
|
|
|
index += 1 |
|
|
|
if y >= mid_y: # 修正变量名 |
|
|
|
index += 2 |
|
|
|
if z >= mid_z: # 修正变量名 |
|
|
|
index += 4 |
|
|
|
#logger.info(f"index: {index}") |
|
|
|
return index |
|
|
|
mid_coords = (min_coords + max_coords) / 2 |
|
|
|
|
|
|
|
# 使用布尔比较结果计算索引 |
|
|
|
index = ((query_point >= mid_coords) << torch.arange(3, device=self.bbox.device)).sum() |
|
|
|
|
|
|
|
return index.unsqueeze(0) |
|
|
|
|
|
|
|
def get_feature_vector(self, query_point:torch.Tensor): |
|
|
|
""" |
|
|
@ -150,58 +157,59 @@ class OctreeNode: |
|
|
|
else: |
|
|
|
index = self.get_child_index(query_point) |
|
|
|
try: |
|
|
|
if index < 0 or index >= len(self.children): |
|
|
|
raise IndexError( |
|
|
|
f"Child index {index} out of range (0-{len(self.children)-1}) " |
|
|
|
f"for query point {query_point.cpu().numpy().tolist()}. " |
|
|
|
f"Node bbox: {self.bbox.cpu().numpy().tolist()}" |
|
|
|
f"dept info: {self.max_depth}" |
|
|
|
) |
|
|
|
# 直接访问子节点,不进行显式检查 |
|
|
|
return self.children[index].get_feature_vector(query_point) |
|
|
|
except IndexError as e: |
|
|
|
logger.error(str(e)) |
|
|
|
# 记录错误日志并重新抛出异常 |
|
|
|
logger.error( |
|
|
|
f"Error accessing child node: {e}. " |
|
|
|
f"Query point: {query_point.cpu().numpy().tolist()}, " |
|
|
|
f"Node bbox: {self.bbox.cpu().numpy().tolist()}, " |
|
|
|
f"Depth info: {self.max_depth}" |
|
|
|
) |
|
|
|
raise e |
|
|
|
|
|
|
|
def trilinear_interpolation(self, query_point:torch.Tensor) -> torch.Tensor: |
|
|
|
""" |
|
|
|
使用三线性插值从补丁特征体积中获取查询点的特征向量。 |
|
|
|
|
|
|
|
:param query_point: 查询点的位置坐标 |
|
|
|
:return: 插值后的特征向量 |
|
|
|
""" |
|
|
|
"""三线性插值""" |
|
|
|
def trilinear_interpolation(self, query_point: torch.Tensor) -> torch.Tensor: |
|
|
|
""" |
|
|
|
实现三线性插值 |
|
|
|
:param query_point: 待插值的点,格式为 (x, y, z) |
|
|
|
:return: 插值结果,形状为 (D,) |
|
|
|
""" |
|
|
|
# 获取包围盒的边界 |
|
|
|
min_x, min_y, min_z, max_x, max_y, max_z = self.bbox |
|
|
|
# 确保 query_point 和 bbox 在同一设备上 |
|
|
|
#query_point = query_point.to(self.bbox.device) |
|
|
|
|
|
|
|
# 获取包围盒的最小和最大坐标 |
|
|
|
min_coords = self.bbox[:3] # [min_x, min_y, min_z] |
|
|
|
max_coords = self.bbox[3:] # [max_x, max_y, max_z] |
|
|
|
|
|
|
|
# 计算归一化坐标 |
|
|
|
x = (query_point[0] - min_x) / (max_x - min_x) |
|
|
|
y = (query_point[1] - min_y) / (max_y - min_y) |
|
|
|
z = (query_point[2] - min_z) / (max_z - min_z) |
|
|
|
normalized_coords = (query_point - min_coords) / (max_coords - min_coords + 1e-8) # 防止除零错误 |
|
|
|
x, y, z = normalized_coords.unbind(dim=-1) |
|
|
|
|
|
|
|
# 使用torch.stack避免Python标量转换 |
|
|
|
wx = torch.stack([1 - x, x], dim=-1) # 保持自动微分 |
|
|
|
wy = torch.stack([1 - y, y], dim=-1) |
|
|
|
wz = torch.stack([1 - z, z], dim=-1) |
|
|
|
|
|
|
|
# 获取8个顶点的特征向量 |
|
|
|
c000 = self.patch_feature_volume[0] |
|
|
|
c100 = self.patch_feature_volume[1] |
|
|
|
c010 = self.patch_feature_volume[2] |
|
|
|
c110 = self.patch_feature_volume[3] |
|
|
|
c001 = self.patch_feature_volume[4] |
|
|
|
c101 = self.patch_feature_volume[5] |
|
|
|
c011 = self.patch_feature_volume[6] |
|
|
|
c111 = self.patch_feature_volume[7] |
|
|
|
c = self.patch_feature_volume # 形状为 (8, D) |
|
|
|
|
|
|
|
# 执行三线性插值 |
|
|
|
c00 = c000 * (1 - x) + c100 * x |
|
|
|
c01 = c001 * (1 - x) + c101 * x |
|
|
|
c10 = c010 * (1 - x) + c110 * x |
|
|
|
c11 = c011 * (1 - x) + c111 * x |
|
|
|
# 先对 x 轴插值 |
|
|
|
c00 = c[0] * wx[0] + c[1] * wx[1] |
|
|
|
c01 = c[2] * wx[0] + c[3] * wx[1] |
|
|
|
c10 = c[4] * wx[0] + c[5] * wx[1] |
|
|
|
c11 = c[6] * wx[0] + c[7] * wx[1] |
|
|
|
|
|
|
|
c0 = c00 * (1 - y) + c10 * y |
|
|
|
c1 = c01 * (1 - y) + c11 * y |
|
|
|
# 再对 y 轴插值 |
|
|
|
c0 = c00 * wy[0] + c10 * wy[1] |
|
|
|
c1 = c01 * wy[0] + c11 * wy[1] |
|
|
|
|
|
|
|
return c0 * (1 - z) + c1 * z |
|
|
|
# 最后对 z 轴插值 |
|
|
|
result = c0 * wz[0] + c1 * wz[1] |
|
|
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None: |
|
|
@ -229,3 +237,40 @@ class OctreeNode: |
|
|
|
for i, child in enumerate(self.children): |
|
|
|
print(f"{indent} Child {i}:") |
|
|
|
child.print_tree(depth + 1, max_print_depth) |
|
|
|
|
|
|
|
# 保存 |
|
|
|
def state_dict(self): |
|
|
|
"""返回节点及其子树的state_dict""" |
|
|
|
state = { |
|
|
|
'bbox': self.bbox, |
|
|
|
'max_depth': self.max_depth, |
|
|
|
'face_indices': self.face_indices, |
|
|
|
'is_leaf': self._is_leaf |
|
|
|
} |
|
|
|
|
|
|
|
if self._is_leaf: |
|
|
|
state['patch_feature_volume'] = self.patch_feature_volume |
|
|
|
else: |
|
|
|
state['children'] = [child.state_dict() for child in self.children] |
|
|
|
|
|
|
|
return state |
|
|
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
|
"""从state_dict加载节点状态""" |
|
|
|
self.bbox = state_dict['bbox'] |
|
|
|
self.max_depth = state_dict['max_depth'] |
|
|
|
self.face_indices = state_dict['face_indices'] |
|
|
|
self._is_leaf = state_dict['is_leaf'] |
|
|
|
|
|
|
|
if self._is_leaf: |
|
|
|
self.patch_feature_volume = nn.Parameter(state_dict['patch_feature_volume']) |
|
|
|
else: |
|
|
|
self.children = [] |
|
|
|
for child_state in state_dict['children']: |
|
|
|
child = OctreeNode( |
|
|
|
bbox=child_state['bbox'], |
|
|
|
face_indices=child_state['face_indices'], |
|
|
|
max_depth=child_state['max_depth'] |
|
|
|
) |
|
|
|
child.load_state_dict(child_state) |
|
|
|
self.children.append(child) |