|
|
@ -29,25 +29,29 @@ def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor: |
|
|
|
return torch.all((max1 >= min2) & (max2 >= min1)) |
|
|
|
|
|
|
|
class OctreeNode(nn.Module): |
|
|
|
def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None): |
|
|
|
def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None,device=None): |
|
|
|
super().__init__() |
|
|
|
# 静态张量存储节点信息 |
|
|
|
self.register_buffer('bbox', bbox) # 当前节点的边界框 |
|
|
|
self.register_buffer('node_bboxes', None) # 所有节点的边界框 |
|
|
|
self.register_buffer('parent_indices', None) # 父节点索引 |
|
|
|
self.register_buffer('child_indices', None) # 子节点索引 |
|
|
|
self.register_buffer('is_leaf_mask', None) # 叶子节点标记 |
|
|
|
self.register_buffer('face_indices', torch.from_numpy(face_indices).to(bbox.device)) # 面片索引张量 |
|
|
|
self.register_buffer('surf_bbox', surf_bbox) # 面片边界框 |
|
|
|
|
|
|
|
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
# 改为普通张量属性 |
|
|
|
self.bbox = bbox.to(self.device) # 显式设备管理 |
|
|
|
self.node_bboxes = None |
|
|
|
self.parent_indices = None |
|
|
|
self.child_indices = None |
|
|
|
self.is_leaf_mask = None |
|
|
|
# 面片索引张量 |
|
|
|
self.face_indices = torch.from_numpy(face_indices).to(self.device) |
|
|
|
self.surf_bbox = surf_bbox.to(self.device) if surf_bbox is not None else None |
|
|
|
|
|
|
|
# PatchGraph作为普通属性 |
|
|
|
self.patch_graph = patch_graph # 不再使用register_buffer |
|
|
|
self.patch_graph = patch_graph.to(self.device) if patch_graph is not None else None |
|
|
|
|
|
|
|
self.max_depth = max_depth |
|
|
|
# 将param_key改为张量 |
|
|
|
self.register_buffer('param_key', torch.tensor(-1, dtype=torch.long)) |
|
|
|
# 参数键改为普通张量 |
|
|
|
self.param_key = torch.tensor(-1, dtype=torch.long, device=self.device) |
|
|
|
self._is_leaf = True |
|
|
|
|
|
|
|
# 删除所有register_buffer调用 |
|
|
|
|
|
|
|
@torch.jit.export |
|
|
|
def set_param_key(self, k: int) -> None: |
|
|
|
"""设置参数键值 |
|
|
@ -64,11 +68,10 @@ class OctreeNode(nn.Module): |
|
|
|
total_nodes = int(sum(8**i for i in range(self.max_depth + 1))) |
|
|
|
|
|
|
|
# 初始化静态张量,使用整数列表作为形状参数 |
|
|
|
self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.bbox.device) |
|
|
|
self.parent_indices = torch.full([int(total_nodes)], -1, dtype=torch.long, device=self.bbox.device) |
|
|
|
self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.bbox.device) |
|
|
|
self.is_leaf_mask = torch.zeros([int(total_nodes)], dtype=torch.bool, device=self.bbox.device) |
|
|
|
|
|
|
|
self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.device) |
|
|
|
self.parent_indices = torch.full([int(total_nodes)], -1, dtype=torch.long, device=self.device) |
|
|
|
self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.device) |
|
|
|
self.is_leaf_mask = torch.zeros([int(total_nodes)], dtype=torch.bool, device=self.device) |
|
|
|
# 使用队列进行广度优先遍历 |
|
|
|
queue = [(0, self.bbox, self.face_indices)] # (node_idx, bbox, face_indices) |
|
|
|
current_idx = 0 |
|
|
@ -108,7 +111,7 @@ class OctreeNode(nn.Module): |
|
|
|
|
|
|
|
# 将子节点加入队列 |
|
|
|
if intersecting_faces: |
|
|
|
queue.append((child_idx, child_bbox, torch.tensor(intersecting_faces, device=self.bbox.device))) |
|
|
|
queue.append((child_idx, child_bbox, torch.tensor(intersecting_faces, device=self.device))) |
|
|
|
|
|
|
|
def _should_split_node(self, current_depth: int) -> bool: |
|
|
|
"""判断节点是否需要分裂""" |
|
|
@ -127,7 +130,7 @@ class OctreeNode(nn.Module): |
|
|
|
def _generate_child_bboxes(self, min_coords: torch.Tensor, mid_coords: torch.Tensor, max_coords: torch.Tensor) -> torch.Tensor: |
|
|
|
# 使用 with torch.no_grad() 减少梯度计算的内存占用 |
|
|
|
with torch.no_grad(): |
|
|
|
child_bboxes = torch.zeros([8, 6], device=self.bbox.device) |
|
|
|
child_bboxes = torch.zeros([8, 6], device=self.device) |
|
|
|
|
|
|
|
# 使用向量化操作生成所有子节点边界框 |
|
|
|
child_bboxes[0] = torch.cat([min_coords, mid_coords]) # 前下左 |
|
|
@ -199,6 +202,7 @@ class OctreeNode(nn.Module): |
|
|
|
bboxes.append(bbox) |
|
|
|
param_indices = torch.stack(param_indices) |
|
|
|
bboxes = torch.stack(bboxes) |
|
|
|
# 添加检查代码 |
|
|
|
return param_indices, bboxes |
|
|
|
|
|
|
|
def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None: |
|
|
@ -258,4 +262,21 @@ class OctreeNode(nn.Module): |
|
|
|
self.patch_graph = state['patch_graph'] |
|
|
|
self.max_depth = state['max_depth'] |
|
|
|
self.param_key = state['param_key'] |
|
|
|
self._is_leaf = state['_is_leaf'] |
|
|
|
self._is_leaf = state['_is_leaf'] |
|
|
|
|
|
|
|
def to(self, device=None, dtype=None, non_blocking=False): |
|
|
|
# 调用父类方法迁移基础参数 |
|
|
|
super().to(device, dtype, non_blocking) |
|
|
|
|
|
|
|
# 迁移自定义属性 |
|
|
|
if self.patch_graph is not None: |
|
|
|
if hasattr(self.patch_graph, 'to'): |
|
|
|
self.patch_graph = self.patch_graph.to(device=device, dtype=dtype) |
|
|
|
else: |
|
|
|
# 手动移动非Module属性 |
|
|
|
for attr in ['edge_index', 'edge_type', 'patch_features']: |
|
|
|
tensor = getattr(self.patch_graph, attr, None) |
|
|
|
if tensor is not None: |
|
|
|
setattr(self.patch_graph, attr, tensor.to(device=device, dtype=dtype)) |
|
|
|
|
|
|
|
return self |