Browse Source

Revert "oct 修改中间版本"

This reverts commit 5bd4a8d866.
final
mckay 2 months ago
parent
commit
48755817c0
  1. 26
      brep2sdf/IsoSurfacing.py
  2. 2
      brep2sdf/config/default_config.py
  3. 244
      brep2sdf/networks/feature_volume.py
  4. 448
      brep2sdf/networks/octree.py
  5. 8
      brep2sdf/train.py

26
brep2sdf/IsoSurfacing.py

@ -5,8 +5,6 @@ from skimage import measure
import time import time
import trimesh import trimesh
from brep2sdf.utils.logger import logger
def create_grid(depth, box_size): def create_grid(depth, box_size):
""" """
创建三维网格点 创建三维网格点
@ -123,7 +121,7 @@ def main():
# 设置设备 # 设置设备
device = torch.device("cuda" if args.use_gpu and torch.cuda.is_available() else "cpu") device = torch.device("cuda" if args.use_gpu and torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}") print(f"Using device: {device}")
model = torch.jit.load(args.input).to(device) model = torch.jit.load(args.input).to(device)
#model = torch.load(args.input).to(device) #model = torch.load(args.input).to(device)
@ -132,32 +130,32 @@ def main():
# 创建网格并预测SDF # 创建网格并预测SDF
points, xx, yy, zz = create_grid(args.depth, args.box_size) points, xx, yy, zz = create_grid(args.depth, args.box_size)
sdf = predict_sdf(model, points, device) sdf = predict_sdf(model, points, device)
logger.info(points.shape) print(points.shape)
logger.info(sdf.shape) print(sdf.shape)
logger.info(sdf) print(sdf)
sdf_grid = sdf.reshape(xx.shape) sdf_grid = sdf.reshape(xx.shape)
# 提取表面 # 提取表面
logger.info("Extracting surface...") print("Extracting surface...")
start_time = time.time() start_time = time.time()
verts, faces = extract_surface(sdf_grid, xx, yy, zz, args.method) verts, faces = extract_surface(sdf_grid, xx, yy, zz, args.method)
logger.info(f"Surface extraction took {time.time() - start_time:.2f} seconds") print(f"Surface extraction took {time.time() - start_time:.2f} seconds")
# 保存网格 # 保存网格
save_ply(verts, faces, args.output) save_ply(verts, faces, args.output)
logger.info(f"Mesh saved to {args.output}") print(f"Mesh saved to {args.output}")
# 误差评估(可选) # 误差评估(可选)
if args.compare: if args.compare:
logger.info("Computing SDF error...") print("Computing SDF error...")
gt_mesh = trimesh.load(args.compare) gt_mesh = trimesh.load(args.compare)
avg_abs, avg_rel, max_abs, max_rel = compute_sdf_error( avg_abs, avg_rel, max_abs, max_rel = compute_sdf_error(
model, gt_mesh, args.compres, device model, gt_mesh, args.compres, device
) )
logger.info(f"Average Absolute Error: {avg_abs:.4f}") print(f"Average Absolute Error: {avg_abs:.4f}")
logger.info(f"Average Relative Error: {avg_rel:.4f}") print(f"Average Relative Error: {avg_rel:.4f}")
logger.info(f"Max Absolute Error: {max_abs:.4f}") print(f"Max Absolute Error: {max_abs:.4f}")
logger.info(f"Max Relative Error: {max_rel:.4f}") print(f"Max Relative Error: {max_rel:.4f}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

2
brep2sdf/config/default_config.py

@ -47,7 +47,7 @@ class TrainConfig:
# 基本训练参数 # 基本训练参数
batch_size: int = 8 batch_size: int = 8
num_workers: int = 4 num_workers: int = 4
num_epochs: int = 1 num_epochs: int = 200
learning_rate: float = 0.01 learning_rate: float = 0.01
min_lr: float = 1e-5 min_lr: float = 1e-5
weight_decay: float = 0.01 weight_decay: float = 0.01

244
brep2sdf/networks/feature_volume.py

@ -4,211 +4,51 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import numpy as np
from brep2sdf.utils.logger import logger
class PatchFeatureVolume(nn.Module): class PatchFeatureVolume(nn.Module):
def __init__(self, bbox, resolution=64, feature_dim=64): def __init__(self, bbox:np, resolution=64, feature_dim=64):
super(PatchFeatureVolume, self).__init__() super(PatchFeatureVolume, self).__init__()
# 统一转换为torch张量并确保在CPU初始化 self.bbox = bbox # 补丁的边界框
if isinstance(bbox, np.ndarray): self.resolution = resolution # 网格分辨率
bbox = torch.from_numpy(bbox).float() self.feature_dim = feature_dim # 特征向量维度
elif isinstance(bbox, torch.Tensor):
bbox = bbox.float()
else:
raise TypeError("bbox必须是np.ndarray或torch.Tensor类型")
# 注册为buffer,自动处理设备
self.register_buffer('bbox', bbox)
# 获取最小和最大坐标 # 创建规则的三维网格
self.min_coords = bbox[:3] x = torch.linspace(bbox[0][0], bbox[1][0], resolution)
self.max_coords = bbox[3:] y = torch.linspace(bbox[0][1], bbox[1][1], resolution)
z = torch.linspace(bbox[0][2], bbox[1][2], resolution)
# 处理分辨率参数 grid_x, grid_y, grid_z = torch.meshgrid(x, y, z)
if isinstance(resolution, int): self.register_buffer('grid', torch.stack([grid_x, grid_y, grid_z], dim=-1))
res_x = res_y = res_z = resolution
elif len(resolution) == 3:
res_x, res_y, res_z = resolution
else:
raise ValueError("resolution必须是整数或包含3个整数的元组/列表")
self.resolution = (res_x, res_y, res_z)
self.feature_dim = feature_dim
# 创建网格(使用min_coords和max_coords)
x = torch.linspace(self.bbox[0], self.bbox[3], res_x)
y = torch.linspace(self.bbox[1], self.bbox[4], res_y)
z = torch.linspace(self.bbox[2], self.bbox[5], res_z)
grid = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1)
self.register_buffer('grid', grid)
# 特征体积(自动继承模块的设备) # 初始化特征向量,作为可训练参数
self.feature_volume = nn.Parameter( self.feature_volume = nn.Parameter(torch.randn(resolution, resolution, resolution, feature_dim))
torch.randn(res_x, res_y, res_z, feature_dim)
)
def forward(self, query_points: torch.Tensor): def forward(self, query_points: List[Tuple[float, float, float]]):
# 自动设备对齐 """
was_2d = query_points.dim() == 2 根据查询点的位置从补丁特征体积中获取插值后的特征向量
if was_2d:
query_points = query_points.unsqueeze(0) # [N,3] -> [1,N,3]
features = self.batched_trilinear_interpolation(query_points) # [B,N,D] :param query_points: 查询点的位置坐标形状为 (N, 3)
:return: 插值后的特征向量形状为 (N, feature_dim)
# 恢复原始形状 """
if was_2d: interpolated_features = torch.zeros(query_points.shape[0], self.feature_dim).to(self.feature_volume.device)
features = features.squeeze(0) # [1,N,D] -> [N,D] for i, point in enumerate(query_points):
return features interpolated_feature = self.trilinear_interpolation(point)
interpolated_features[i] = interpolated_feature
def batched_trilinear_interpolation(self, query_points: torch.Tensor): return interpolated_features
B, N, _ = query_points.shape
def trilinear_interpolation(self, query_point):
# 将查询点转换到网格坐标系 """三线性插值"""
x = (query_points[..., 0] - self.min_coords[0]) / (self.max_coords[0] - self.min_coords[0]) * (self.resolution[0] - 1) normalized_coords = ((query_point - torch.tensor(self.bbox[0]).to(self.grid.device)) /
y = (query_points[..., 1] - self.min_coords[1]) / (self.max_coords[1] - self.min_coords[1]) * (self.resolution[1] - 1) (torch.tensor(self.bbox[1]).to(self.grid.device) - torch.tensor(self.bbox[0]).to(self.grid.device))) * (self.resolution - 1)
z = (query_points[..., 2] - self.min_coords[2]) / (self.max_coords[2] - self.min_coords[2]) * (self.resolution[2] - 1) indices = torch.floor(normalized_coords).long()
weights = normalized_coords - indices.float()
# 确保坐标在网格范围内(防止溢出)
x = torch.clamp(x, 0, self.resolution[0]-1e-5) interpolated_feature = torch.zeros(self.feature_dim).to(self.feature_volume.device)
y = torch.clamp(y, 0, self.resolution[1]-1e-5) for di in range(2):
z = torch.clamp(z, 0, self.resolution[2]-1e-5) for dj in range(2):
for dk in range(2):
# 分解为整数坐标和分数部分 weight = (weights[0] if di == 1 else 1 - weights[0]) * \
x0 = torch.floor(x).long() (weights[1] if dj == 1 else 1 - weights[1]) * \
x1 = x0 + 1 (weights[2] if dk == 1 else 1 - weights[2])
x_frac = x - x0.float() index = indices + torch.tensor([di, dj, dk]).to(indices.device)
index = torch.clamp(index, 0, self.resolution - 1)
y0 = torch.floor(y).long() interpolated_feature += weight * self.feature_volume[index[0], index[1], index[2]]
y1 = y0 + 1 return interpolated_feature
y_frac = y - y0.float()
z0 = torch.floor(z).long()
z1 = z0 + 1
z_frac = z - z0.float()
# 处理边界情况(确保索引在合法范围内)
x0 = torch.clamp(x0, 0, self.resolution[0]-1)
x1 = torch.clamp(x1, 0, self.resolution[0]-1)
y0 = torch.clamp(y0, 0, self.resolution[1]-1)
y1 = torch.clamp(y1, 0, self.resolution[1]-1)
z0 = torch.clamp(z0, 0, self.resolution[2]-1)
z1 = torch.clamp(z1, 0, self.resolution[2]-1)
# 将索引展平为1D张量以进行高效的特征提取
x0_flat = x0.view(-1)
x1_flat = x1.view(-1)
y0_flat = y0.view(-1)
y1_flat = y1.view(-1)
z0_flat = z0.view(-1)
z1_flat = z1.view(-1)
# 提取8个顶点的特征
feat_0 = self.feature_volume[x0_flat, y0_flat, z0_flat] # (x0,y0,z0)
feat_1 = self.feature_volume[x1_flat, y0_flat, z0_flat] # (x1,y0,z0)
feat_2 = self.feature_volume[x0_flat, y1_flat, z0_flat] # (x0,y1,z0)
feat_3 = self.feature_volume[x1_flat, y1_flat, z0_flat] # (x1,y1,z0)
feat_4 = self.feature_volume[x0_flat, y0_flat, z1_flat] # (x0,y0,z1)
feat_5 = self.feature_volume[x1_flat, y0_flat, z1_flat] # (x1,y0,z1)
feat_6 = self.feature_volume[x0_flat, y1_flat, z1_flat] # (x0,y1,z1)
feat_7 = self.feature_volume[x1_flat, y1_flat, z1_flat] # (x1,y1,z1)
# 将特征重塑为 [B, N, D]
D = self.feature_volume.shape[-1]
feat_0 = feat_0.view(B, N, D)
feat_1 = feat_1.view(B, N, D)
feat_2 = feat_2.view(B, N, D)
feat_3 = feat_3.view(B, N, D)
feat_4 = feat_4.view(B, N, D)
feat_5 = feat_5.view(B, N, D)
feat_6 = feat_6.view(B, N, D)
feat_7 = feat_7.view(B, N, D)
# 计算各顶点的权重
xw0 = (1 - x_frac).unsqueeze(-1) # [B, N, 1]
xw1 = x_frac.unsqueeze(-1)
yw0 = (1 - y_frac).unsqueeze(-1)
yw1 = y_frac.unsqueeze(-1)
zw0 = (1 - z_frac).unsqueeze(-1)
zw1 = z_frac.unsqueeze(-1)
w0 = xw0 * yw0 * zw0
w1 = xw1 * yw0 * zw0
w2 = xw0 * yw1 * zw0
w3 = xw1 * yw1 * zw0
w4 = xw0 * yw0 * zw1
w5 = xw1 * yw0 * zw1
w6 = xw0 * yw1 * zw1
w7 = xw1 * yw1 * zw1
# 加权求和得到最终特征
output = (
feat_0 * w0 +
feat_1 * w1 +
feat_2 * w2 +
feat_3 * w3 +
feat_4 * w4 +
feat_5 * w5 +
feat_6 * w6 +
feat_7 * w7
)
return output
def test_feature_volume():
# 1. 准备测试数据
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# 定义包围盒 [min_x, min_y, min_z, max_x, max_y, max_z]
bbox = np.array([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], dtype=np.float32)
# 2. 创建特征体积 (使用不同分辨率测试)
feature_vol = PatchFeatureVolume(
bbox=bbox,
resolution=(32, 64, 32), # 各维度不同分辨率
feature_dim=64
).to(device)
# 3. 生成测试点集 (包括边界点和内部点)
# 单个点测试
test_point = torch.tensor([[0.5, 0.3, -0.2]], device=device, dtype=torch.float32)
# 批量点测试 (包含边界情况)
test_points = torch.tensor([
[0.0, 0.0, 0.0], # 中心点
[1.0, 1.0, 1.0], # 最大边界
[-1.0, -1.0, -1.0], # 最小边界
[0.5, 0.5, 0.5], # 内部点
[0.9, -0.8, 0.2] # 非对称点
], device=device, dtype=torch.float32)
# 4. 执行查询
# 测试单个点
single_feature = feature_vol(test_point)
print(f"Single point feature shape: {single_feature.shape}") # 应为 [1, 64]
# 测试批量点
batch_features = feature_vol(test_points)
print(f"Batch features shape: {batch_features.shape}") # 应为 [5, 64]
# 5. 验证结果
assert not torch.isnan(single_feature).any(), "Output contains NaN values"
assert not torch.isnan(batch_features).any(), "Output contains NaN values"
assert single_feature.shape == (1, 64), "Single point output shape mismatch"
assert batch_features.shape == (5, 64), "Batch points output shape mismatch"
# 6. 测试梯度计算
test_point.requires_grad_(True)
feature = feature_vol(test_point)
loss = feature.sum()
loss.backward()
assert test_point.grad is not None, "Gradient not computed"
print("All tests passed!")
if __name__ == "__main__":
test_feature_volume()

