From 48755817c09920693b302bc2f5c00e7efd56de34 Mon Sep 17 00:00:00 2001 From: mckay Date: Thu, 17 Apr 2025 15:02:42 +0800 Subject: [PATCH] =?UTF-8?q?Revert=20"oct=20=E4=BF=AE=E6=94=B9=E4=B8=AD?= =?UTF-8?q?=E9=97=B4=E7=89=88=E6=9C=AC"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 5bd4a8d8662d694c0cf161dc630023bf0b5c806e. --- brep2sdf/IsoSurfacing.py | 26 +- brep2sdf/config/default_config.py | 2 +- brep2sdf/networks/feature_volume.py | 244 +++------------ brep2sdf/networks/octree.py | 448 ++++++++++------------------ brep2sdf/train.py | 8 +- 5 files changed, 209 insertions(+), 519 deletions(-) diff --git a/brep2sdf/IsoSurfacing.py b/brep2sdf/IsoSurfacing.py index faa9a97..5645f5c 100644 --- a/brep2sdf/IsoSurfacing.py +++ b/brep2sdf/IsoSurfacing.py @@ -5,8 +5,6 @@ from skimage import measure import time import trimesh -from brep2sdf.utils.logger import logger - def create_grid(depth, box_size): """ 创建三维网格点 @@ -123,7 +121,7 @@ def main(): # 设置设备 device = torch.device("cuda" if args.use_gpu and torch.cuda.is_available() else "cpu") - logger.info(f"Using device: {device}") + print(f"Using device: {device}") model = torch.jit.load(args.input).to(device) #model = torch.load(args.input).to(device) @@ -132,32 +130,32 @@ def main(): # 创建网格并预测SDF points, xx, yy, zz = create_grid(args.depth, args.box_size) sdf = predict_sdf(model, points, device) - logger.info(points.shape) - logger.info(sdf.shape) - logger.info(sdf) + print(points.shape) + print(sdf.shape) + print(sdf) sdf_grid = sdf.reshape(xx.shape) # 提取表面 - logger.info("Extracting surface...") + print("Extracting surface...") start_time = time.time() verts, faces = extract_surface(sdf_grid, xx, yy, zz, args.method) - logger.info(f"Surface extraction took {time.time() - start_time:.2f} seconds") + print(f"Surface extraction took {time.time() - start_time:.2f} seconds") # 保存网格 save_ply(verts, faces, args.output) - logger.info(f"Mesh saved to {args.output}") + print(f"Mesh saved to {args.output}") # 误差评估(可选) if args.compare: - logger.info("Computing SDF error...") + print("Computing SDF error...") gt_mesh = trimesh.load(args.compare) avg_abs, avg_rel, max_abs, max_rel = compute_sdf_error( model, gt_mesh, args.compres, device ) - logger.info(f"Average Absolute Error: {avg_abs:.4f}") - logger.info(f"Average Relative Error: {avg_rel:.4f}") - logger.info(f"Max Absolute Error: {max_abs:.4f}") - logger.info(f"Max Relative Error: {max_rel:.4f}") + print(f"Average Absolute Error: {avg_abs:.4f}") + print(f"Average Relative Error: {avg_rel:.4f}") + print(f"Max Absolute Error: {max_abs:.4f}") + print(f"Max Relative Error: {max_rel:.4f}") if __name__ == "__main__": main() \ No newline at end of file diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py index ae2759c..aac078e 100644 --- a/brep2sdf/config/default_config.py +++ b/brep2sdf/config/default_config.py @@ -47,7 +47,7 @@ class TrainConfig: # 基本训练参数 batch_size: int = 8 num_workers: int = 4 - num_epochs: int = 1 + num_epochs: int = 200 learning_rate: float = 0.01 min_lr: float = 1e-5 weight_decay: float = 0.01 diff --git a/brep2sdf/networks/feature_volume.py b/brep2sdf/networks/feature_volume.py index 3531ed6..f1bffd8 100644 --- a/brep2sdf/networks/feature_volume.py +++ b/brep2sdf/networks/feature_volume.py @@ -4,211 +4,51 @@ import torch import torch.nn as nn import torch.optim as optim -import numpy as np - -from brep2sdf.utils.logger import logger - class PatchFeatureVolume(nn.Module): - def __init__(self, bbox, resolution=64, feature_dim=64): + def __init__(self, bbox:np, resolution=64, feature_dim=64): super(PatchFeatureVolume, self).__init__() - # 统一转换为torch张量并确保在CPU初始化 - if isinstance(bbox, np.ndarray): - bbox = torch.from_numpy(bbox).float() - elif isinstance(bbox, torch.Tensor): - bbox = bbox.float() - else: - raise TypeError("bbox必须是np.ndarray或torch.Tensor类型") - - # 注册为buffer,自动处理设备 - self.register_buffer('bbox', bbox) + self.bbox = bbox # 补丁的边界框 + self.resolution = resolution # 网格分辨率 + self.feature_dim = feature_dim # 特征向量维度 - # 获取最小和最大坐标 - self.min_coords = bbox[:3] - self.max_coords = bbox[3:] - - # 处理分辨率参数 - if isinstance(resolution, int): - res_x = res_y = res_z = resolution - elif len(resolution) == 3: - res_x, res_y, res_z = resolution - else: - raise ValueError("resolution必须是整数或包含3个整数的元组/列表") - - self.resolution = (res_x, res_y, res_z) - self.feature_dim = feature_dim - - # 创建网格(使用min_coords和max_coords) - x = torch.linspace(self.bbox[0], self.bbox[3], res_x) - y = torch.linspace(self.bbox[1], self.bbox[4], res_y) - z = torch.linspace(self.bbox[2], self.bbox[5], res_z) - grid = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1) - self.register_buffer('grid', grid) + # 创建规则的三维网格 + x = torch.linspace(bbox[0][0], bbox[1][0], resolution) + y = torch.linspace(bbox[0][1], bbox[1][1], resolution) + z = torch.linspace(bbox[0][2], bbox[1][2], resolution) + grid_x, grid_y, grid_z = torch.meshgrid(x, y, z) + self.register_buffer('grid', torch.stack([grid_x, grid_y, grid_z], dim=-1)) - # 特征体积(自动继承模块的设备) - self.feature_volume = nn.Parameter( - torch.randn(res_x, res_y, res_z, feature_dim) - ) + # 初始化特征向量,作为可训练参数 + self.feature_volume = nn.Parameter(torch.randn(resolution, resolution, resolution, feature_dim)) - def forward(self, query_points: torch.Tensor): - # 自动设备对齐 - was_2d = query_points.dim() == 2 - if was_2d: - query_points = query_points.unsqueeze(0) # [N,3] -> [1,N,3] + def forward(self, query_points: List[Tuple[float, float, float]]): + """ + 根据查询点的位置,从补丁特征体积中获取插值后的特征向量。 - features = self.batched_trilinear_interpolation(query_points) # [B,N,D] - - # 恢复原始形状 - if was_2d: - features = features.squeeze(0) # [1,N,D] -> [N,D] - return features - - def batched_trilinear_interpolation(self, query_points: torch.Tensor): - B, N, _ = query_points.shape - - # 将查询点转换到网格坐标系 - x = (query_points[..., 0] - self.min_coords[0]) / (self.max_coords[0] - self.min_coords[0]) * (self.resolution[0] - 1) - y = (query_points[..., 1] - self.min_coords[1]) / (self.max_coords[1] - self.min_coords[1]) * (self.resolution[1] - 1) - z = (query_points[..., 2] - self.min_coords[2]) / (self.max_coords[2] - self.min_coords[2]) * (self.resolution[2] - 1) - - # 确保坐标在网格范围内(防止溢出) - x = torch.clamp(x, 0, self.resolution[0]-1e-5) - y = torch.clamp(y, 0, self.resolution[1]-1e-5) - z = torch.clamp(z, 0, self.resolution[2]-1e-5) - - # 分解为整数坐标和分数部分 - x0 = torch.floor(x).long() - x1 = x0 + 1 - x_frac = x - x0.float() - - y0 = torch.floor(y).long() - y1 = y0 + 1 - y_frac = y - y0.float() - - z0 = torch.floor(z).long() - z1 = z0 + 1 - z_frac = z - z0.float() - - # 处理边界情况(确保索引在合法范围内) - x0 = torch.clamp(x0, 0, self.resolution[0]-1) - x1 = torch.clamp(x1, 0, self.resolution[0]-1) - y0 = torch.clamp(y0, 0, self.resolution[1]-1) - y1 = torch.clamp(y1, 0, self.resolution[1]-1) - z0 = torch.clamp(z0, 0, self.resolution[2]-1) - z1 = torch.clamp(z1, 0, self.resolution[2]-1) - - # 将索引展平为1D张量以进行高效的特征提取 - x0_flat = x0.view(-1) - x1_flat = x1.view(-1) - y0_flat = y0.view(-1) - y1_flat = y1.view(-1) - z0_flat = z0.view(-1) - z1_flat = z1.view(-1) - - # 提取8个顶点的特征 - feat_0 = self.feature_volume[x0_flat, y0_flat, z0_flat] # (x0,y0,z0) - feat_1 = self.feature_volume[x1_flat, y0_flat, z0_flat] # (x1,y0,z0) - feat_2 = self.feature_volume[x0_flat, y1_flat, z0_flat] # (x0,y1,z0) - feat_3 = self.feature_volume[x1_flat, y1_flat, z0_flat] # (x1,y1,z0) - feat_4 = self.feature_volume[x0_flat, y0_flat, z1_flat] # (x0,y0,z1) - feat_5 = self.feature_volume[x1_flat, y0_flat, z1_flat] # (x1,y0,z1) - feat_6 = self.feature_volume[x0_flat, y1_flat, z1_flat] # (x0,y1,z1) - feat_7 = self.feature_volume[x1_flat, y1_flat, z1_flat] # (x1,y1,z1) - - # 将特征重塑为 [B, N, D] - D = self.feature_volume.shape[-1] - feat_0 = feat_0.view(B, N, D) - feat_1 = feat_1.view(B, N, D) - feat_2 = feat_2.view(B, N, D) - feat_3 = feat_3.view(B, N, D) - feat_4 = feat_4.view(B, N, D) - feat_5 = feat_5.view(B, N, D) - feat_6 = feat_6.view(B, N, D) - feat_7 = feat_7.view(B, N, D) - - # 计算各顶点的权重 - xw0 = (1 - x_frac).unsqueeze(-1) # [B, N, 1] - xw1 = x_frac.unsqueeze(-1) - yw0 = (1 - y_frac).unsqueeze(-1) - yw1 = y_frac.unsqueeze(-1) - zw0 = (1 - z_frac).unsqueeze(-1) - zw1 = z_frac.unsqueeze(-1) - - w0 = xw0 * yw0 * zw0 - w1 = xw1 * yw0 * zw0 - w2 = xw0 * yw1 * zw0 - w3 = xw1 * yw1 * zw0 - w4 = xw0 * yw0 * zw1 - w5 = xw1 * yw0 * zw1 - w6 = xw0 * yw1 * zw1 - w7 = xw1 * yw1 * zw1 - - # 加权求和得到最终特征 - output = ( - feat_0 * w0 + - feat_1 * w1 + - feat_2 * w2 + - feat_3 * w3 + - feat_4 * w4 + - feat_5 * w5 + - feat_6 * w6 + - feat_7 * w7 - ) - - return output - - - -def test_feature_volume(): - # 1. 准备测试数据 - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - print(f"Using device: {device}") - - # 定义包围盒 [min_x, min_y, min_z, max_x, max_y, max_z] - bbox = np.array([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], dtype=np.float32) - - # 2. 创建特征体积 (使用不同分辨率测试) - feature_vol = PatchFeatureVolume( - bbox=bbox, - resolution=(32, 64, 32), # 各维度不同分辨率 - feature_dim=64 - ).to(device) - - # 3. 生成测试点集 (包括边界点和内部点) - # 单个点测试 - test_point = torch.tensor([[0.5, 0.3, -0.2]], device=device, dtype=torch.float32) - - # 批量点测试 (包含边界情况) - test_points = torch.tensor([ - [0.0, 0.0, 0.0], # 中心点 - [1.0, 1.0, 1.0], # 最大边界 - [-1.0, -1.0, -1.0], # 最小边界 - [0.5, 0.5, 0.5], # 内部点 - [0.9, -0.8, 0.2] # 非对称点 - ], device=device, dtype=torch.float32) - - # 4. 执行查询 - # 测试单个点 - single_feature = feature_vol(test_point) - print(f"Single point feature shape: {single_feature.shape}") # 应为 [1, 64] - - # 测试批量点 - batch_features = feature_vol(test_points) - print(f"Batch features shape: {batch_features.shape}") # 应为 [5, 64] - - # 5. 验证结果 - assert not torch.isnan(single_feature).any(), "Output contains NaN values" - assert not torch.isnan(batch_features).any(), "Output contains NaN values" - assert single_feature.shape == (1, 64), "Single point output shape mismatch" - assert batch_features.shape == (5, 64), "Batch points output shape mismatch" - - # 6. 测试梯度计算 - test_point.requires_grad_(True) - feature = feature_vol(test_point) - loss = feature.sum() - loss.backward() - assert test_point.grad is not None, "Gradient not computed" - - print("All tests passed!") - -if __name__ == "__main__": - test_feature_volume() \ No newline at end of file + :param query_points: 查询点的位置坐标,形状为 (N, 3) + :return: 插值后的特征向量,形状为 (N, feature_dim) + """ + interpolated_features = torch.zeros(query_points.shape[0], self.feature_dim).to(self.feature_volume.device) + for i, point in enumerate(query_points): + interpolated_feature = self.trilinear_interpolation(point) + interpolated_features[i] = interpolated_feature + return interpolated_features + + def trilinear_interpolation(self, query_point): + """三线性插值""" + normalized_coords = ((query_point - torch.tensor(self.bbox[0]).to(self.grid.device)) / + (torch.tensor(self.bbox[1]).to(self.grid.device) - torch.tensor(self.bbox[0]).to(self.grid.device))) * (self.resolution - 1) + indices = torch.floor(normalized_coords).long() + weights = normalized_coords - indices.float() + + interpolated_feature = torch.zeros(self.feature_dim).to(self.feature_volume.device) + for di in range(2): + for dj in range(2): + for dk in range(2): + weight = (weights[0] if di == 1 else 1 - weights[0]) * \ + (weights[1] if dj == 1 else 1 - weights[1]) * \ + (weights[2] if dk == 1 else 1 - weights[2]) + index = indices + torch.tensor([di, dj, dk]).to(indices.device) + index = torch.clamp(index, 0, self.resolution - 1) + interpolated_feature += weight * self.feature_volume[index[0], index[1], index[2]] + return interpolated_feature \ No newline at end of file diff --git a/brep2sdf/networks/octree.py b/brep2sdf/networks/octree.py index aea8b19..4b2a308 100644 --- a/brep2sdf/networks/octree.py +++ b/brep2sdf/networks/octree.py @@ -1,129 +1,136 @@ -from typing import Tuple, List +from typing import Tuple, List, cast, Dict, Any, Tuple import torch import torch.nn as nn import torch.nn.functional as F import numpy as np -import pickle from brep2sdf.utils.logger import logger -def bbox_intersect(bbox1: np.ndarray, bbox2: np.ndarray) -> bool: +def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> bool: """判断两个轴对齐包围盒(AABB)是否相交 参数: - bbox1: 形状为 (6,) 的数组,格式 [min_x, min_y, min_z, max_x, max_y, max_z] + bbox1: 形状为 (6,) 的张量,格式 [min_x, min_y, min_z, max_x, max_y, max_z] bbox2: 同bbox1格式 返回: bool: 两包围盒是否相交(包括刚好接触的情况) """ - assert bbox1.shape == (6,) and bbox2.shape == (6,), "输入必须是形状为(6,)的数组" + assert bbox1.shape == (6,) and bbox2.shape == (6,), "输入必须是形状为(6,)的张量" # 提取min和max坐标 min1, max1 = bbox1[:3], bbox1[3:] min2, max2 = bbox2[:3], bbox2[3:] # 向量化比较 - return np.all((max1 >= min2) & (max2 >= min1)) + return torch.all((max1 >= min2) & (max2 >= min1)) - -class OctreeNode: - feature_dim = None +class OctreeNode(nn.Module): + device=None surf_bbox = None - - def __init__(self, bbox: np.ndarray, face_indices: np.ndarray, max_depth: int = 5, feature_dim: int = None, surf_bbox: np.ndarray = None): - """ - 初始化八叉树节点。 - :param bbox: 节点的边界框,格式为 [min_x, min_y, min_z, max_x, max_y, max_z] (形状为 (6,)) - :param face_indices: 当前节点包含的面索引数组 - :param max_depth: 八叉树的最大深度 - :param feature_dim: 特征维度(仅在叶子节点时使用) - :param surf_bbox: 面的包围盒数组,形状为 (N, 6) - """ + def __init__(self, bbox: torch.Tensor,face_indices: np.ndarray, max_depth: int = 5, surf_bbox:torch.Tensor = None): + super().__init__() self.bbox = bbox # 节点的边界框 self.max_depth: int = max_depth # 最大深度,当这个为0时,表示已经到达最大深度,不可再分子节点 - self.children: List['OctreeNode'] = [] # 子节点列表 + self.child_nodes: torch.nn.ModuleList = torch.nn.ModuleList() # 子节点列表 self.face_indices = face_indices + self.param_key = "" #self.patch_feature_volume = None # 补丁特征体积,only leaf has self._is_leaf = True + #print(f"box shape: {self.bbox.shape}") - if feature_dim is not None: - OctreeNode.feature_dim = feature_dim - if surf_bbox is not None: - if not isinstance(surf_bbox, np.ndarray): - raise TypeError(f"surf_bbox 必须是 numpy.ndarray 类型,但得到 {type(surf_bbox)}") - if surf_bbox.ndim != 2 or surf_bbox.shape[1] != 6: - raise ValueError(f"surf_bbox 应为二维数组且形状为 (N,6),但得到 {surf_bbox.shape}") - OctreeNode.surf_bbox = surf_bbox # NOTE: 只在根节点时创建 + if surf_bbox is not None: + if not isinstance(surf_bbox, torch.Tensor): + raise TypeError( + f"surf_bbox 必须是 torch.Tensor 类型,但得到 {type(surf_bbox)}" + ) + if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6: + raise ValueError( + f"surf_bbox 应为二维张量且形状为 (N,6),但得到 {surf_bbox.shape}" + ) + OctreeNode.surf_bbox = surf_bbox # NOTE: 只在根节点时创建 + OctreeNode.device = bbox.device def is_leaf(self): - # Check if self.children is None before calling len() + # Check if self.child——nodes is None before calling len() return self._is_leaf + def set_param_key(self, k): + self.param_key = k + def conduct_tree(self): - """ - 构建八叉树:如果达到最大深度或当前节点包含的面数小于等于2,则停止划分。 - """ - if self.max_depth <= 0 or len(self.face_indices) <= 2: + if self.max_depth <= 0 or self.face_indices.shape[0] <= 2: # 达到最大深度 or 一个单元格至多只有两个面 - #self.patch_feature_volume = np.random.randn(8, OctreeNode.feature_dim) - return + return self.subdivide() + def subdivide(self): - """ - 将当前节点划分为8个子节点,并分配相交的面。 - """ + + #min_x, min_y, min_z, max_x, max_y, max_z = self.bbox + # 使用索引操作替代解包 min_coords = self.bbox[:3] # [min_x, min_y, min_z] max_coords = self.bbox[3:] # [max_x, max_y, max_z] # 计算中间点 mid_coords = (min_coords + max_coords) / 2 - + # 提取 min_x, min_y, min_z, mid_x, mid_y, mid_z - min_x, min_y, min_z = min_coords - mid_x, mid_y, mid_z = mid_coords - max_x, max_y, max_z = max_coords + min_x, min_y, min_z = min_coords[0], min_coords[1], min_coords[2] + mid_x, mid_y, mid_z = mid_coords[0], mid_coords[1], mid_coords[2] + max_x, max_y, max_z = max_coords[0], max_coords[1], max_coords[2] # 生成 8 个子包围盒 - child_bboxes = np.array([ - [*min_coords, *mid_coords], # 前下左 - [mid_x, min_y, min_z, max_x, mid_y, mid_z], # 前下右 - [min_x, mid_y, min_z, mid_x, max_y, mid_z], # 前上左 - [mid_x, mid_y, min_z, max_x, max_y, mid_z], # 前上右 - [min_x, min_y, mid_z, mid_x, mid_y, max_z], # 后下左 - [mid_x, min_y, mid_z, max_x, mid_y, max_z], # 后下右 - [min_x, mid_y, mid_z, mid_x, max_y, max_z], # 后上左 - [mid_x, mid_y, mid_z, max_x, max_y, max_z] # 后上右 + child_bboxes = torch.stack([ + torch.cat([min_coords, mid_coords]), # 前下左 + torch.cat([torch.tensor([mid_x, min_y, min_z], device=self.bbox.device), + torch.tensor([max_x, mid_y, mid_z], device=self.bbox.device)]), # 前下右 + torch.cat([torch.tensor([min_x, mid_y, min_z], device=self.bbox.device), + torch.tensor([mid_x, max_y, mid_z], device=self.bbox.device)]), # 前上左 + torch.cat([torch.tensor([mid_x, mid_y, min_z], device=self.bbox.device), + torch.tensor([max_x, max_y, mid_z], device=self.bbox.device)]), # 前上右 + torch.cat([torch.tensor([min_x, min_y, mid_z], device=self.bbox.device), + torch.tensor([mid_x, mid_y, max_z], device=self.bbox.device)]), # 后下左 + torch.cat([torch.tensor([mid_x, min_y, mid_z], device=self.bbox.device), + torch.tensor([max_x, mid_y, max_z], device=self.bbox.device)]), # 后下右 + torch.cat([torch.tensor([min_x, mid_y, mid_z], device=self.bbox.device), + torch.tensor([mid_x, max_y, max_z], device=self.bbox.device)]), # 后上左 + torch.cat([torch.tensor([mid_x, mid_y, mid_z], device=self.bbox.device), + torch.tensor([max_x, max_y, max_z], device=self.bbox.device)]) # 后上右 ]) - + # 为每个子包围盒创建子节点,并分配相交的面 - self.children = [] for bbox in child_bboxes: # 找到与子包围盒相交的面 - intersecting_faces = [ - face_idx for face_idx in self.face_indices - if bbox_intersect(bbox, OctreeNode.surf_bbox[face_idx]) - ] + intersecting_faces = [] + for face_idx in self.face_indices: + face_bbox = OctreeNode.surf_bbox[face_idx] + if bbox_intersect(bbox, face_bbox): + intersecting_faces.append(face_idx) + #print(f"{bbox}: {intersecting_faces}") + child_node = OctreeNode( bbox=bbox, face_indices=np.array(intersecting_faces), max_depth=self.max_depth - 1 ) child_node.conduct_tree() - self.children.append(child_node) - + self.child_nodes.append(child_node) + self._is_leaf = False - def get_child_index(self, query_point: np.ndarray) -> int: + def get_child_index(self, query_point: torch.Tensor) -> int: """ - 计算点所在子节点的索引。 + 计算点所在子节点的索引 :param query_point: 待检查的点,格式为 (x, y, z) :return: 子节点的索引,范围从 0 到 7 """ + # 确保 query_point 和 bbox 在同一设备上 + query_point = query_point.to(self.bbox.device) + # 提取 bbox 的最小和最大坐标 min_coords = self.bbox[:3] # [min_x, min_y, min_z] max_coords = self.bbox[3:] # [max_x, max_y, max_z] @@ -132,11 +139,11 @@ class OctreeNode: mid_coords = (min_coords + max_coords) / 2 # 使用布尔比较结果计算索引 - index = ((query_point >= mid_coords) << np.arange(3)).sum() + index = ((query_point >= mid_coords) << torch.arange(3, device=self.bbox.device)).sum() - return index + return index.item() - def find_leaf(self, query_point: np.ndarray) -> np.ndarray: + def find_leaf(self, query_point: torch.Tensor) -> Tuple[torch.Tensor, str, bool]: """ 查找包含给定点的叶子节点,并返回其信息(以元组形式) :param query_point: 待查找的点 @@ -145,22 +152,66 @@ class OctreeNode: # 如果当前节点是叶子节点,返回其信息 if self._is_leaf: #logger.info(f"{self.bbox}, {self.param_key}, {True}") - return self.face_indices + return (self.bbox, self.param_key, True) # 计算查询点所在的子节点索引 index = self.get_child_index(query_point) - try: - # 直接访问子节点,不进行显式检查 - return self.children[index].find_leaf(query_point) - except IndexError as e: - # 记录错误日志并重新抛出异常 - logger.error( - f"Error accessing child node: {e}. " - f"Query point: {query_point.tolist()}, " - f"Node bbox: {self.bbox.tolist()}, " - f"Depth info: {self.max_depth}" - ) - raise e + + # 遍历子节点列表,找到对应的子节点 + for i, child_node in enumerate(self.child_nodes): + if i == index and child_node is not None: + # 递归调用子节点的 find_leaf 方法 + result = child_node.find_leaf(query_point) + + # 确保返回值是一个元组 + assert isinstance(result, tuple), f"Unexpected return type: {type(result)}" + return result + + # 如果找不到有效的子节点,抛出异常 + raise IndexError(f"Invalid child node index: {index}") + ''' + try: + # 直接访问子节点,不进行显式检查 + return self.child_nodes[index].find_leaf(query_point) + except IndexError as e: + # 记录错误日志并重新抛出异常 + logger.error( + f"Error accessing child node: {e}. " + f"Query point: {query_point.cpu().numpy().tolist()}, " + f"Node bbox: {self.bbox.cpu().numpy().tolist()}, " + f"Depth info: {self.max_depth}" + ) + raise e + ''' + + ''' + def get_feature_vector(self, query_point:torch.Tensor): + """ + 预测给定点的 SDF 值 + :param point: 待预测的点,格式为 (x, y, z) + :return: 预测的 SDF 值 + """ + # 将点转换为 numpy 数组 + + # 从根节点开始递归查找包含该点的叶子节点 + if self._is_leaf: + return self.trilinear_interpolation(query_point) + else: + index = self.get_child_index(query_point) + try: + # 直接访问子节点,不进行显式检查 + return self.child_nodes[index].get_feature_vector(query_point) + except IndexError as e: + # 记录错误日志并重新抛出异常 + logger.error( + f"Error accessing child node: {e}. " + f"Query point: {query_point.cpu().numpy().tolist()}, " + f"Node bbox: {self.bbox.cpu().numpy().tolist()}, " + f"Depth info: {self.max_depth}" + ) + raise e + ''' + @@ -178,236 +229,37 @@ class OctreeNode: # 打印当前节点信息 indent = " " * depth node_type = "Leaf" if self._is_leaf else "Internal" - print(f"{indent}L{depth} [{node_type}] BBox: {self.bbox.tolist()}") + print(f"{indent}L{depth} [{node_type}] BBox: {self.bbox.cpu().numpy().tolist()}") # 打印面片信息(如果有) if self.face_indices is not None: print(f"{indent} Face indices: {self.face_indices.tolist()}") - print(f"{indent} len children: {len(self.children)}") + print(f"{indent} len child_nodes: {len(self.child_nodes)}") # 递归打印子节点 - for i, child in enumerate(self.children): + for i, child in enumerate(self.child_nodes): print(f"{indent} Child {i}:") child.print_tree(depth + 1, max_print_depth) -# 保存 - - def save_tree_to_file(self, file_path: str): - """ - 将八叉树保存到文件 - :param file_path: 要保存的文件路径 - """ - - # 获取完整状态字典 - state = self.state_dict() - - # 添加类级别的静态变量 - state['feature_dim'] = OctreeNode.feature_dim - state['surf_bbox'] = OctreeNode.surf_bbox - - # 保存到文件 - with open(file_path, 'wb') as f: - pickle.dump(state, f) - - print(f"八叉树已成功保存到 {file_path}") - - @classmethod - def load_tree_from_file(cls, file_path: str) -> 'OctreeNode': - """ - 从文件加载八叉树 - :param file_path: 要加载的文件路径 - :return: 恢复的八叉树根节点 - """ - - - with open(file_path, 'rb') as f: - state = pickle.load(f) - - # 恢复类级别的静态变量 - cls.feature_dim = state.pop('feature_dim') - cls.surf_bbox = state.pop('surf_bbox') - - # 创建根节点 - root = cls( - bbox=state['bbox'], - face_indices=state['face_indices'], - max_depth=state['max_depth'] - ) - - # 加载状态 - root.load_state_dict(state) + def __getstate__(self): + """支持pickle序列化""" + return self._serialize_node(self) + + def __setstate__(self, state): + """支持pickle反序列化""" + self = self._deserialize_node(state) - print(f"八叉树已从 {file_path} 成功加载") - return root - def state_dict(self): - """返回节点及其子树的state_dict""" - state = { - 'bbox': self.bbox, - 'max_depth': self.max_depth, - 'face_indices': self.face_indices, - 'is_leaf': self._is_leaf + def _serialize_node(self, node): + return { + 'bbox': node.bbox, + 'is_leaf': node._is_leaf, + 'child_nodes': [self._serialize_node(c) for c in node.child_nodes], + 'param_key': node.param_key } - - if self._is_leaf: - pass - else: - state['children'] = [child.state_dict() for child in self.children] - - return state - - def load_state_dict(self, state_dict): - """从state_dict加载节点状态""" - self.bbox = state_dict['bbox'] - self.max_depth = state_dict['max_depth'] - self.face_indices = state_dict['face_indices'] - self._is_leaf = state_dict['is_leaf'] - - if self._is_leaf: - return - else: - self.children = [] - for child_state in state_dict['children']: - child = OctreeNode( - bbox=child_state['bbox'], - face_indices=child_state['face_indices'], - max_depth=child_state['max_depth'] - ) - child.load_state_dict(child_state) - self.children.append(child) - - - - -def test_octree(): - # 1. 测试bbox_intersect函数 - print("测试bbox_intersect函数...") - bbox1 = np.array([0, 0, 0, 1, 1, 1]) - bbox2 = np.array([0.5, 0.5, 0.5, 1.5, 1.5, 1.5]) - assert bbox_intersect(bbox1, bbox2), "相交测试失败" - - bbox3 = np.array([2, 2, 2, 3, 3, 3]) - assert not bbox_intersect(bbox1, bbox3), "不相交测试失败" - print("bbox_intersect测试通过!\n") - - # 2. 创建测试用的面包围盒 - # 假设有4个面,每个面有一个包围盒 - surf_bbox = np.array([ - [0.1, 0.1, 0.1, 0.4, 0.4, 0.4], # 面0 - [0.6, 0.1, 0.1, 0.9, 0.4, 0.4], # 面1 - [0.1, 0.6, 0.1, 0.4, 0.9, 0.4], # 面2 - [0.6, 0.6, 0.1, 0.9, 0.9, 0.4] # 面3 - ]) - # 3. 创建根节点 - root_bbox = np.array([0, 0, 0, 1, 1, 1]) - face_indices = np.arange(len(surf_bbox)) # 初始包含所有面 - root = OctreeNode( - bbox=root_bbox, - face_indices=face_indices, - max_depth=2, - feature_dim=32, - surf_bbox=surf_bbox - ) - - # 4. 构建八叉树 - root.conduct_tree() - - # 5. 打印树结构(只打印前2层) - print("八叉树结构:") - root.print_tree(max_print_depth=2) - - # 6. 测试子节点索引计算 - print("\n测试子节点索引计算...") - test_points = [ - ([0.25, 0.25, 0.25], "应在前下左子节点"), - ([0.75, 0.25, 0.25], "应在前下右子节点"), - ([0.25, 0.75, 0.25], "应在前上左子节点"), - ([0.75, 0.75, 0.25], "应在前上右子节点") - ] - - for point, desc in test_points: - idx = root.get_child_index(np.array(point)) - print(f"点 {point} {desc}, 计算得到的索引: {idx}") - - # 7. 验证叶子节点特征 - print("\n验证叶子节点特征:") - for i, child in enumerate(root.children): - if child.is_leaf(): - print(f"子节点 {i} 是叶子节点,") - else: - print(f"子节点 {i} 不是叶子节点") - - print("\n所有测试完成!") - -# ... existing code ... - -def test_octree_save_load(): - print("\n测试八叉树的保存和加载功能...") - - # 1. 创建测试数据 - surf_bbox = np.array([ - [0.1, 0.1, 0.1, 0.4, 0.4, 0.4], # 面0 - [0.6, 0.1, 0.1, 0.9, 0.4, 0.4], # 面1 - [0.1, 0.6, 0.1, 0.4, 0.9, 0.4], # 面2 - [0.6, 0.6, 0.1, 0.9, 0.9, 0.4] # 面3 - ]) - - # 2. 创建原始树 - root = OctreeNode( - bbox=np.array([0, 0, 0, 1, 1, 1]), - face_indices=np.arange(len(surf_bbox)), - max_depth=2, - feature_dim=32, - surf_bbox=surf_bbox - ) - root.conduct_tree() - - # 3. 保存树状态 - test_file = 'test_octree.pkl' - root.save_tree_to_file(test_file) - - # 4. 从文件加载树 - new_root = OctreeNode.load_tree_from_file(test_file) - print("树状态加载成功!") - - # 5. 验证加载后的树结构 - print("\n验证加载后的树结构:") - - # 5.1 验证根节点属性 - assert np.allclose(root.bbox, new_root.bbox), "bbox不匹配" - assert root.max_depth == new_root.max_depth, "max_depth不匹配" - assert np.array_equal(root.face_indices, new_root.face_indices), "face_indices不匹配" - assert root._is_leaf == new_root._is_leaf, "is_leaf不匹配" - print("根节点属性验证通过!") - - # 5.2 验证叶子节点特征 - if root._is_leaf: - #assert np.allclose(root.patch_feature_volume, new_root.patch_feature_volume), "特征不匹配" - print("叶子节点特征验证通过!") - else: - # 递归验证子节点 - for i, (orig_child, new_child) in enumerate(zip(root.children, new_root.children)): - print(f"\n验证子节点 {i}:") - assert np.allclose(orig_child.bbox, new_child.bbox), f"子节点{i} bbox不匹配" - assert orig_child.max_depth == new_child.max_depth, f"子节点{i} max_depth不匹配" - assert np.array_equal(orig_child.face_indices, new_child.face_indices), f"子节点{i} face_indices不匹配" - assert orig_child._is_leaf == new_child._is_leaf, f"子节点{i} is_leaf不匹配" - - if orig_child._is_leaf: - #assert np.allclose(orig_child.patch_feature_volume, new_child.patch_feature_volume), f"子节点{i} 特征不匹配" - print(f"子节点{i} 叶子节点特征验证通过!") - else: - print(f"子节点{i} 是非叶子节点,继续验证其子节点...") - - # 6. 打印部分树结构对比 - print("\n原始树结构(前2层):") - root.print_tree(max_print_depth=2) - - print("\n加载后的树结构(前2层):") - new_root.print_tree(max_print_depth=2) - - print("\n八叉树保存和加载测试全部通过!") - -if __name__ == "__main__": - test_octree() # 运行基本功能测试 - test_octree_save_load() # 运行保存加载测试 + def _deserialize_node(self, data): + node = OctreeNode(data['bbox'], 0) # max_depth会在encoder中重建 + node._is_leaf = data['is_leaf'] + node.param_key = data['param_key'] + node.child_nodes = [self._deserialize_node(c) for c in data['child_nodes']] + return node \ No newline at end of file diff --git a/brep2sdf/train.py b/brep2sdf/train.py index b846f8b..474f856 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -91,7 +91,6 @@ class Trainer: # 将曲面点云列表转换为 (N*M, 4) 数组 surfs = self.data["surf_ncs"] - #logger.debug(self.data['faceEdge_adj'].shape) self.sdf_data = prepare_sdf_data( surfs, normals = self.data["surf_pnt_normals"], @@ -112,7 +111,7 @@ class Trainer: ) - self.build_tree(surf_bbox=self.data['surf_bbox_ncs'], max_depth=4) + self.build_tree(surf_bbox=surf_bbox, max_depth=4) self.model = Net( @@ -279,8 +278,9 @@ class Trainer: def _tracing_model(self): """保存模型""" self.model.eval() - self.root.save_tree_to_file(f"/home/wch/brep2sdf/data/output_data/{self.model_name}_tree.pkl") - torch.save(self.model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt") + # 确保模型中的所有逻辑都兼容 TorchScript + scripted_model = torch.jit.script(self.model) + torch.jit.save(scripted_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt") def _load_checkpoint(self, checkpoint_path): """从检查点恢复训练状态"""