|
@ -28,21 +28,20 @@ def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> bool: |
|
|
# 向量化比较 |
|
|
# 向量化比较 |
|
|
return torch.all((max1 >= min2) & (max2 >= min1)) |
|
|
return torch.all((max1 >= min2) & (max2 >= min1)) |
|
|
|
|
|
|
|
|
class OctreeNode: |
|
|
class OctreeNode(nn.Module): |
|
|
feature_dim=None |
|
|
|
|
|
device=None |
|
|
device=None |
|
|
surf_bbox = None |
|
|
surf_bbox = None |
|
|
def __init__(self, bbox: torch.Tensor,face_indices: np.ndarray, max_depth: int = 5, feature_dim:int = None, surf_bbox:torch.Tensor = None): |
|
|
def __init__(self, bbox: torch.Tensor,face_indices: np.ndarray, max_depth: int = 5, surf_bbox:torch.Tensor = None): |
|
|
|
|
|
super().__init__() |
|
|
self.bbox = bbox # 节点的边界框 |
|
|
self.bbox = bbox # 节点的边界框 |
|
|
self.max_depth: int = max_depth # 最大深度,当这个为0时,表示已经到达最大深度,不可再分子节点 |
|
|
self.max_depth: int = max_depth # 最大深度,当这个为0时,表示已经到达最大深度,不可再分子节点 |
|
|
self.children: List['OctreeNode'] = [] # 子节点列表 |
|
|
self.child_nodes: List['OctreeNode'] = [] # 子节点列表 |
|
|
self.face_indices = face_indices |
|
|
self.face_indices = face_indices |
|
|
self.patch_feature_volume = None # 补丁特征体积,only leaf has |
|
|
self.param_key = None |
|
|
|
|
|
#self.patch_feature_volume = None # 补丁特征体积,only leaf has |
|
|
self._is_leaf = True |
|
|
self._is_leaf = True |
|
|
#print(f"box shape: {self.bbox.shape}") |
|
|
#print(f"box shape: {self.bbox.shape}") |
|
|
|
|
|
|
|
|
if feature_dim is not None: |
|
|
|
|
|
OctreeNode.feature_dim = feature_dim |
|
|
|
|
|
if surf_bbox is not None: |
|
|
if surf_bbox is not None: |
|
|
if not isinstance(surf_bbox, torch.Tensor): |
|
|
if not isinstance(surf_bbox, torch.Tensor): |
|
|
raise TypeError( |
|
|
raise TypeError( |
|
@ -56,13 +55,15 @@ class OctreeNode: |
|
|
OctreeNode.device = bbox.device |
|
|
OctreeNode.device = bbox.device |
|
|
|
|
|
|
|
|
def is_leaf(self): |
|
|
def is_leaf(self): |
|
|
# Check if self.children is None before calling len() |
|
|
# Check if self.child——nodes is None before calling len() |
|
|
return self._is_leaf |
|
|
return self._is_leaf |
|
|
|
|
|
|
|
|
|
|
|
def set_param_key(self, k): |
|
|
|
|
|
self.param_key = k |
|
|
|
|
|
|
|
|
def conduct_tree(self): |
|
|
def conduct_tree(self): |
|
|
if self.max_depth <= 0 or self.face_indices.shape[0] <= 2: |
|
|
if self.max_depth <= 0 or self.face_indices.shape[0] <= 2: |
|
|
# 达到最大深度 or 一个单元格至多只有两个面 |
|
|
# 达到最大深度 or 一个单元格至多只有两个面 |
|
|
self.patch_feature_volume = nn.Parameter(torch.randn(8, OctreeNode.feature_dim, device=OctreeNode.device)) |
|
|
|
|
|
return |
|
|
return |
|
|
self.subdivide() |
|
|
self.subdivide() |
|
|
|
|
|
|
|
@ -102,7 +103,7 @@ class OctreeNode: |
|
|
]) |
|
|
]) |
|
|
|
|
|
|
|
|
# 为每个子包围盒创建子节点,并分配相交的面 |
|
|
# 为每个子包围盒创建子节点,并分配相交的面 |
|
|
self.children = [] |
|
|
self.child_nodes = [] |
|
|
for bbox in child_bboxes: |
|
|
for bbox in child_bboxes: |
|
|
# 找到与子包围盒相交的面 |
|
|
# 找到与子包围盒相交的面 |
|
|
intersecting_faces = [] |
|
|
intersecting_faces = [] |
|
@ -118,7 +119,7 @@ class OctreeNode: |
|
|
max_depth=self.max_depth - 1 |
|
|
max_depth=self.max_depth - 1 |
|
|
) |
|
|
) |
|
|
child_node.conduct_tree() |
|
|
child_node.conduct_tree() |
|
|
self.children.append(child_node) |
|
|
self.child_nodes.append(child_node) |
|
|
|
|
|
|
|
|
self._is_leaf = False |
|
|
self._is_leaf = False |
|
|
|
|
|
|
|
@ -143,6 +144,25 @@ class OctreeNode: |
|
|
|
|
|
|
|
|
return index.unsqueeze(0) |
|
|
return index.unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
def find_leaf(self, query_point:torch.Tensor): |
|
|
|
|
|
# 从根节点开始递归查找包含该点的叶子节点 |
|
|
|
|
|
if self._is_leaf: |
|
|
|
|
|
return self |
|
|
|
|
|
else: |
|
|
|
|
|
index = self.get_child_index(query_point) |
|
|
|
|
|
try: |
|
|
|
|
|
# 直接访问子节点,不进行显式检查 |
|
|
|
|
|
return self.child_nodes[index].find_leaf(query_point) |
|
|
|
|
|
except IndexError as e: |
|
|
|
|
|
# 记录错误日志并重新抛出异常 |
|
|
|
|
|
logger.error( |
|
|
|
|
|
f"Error accessing child node: {e}. " |
|
|
|
|
|
f"Query point: {query_point.cpu().numpy().tolist()}, " |
|
|
|
|
|
f"Node bbox: {self.bbox.cpu().numpy().tolist()}, " |
|
|
|
|
|
f"Depth info: {self.max_depth}" |
|
|
|
|
|
) |
|
|
|
|
|
raise e |
|
|
|
|
|
|
|
|
def get_feature_vector(self, query_point:torch.Tensor): |
|
|
def get_feature_vector(self, query_point:torch.Tensor): |
|
|
""" |
|
|
""" |
|
|
预测给定点的 SDF 值 |
|
|
预测给定点的 SDF 值 |
|
@ -158,7 +178,7 @@ class OctreeNode: |
|
|
index = self.get_child_index(query_point) |
|
|
index = self.get_child_index(query_point) |
|
|
try: |
|
|
try: |
|
|
# 直接访问子节点,不进行显式检查 |
|
|
# 直接访问子节点,不进行显式检查 |
|
|
return self.children[index].get_feature_vector(query_point) |
|
|
return self.child_nodes[index].get_feature_vector(query_point) |
|
|
except IndexError as e: |
|
|
except IndexError as e: |
|
|
# 记录错误日志并重新抛出异常 |
|
|
# 记录错误日志并重新抛出异常 |
|
|
logger.error( |
|
|
logger.error( |
|
@ -170,46 +190,7 @@ class OctreeNode: |
|
|
raise e |
|
|
raise e |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def trilinear_interpolation(self, query_point: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
""" |
|
|
|
|
|
实现三线性插值 |
|
|
|
|
|
:param query_point: 待插值的点,格式为 (x, y, z) |
|
|
|
|
|
:return: 插值结果,形状为 (D,) |
|
|
|
|
|
""" |
|
|
|
|
|
# 确保 query_point 和 bbox 在同一设备上 |
|
|
|
|
|
#query_point = query_point.to(self.bbox.device) |
|
|
|
|
|
|
|
|
|
|
|
# 获取包围盒的最小和最大坐标 |
|
|
|
|
|
min_coords = self.bbox[:3] # [min_x, min_y, min_z] |
|
|
|
|
|
max_coords = self.bbox[3:] # [max_x, max_y, max_z] |
|
|
|
|
|
|
|
|
|
|
|
# 计算归一化坐标 |
|
|
|
|
|
normalized_coords = (query_point - min_coords) / (max_coords - min_coords + 1e-8) # 防止除零错误 |
|
|
|
|
|
x, y, z = normalized_coords.unbind(dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
# 使用torch.stack避免Python标量转换 |
|
|
|
|
|
wx = torch.stack([1 - x, x], dim=-1) # 保持自动微分 |
|
|
|
|
|
wy = torch.stack([1 - y, y], dim=-1) |
|
|
|
|
|
wz = torch.stack([1 - z, z], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
# 获取8个顶点的特征向量 |
|
|
|
|
|
c = self.patch_feature_volume # 形状为 (8, D) |
|
|
|
|
|
|
|
|
|
|
|
# 执行三线性插值 |
|
|
|
|
|
# 先对 x 轴插值 |
|
|
|
|
|
c00 = c[0] * wx[0] + c[1] * wx[1] |
|
|
|
|
|
c01 = c[2] * wx[0] + c[3] * wx[1] |
|
|
|
|
|
c10 = c[4] * wx[0] + c[5] * wx[1] |
|
|
|
|
|
c11 = c[6] * wx[0] + c[7] * wx[1] |
|
|
|
|
|
|
|
|
|
|
|
# 再对 y 轴插值 |
|
|
|
|
|
c0 = c00 * wy[0] + c10 * wy[1] |
|
|
|
|
|
c1 = c01 * wy[0] + c11 * wy[1] |
|
|
|
|
|
|
|
|
|
|
|
# 最后对 z 轴插值 |
|
|
|
|
|
result = c0 * wz[0] + c1 * wz[1] |
|
|
|
|
|
|
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None: |
|
|
def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None: |
|
@ -231,46 +212,32 @@ class OctreeNode: |
|
|
# 打印面片信息(如果有) |
|
|
# 打印面片信息(如果有) |
|
|
if self.face_indices is not None: |
|
|
if self.face_indices is not None: |
|
|
print(f"{indent} Face indices: {self.face_indices.tolist()}") |
|
|
print(f"{indent} Face indices: {self.face_indices.tolist()}") |
|
|
print(f"{indent} len children: {len(self.children)}") |
|
|
print(f"{indent} len child_nodes: {len(self.child_nodes)}") |
|
|
|
|
|
|
|
|
# 递归打印子节点 |
|
|
# 递归打印子节点 |
|
|
for i, child in enumerate(self.children): |
|
|
for i, child in enumerate(self.child_nodes): |
|
|
print(f"{indent} Child {i}:") |
|
|
print(f"{indent} Child {i}:") |
|
|
child.print_tree(depth + 1, max_print_depth) |
|
|
child.print_tree(depth + 1, max_print_depth) |
|
|
|
|
|
|
|
|
# 保存 |
|
|
def __getstate__(self): |
|
|
def state_dict(self): |
|
|
"""支持pickle序列化""" |
|
|
"""返回节点及其子树的state_dict""" |
|
|
return self._serialize_node(self) |
|
|
state = { |
|
|
|
|
|
'bbox': self.bbox, |
|
|
|
|
|
'max_depth': self.max_depth, |
|
|
|
|
|
'face_indices': self.face_indices, |
|
|
|
|
|
'is_leaf': self._is_leaf |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if self._is_leaf: |
|
|
|
|
|
state['patch_feature_volume'] = self.patch_feature_volume |
|
|
|
|
|
else: |
|
|
|
|
|
state['children'] = [child.state_dict() for child in self.children] |
|
|
|
|
|
|
|
|
|
|
|
return state |
|
|
def __setstate__(self, state): |
|
|
|
|
|
"""支持pickle反序列化""" |
|
|
|
|
|
self = self._deserialize_node(state) |
|
|
|
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
def _serialize_node(self, node): |
|
|
"""从state_dict加载节点状态""" |
|
|
return { |
|
|
self.bbox = state_dict['bbox'] |
|
|
'bbox': node.bbox, |
|
|
self.max_depth = state_dict['max_depth'] |
|
|
'is_leaf': node._is_leaf, |
|
|
self.face_indices = state_dict['face_indices'] |
|
|
'child_nodes': [self._serialize_node(c) for c in node.child_nodes], |
|
|
self._is_leaf = state_dict['is_leaf'] |
|
|
'param_key': node.param_key |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
if self._is_leaf: |
|
|
def _deserialize_node(self, data): |
|
|
self.patch_feature_volume = nn.Parameter(state_dict['patch_feature_volume']) |
|
|
node = OctreeNode(data['bbox'], 0) # max_depth会在encoder中重建 |
|
|
else: |
|
|
node._is_leaf = data['is_leaf'] |
|
|
self.children = [] |
|
|
node.param_key = data['param_key'] |
|
|
for child_state in state_dict['children']: |
|
|
node.child_nodes = [self._deserialize_node(c) for c in data['child_nodes']] |
|
|
child = OctreeNode( |
|
|
return node |
|
|
bbox=child_state['bbox'], |
|
|
|
|
|
face_indices=child_state['face_indices'], |
|
|
|
|
|
max_depth=child_state['max_depth'] |
|
|
|
|
|
) |
|
|
|
|
|
child.load_state_dict(child_state) |
|
|
|
|
|
self.children.append(child) |
|
|
|