|
|
@ -119,8 +119,6 @@ class OctreeNode(nn.Module): |
|
|
|
# 改为普通张量属性 |
|
|
|
self.bbox = bbox.to(self.device) # 显式设备管理 |
|
|
|
self.node_bboxes = None |
|
|
|
self.parent_indices = None |
|
|
|
self.child_indices = None |
|
|
|
self.is_leaf_mask = None |
|
|
|
# 面片索引张量 |
|
|
|
self.all_face_indices = torch.from_numpy(face_indices).to(self.device) |
|
|
@ -141,24 +139,21 @@ class OctreeNode(nn.Module): |
|
|
|
|
|
|
|
# 初始化静态张量,使用整数列表作为形状参数 |
|
|
|
self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.device) |
|
|
|
self.parent_indices = torch.full([int(total_nodes)], -1, dtype=torch.long, device=self.device) |
|
|
|
self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.device) |
|
|
|
self.face_indices_mask = torch.zeros([int(total_nodes),num_faces], dtype=torch.bool, device=self.device) # 1 代表有 |
|
|
|
self.is_leaf_mask = torch.ones([int(total_nodes)], dtype=torch.bool, device=self.device) |
|
|
|
self.is_valid_leaf_mask = torch.ones([int(total_nodes)], dtype=torch.bool, device=self.device) |
|
|
|
# 使用队列进行广度优先遍历 |
|
|
|
queue = [(0, self.bbox, self.all_face_indices)] # (node_idx, bbox, face_indices) |
|
|
|
current_idx = 0 |
|
|
|
|
|
|
|
while queue: |
|
|
|
node_idx, bbox, faces = queue.pop(0) |
|
|
|
node_idx, bbox, faces_indices = queue.pop(0) |
|
|
|
|
|
|
|
#logger.debug(f"Processing node {node_idx} with {len(faces)} faces.") |
|
|
|
self.node_bboxes[node_idx] = bbox |
|
|
|
|
|
|
|
# 判断 要不要继续分裂 |
|
|
|
if not self._should_split_node(current_idx, faces, total_nodes): |
|
|
|
if not self._should_split_node(node_idx, faces_indices, total_nodes): |
|
|
|
continue |
|
|
|
|
|
|
|
self.is_leaf_mask[node_idx] = 0 |
|
|
|
# 计算子节点边界框 |
|
|
|
min_coords = bbox[:3] |
|
|
@ -167,28 +162,28 @@ class OctreeNode(nn.Module): |
|
|
|
|
|
|
|
# 生成8个子节点 |
|
|
|
child_bboxes = self._generate_child_bboxes(min_coords, mid_coords, max_coords) |
|
|
|
intersect_mask = bbox_intersect(self.surf_bbox, faces, child_bboxes) |
|
|
|
self.face_indices_mask[current_idx + 1:current_idx + 9, :] = intersect_mask |
|
|
|
intersect_mask = bbox_intersect(self.surf_bbox, faces_indices, child_bboxes) |
|
|
|
self.face_indices_mask[node_idx * 8 + 1:node_idx * 8 + 9, :] = intersect_mask |
|
|
|
|
|
|
|
# 为每个子节点分配面片 |
|
|
|
for i, child_bbox in enumerate(child_bboxes): |
|
|
|
child_idx = child_idx = current_idx + i + 1 |
|
|
|
child_idx = 8 * node_idx + i + 1 |
|
|
|
|
|
|
|
intersecting_faces = intersect_mask[i].nonzero().flatten() |
|
|
|
#logger.debug(f"Node {child_idx} has {len(intersecting_faces)} intersecting faces.") |
|
|
|
# 更新节点关系 |
|
|
|
self.parent_indices[child_idx] = node_idx |
|
|
|
self.child_indices[node_idx, i] = child_idx |
|
|
|
|
|
|
|
# 将子节点加入队列 |
|
|
|
if len(intersecting_faces) > 0: |
|
|
|
queue.append((child_idx, child_bbox, intersecting_faces.clone().detach())) |
|
|
|
current_idx += 8 |
|
|
|
else: |
|
|
|
self.is_valid_leaf_mask[child_idx] = False |
|
|
|
|
|
|
|
#self.is_valid_leaf_mask = self.is_valid_leaf_mask & self.is_leaf_mask |
|
|
|
|
|
|
|
def _should_split_node(self, current_idx: int,face_indices:torch.Tensor,max_node:int) -> bool: |
|
|
|
"""判断节点是否需要分裂""" |
|
|
|
# 检查是否达到最大深度 |
|
|
|
if current_idx + 8 >= max_node: |
|
|
|
if current_idx * 8 + 9 >= max_node: |
|
|
|
return False |
|
|
|
|
|
|
|
# 检查是否为完全图 |
|
|
@ -225,61 +220,64 @@ class OctreeNode(nn.Module): |
|
|
|
return child_bboxes |
|
|
|
|
|
|
|
@torch.jit.export |
|
|
|
def find_leaf(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, bool]: |
|
|
|
def find_leaf_batch(self, query_points: torch.Tensor) -> torch.Tensor: |
|
|
|
""" |
|
|
|
修改后的查找叶子节点方法,返回face indices |
|
|
|
:param query_points: 待查找的点,形状为 (3,) |
|
|
|
:return: (bbox, face_indices, is_leaf) |
|
|
|
""" |
|
|
|
# 确保输入是单个点 |
|
|
|
if query_points.dim() != 1 or query_points.shape[0] != 3: |
|
|
|
raise ValueError(f"query_points 必须是形状为 (3,) 的张量,但得到 {query_points.shape}") |
|
|
|
|
|
|
|
current_idx = torch.tensor(0, dtype=torch.long, device=query_points.device) |
|
|
|
max_iterations = 1000 # 防止无限循环 |
|
|
|
iteration = 0 |
|
|
|
|
|
|
|
while iteration < max_iterations: |
|
|
|
# 获取当前节点的叶子状态 |
|
|
|
if self.is_leaf_mask[current_idx].item(): |
|
|
|
#logger.debug(f"Reached leaf node {current_idx}, with {self.face_indices_mask[current_idx].sum()} faces.") |
|
|
|
if self.face_indices_mask[current_idx].sum() == 0: |
|
|
|
parent_idx = self.parent_indices[current_idx] |
|
|
|
#logger.debug(f"Use parent node {parent_idx}, with {self.face_indices_mask[parent_idx].sum()} faces.") |
|
|
|
if parent_idx == -1: |
|
|
|
# 根节点没有父节点,返回根节点的信息 |
|
|
|
return ( |
|
|
|
self.node_bboxes[current_idx], |
|
|
|
torch.tensor([], dtype=torch.float32, device=self.node_bboxes.device), # 新增返回face indices |
|
|
|
False |
|
|
|
) |
|
|
|
return ( |
|
|
|
self.node_bboxes[parent_idx], |
|
|
|
self.face_indices_mask[parent_idx], # 新增返回face indices |
|
|
|
False |
|
|
|
) |
|
|
|
return ( |
|
|
|
self.node_bboxes[current_idx], |
|
|
|
self.face_indices_mask[current_idx], # 新增返回face indices |
|
|
|
True |
|
|
|
) |
|
|
|
批量定位点在八叉树中的对应节点索引。 |
|
|
|
|
|
|
|
# 计算子节点索引 |
|
|
|
child_idx = self._get_child_indices(query_points.unsqueeze(0), |
|
|
|
self.node_bboxes[current_idx].unsqueeze(0)) |
|
|
|
Args: |
|
|
|
query_points (torch.Tensor): 点的坐标,形状为 [N, 3]。 |
|
|
|
|
|
|
|
# 获取下一个要访问的节点 |
|
|
|
next_idx = self.child_indices[current_idx, child_idx[0]] |
|
|
|
Returns: |
|
|
|
torch.Tensor: 每个点对应的节点索引,形状为 [N]。 |
|
|
|
""" |
|
|
|
num_points = query_points.shape[0] |
|
|
|
|
|
|
|
# 检查索引是否有效 |
|
|
|
if next_idx == -1: |
|
|
|
raise IndexError(f"Invalid child node index: {child_idx[0]}") |
|
|
|
# 初始化:所有点从根节点开始 |
|
|
|
node_indices = torch.zeros(num_points, dtype=torch.int64, device=self.device) |
|
|
|
active_mask = torch.full((num_points,), |
|
|
|
not self.is_leaf_mask[0].item(), |
|
|
|
dtype=torch.bool, |
|
|
|
device=self.device) # 标记需要继续处理的点 |
|
|
|
|
|
|
|
current_idx = next_idx |
|
|
|
iteration = 0 |
|
|
|
while active_mask.any() and iteration < self.max_depth: |
|
|
|
# 获取当前活动点的索引和对应的节点信息 |
|
|
|
active_indices = torch.nonzero(active_mask).squeeze(1) # 当前活动点的索引 |
|
|
|
current_node_indices = node_indices[active_indices] # 当前活动点对应的节点索引 |
|
|
|
|
|
|
|
# 获取当前节点的边界框 |
|
|
|
current_bboxes = self.node_bboxes[current_node_indices] # 形状为 [M, 6],其中 M 是活动点数 |
|
|
|
|
|
|
|
# 将边界框分为 min_coords 和 max_coords |
|
|
|
min_coords = current_bboxes[:, :3] # 形状为 [M, 3] |
|
|
|
max_coords = current_bboxes[:, 3:] # 形状为 [M, 3] |
|
|
|
mid_coords = (min_coords + max_coords) / 2 # 中点坐标,形状为 [M, 3] |
|
|
|
|
|
|
|
# 获取当前活动点的坐标 |
|
|
|
active_points = query_points[active_indices] # 形状为 [M, 3] |
|
|
|
|
|
|
|
# 判断点属于哪个子节点 |
|
|
|
child_offsets = ((active_points > mid_coords).long() * torch.tensor([1, 2, 4], device=self.device)).sum(dim=1) # 子节点偏移量 |
|
|
|
child_indices = current_node_indices * 8 + child_offsets + 1 # 子节点索引 |
|
|
|
|
|
|
|
# 更新节点索引 |
|
|
|
node_indices[active_indices] = child_indices |
|
|
|
|
|
|
|
# 更新活动掩码:如果子节点是叶子节点,则不再处理 |
|
|
|
is_leaf = self.is_leaf_mask[child_indices] |
|
|
|
active_mask[active_indices[is_leaf]] = False |
|
|
|
# logger.debug(f"iter{iteration}: active_indices{active_indices}") |
|
|
|
|
|
|
|
is_valid = self.is_valid_leaf_mask[child_indices] |
|
|
|
#logger.debug(is_valid) |
|
|
|
# 如果是无效叶节点,返回父节点 |
|
|
|
invalid_leaf_mask = is_leaf & (~is_valid) |
|
|
|
#logger.debug(torch.nonzero(invalid_leaf_mask).squeeze(1) ) |
|
|
|
node_indices[active_indices[invalid_leaf_mask]] = (child_indices[invalid_leaf_mask]-1) // 8 |
|
|
|
#print(active_indices[invalid_leaf_mask]) |
|
|
|
#logger.debug(child_indices[invalid_leaf_mask] // 8) |
|
|
|
iteration += 1 |
|
|
|
|
|
|
|
# 如果达到最大迭代次数,返回当前节点的信息 |
|
|
|
return self.node_bboxes[current_idx], torch.tensor([], dtype=torch.float32, device=self.node_bboxes.device),bool(self.is_leaf_mask[current_idx].item()) |
|
|
|
return node_indices |
|
|
|
|
|
|
|
@torch.jit.export |
|
|
|
def _get_child_indices(self, points: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor: |
|
|
@ -289,21 +287,30 @@ class OctreeNode(nn.Module): |
|
|
|
|
|
|
|
def forward(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
|
with torch.no_grad(): |
|
|
|
bboxes: List[torch.Tensor] = [] |
|
|
|
face_indices_mask: List[torch.Tensor] = [] |
|
|
|
operator: List[int] = [] |
|
|
|
for point in query_points: |
|
|
|
bbox, faces_mask, _ = self.find_leaf(point) |
|
|
|
bboxes.append(bbox) |
|
|
|
face_indices_mask.append(faces_mask) |
|
|
|
# 获取当前节点的CSG树结构 |
|
|
|
#csg_tree = self.patch_graph.get_csg_tree(faces_mask) if self.patch_graph else None |
|
|
|
#csg_trees.append(csg_tree) # 保持原始列表结构 |
|
|
|
operator.append(self.patch_graph.get_operator(faces_mask.nonzero()) if self.patch_graph is not None else -1) |
|
|
|
# 使用批量查询方法 |
|
|
|
#logger.gpu_memory_stats("后") |
|
|
|
node_indices = self.find_leaf_batch(query_points) # 点对应的 node idx |
|
|
|
#logger.gpu_memory_stats("SDF后") |
|
|
|
bboxes = torch.stack([self.node_bboxes[idx] for idx in node_indices]) |
|
|
|
face_indices_masks = torch.stack([self.face_indices_mask[idx] for idx in node_indices]) |
|
|
|
#print(face_indices_masks) |
|
|
|
# 处理操作符 |
|
|
|
operators: List[int] = [] |
|
|
|
if self.patch_graph is not None: |
|
|
|
# 批量获取操作符 |
|
|
|
for i in range(query_points.shape[0]): |
|
|
|
faces_mask = face_indices_masks[i] |
|
|
|
#print(faces_mask.nonzero()) |
|
|
|
#print(node_indices[i]) |
|
|
|
#print(self.is_leaf_mask[node_indices[i]],self.is_valid_leaf_mask[node_indices[i]]) |
|
|
|
operators.append(self.patch_graph.get_operator(faces_mask.nonzero())) |
|
|
|
else: |
|
|
|
operators = [-1] * query_points.shape[0] |
|
|
|
|
|
|
|
return ( |
|
|
|
torch.stack(bboxes), |
|
|
|
torch.stack(face_indices_mask), |
|
|
|
torch.tensor(operator, dtype=torch.int) # 直接返回列表,不转换为张量 |
|
|
|
bboxes, |
|
|
|
face_indices_masks, |
|
|
|
torch.tensor(operators, dtype=torch.int, device=query_points.device) |
|
|
|
) |
|
|
|
|
|
|
|
def print_tree(self, max_print_depth: Optional[int] = None) -> None: |
|
|
@ -327,25 +334,24 @@ class OctreeNode(nn.Module): |
|
|
|
|
|
|
|
indent = " " * depth # 根据深度生成缩进 |
|
|
|
is_leaf = self.is_leaf_mask[node_idx].item() # 判断是否为叶子节点 |
|
|
|
is_valid_leaf = self.is_valid_leaf_mask[node_idx].item() |
|
|
|
bbox = self.node_bboxes[node_idx].cpu().numpy().tolist() # 获取边界框信息 |
|
|
|
|
|
|
|
# 打印当前节点的基本信息 |
|
|
|
node_type = "Leaf" if is_leaf else "Internal" |
|
|
|
log_lines.append(f"{indent}L{depth} [{node_type}] NODE_ID-{node_idx}, BBox: {bbox}") |
|
|
|
log_lines.append(f"{indent}L{depth} [{node_type}] NODE_ID-{node_idx}, BBox: {bbox}, VALID:{is_valid_leaf}") |
|
|
|
if self.face_indices_mask is not None: |
|
|
|
face_indices = self.face_indices_mask[node_idx].nonzero().cpu().numpy().flatten().tolist() |
|
|
|
log_lines.append(f"{indent} Face Indices: {face_indices}") |
|
|
|
# 如果是叶子节点,打印额外信息 |
|
|
|
if is_leaf: |
|
|
|
|
|
|
|
child_indices = self.child_indices[node_idx].cpu().numpy().tolist() |
|
|
|
log_lines.append(f"{indent} Child Indices: {child_indices}") |
|
|
|
pass |
|
|
|
|
|
|
|
# 如果不是叶子节点,递归处理子节点 |
|
|
|
if not is_leaf: |
|
|
|
for i in range(8): # 遍历所有子节点 |
|
|
|
child_idx = self.child_indices[node_idx, i].item() |
|
|
|
if child_idx != -1: # 忽略无效的子节点索引 |
|
|
|
child_idx = 8 * node_idx + i + 1 |
|
|
|
if self.is_leaf_mask[child_idx] != -1: # 忽略无效的子节点索引 |
|
|
|
dfs(child_idx, depth + 1) |
|
|
|
|
|
|
|
# 初始化日志行列表 |
|
|
@ -362,9 +368,8 @@ class OctreeNode(nn.Module): |
|
|
|
state = { |
|
|
|
'bbox': self.bbox.cpu(), # 转换为CPU张量 |
|
|
|
'node_bboxes': self.node_bboxes.cpu() if self.node_bboxes is not None else None, |
|
|
|
'parent_indices': self.parent_indices.cpu() if self.parent_indices is not None else None, |
|
|
|
'child_indices': self.child_indices.cpu() if self.child_indices is not None else None, |
|
|
|
'is_leaf_mask': self.is_leaf_mask.cpu() if self.is_leaf_mask is not None else None, |
|
|
|
"is_valid_leaf_mask": self.is_valid_leaf_mask.cpu() if self.is_valid_leaf_mask is not None else None, |
|
|
|
'all_face_indices': self.all_face_indices.cpu(), |
|
|
|
'face_indices_mask':self.face_indices_mask.cpu() if self.face_indices_mask is not None else None, |
|
|
|
'surf_bbox': self.surf_bbox.cpu() if self.surf_bbox is not None else None, |
|
|
@ -389,10 +394,9 @@ class OctreeNode(nn.Module): |
|
|
|
) |
|
|
|
# 可以在这里设置其他不需要在 __init__ 中处理的属性 |
|
|
|
self.node_bboxes = state['node_bboxes'].to(self.device) if state['node_bboxes'] is not None else None |
|
|
|
self.parent_indices = state['parent_indices'].to(self.device) if state['parent_indices'] is not None else None |
|
|
|
self.child_indices = state['child_indices'].to(self.device) if state['child_indices'] is not None else None |
|
|
|
self.face_indices_mask = state['face_indices_mask'].to(self.device) if state['face_indices_mask'] is not None else None |
|
|
|
self.is_leaf_mask = state['is_leaf_mask'].to(self.device) if state['is_leaf_mask'] is not None else None |
|
|
|
self.is_valid_leaf_mask = state['is_valid_leaf_mask'].to(self.device) if state['is_valid_leaf_mask'] is not None else None |
|
|
|
|
|
|
|
|
|
|
|
def to(self, device=None, dtype=None, non_blocking=False): |
|
|
|