448
brep2sdf/networks/octree.py

@ -1,129 +1,136 @@
from typing import Tuple, List from typing import Tuple, List, cast, Dict, Any, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
import pickle
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
def bbox_intersect(bbox1: np.ndarray, bbox2: np.ndarray) -> bool: def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> bool:
"""判断两个轴对齐包围盒(AABB)是否相交 """判断两个轴对齐包围盒(AABB)是否相交
参数: 参数:
bbox1: 形状为 (6,) 数组格式 [min_x, min_y, min_z, max_x, max_y, max_z] bbox1: 形状为 (6,) 张量格式 [min_x, min_y, min_z, max_x, max_y, max_z]
bbox2: 同bbox1格式 bbox2: 同bbox1格式
返回: 返回:
bool: 两包围盒是否相交(包括刚好接触的情况) bool: 两包围盒是否相交(包括刚好接触的情况)
""" """
assert bbox1.shape == (6,) and bbox2.shape == (6,), "输入必须是形状为(6,)的数组" assert bbox1.shape == (6,) and bbox2.shape == (6,), "输入必须是形状为(6,)的张量"
# 提取min和max坐标 # 提取min和max坐标
min1, max1 = bbox1[:3], bbox1[3:] min1, max1 = bbox1[:3], bbox1[3:]
min2, max2 = bbox2[:3], bbox2[3:] min2, max2 = bbox2[:3], bbox2[3:]
# 向量化比较 # 向量化比较
return np.all((max1 >= min2) & (max2 >= min1)) return torch.all((max1 >= min2) & (max2 >= min1))
class OctreeNode(nn.Module):
class OctreeNode: device=None
feature_dim = None
surf_bbox = None surf_bbox = None
def __init__(self, bbox: torch.Tensor,face_indices: np.ndarray, max_depth: int = 5, surf_bbox:torch.Tensor = None):
def __init__(self, bbox: np.ndarray, face_indices: np.ndarray, max_depth: int = 5, feature_dim: int = None, surf_bbox: np.ndarray = None): super().__init__()
"""
初始化八叉树节点
:param bbox: 节点的边界框格式为 [min_x, min_y, min_z, max_x, max_y, max_z] (形状为 (6,))
:param face_indices: 当前节点包含的面索引数组
:param max_depth: 八叉树的最大深度
:param feature_dim: 特征维度仅在叶子节点时使用
:param surf_bbox: 面的包围盒数组形状为 (N, 6)
"""
self.bbox = bbox # 节点的边界框 self.bbox = bbox # 节点的边界框
self.max_depth: int = max_depth # 最大深度,当这个为0时,表示已经到达最大深度,不可再分子节点 self.max_depth: int = max_depth # 最大深度,当这个为0时,表示已经到达最大深度,不可再分子节点
self.children: List['OctreeNode'] = [] # 子节点列表 self.child_nodes: torch.nn.ModuleList = torch.nn.ModuleList() # 子节点列表
self.face_indices = face_indices self.face_indices = face_indices
self.param_key = ""
#self.patch_feature_volume = None # 补丁特征体积,only leaf has #self.patch_feature_volume = None # 补丁特征体积,only leaf has
self._is_leaf = True self._is_leaf = True
#print(f"box shape: {self.bbox.shape}")
if feature_dim is not None: if surf_bbox is not None:
OctreeNode.feature_dim = feature_dim if not isinstance(surf_bbox, torch.Tensor):
if surf_bbox is not None: raise TypeError(
if not isinstance(surf_bbox, np.ndarray): f"surf_bbox 必须是 torch.Tensor 类型,但得到 {type(surf_bbox)}"
raise TypeError(f"surf_bbox 必须是 numpy.ndarray 类型,但得到 {type(surf_bbox)}") )
if surf_bbox.ndim != 2 or surf_bbox.shape[1] != 6: if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6:
raise ValueError(f"surf_bbox 应为二维数组且形状为 (N,6),但得到 {surf_bbox.shape}") raise ValueError(
OctreeNode.surf_bbox = surf_bbox # NOTE: 只在根节点时创建 f"surf_bbox 应为二维张量且形状为 (N,6),但得到 {surf_bbox.shape}"
)
OctreeNode.surf_bbox = surf_bbox # NOTE: 只在根节点时创建
OctreeNode.device = bbox.device
def is_leaf(self): def is_leaf(self):
# Check if self.children is None before calling len() # Check if self.child——nodes is None before calling len()
return self._is_leaf return self._is_leaf
def set_param_key(self, k):
self.param_key = k
def conduct_tree(self): def conduct_tree(self):
""" if self.max_depth <= 0 or self.face_indices.shape[0] <= 2:
构建八叉树如果达到最大深度或当前节点包含的面数小于等于2则停止划分
"""
if self.max_depth <= 0 or len(self.face_indices) <= 2:
# 达到最大深度 or 一个单元格至多只有两个面 # 达到最大深度 or 一个单元格至多只有两个面
#self.patch_feature_volume = np.random.randn(8, OctreeNode.feature_dim) return
return
self.subdivide() self.subdivide()
def subdivide(self): def subdivide(self):
"""
将当前节点划分为8个子节点并分配相交的面 #min_x, min_y, min_z, max_x, max_y, max_z = self.bbox
""" # 使用索引操作替代解包
min_coords = self.bbox[:3] # [min_x, min_y, min_z] min_coords = self.bbox[:3] # [min_x, min_y, min_z]
max_coords = self.bbox[3:] # [max_x, max_y, max_z] max_coords = self.bbox[3:] # [max_x, max_y, max_z]
# 计算中间点 # 计算中间点
mid_coords = (min_coords + max_coords) / 2 mid_coords = (min_coords + max_coords) / 2
# 提取 min_x, min_y, min_z, mid_x, mid_y, mid_z # 提取 min_x, min_y, min_z, mid_x, mid_y, mid_z
min_x, min_y, min_z = min_coords min_x, min_y, min_z = min_coords[0], min_coords[1], min_coords[2]
mid_x, mid_y, mid_z = mid_coords mid_x, mid_y, mid_z = mid_coords[0], mid_coords[1], mid_coords[2]
max_x, max_y, max_z = max_coords max_x, max_y, max_z = max_coords[0], max_coords[1], max_coords[2]
# 生成 8 个子包围盒 # 生成 8 个子包围盒
child_bboxes = np.array([ child_bboxes = torch.stack([
[*min_coords, *mid_coords], # 前下左 torch.cat([min_coords, mid_coords]), # 前下左
[mid_x, min_y, min_z, max_x, mid_y, mid_z], # 前下右 torch.cat([torch.tensor([mid_x, min_y, min_z], device=self.bbox.device),
[min_x, mid_y, min_z, mid_x, max_y, mid_z], # 前上左 torch.tensor([max_x, mid_y, mid_z], device=self.bbox.device)]), # 前下右
[mid_x, mid_y, min_z, max_x, max_y, mid_z], # 前上右 torch.cat([torch.tensor([min_x, mid_y, min_z], device=self.bbox.device),
[min_x, min_y, mid_z, mid_x, mid_y, max_z], # 后下左 torch.tensor([mid_x, max_y, mid_z], device=self.bbox.device)]), # 前上左
[mid_x, min_y, mid_z, max_x, mid_y, max_z], # 后下右 torch.cat([torch.tensor([mid_x, mid_y, min_z], device=self.bbox.device),
[min_x, mid_y, mid_z, mid_x, max_y, max_z], # 后上左 torch.tensor([max_x, max_y, mid_z], device=self.bbox.device)]), # 前上右
[mid_x, mid_y, mid_z, max_x, max_y, max_z] # 后上右 torch.cat([torch.tensor([min_x, min_y, mid_z], device=self.bbox.device),
torch.tensor([mid_x, mid_y, max_z], device=self.bbox.device)]), # 后下左
torch.cat([torch.tensor([mid_x, min_y, mid_z], device=self.bbox.device),
torch.tensor([max_x, mid_y, max_z], device=self.bbox.device)]), # 后下右
torch.cat([torch.tensor([min_x, mid_y, mid_z], device=self.bbox.device),
torch.tensor([mid_x, max_y, max_z], device=self.bbox.device)]), # 后上左
torch.cat([torch.tensor([mid_x, mid_y, mid_z], device=self.bbox.device),
torch.tensor([max_x, max_y, max_z], device=self.bbox.device)]) # 后上右
]) ])
# 为每个子包围盒创建子节点,并分配相交的面 # 为每个子包围盒创建子节点,并分配相交的面
self.children = []
for bbox in child_bboxes: for bbox in child_bboxes:
# 找到与子包围盒相交的面 # 找到与子包围盒相交的面
intersecting_faces = [ intersecting_faces = []
face_idx for face_idx in self.face_indices for face_idx in self.face_indices:
if bbox_intersect(bbox, OctreeNode.surf_bbox[face_idx]) face_bbox = OctreeNode.surf_bbox[face_idx]
] if bbox_intersect(bbox, face_bbox):
intersecting_faces.append(face_idx)
#print(f"{bbox}: {intersecting_faces}")
child_node = OctreeNode( child_node = OctreeNode(
bbox=bbox, bbox=bbox,
face_indices=np.array(intersecting_faces), face_indices=np.array(intersecting_faces),
max_depth=self.max_depth - 1 max_depth=self.max_depth - 1
) )
child_node.conduct_tree() child_node.conduct_tree()
self.children.append(child_node) self.child_nodes.append(child_node)
self._is_leaf = False self._is_leaf = False
def get_child_index(self, query_point: np.ndarray) -> int: def get_child_index(self, query_point: torch.Tensor) -> int:
""" """
计算点所在子节点的索引 计算点所在子节点的索引
:param query_point: 待检查的点格式为 (x, y, z) :param query_point: 待检查的点格式为 (x, y, z)
:return: 子节点的索引范围从 0 7 :return: 子节点的索引范围从 0 7
""" """
# 确保 query_point 和 bbox 在同一设备上
query_point = query_point.to(self.bbox.device)
# 提取 bbox 的最小和最大坐标 # 提取 bbox 的最小和最大坐标
min_coords = self.bbox[:3] # [min_x, min_y, min_z] min_coords = self.bbox[:3] # [min_x, min_y, min_z]
max_coords = self.bbox[3:] # [max_x, max_y, max_z] max_coords = self.bbox[3:] # [max_x, max_y, max_z]
@ -132,11 +139,11 @@ class OctreeNode:
mid_coords = (min_coords + max_coords) / 2 mid_coords = (min_coords + max_coords) / 2
# 使用布尔比较结果计算索引 # 使用布尔比较结果计算索引
index = ((query_point >= mid_coords) << np.arange(3)).sum() index = ((query_point >= mid_coords) << torch.arange(3, device=self.bbox.device)).sum()
return index return index.item()
def find_leaf(self, query_point: np.ndarray) -> np.ndarray: def find_leaf(self, query_point: torch.Tensor) -> Tuple[torch.Tensor, str, bool]:
""" """
查找包含给定点的叶子节点并返回其信息以元组形式 查找包含给定点的叶子节点并返回其信息以元组形式
:param query_point: 待查找的点 :param query_point: 待查找的点
@ -145,22 +152,66 @@ class OctreeNode:
# 如果当前节点是叶子节点,返回其信息 # 如果当前节点是叶子节点,返回其信息
if self._is_leaf: if self._is_leaf:
#logger.info(f"{self.bbox}, {self.param_key}, {True}") #logger.info(f"{self.bbox}, {self.param_key}, {True}")
return self.face_indices return (self.bbox, self.param_key, True)
# 计算查询点所在的子节点索引 # 计算查询点所在的子节点索引
index = self.get_child_index(query_point) index = self.get_child_index(query_point)
try:
# 直接访问子节点,不进行显式检查 # 遍历子节点列表,找到对应的子节点
return self.children[index].find_leaf(query_point) for i, child_node in enumerate(self.child_nodes):
except IndexError as e: if i == index and child_node is not None:
# 记录错误日志并重新抛出异常 # 递归调用子节点的 find_leaf 方法
logger.error( result = child_node.find_leaf(query_point)
f"Error accessing child node: {e}. "
f"Query point: {query_point.tolist()}, " # 确保返回值是一个元组
f"Node bbox: {self.bbox.tolist()}, " assert isinstance(result, tuple), f"Unexpected return type: {type(result)}"
f"Depth info: {self.max_depth}" return result
)
raise e # 如果找不到有效的子节点,抛出异常
raise IndexError(f"Invalid child node index: {index}")
'''
try:
# 直接访问子节点,不进行显式检查
return self.child_nodes[index].find_leaf(query_point)
except IndexError as e:
# 记录错误日志并重新抛出异常
logger.error(
f"Error accessing child node: {e}. "
f"Query point: {query_point.cpu().numpy().tolist()}, "
f"Node bbox: {self.bbox.cpu().numpy().tolist()}, "
f"Depth info: {self.max_depth}"
)
raise e
'''
'''
def get_feature_vector(self, query_point:torch.Tensor):
"""
预测给定点的 SDF
:param point: 待预测的点格式为 (x, y, z)
:return: 预测的 SDF
"""
# 将点转换为 numpy 数组
# 从根节点开始递归查找包含该点的叶子节点
if self._is_leaf:
return self.trilinear_interpolation(query_point)
else:
index = self.get_child_index(query_point)
try:
# 直接访问子节点,不进行显式检查
return self.child_nodes[index].get_feature_vector(query_point)
except IndexError as e:
# 记录错误日志并重新抛出异常
logger.error(
f"Error accessing child node: {e}. "
f"Query point: {query_point.cpu().numpy().tolist()}, "
f"Node bbox: {self.bbox.cpu().numpy().tolist()}, "
f"Depth info: {self.max_depth}"
)
raise e
'''
@ -178,236 +229,37 @@ class OctreeNode:
# 打印当前节点信息 # 打印当前节点信息
indent = " " * depth indent = " " * depth
node_type = "Leaf" if self._is_leaf else "Internal" node_type = "Leaf" if self._is_leaf else "Internal"
print(f"{indent}L{depth} [{node_type}] BBox: {self.bbox.tolist()}") print(f"{indent}L{depth} [{node_type}] BBox: {self.bbox.cpu().numpy().tolist()}")
# 打印面片信息(如果有) # 打印面片信息(如果有)
if self.face_indices is not None: if self.face_indices is not None:
print(f"{indent} Face indices: {self.face_indices.tolist()}") print(f"{indent} Face indices: {self.face_indices.tolist()}")
print(f"{indent} len children: {len(self.children)}") print(f"{indent} len child_nodes: {len(self.child_nodes)}")
# 递归打印子节点 # 递归打印子节点
for i, child in enumerate(self.children): for i, child in enumerate(self.child_nodes):
print(f"{indent} Child {i}:") print(f"{indent} Child {i}:")
child.print_tree(depth + 1, max_print_depth) child.print_tree(depth + 1, max_print_depth)
# 保存 def __getstate__(self):
"""支持pickle序列化"""
def save_tree_to_file(self, file_path: str): return self._serialize_node(self)
"""
将八叉树保存到文件 def __setstate__(self, state):
:param file_path: 要保存的文件路径 """支持pickle反序列化"""
""" self = self._deserialize_node(state)
# 获取完整状态字典
state = self.state_dict()
# 添加类级别的静态变量
state['feature_dim'] = OctreeNode.feature_dim
state['surf_bbox'] = OctreeNode.surf_bbox
# 保存到文件
with open(file_path, 'wb') as f:
pickle.dump(state, f)
print(f"八叉树已成功保存到 {file_path}")
@classmethod
def load_tree_from_file(cls, file_path: str) -> 'OctreeNode':
"""
从文件加载八叉树
:param file_path: 要加载的文件路径
:return: 恢复的八叉树根节点
"""
with open(file_path, 'rb') as f:
state = pickle.load(f)
# 恢复类级别的静态变量
cls.feature_dim = state.pop('feature_dim')
cls.surf_bbox = state.pop('surf_bbox')
# 创建根节点
root = cls(
bbox=state['bbox'],
face_indices=state['face_indices'],
max_depth=state['max_depth']
)
# 加载状态
root.load_state_dict(state)
print(f"八叉树已从 {file_path} 成功加载") def _serialize_node(self, node):
return root return {
def state_dict(self): 'bbox': node.bbox,
"""返回节点及其子树的state_dict""" 'is_leaf': node._is_leaf,
state = { 'child_nodes': [self._serialize_node(c) for c in node.child_nodes],
'bbox': self.bbox, 'param_key': node.param_key
'max_depth': self.max_depth,
'face_indices': self.face_indices,
'is_leaf': self._is_leaf
} }
if self._is_leaf:
pass
else:
state['children'] = [child.state_dict() for child in self.children]
return state
def load_state_dict(self, state_dict):
"""从state_dict加载节点状态"""
self.bbox = state_dict['bbox']
self.max_depth = state_dict['max_depth']
self.face_indices = state_dict['face_indices']
self._is_leaf = state_dict['is_leaf']
if self._is_leaf:
return
else:
self.children = []
for child_state in state_dict['children']:
child = OctreeNode(
bbox=child_state['bbox'],
face_indices=child_state['face_indices'],
max_depth=child_state['max_depth']
)
child.load_state_dict(child_state)
self.children.append(child)
def test_octree():
# 1. 测试bbox_intersect函数
print("测试bbox_intersect函数...")
bbox1 = np.array([0, 0, 0, 1, 1, 1])
bbox2 = np.array([0.5, 0.5, 0.5, 1.5, 1.5, 1.5])
assert bbox_intersect(bbox1, bbox2), "相交测试失败"
bbox3 = np.array([2, 2, 2, 3, 3, 3])
assert not bbox_intersect(bbox1, bbox3), "不相交测试失败"
print("bbox_intersect测试通过!\n")
# 2. 创建测试用的面包围盒
# 假设有4个面,每个面有一个包围盒
surf_bbox = np.array([
[0.1, 0.1, 0.1, 0.4, 0.4, 0.4], # 面0
[0.6, 0.1, 0.1, 0.9, 0.4, 0.4], # 面1
[0.1, 0.6, 0.1, 0.4, 0.9, 0.4], # 面2
[0.6, 0.6, 0.1, 0.9, 0.9, 0.4] # 面3
])
# 3. 创建根节点 def _deserialize_node(self, data):
root_bbox = np.array([0, 0, 0, 1, 1, 1]) node = OctreeNode(data['bbox'], 0) # max_depth会在encoder中重建
face_indices = np.arange(len(surf_bbox)) # 初始包含所有面 node._is_leaf = data['is_leaf']
root = OctreeNode( node.param_key = data['param_key']
bbox=root_bbox, node.child_nodes = [self._deserialize_node(c) for c in data['child_nodes']]
face_indices=face_indices, return node
max_depth=2,
feature_dim=32,
surf_bbox=surf_bbox
)
# 4. 构建八叉树
root.conduct_tree()
# 5. 打印树结构(只打印前2层)
print("八叉树结构:")
root.print_tree(max_print_depth=2)
# 6. 测试子节点索引计算
print("\n测试子节点索引计算...")
test_points = [
([0.25, 0.25, 0.25], "应在前下左子节点"),
([0.75, 0.25, 0.25], "应在前下右子节点"),
([0.25, 0.75, 0.25], "应在前上左子节点"),
([0.75, 0.75, 0.25], "应在前上右子节点")
]
for point, desc in test_points:
idx = root.get_child_index(np.array(point))
print(f"{point} {desc}, 计算得到的索引: {idx}")
# 7. 验证叶子节点特征
print("\n验证叶子节点特征:")
for i, child in enumerate(root.children):
if child.is_leaf():
print(f"子节点 {i} 是叶子节点,")
else:
print(f"子节点 {i} 不是叶子节点")
print("\n所有测试完成!")
# ... existing code ...
def test_octree_save_load():
print("\n测试八叉树的保存和加载功能...")
# 1. 创建测试数据
surf_bbox = np.array([
[0.1, 0.1, 0.1, 0.4, 0.4, 0.4], # 面0
[0.6, 0.1, 0.1, 0.9, 0.4, 0.4], # 面1
[0.1, 0.6, 0.1, 0.4, 0.9, 0.4], # 面2
[0.6, 0.6, 0.1, 0.9, 0.9, 0.4] # 面3
])
# 2. 创建原始树
root = OctreeNode(
bbox=np.array([0, 0, 0, 1, 1, 1]),
face_indices=np.arange(len(surf_bbox)),
max_depth=2,
feature_dim=32,
surf_bbox=surf_bbox
)
root.conduct_tree()
# 3. 保存树状态
test_file = 'test_octree.pkl'
root.save_tree_to_file(test_file)
# 4. 从文件加载树
new_root = OctreeNode.load_tree_from_file(test_file)
print("树状态加载成功!")
# 5. 验证加载后的树结构
print("\n验证加载后的树结构:")
# 5.1 验证根节点属性
assert np.allclose(root.bbox, new_root.bbox), "bbox不匹配"
assert root.max_depth == new_root.max_depth, "max_depth不匹配"
assert np.array_equal(root.face_indices, new_root.face_indices), "face_indices不匹配"
assert root._is_leaf == new_root._is_leaf, "is_leaf不匹配"
print("根节点属性验证通过!")
# 5.2 验证叶子节点特征
if root._is_leaf:
#assert np.allclose(root.patch_feature_volume, new_root.patch_feature_volume), "特征不匹配"
print("叶子节点特征验证通过!")
else:
# 递归验证子节点
for i, (orig_child, new_child) in enumerate(zip(root.children, new_root.children)):
print(f"\n验证子节点 {i}:")
assert np.allclose(orig_child.bbox, new_child.bbox), f"子节点{i} bbox不匹配"
assert orig_child.max_depth == new_child.max_depth, f"子节点{i} max_depth不匹配"
assert np.array_equal(orig_child.face_indices, new_child.face_indices), f"子节点{i} face_indices不匹配"
assert orig_child._is_leaf == new_child._is_leaf, f"子节点{i} is_leaf不匹配"
if orig_child._is_leaf:
#assert np.allclose(orig_child.patch_feature_volume, new_child.patch_feature_volume), f"子节点{i} 特征不匹配"
print(f"子节点{i} 叶子节点特征验证通过!")
else:
print(f"子节点{i} 是非叶子节点,继续验证其子节点...")
# 6. 打印部分树结构对比
print("\n原始树结构(前2层):")
root.print_tree(max_print_depth=2)
print("\n加载后的树结构(前2层):")
new_root.print_tree(max_print_depth=2)
print("\n八叉树保存和加载测试全部通过!")
if __name__ == "__main__":
test_octree() # 运行基本功能测试
test_octree_save_load() # 运行保存加载测试

