diff --git a/brep2sdf/IsoSurfacing.py b/brep2sdf/IsoSurfacing.py index 5645f5c..faa9a97 100644 --- a/brep2sdf/IsoSurfacing.py +++ b/brep2sdf/IsoSurfacing.py @@ -5,6 +5,8 @@ from skimage import measure import time import trimesh +from brep2sdf.utils.logger import logger + def create_grid(depth, box_size): """ 创建三维网格点 @@ -121,7 +123,7 @@ def main(): # 设置设备 device = torch.device("cuda" if args.use_gpu and torch.cuda.is_available() else "cpu") - print(f"Using device: {device}") + logger.info(f"Using device: {device}") model = torch.jit.load(args.input).to(device) #model = torch.load(args.input).to(device) @@ -130,32 +132,32 @@ def main(): # 创建网格并预测SDF points, xx, yy, zz = create_grid(args.depth, args.box_size) sdf = predict_sdf(model, points, device) - print(points.shape) - print(sdf.shape) - print(sdf) + logger.info(points.shape) + logger.info(sdf.shape) + logger.info(sdf) sdf_grid = sdf.reshape(xx.shape) # 提取表面 - print("Extracting surface...") + logger.info("Extracting surface...") start_time = time.time() verts, faces = extract_surface(sdf_grid, xx, yy, zz, args.method) - print(f"Surface extraction took {time.time() - start_time:.2f} seconds") + logger.info(f"Surface extraction took {time.time() - start_time:.2f} seconds") # 保存网格 save_ply(verts, faces, args.output) - print(f"Mesh saved to {args.output}") + logger.info(f"Mesh saved to {args.output}") # 误差评估(可选) if args.compare: - print("Computing SDF error...") + logger.info("Computing SDF error...") gt_mesh = trimesh.load(args.compare) avg_abs, avg_rel, max_abs, max_rel = compute_sdf_error( model, gt_mesh, args.compres, device ) - print(f"Average Absolute Error: {avg_abs:.4f}") - print(f"Average Relative Error: {avg_rel:.4f}") - print(f"Max Absolute Error: {max_abs:.4f}") - print(f"Max Relative Error: {max_rel:.4f}") + logger.info(f"Average Absolute Error: {avg_abs:.4f}") + logger.info(f"Average Relative Error: {avg_rel:.4f}") + logger.info(f"Max Absolute Error: {max_abs:.4f}") + logger.info(f"Max Relative Error: {max_rel:.4f}") if __name__ == "__main__": main() \ No newline at end of file diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py index aac078e..ae2759c 100644 --- a/brep2sdf/config/default_config.py +++ b/brep2sdf/config/default_config.py @@ -47,7 +47,7 @@ class TrainConfig: # 基本训练参数 batch_size: int = 8 num_workers: int = 4 - num_epochs: int = 200 + num_epochs: int = 1 learning_rate: float = 0.01 min_lr: float = 1e-5 weight_decay: float = 0.01 diff --git a/brep2sdf/networks/feature_volume.py b/brep2sdf/networks/feature_volume.py index f1bffd8..3531ed6 100644 --- a/brep2sdf/networks/feature_volume.py +++ b/brep2sdf/networks/feature_volume.py @@ -4,51 +4,211 @@ import torch import torch.nn as nn import torch.optim as optim +import numpy as np + +from brep2sdf.utils.logger import logger + class PatchFeatureVolume(nn.Module): - def __init__(self, bbox:np, resolution=64, feature_dim=64): + def __init__(self, bbox, resolution=64, feature_dim=64): super(PatchFeatureVolume, self).__init__() - self.bbox = bbox # 补丁的边界框 - self.resolution = resolution # 网格分辨率 - self.feature_dim = feature_dim # 特征向量维度 + # 统一转换为torch张量并确保在CPU初始化 + if isinstance(bbox, np.ndarray): + bbox = torch.from_numpy(bbox).float() + elif isinstance(bbox, torch.Tensor): + bbox = bbox.float() + else: + raise TypeError("bbox必须是np.ndarray或torch.Tensor类型") + + # 注册为buffer,自动处理设备 + self.register_buffer('bbox', bbox) - # 创建规则的三维网格 - x = torch.linspace(bbox[0][0], bbox[1][0], resolution) - y = torch.linspace(bbox[0][1], bbox[1][1], resolution) - z = torch.linspace(bbox[0][2], bbox[1][2], resolution) - grid_x, grid_y, grid_z = torch.meshgrid(x, y, z) - self.register_buffer('grid', torch.stack([grid_x, grid_y, grid_z], dim=-1)) + # 获取最小和最大坐标 + self.min_coords = bbox[:3] + self.max_coords = bbox[3:] + + # 处理分辨率参数 + if isinstance(resolution, int): + res_x = res_y = res_z = resolution + elif len(resolution) == 3: + res_x, res_y, res_z = resolution + else: + raise ValueError("resolution必须是整数或包含3个整数的元组/列表") + + self.resolution = (res_x, res_y, res_z) + self.feature_dim = feature_dim + + # 创建网格(使用min_coords和max_coords) + x = torch.linspace(self.bbox[0], self.bbox[3], res_x) + y = torch.linspace(self.bbox[1], self.bbox[4], res_y) + z = torch.linspace(self.bbox[2], self.bbox[5], res_z) + grid = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1) + self.register_buffer('grid', grid) - # 初始化特征向量,作为可训练参数 - self.feature_volume = nn.Parameter(torch.randn(resolution, resolution, resolution, feature_dim)) + # 特征体积(自动继承模块的设备) + self.feature_volume = nn.Parameter( + torch.randn(res_x, res_y, res_z, feature_dim) + ) - def forward(self, query_points: List[Tuple[float, float, float]]): - """ - 根据查询点的位置,从补丁特征体积中获取插值后的特征向量。 + def forward(self, query_points: torch.Tensor): + # 自动设备对齐 + was_2d = query_points.dim() == 2 + if was_2d: + query_points = query_points.unsqueeze(0) # [N,3] -> [1,N,3] - :param query_points: 查询点的位置坐标,形状为 (N, 3) - :return: 插值后的特征向量,形状为 (N, feature_dim) - """ - interpolated_features = torch.zeros(query_points.shape[0], self.feature_dim).to(self.feature_volume.device) - for i, point in enumerate(query_points): - interpolated_feature = self.trilinear_interpolation(point) - interpolated_features[i] = interpolated_feature - return interpolated_features - - def trilinear_interpolation(self, query_point): - """三线性插值""" - normalized_coords = ((query_point - torch.tensor(self.bbox[0]).to(self.grid.device)) / - (torch.tensor(self.bbox[1]).to(self.grid.device) - torch.tensor(self.bbox[0]).to(self.grid.device))) * (self.resolution - 1) - indices = torch.floor(normalized_coords).long() - weights = normalized_coords - indices.float() - - interpolated_feature = torch.zeros(self.feature_dim).to(self.feature_volume.device) - for di in range(2): - for dj in range(2): - for dk in range(2): - weight = (weights[0] if di == 1 else 1 - weights[0]) * \ - (weights[1] if dj == 1 else 1 - weights[1]) * \ - (weights[2] if dk == 1 else 1 - weights[2]) - index = indices + torch.tensor([di, dj, dk]).to(indices.device) - index = torch.clamp(index, 0, self.resolution - 1) - interpolated_feature += weight * self.feature_volume[index[0], index[1], index[2]] - return interpolated_feature \ No newline at end of file + features = self.batched_trilinear_interpolation(query_points) # [B,N,D] + + # 恢复原始形状 + if was_2d: + features = features.squeeze(0) # [1,N,D] -> [N,D] + return features + + def batched_trilinear_interpolation(self, query_points: torch.Tensor): + B, N, _ = query_points.shape + + # 将查询点转换到网格坐标系 + x = (query_points[..., 0] - self.min_coords[0]) / (self.max_coords[0] - self.min_coords[0]) * (self.resolution[0] - 1) + y = (query_points[..., 1] - self.min_coords[1]) / (self.max_coords[1] - self.min_coords[1]) * (self.resolution[1] - 1) + z = (query_points[..., 2] - self.min_coords[2]) / (self.max_coords[2] - self.min_coords[2]) * (self.resolution[2] - 1) + + # 确保坐标在网格范围内(防止溢出) + x = torch.clamp(x, 0, self.resolution[0]-1e-5) + y = torch.clamp(y, 0, self.resolution[1]-1e-5) + z = torch.clamp(z, 0, self.resolution[2]-1e-5) + + # 分解为整数坐标和分数部分 + x0 = torch.floor(x).long() + x1 = x0 + 1 + x_frac = x - x0.float() + + y0 = torch.floor(y).long() + y1 = y0 + 1 + y_frac = y - y0.float() + + z0 = torch.floor(z).long() + z1 = z0 + 1 + z_frac = z - z0.float() + + # 处理边界情况(确保索引在合法范围内) + x0 = torch.clamp(x0, 0, self.resolution[0]-1) + x1 = torch.clamp(x1, 0, self.resolution[0]-1) + y0 = torch.clamp(y0, 0, self.resolution[1]-1) + y1 = torch.clamp(y1, 0, self.resolution[1]-1) + z0 = torch.clamp(z0, 0, self.resolution[2]-1) + z1 = torch.clamp(z1, 0, self.resolution[2]-1) + + # 将索引展平为1D张量以进行高效的特征提取 + x0_flat = x0.view(-1) + x1_flat = x1.view(-1) + y0_flat = y0.view(-1) + y1_flat = y1.view(-1) + z0_flat = z0.view(-1) + z1_flat = z1.view(-1) + + # 提取8个顶点的特征 + feat_0 = self.feature_volume[x0_flat, y0_flat, z0_flat] # (x0,y0,z0) + feat_1 = self.feature_volume[x1_flat, y0_flat, z0_flat] # (x1,y0,z0) + feat_2 = self.feature_volume[x0_flat, y1_flat, z0_flat] # (x0,y1,z0) + feat_3 = self.feature_volume[x1_flat, y1_flat, z0_flat] # (x1,y1,z0) + feat_4 = self.feature_volume[x0_flat, y0_flat, z1_flat] # (x0,y0,z1) + feat_5 = self.feature_volume[x1_flat, y0_flat, z1_flat] # (x1,y0,z1) + feat_6 = self.feature_volume[x0_flat, y1_flat, z1_flat] # (x0,y1,z1) + feat_7 = self.feature_volume[x1_flat, y1_flat, z1_flat] # (x1,y1,z1) + + # 将特征重塑为 [B, N, D] + D = self.feature_volume.shape[-1] + feat_0 = feat_0.view(B, N, D) + feat_1 = feat_1.view(B, N, D) + feat_2 = feat_2.view(B, N, D) + feat_3 = feat_3.view(B, N, D) + feat_4 = feat_4.view(B, N, D) + feat_5 = feat_5.view(B, N, D) + feat_6 = feat_6.view(B, N, D) + feat_7 = feat_7.view(B, N, D) + + # 计算各顶点的权重 + xw0 = (1 - x_frac).unsqueeze(-1) # [B, N, 1] + xw1 = x_frac.unsqueeze(-1) + yw0 = (1 - y_frac).unsqueeze(-1) + yw1 = y_frac.unsqueeze(-1) + zw0 = (1 - z_frac).unsqueeze(-1) + zw1 = z_frac.unsqueeze(-1) + + w0 = xw0 * yw0 * zw0 + w1 = xw1 * yw0 * zw0 + w2 = xw0 * yw1 * zw0 + w3 = xw1 * yw1 * zw0 + w4 = xw0 * yw0 * zw1 + w5 = xw1 * yw0 * zw1 + w6 = xw0 * yw1 * zw1 + w7 = xw1 * yw1 * zw1 + + # 加权求和得到最终特征 + output = ( + feat_0 * w0 + + feat_1 * w1 + + feat_2 * w2 + + feat_3 * w3 + + feat_4 * w4 + + feat_5 * w5 + + feat_6 * w6 + + feat_7 * w7 + ) + + return output + + + +def test_feature_volume(): + # 1. 准备测试数据 + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + + # 定义包围盒 [min_x, min_y, min_z, max_x, max_y, max_z] + bbox = np.array([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], dtype=np.float32) + + # 2. 创建特征体积 (使用不同分辨率测试) + feature_vol = PatchFeatureVolume( + bbox=bbox, + resolution=(32, 64, 32), # 各维度不同分辨率 + feature_dim=64 + ).to(device) + + # 3. 生成测试点集 (包括边界点和内部点) + # 单个点测试 + test_point = torch.tensor([[0.5, 0.3, -0.2]], device=device, dtype=torch.float32) + + # 批量点测试 (包含边界情况) + test_points = torch.tensor([ + [0.0, 0.0, 0.0], # 中心点 + [1.0, 1.0, 1.0], # 最大边界 + [-1.0, -1.0, -1.0], # 最小边界 + [0.5, 0.5, 0.5], # 内部点 + [0.9, -0.8, 0.2] # 非对称点 + ], device=device, dtype=torch.float32) + + # 4. 执行查询 + # 测试单个点 + single_feature = feature_vol(test_point) + print(f"Single point feature shape: {single_feature.shape}") # 应为 [1, 64] + + # 测试批量点 + batch_features = feature_vol(test_points) + print(f"Batch features shape: {batch_features.shape}") # 应为 [5, 64] + + # 5. 验证结果 + assert not torch.isnan(single_feature).any(), "Output contains NaN values" + assert not torch.isnan(batch_features).any(), "Output contains NaN values" + assert single_feature.shape == (1, 64), "Single point output shape mismatch" + assert batch_features.shape == (5, 64), "Batch points output shape mismatch" + + # 6. 测试梯度计算 + test_point.requires_grad_(True) + feature = feature_vol(test_point) + loss = feature.sum() + loss.backward() + assert test_point.grad is not None, "Gradient not computed" + + print("All tests passed!") + +if __name__ == "__main__": + test_feature_volume() \ No newline at end of file diff --git a/brep2sdf/networks/octree.py b/brep2sdf/networks/octree.py index 4b2a308..aea8b19 100644 --- a/brep2sdf/networks/octree.py +++ b/brep2sdf/networks/octree.py @@ -1,136 +1,129 @@ -from typing import Tuple, List, cast, Dict, Any, Tuple +from typing import Tuple, List import torch import torch.nn as nn import torch.nn.functional as F import numpy as np +import pickle from brep2sdf.utils.logger import logger -def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> bool: +def bbox_intersect(bbox1: np.ndarray, bbox2: np.ndarray) -> bool: """判断两个轴对齐包围盒(AABB)是否相交 参数: - bbox1: 形状为 (6,) 的张量,格式 [min_x, min_y, min_z, max_x, max_y, max_z] + bbox1: 形状为 (6,) 的数组,格式 [min_x, min_y, min_z, max_x, max_y, max_z] bbox2: 同bbox1格式 返回: bool: 两包围盒是否相交(包括刚好接触的情况) """ - assert bbox1.shape == (6,) and bbox2.shape == (6,), "输入必须是形状为(6,)的张量" + assert bbox1.shape == (6,) and bbox2.shape == (6,), "输入必须是形状为(6,)的数组" # 提取min和max坐标 min1, max1 = bbox1[:3], bbox1[3:] min2, max2 = bbox2[:3], bbox2[3:] # 向量化比较 - return torch.all((max1 >= min2) & (max2 >= min1)) + return np.all((max1 >= min2) & (max2 >= min1)) -class OctreeNode(nn.Module): - device=None + +class OctreeNode: + feature_dim = None surf_bbox = None - def __init__(self, bbox: torch.Tensor,face_indices: np.ndarray, max_depth: int = 5, surf_bbox:torch.Tensor = None): - super().__init__() + + def __init__(self, bbox: np.ndarray, face_indices: np.ndarray, max_depth: int = 5, feature_dim: int = None, surf_bbox: np.ndarray = None): + """ + 初始化八叉树节点。 + :param bbox: 节点的边界框,格式为 [min_x, min_y, min_z, max_x, max_y, max_z] (形状为 (6,)) + :param face_indices: 当前节点包含的面索引数组 + :param max_depth: 八叉树的最大深度 + :param feature_dim: 特征维度(仅在叶子节点时使用) + :param surf_bbox: 面的包围盒数组,形状为 (N, 6) + """ self.bbox = bbox # 节点的边界框 self.max_depth: int = max_depth # 最大深度,当这个为0时,表示已经到达最大深度,不可再分子节点 - self.child_nodes: torch.nn.ModuleList = torch.nn.ModuleList() # 子节点列表 + self.children: List['OctreeNode'] = [] # 子节点列表 self.face_indices = face_indices - self.param_key = "" #self.patch_feature_volume = None # 补丁特征体积,only leaf has self._is_leaf = True - #print(f"box shape: {self.bbox.shape}") - if surf_bbox is not None: - if not isinstance(surf_bbox, torch.Tensor): - raise TypeError( - f"surf_bbox 必须是 torch.Tensor 类型,但得到 {type(surf_bbox)}" - ) - if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6: - raise ValueError( - f"surf_bbox 应为二维张量且形状为 (N,6),但得到 {surf_bbox.shape}" - ) - OctreeNode.surf_bbox = surf_bbox # NOTE: 只在根节点时创建 - OctreeNode.device = bbox.device + if feature_dim is not None: + OctreeNode.feature_dim = feature_dim + if surf_bbox is not None: + if not isinstance(surf_bbox, np.ndarray): + raise TypeError(f"surf_bbox 必须是 numpy.ndarray 类型,但得到 {type(surf_bbox)}") + if surf_bbox.ndim != 2 or surf_bbox.shape[1] != 6: + raise ValueError(f"surf_bbox 应为二维数组且形状为 (N,6),但得到 {surf_bbox.shape}") + OctreeNode.surf_bbox = surf_bbox # NOTE: 只在根节点时创建 def is_leaf(self): - # Check if self.child——nodes is None before calling len() + # Check if self.children is None before calling len() return self._is_leaf - def set_param_key(self, k): - self.param_key = k - def conduct_tree(self): - if self.max_depth <= 0 or self.face_indices.shape[0] <= 2: + """ + 构建八叉树:如果达到最大深度或当前节点包含的面数小于等于2,则停止划分。 + """ + if self.max_depth <= 0 or len(self.face_indices) <= 2: # 达到最大深度 or 一个单元格至多只有两个面 - return + #self.patch_feature_volume = np.random.randn(8, OctreeNode.feature_dim) + return self.subdivide() - def subdivide(self): - - #min_x, min_y, min_z, max_x, max_y, max_z = self.bbox - # 使用索引操作替代解包 + """ + 将当前节点划分为8个子节点,并分配相交的面。 + """ min_coords = self.bbox[:3] # [min_x, min_y, min_z] max_coords = self.bbox[3:] # [max_x, max_y, max_z] # 计算中间点 mid_coords = (min_coords + max_coords) / 2 - + # 提取 min_x, min_y, min_z, mid_x, mid_y, mid_z - min_x, min_y, min_z = min_coords[0], min_coords[1], min_coords[2] - mid_x, mid_y, mid_z = mid_coords[0], mid_coords[1], mid_coords[2] - max_x, max_y, max_z = max_coords[0], max_coords[1], max_coords[2] + min_x, min_y, min_z = min_coords + mid_x, mid_y, mid_z = mid_coords + max_x, max_y, max_z = max_coords # 生成 8 个子包围盒 - child_bboxes = torch.stack([ - torch.cat([min_coords, mid_coords]), # 前下左 - torch.cat([torch.tensor([mid_x, min_y, min_z], device=self.bbox.device), - torch.tensor([max_x, mid_y, mid_z], device=self.bbox.device)]), # 前下右 - torch.cat([torch.tensor([min_x, mid_y, min_z], device=self.bbox.device), - torch.tensor([mid_x, max_y, mid_z], device=self.bbox.device)]), # 前上左 - torch.cat([torch.tensor([mid_x, mid_y, min_z], device=self.bbox.device), - torch.tensor([max_x, max_y, mid_z], device=self.bbox.device)]), # 前上右 - torch.cat([torch.tensor([min_x, min_y, mid_z], device=self.bbox.device), - torch.tensor([mid_x, mid_y, max_z], device=self.bbox.device)]), # 后下左 - torch.cat([torch.tensor([mid_x, min_y, mid_z], device=self.bbox.device), - torch.tensor([max_x, mid_y, max_z], device=self.bbox.device)]), # 后下右 - torch.cat([torch.tensor([min_x, mid_y, mid_z], device=self.bbox.device), - torch.tensor([mid_x, max_y, max_z], device=self.bbox.device)]), # 后上左 - torch.cat([torch.tensor([mid_x, mid_y, mid_z], device=self.bbox.device), - torch.tensor([max_x, max_y, max_z], device=self.bbox.device)]) # 后上右 + child_bboxes = np.array([ + [*min_coords, *mid_coords], # 前下左 + [mid_x, min_y, min_z, max_x, mid_y, mid_z], # 前下右 + [min_x, mid_y, min_z, mid_x, max_y, mid_z], # 前上左 + [mid_x, mid_y, min_z, max_x, max_y, mid_z], # 前上右 + [min_x, min_y, mid_z, mid_x, mid_y, max_z], # 后下左 + [mid_x, min_y, mid_z, max_x, mid_y, max_z], # 后下右 + [min_x, mid_y, mid_z, mid_x, max_y, max_z], # 后上左 + [mid_x, mid_y, mid_z, max_x, max_y, max_z] # 后上右 ]) - + # 为每个子包围盒创建子节点,并分配相交的面 + self.children = [] for bbox in child_bboxes: # 找到与子包围盒相交的面 - intersecting_faces = [] - for face_idx in self.face_indices: - face_bbox = OctreeNode.surf_bbox[face_idx] - if bbox_intersect(bbox, face_bbox): - intersecting_faces.append(face_idx) - #print(f"{bbox}: {intersecting_faces}") - + intersecting_faces = [ + face_idx for face_idx in self.face_indices + if bbox_intersect(bbox, OctreeNode.surf_bbox[face_idx]) + ] child_node = OctreeNode( bbox=bbox, face_indices=np.array(intersecting_faces), max_depth=self.max_depth - 1 ) child_node.conduct_tree() - self.child_nodes.append(child_node) - + self.children.append(child_node) + self._is_leaf = False - def get_child_index(self, query_point: torch.Tensor) -> int: + def get_child_index(self, query_point: np.ndarray) -> int: """ - 计算点所在子节点的索引 + 计算点所在子节点的索引。 :param query_point: 待检查的点,格式为 (x, y, z) :return: 子节点的索引,范围从 0 到 7 """ - # 确保 query_point 和 bbox 在同一设备上 - query_point = query_point.to(self.bbox.device) - # 提取 bbox 的最小和最大坐标 min_coords = self.bbox[:3] # [min_x, min_y, min_z] max_coords = self.bbox[3:] # [max_x, max_y, max_z] @@ -139,11 +132,11 @@ class OctreeNode(nn.Module): mid_coords = (min_coords + max_coords) / 2 # 使用布尔比较结果计算索引 - index = ((query_point >= mid_coords) << torch.arange(3, device=self.bbox.device)).sum() + index = ((query_point >= mid_coords) << np.arange(3)).sum() - return index.item() + return index - def find_leaf(self, query_point: torch.Tensor) -> Tuple[torch.Tensor, str, bool]: + def find_leaf(self, query_point: np.ndarray) -> np.ndarray: """ 查找包含给定点的叶子节点,并返回其信息(以元组形式) :param query_point: 待查找的点 @@ -152,66 +145,22 @@ class OctreeNode(nn.Module): # 如果当前节点是叶子节点,返回其信息 if self._is_leaf: #logger.info(f"{self.bbox}, {self.param_key}, {True}") - return (self.bbox, self.param_key, True) + return self.face_indices # 计算查询点所在的子节点索引 index = self.get_child_index(query_point) - - # 遍历子节点列表,找到对应的子节点 - for i, child_node in enumerate(self.child_nodes): - if i == index and child_node is not None: - # 递归调用子节点的 find_leaf 方法 - result = child_node.find_leaf(query_point) - - # 确保返回值是一个元组 - assert isinstance(result, tuple), f"Unexpected return type: {type(result)}" - return result - - # 如果找不到有效的子节点,抛出异常 - raise IndexError(f"Invalid child node index: {index}") - ''' - try: - # 直接访问子节点,不进行显式检查 - return self.child_nodes[index].find_leaf(query_point) - except IndexError as e: - # 记录错误日志并重新抛出异常 - logger.error( - f"Error accessing child node: {e}. " - f"Query point: {query_point.cpu().numpy().tolist()}, " - f"Node bbox: {self.bbox.cpu().numpy().tolist()}, " - f"Depth info: {self.max_depth}" - ) - raise e - ''' - - ''' - def get_feature_vector(self, query_point:torch.Tensor): - """ - 预测给定点的 SDF 值 - :param point: 待预测的点,格式为 (x, y, z) - :return: 预测的 SDF 值 - """ - # 将点转换为 numpy 数组 - - # 从根节点开始递归查找包含该点的叶子节点 - if self._is_leaf: - return self.trilinear_interpolation(query_point) - else: - index = self.get_child_index(query_point) - try: - # 直接访问子节点,不进行显式检查 - return self.child_nodes[index].get_feature_vector(query_point) - except IndexError as e: - # 记录错误日志并重新抛出异常 - logger.error( - f"Error accessing child node: {e}. " - f"Query point: {query_point.cpu().numpy().tolist()}, " - f"Node bbox: {self.bbox.cpu().numpy().tolist()}, " - f"Depth info: {self.max_depth}" - ) - raise e - ''' - + try: + # 直接访问子节点,不进行显式检查 + return self.children[index].find_leaf(query_point) + except IndexError as e: + # 记录错误日志并重新抛出异常 + logger.error( + f"Error accessing child node: {e}. " + f"Query point: {query_point.tolist()}, " + f"Node bbox: {self.bbox.tolist()}, " + f"Depth info: {self.max_depth}" + ) + raise e @@ -229,37 +178,236 @@ class OctreeNode(nn.Module): # 打印当前节点信息 indent = " " * depth node_type = "Leaf" if self._is_leaf else "Internal" - print(f"{indent}L{depth} [{node_type}] BBox: {self.bbox.cpu().numpy().tolist()}") + print(f"{indent}L{depth} [{node_type}] BBox: {self.bbox.tolist()}") # 打印面片信息(如果有) if self.face_indices is not None: print(f"{indent} Face indices: {self.face_indices.tolist()}") - print(f"{indent} len child_nodes: {len(self.child_nodes)}") + print(f"{indent} len children: {len(self.children)}") # 递归打印子节点 - for i, child in enumerate(self.child_nodes): + for i, child in enumerate(self.children): print(f"{indent} Child {i}:") child.print_tree(depth + 1, max_print_depth) - def __getstate__(self): - """支持pickle序列化""" - return self._serialize_node(self) - - def __setstate__(self, state): - """支持pickle反序列化""" - self = self._deserialize_node(state) +# 保存 + + def save_tree_to_file(self, file_path: str): + """ + 将八叉树保存到文件 + :param file_path: 要保存的文件路径 + """ + + # 获取完整状态字典 + state = self.state_dict() + + # 添加类级别的静态变量 + state['feature_dim'] = OctreeNode.feature_dim + state['surf_bbox'] = OctreeNode.surf_bbox + + # 保存到文件 + with open(file_path, 'wb') as f: + pickle.dump(state, f) + + print(f"八叉树已成功保存到 {file_path}") + + @classmethod + def load_tree_from_file(cls, file_path: str) -> 'OctreeNode': + """ + 从文件加载八叉树 + :param file_path: 要加载的文件路径 + :return: 恢复的八叉树根节点 + """ + + + with open(file_path, 'rb') as f: + state = pickle.load(f) + + # 恢复类级别的静态变量 + cls.feature_dim = state.pop('feature_dim') + cls.surf_bbox = state.pop('surf_bbox') + + # 创建根节点 + root = cls( + bbox=state['bbox'], + face_indices=state['face_indices'], + max_depth=state['max_depth'] + ) + + # 加载状态 + root.load_state_dict(state) - def _serialize_node(self, node): - return { - 'bbox': node.bbox, - 'is_leaf': node._is_leaf, - 'child_nodes': [self._serialize_node(c) for c in node.child_nodes], - 'param_key': node.param_key + print(f"八叉树已从 {file_path} 成功加载") + return root + def state_dict(self): + """返回节点及其子树的state_dict""" + state = { + 'bbox': self.bbox, + 'max_depth': self.max_depth, + 'face_indices': self.face_indices, + 'is_leaf': self._is_leaf } + + if self._is_leaf: + pass + else: + state['children'] = [child.state_dict() for child in self.children] + + return state + + def load_state_dict(self, state_dict): + """从state_dict加载节点状态""" + self.bbox = state_dict['bbox'] + self.max_depth = state_dict['max_depth'] + self.face_indices = state_dict['face_indices'] + self._is_leaf = state_dict['is_leaf'] + + if self._is_leaf: + return + else: + self.children = [] + for child_state in state_dict['children']: + child = OctreeNode( + bbox=child_state['bbox'], + face_indices=child_state['face_indices'], + max_depth=child_state['max_depth'] + ) + child.load_state_dict(child_state) + self.children.append(child) + + + + +def test_octree(): + # 1. 测试bbox_intersect函数 + print("测试bbox_intersect函数...") + bbox1 = np.array([0, 0, 0, 1, 1, 1]) + bbox2 = np.array([0.5, 0.5, 0.5, 1.5, 1.5, 1.5]) + assert bbox_intersect(bbox1, bbox2), "相交测试失败" + + bbox3 = np.array([2, 2, 2, 3, 3, 3]) + assert not bbox_intersect(bbox1, bbox3), "不相交测试失败" + print("bbox_intersect测试通过!\n") + + # 2. 创建测试用的面包围盒 + # 假设有4个面,每个面有一个包围盒 + surf_bbox = np.array([ + [0.1, 0.1, 0.1, 0.4, 0.4, 0.4], # 面0 + [0.6, 0.1, 0.1, 0.9, 0.4, 0.4], # 面1 + [0.1, 0.6, 0.1, 0.4, 0.9, 0.4], # 面2 + [0.6, 0.6, 0.1, 0.9, 0.9, 0.4] # 面3 + ]) - def _deserialize_node(self, data): - node = OctreeNode(data['bbox'], 0) # max_depth会在encoder中重建 - node._is_leaf = data['is_leaf'] - node.param_key = data['param_key'] - node.child_nodes = [self._deserialize_node(c) for c in data['child_nodes']] - return node \ No newline at end of file + # 3. 创建根节点 + root_bbox = np.array([0, 0, 0, 1, 1, 1]) + face_indices = np.arange(len(surf_bbox)) # 初始包含所有面 + root = OctreeNode( + bbox=root_bbox, + face_indices=face_indices, + max_depth=2, + feature_dim=32, + surf_bbox=surf_bbox + ) + + # 4. 构建八叉树 + root.conduct_tree() + + # 5. 打印树结构(只打印前2层) + print("八叉树结构:") + root.print_tree(max_print_depth=2) + + # 6. 测试子节点索引计算 + print("\n测试子节点索引计算...") + test_points = [ + ([0.25, 0.25, 0.25], "应在前下左子节点"), + ([0.75, 0.25, 0.25], "应在前下右子节点"), + ([0.25, 0.75, 0.25], "应在前上左子节点"), + ([0.75, 0.75, 0.25], "应在前上右子节点") + ] + + for point, desc in test_points: + idx = root.get_child_index(np.array(point)) + print(f"点 {point} {desc}, 计算得到的索引: {idx}") + + # 7. 验证叶子节点特征 + print("\n验证叶子节点特征:") + for i, child in enumerate(root.children): + if child.is_leaf(): + print(f"子节点 {i} 是叶子节点,") + else: + print(f"子节点 {i} 不是叶子节点") + + print("\n所有测试完成!") + +# ... existing code ... + +def test_octree_save_load(): + print("\n测试八叉树的保存和加载功能...") + + # 1. 创建测试数据 + surf_bbox = np.array([ + [0.1, 0.1, 0.1, 0.4, 0.4, 0.4], # 面0 + [0.6, 0.1, 0.1, 0.9, 0.4, 0.4], # 面1 + [0.1, 0.6, 0.1, 0.4, 0.9, 0.4], # 面2 + [0.6, 0.6, 0.1, 0.9, 0.9, 0.4] # 面3 + ]) + + # 2. 创建原始树 + root = OctreeNode( + bbox=np.array([0, 0, 0, 1, 1, 1]), + face_indices=np.arange(len(surf_bbox)), + max_depth=2, + feature_dim=32, + surf_bbox=surf_bbox + ) + root.conduct_tree() + + # 3. 保存树状态 + test_file = 'test_octree.pkl' + root.save_tree_to_file(test_file) + + # 4. 从文件加载树 + new_root = OctreeNode.load_tree_from_file(test_file) + print("树状态加载成功!") + + # 5. 验证加载后的树结构 + print("\n验证加载后的树结构:") + + # 5.1 验证根节点属性 + assert np.allclose(root.bbox, new_root.bbox), "bbox不匹配" + assert root.max_depth == new_root.max_depth, "max_depth不匹配" + assert np.array_equal(root.face_indices, new_root.face_indices), "face_indices不匹配" + assert root._is_leaf == new_root._is_leaf, "is_leaf不匹配" + print("根节点属性验证通过!") + + # 5.2 验证叶子节点特征 + if root._is_leaf: + #assert np.allclose(root.patch_feature_volume, new_root.patch_feature_volume), "特征不匹配" + print("叶子节点特征验证通过!") + else: + # 递归验证子节点 + for i, (orig_child, new_child) in enumerate(zip(root.children, new_root.children)): + print(f"\n验证子节点 {i}:") + assert np.allclose(orig_child.bbox, new_child.bbox), f"子节点{i} bbox不匹配" + assert orig_child.max_depth == new_child.max_depth, f"子节点{i} max_depth不匹配" + assert np.array_equal(orig_child.face_indices, new_child.face_indices), f"子节点{i} face_indices不匹配" + assert orig_child._is_leaf == new_child._is_leaf, f"子节点{i} is_leaf不匹配" + + if orig_child._is_leaf: + #assert np.allclose(orig_child.patch_feature_volume, new_child.patch_feature_volume), f"子节点{i} 特征不匹配" + print(f"子节点{i} 叶子节点特征验证通过!") + else: + print(f"子节点{i} 是非叶子节点,继续验证其子节点...") + + # 6. 打印部分树结构对比 + print("\n原始树结构(前2层):") + root.print_tree(max_print_depth=2) + + print("\n加载后的树结构(前2层):") + new_root.print_tree(max_print_depth=2) + + print("\n八叉树保存和加载测试全部通过!") + +if __name__ == "__main__": + test_octree() # 运行基本功能测试 + test_octree_save_load() # 运行保存加载测试 diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 474f856..b846f8b 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -91,6 +91,7 @@ class Trainer: # 将曲面点云列表转换为 (N*M, 4) 数组 surfs = self.data["surf_ncs"] + #logger.debug(self.data['faceEdge_adj'].shape) self.sdf_data = prepare_sdf_data( surfs, normals = self.data["surf_pnt_normals"], @@ -111,7 +112,7 @@ class Trainer: ) - self.build_tree(surf_bbox=surf_bbox, max_depth=4) + self.build_tree(surf_bbox=self.data['surf_bbox_ncs'], max_depth=4) self.model = Net( @@ -278,9 +279,8 @@ class Trainer: def _tracing_model(self): """保存模型""" self.model.eval() - # 确保模型中的所有逻辑都兼容 TorchScript - scripted_model = torch.jit.script(self.model) - torch.jit.save(scripted_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt") + self.root.save_tree_to_file(f"/home/wch/brep2sdf/data/output_data/{self.model_name}_tree.pkl") + torch.save(self.model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt") def _load_checkpoint(self, checkpoint_path): """从检查点恢复训练状态"""