8
brep2sdf/train.py

@ -91,7 +91,6 @@ class Trainer:
# 将曲面点云列表转换为 (N*M, 4) 数组 # 将曲面点云列表转换为 (N*M, 4) 数组
surfs = self.data["surf_ncs"] surfs = self.data["surf_ncs"]
#logger.debug(self.data['faceEdge_adj'].shape)
self.sdf_data = prepare_sdf_data( self.sdf_data = prepare_sdf_data(
surfs, surfs,
normals = self.data["surf_pnt_normals"], normals = self.data["surf_pnt_normals"],
@ -112,7 +111,7 @@ class Trainer:
) )
self.build_tree(surf_bbox=self.data['surf_bbox_ncs'], max_depth=4) self.build_tree(surf_bbox=surf_bbox, max_depth=4)
self.model = Net( self.model = Net(
@ -279,8 +278,9 @@ class Trainer:
def _tracing_model(self): def _tracing_model(self):
"""保存模型""" """保存模型"""
self.model.eval() self.model.eval()
self.root.save_tree_to_file(f"/home/wch/brep2sdf/data/output_data/{self.model_name}_tree.pkl") # 确保模型中的所有逻辑都兼容 TorchScript
torch.save(self.model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt") scripted_model = torch.jit.script(self.model)
torch.jit.save(scripted_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt")
def _load_checkpoint(self, checkpoint_path): def _load_checkpoint(self, checkpoint_path):
"""从检查点恢复训练状态""" """从检查点恢复训练状态"""

Loading…
Cancel
Save