Browse Source

优化了八叉树并行,确保一个叶节点不超过两个面

final
mckay 12 months ago
parent
commit
fa17441396
  1. 4
      brep2sdf/config/default_config.py
  2. 18
      brep2sdf/data/utils.py
  3. 249
      brep2sdf/networks/decoder.py
  4. 125
      brep2sdf/networks/encoder.py
  5. 98
      brep2sdf/networks/feature_volume.py
  6. 29
      brep2sdf/networks/network.py
  7. 264
      brep2sdf/networks/octree.py
  8. 35
      brep2sdf/networks/patch_graph.py
  9. 69
      brep2sdf/test.py
  10. 209
      brep2sdf/train.py
  11. 2
      brep2sdf/utils/logger.py

4
brep2sdf/config/default_config.py

@ -48,7 +48,7 @@ class TrainConfig:
# 基本训练参数 # 基本训练参数
batch_size: int = 8 batch_size: int = 8
num_workers: int = 4 num_workers: int = 4
num_epochs: int = 100 num_epochs: int = 1000
learning_rate: float = 0.001 learning_rate: float = 0.001
min_lr: float = 1e-5 min_lr: float = 1e-5
weight_decay: float = 0.01 weight_decay: float = 0.01
@ -90,7 +90,7 @@ class LogConfig:
# 本地日志 # 本地日志
log_dir: str = '/home/wch/brep2sdf/logs' # 日志保存目录 log_dir: str = '/home/wch/brep2sdf/logs' # 日志保存目录
log_level: str = 'INFO' # 日志级别 log_level: str = 'INFO' # 日志级别
console_level: str = 'DEBUG' # 控制台日志级别 console_level: str = 'INFO' # 控制台日志级别
file_level: str = 'DEBUG' # 文件日志级别 file_level: str = 'DEBUG' # 文件日志级别
@dataclass @dataclass

18
brep2sdf/data/utils.py

@ -10,6 +10,9 @@ from OCC.Core.TopoDS import topods, TopoDS_Vertex # 拓扑数据结构
import numpy as np import numpy as np
from scipy.spatial import cKDTree from scipy.spatial import cKDTree
from torch.nn.utils.rnn import pad_sequence
import torch
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
def load_step(step_path): def load_step(step_path):
@ -35,7 +38,22 @@ def get_bbox(shape, subshape):
xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get() xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get()
return np.array([xmin, ymin, zmin, xmax, ymax, zmax]) return np.array([xmin, ymin, zmin, xmax, ymax, zmax])
def process_surf_ncs_with_dynamic_padding(surf_ncs: np.ndarray) -> torch.Tensor:
"""
使用 pad_sequence 动态填充 surf_ncs
参数:
surf_ncs: 形状为 (N,) np.ndarray(dtype=object)每个元素是形状为 (M, 3) float32 数组
返回:
padded_tensor: 形状为 (N, M_max, 3) 的张量其中 M_max 是最长子数组的长度
"""
# 转换为张量列表
tensor_list = [torch.tensor(arr, dtype=torch.float32) for arr in surf_ncs]
# 动态填充
padded_tensor = pad_sequence(tensor_list, batch_first=True, padding_value=float('inf'))
return padded_tensor
def normalize(surfs, edges, corners): def normalize(surfs, edges, corners):

249
brep2sdf/networks/decoder.py

@ -1,42 +1,219 @@
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch
from torch import Tensor import numpy as np
from typing import Tuple, List, Union
from brep2sdf.utils.logger import logger
class Sine(nn.Module):
def __init(self):
super().__init__()
def forward(self, input):
# See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
return torch.sin(30 * input)
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, def __init__(
input_dim: int, self,
output_dim: int, d_in: int,
hidden_dim: int = 256) : dims_sdf: List[int],
""" skip_in: Tuple[int, ...] = (),
最简单的Decoder实现 flag_convex: bool = True,
geometric_init: bool = True,
参数: radius_init: float = 1,
input_dim: 输入维度 beta: float = 100,
output_dim: 输出维度 ) -> None:
hidden_dim: 隐藏层维度 (默认: 256)
"""
super().__init__() super().__init__()
# 三层全连接网络 self.flag_convex = flag_convex
self.fc1 = nn.Linear(input_dim, hidden_dim) self.skip_in = skip_in
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, output_dim) dims_sdf = [d_in] + dims_sdf + [1] #[self.n_branch]
self.sdf_layers = len(dims_sdf)
def forward(self, x: Tensor) -> Tensor: for layer in range(0, len(dims_sdf) - 1):
""" if layer + 1 in skip_in:
前向传播 out_dim = dims_sdf[layer + 1] - d_in
else:
参数: out_dim = dims_sdf[layer + 1]
x: 输入张量 lin = nn.Linear(dims_sdf[layer], out_dim)
返回: if geometric_init:
输出张量 if layer == self.sdf_layers - 2:
""" torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims_sdf[layer]), std=0.00001)
# 第一层 torch.nn.init.constant_(lin.bias, -radius_init)
h = F.relu(self.fc1(x)) else:
# 第二层 torch.nn.init.constant_(lin.bias, 0.0)
h = F.relu(self.fc2(h)) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
# 输出层 setattr(self, "sdf_"+str(layer), lin)
out = self.fc3(h) if geometric_init:
if beta > 0:
return out self.activation = nn.Softplus(beta=beta)
# vanilla relu
else:
self.activation = nn.ReLU()
else:
#siren
self.activation = Sine()
self.final_activation = nn.ReLU()
# composite f_i to h
def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor:
'''
:param feature_matrix: 形状为 (B, P, D) 的特征矩阵
B: 批大小
P: patch volume数量
D: 特征维度
:return:
f_i: 各patch的SDF值 (B, P)
'''
B, P, D = feature_matrix.shape
# 展平处理 (B*P, D)
x = feature_matrix.view(-1, D)
for layer in range(0, self.sdf_layers - 1):
lin = getattr(self, "sdf_" + str(layer))
if layer in self.skip_in:
x = torch.cat([x, x], -1) / np.sqrt(2) # Fix undefined 'input'
x = lin(x)
if layer < self.sdf_layers - 2:
x = self.activation(x)
output_value = x #all f_i
# 恢复维度 (B, P)
f_i = output_value.view(B, P)
return f_i
# 一个基础情形: 输入 fi 形状[P] 和 csg tree,凹凸组合输出h
#注意考虑如何批量处理 (B, P) 和 [csg tree]
class CSGCombiner:
def __init__(self, flag_convex: bool = True, rho: float = 0.05):
self.flag_convex = flag_convex
self.rho = rho
def forward(self, f_i: torch.Tensor, csg_tree) -> torch.Tensor:
'''
:param f_i: 形状为 (B, P) 的各patch SDF值
:param csg_tree: CSG树结构
:return: 组合后的整体SDF (B,)
'''
logger.info("\n".join(f"{i}个csg: {t}" for i,t in enumerate(csg_tree)))
B = f_i.shape[0]
results = []
for i in range(B):
# 处理每个样本的CSG组合
h = self.nested_cvx_output_soft_blend(
f_i[i].unsqueeze(0),
csg_tree,
self.flag_convex
)
results.append(h)
return torch.cat(results, dim=0).squeeze(1) # 从(B,1)变为(B,)
def nested_cvx_output_soft_blend(
self,
value_matrix: torch.Tensor,
list_operation: List[Union[int, List]],
cvx_flag: bool = True
) -> torch.Tensor:
list_value = []
for v in list_operation:
if not isinstance(v, list):
list_value.append(v)
op_mat = torch.zeros(value_matrix.shape[1], len(list_value),
device=value_matrix.device)
for i in range(len(list_value)):
op_mat[list_value[i]][i] = 1.0
mat_mul = torch.matmul(value_matrix, op_mat)
if len(list_operation) == len(list_value):
return self.max_soft_blend(mat_mul, self.rho) if cvx_flag \
else self.min_soft_blend(mat_mul, self.rho)
list_output = [mat_mul]
for v in list_operation:
if isinstance(v, list):
list_output.append(
self.nested_cvx_output_soft_blend(
value_matrix, v, not cvx_flag
)
)
return self.max_soft_blend(torch.cat(list_output, 1), self.rho) if cvx_flag \
else self.min_soft_blend(torch.cat(list_output, 1), self.rho)
def min_soft_blend(self, mat, rho):
res = mat[:,0]
for i in range(1, mat.shape[1]):
srho = res * res + mat[:,i] * mat[:,i] - rho * rho
res = res + mat[:,i] - torch.sqrt(res * res + mat[:,i] * mat[:,i] + 1.0/(8 * rho * rho) * srho * (srho - srho.abs()))
return res.unsqueeze(1)
def max_soft_blend(self, mat, rho):
res = mat[:,0]
for i in range(1, mat.shape[1]):
srho = res * res + mat[:,i] * mat[:,i] - rho * rho
res = res + mat[:,i] + torch.sqrt(res * res + mat[:,i] * mat[:,i] + 1.0/(8 * rho * rho) * srho * (srho - srho.abs()))
return res.unsqueeze(1)
def test_csg_combiner():
# 测试数据 (B=3, P=5)
f_i = torch.tensor([
[1.0, 2.0, 3.0, 4.0, 5.0],
[0.5, 1.5, 2.5, 3.5, 4.5],
[-1.0, 0.0, 1.0, 2.0, 3.0]
])
# 每个样本使用不同的CSG树结构
csg_trees = [
[0, [1, 2]], # 使用索引0,1,2
[[0, 1], 3], # 使用索引0,1,3
[0, 1, [2, 4]] # 使用索引0,1,2,4
]
# 验证所有索引都有效
P = f_i.shape[1]
for i, tree in enumerate(csg_trees):
def check_indices(node):
if isinstance(node, list):
for n in node:
check_indices(n)
else:
assert node < P, f"样本{i}的树包含无效索引{node},P={P}"
check_indices(tree)
print("Input SDF values:")
print(f_i)
print("\nCSG Trees:")
for i, tree in enumerate(csg_trees):
print(f"Sample {i}: {tree}")
# 测试凸组合
print("\nTesting convex combination:")
combiner_convex = CSGCombiner(flag_convex=True)
h_convex = combiner_convex.forward(f_i, csg_trees)
print("Results:", h_convex)
# 测试凹组合
print("\nTesting concave combination:")
combiner_concave = CSGCombiner(flag_convex=False)
h_concave = combiner_concave.forward(f_i, csg_trees)
print("Results:", h_concave)
# 测试不同rho值的软混合
print("\nTesting soft blends:")
for rho in [0.01, 0.1, 0.5]:
combiner_soft = CSGCombiner(flag_convex=True, rho=rho)
h_soft = combiner_soft.forward(f_i, csg_trees)
print(f"rho={rho}:", h_soft)
if __name__ == "__main__":
test_csg_combiner()

125
brep2sdf/networks/encoder.py

@ -2,72 +2,89 @@ import torch
import torch.nn as nn import torch.nn as nn
from .octree import OctreeNode from .octree import OctreeNode
from .feature_volume import PatchFeatureVolume
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, octree: OctreeNode, feature_dim: int = 32): def __init__(self, volume_bboxs:torch.tensor, feature_dim: int = 32):
""" """
分离后的编码器接收预构建的八叉树 分离后的编码器接收预构建的八叉树
参数: 参数:
octree: 预构建的八叉树结构 volume_bboxs: 所有面片的边界框集合形状为 (N, 2, 3)
feature_dim: 特征维度 feature_dim: 特征维度
""" """
super().__init__() super().__init__()
self.feature_dim = feature_dim self.feature_dim = feature_dim
# 初始化叶子节点参数 # 批量计算所有bbox的分辨率
self.register_buffer('num_parameters', torch.tensor(0, dtype=torch.long)) resolutions = self._batch_calculate_resolution(volume_bboxs)
self._leaf_features = None # 将在_init_parameters中初始化
self._init_parameters(octree) # 初始化多个特征体积
self.feature_volumes = nn.ModuleList([
def _init_parameters(self,octree): PatchFeatureVolume(
stack = [(octree, 0)] bbox=bbox,
param_count = 0 resolution=int(resolutions[i]),
feature_dim=feature_dim
while stack: ) for i, bbox in enumerate(volume_bboxs)
node, _ = stack.pop() ])
if node._is_leaf: print(f"Initialized {len(self.feature_volumes)} feature volumes with resolutions: {resolutions.tolist()}")
param_count += 1 print(f"Model parameters memory usage: {sum(p.numel() * p.element_size() for p in self.parameters()) / 1024**2:.2f} MB")
else:
for child in node.child_nodes: def _batch_calculate_resolution(self, bboxes: torch.Tensor) -> torch.Tensor:
if child: stack.append((child, 0)) """
批量计算归一化bboxes的分辨率
# 初始化连续参数张量
self._leaf_features = nn.Parameter( 参数:
torch.randn(param_count, 8, self.feature_dim)) bboxes: 归一化边界框张量形状为 (N, 2, 3)
# 重新遍历设置索引 返回:
stack = [(octree, 0)] 分辨率张量 (N,)
index = 0 """
while stack: with torch.no_grad():
node, _ = stack.pop() # 计算每个bbox的对角线长度(归一化后范围约为0.0-1.732)
if node._is_leaf: diagonals = torch.norm(bboxes[:,3:6] - bboxes[:,0:3], dim=1)
node.set_param_key(index)
index += 1 # 根据归一化后的对角线长度调整分辨率
else: resolutions = torch.zeros_like(diagonals, dtype=torch.long)
for child in node.child_nodes: resolutions[diagonals > 1.0] = 16 # 大尺寸
if child: stack.append((child, 0)) resolutions[(diagonals > 0.5) & (diagonals <= 1.0)] = 8 # 中等尺寸
self.num_parameters.fill_(index) resolutions[diagonals <= 0.5] = 4 # 小尺寸
def forward(self, query_points: torch.Tensor,param_indices,bboxes) -> torch.Tensor: return resolutions
batch_size = query_points.shape[0]
# 批量获取特征 def forward(self, query_points: torch.Tensor, volume_indices: torch.Tensor) -> torch.Tensor:
unique_ids, inverse_ids = torch.unique(param_indices, return_inverse=True) """
all_features = self._leaf_features[unique_ids] # (U, 8, D) 修改后的前向传播返回所有关联volume的特征矩阵
node_features = all_features[inverse_ids] # (B, 8, D)
参数:
# 启用混合精度和优化后的插值 query_points: 查询点坐标 (B, 3)
with torch.cuda.amp.autocast(): volume_indices: 关联的volume索引矩阵 (B, K)
features = self._optimized_trilinear(
query_points, 返回:
bboxes.detach(), 特征张量 (B, K, D)
node_features """
) batch_size, num_volumes = volume_indices.shape
all_features = torch.zeros(batch_size, num_volumes, self.feature_dim,
# 添加类型转换确保输出为float32 device=query_points.device)
return features.to(torch.float32) # 添加这行
# 遍历每个volume索引
for k in range(num_volumes):
# 获取当前volume的索引 (B,)
current_indices = volume_indices[:, k]
# 遍历所有存在的volume
for vol_id in range(len(self.feature_volumes)):
# 创建掩码 (B,)
mask = (current_indices == vol_id)
if mask.any():
# 获取对应volume的特征 (M, D)
features = self.feature_volumes[vol_id](query_points[mask])
all_features[mask, k] = features
return all_features
def _optimized_trilinear(self, points, bboxes, features): def _optimized_trilinear(self, points, bboxes, features):
"""优化后的向量化三线性插值""" """优化后的向量化三线性插值"""

98
brep2sdf/networks/feature_volume.py

@ -4,50 +4,72 @@ import torch
import torch.nn as nn import torch.nn as nn
class PatchFeatureVolume(nn.Module): class PatchFeatureVolume(nn.Module):
def __init__(self, bbox:np, resolution=64, feature_dim=64): def __init__(self, bbox: torch.Tensor, resolution=64, feature_dim=64, padding_ratio=0.05):
super(PatchFeatureVolume, self).__init__() super(PatchFeatureVolume, self).__init__()
self.bbox = bbox # 补丁的边界框 # 将输入bbox转换为[min, max]格式
self.resolution = resolution # 网格分辨率 self.resolution = resolution
self.feature_dim = feature_dim # 特征向量维度 min_coords = bbox[:3]
max_coords = bbox[3:]
self.original_bbox = torch.stack([min_coords, max_coords])
expanded_bbox = self._expand_bbox(min_coords, max_coords, padding_ratio)
# 创建规则的三维网格 # 创建规则的三维网格
x = torch.linspace(bbox[0][0], bbox[1][0], resolution) x = torch.linspace(expanded_bbox[0][0], expanded_bbox[1][0], resolution)
y = torch.linspace(bbox[0][1], bbox[1][1], resolution) y = torch.linspace(expanded_bbox[0][1], expanded_bbox[1][1], resolution)
z = torch.linspace(bbox[0][2], bbox[1][2], resolution) z = torch.linspace(expanded_bbox[0][2], expanded_bbox[1][2], resolution)
grid_x, grid_y, grid_z = torch.meshgrid(x, y, z) grid_x, grid_y, grid_z = torch.meshgrid(x, y, z)
self.register_buffer('grid', torch.stack([grid_x, grid_y, grid_z], dim=-1)) self.register_buffer('grid', torch.stack([grid_x, grid_y, grid_z], dim=-1))
# 初始化特征向量,作为可训练参数 # 初始化特征向量
self.feature_volume = nn.Parameter(torch.randn(resolution, resolution, resolution, feature_dim)) self.feature_volume = nn.Parameter(torch.randn(resolution, resolution, resolution, feature_dim))
def forward(self, query_points: List[Tuple[float, float, float]]): def _expand_bbox(self, min_coords, max_coords, ratio):
""" # 扩展包围盒范围
根据查询点的位置从补丁特征体积中获取插值后的特征向量 center = (min_coords + max_coords) / 2
expanded_min = center - (center - min_coords) * (1 + ratio)
expanded_max = center + (max_coords - center) * (1 + ratio)
return torch.stack([expanded_min, expanded_max])
:param query_points: 查询点的位置坐标形状为 (N, 3) def forward(self, query_points: torch.Tensor) -> torch.Tensor:
:return: 插值后的特征向量形状为 (N, feature_dim) """批量处理版本的三线性插值
Args:
query_points: 形状为 (B, 3) 的查询点坐标
Returns:
形状为 (B, D) 的特征向量
""" """
interpolated_features = torch.zeros(query_points.shape[0], self.feature_dim).to(self.feature_volume.device) # 添加类型转换确保计算稳定性
for i, point in enumerate(query_points): normalized = ((query_points - self.grid[0,0,0]) /
interpolated_feature = self.trilinear_interpolation(point) (self.grid[-1,-1,-1] - self.grid[0,0,0] + 1e-8)) # (B,3)
interpolated_features[i] = interpolated_feature
return interpolated_features # 向量化三线性插值
return self._batched_trilinear(normalized)
def trilinear_interpolation(self, query_point):
"""三线性插值""" def _batched_trilinear(self, normalized: torch.Tensor) -> torch.Tensor:
normalized_coords = ((query_point - torch.tensor(self.bbox[0]).to(self.grid.device)) / """批量处理的三线性插值"""
(torch.tensor(self.bbox[1]).to(self.grid.device) - torch.tensor(self.bbox[0]).to(self.grid.device))) * (self.resolution - 1) # 计算8个顶点的权重
indices = torch.floor(normalized_coords).long() uvw = normalized * (self.resolution - 1)
weights = normalized_coords - indices.float() indices = torch.floor(uvw).long() # (B,3)
weights = uvw - indices.float() # (B,3)
interpolated_feature = torch.zeros(self.feature_dim).to(self.feature_volume.device)
for di in range(2): # 计算8个顶点的权重组合 (B,8)
for dj in range(2): weights = torch.stack([
for dk in range(2): (1 - weights[...,0]) * (1 - weights[...,1]) * (1 - weights[...,2]),
weight = (weights[0] if di == 1 else 1 - weights[0]) * \ (1 - weights[...,0]) * (1 - weights[...,1]) * weights[...,2],
(weights[1] if dj == 1 else 1 - weights[1]) * \ (1 - weights[...,0]) * weights[...,1] * (1 - weights[...,2]),
(weights[2] if dk == 1 else 1 - weights[2]) (1 - weights[...,0]) * weights[...,1] * weights[...,2],
index = indices + torch.tensor([di, dj, dk]).to(indices.device) weights[...,0] * (1 - weights[...,1]) * (1 - weights[...,2]),
index = torch.clamp(index, 0, self.resolution - 1) weights[...,0] * (1 - weights[...,1]) * weights[...,2],
interpolated_feature += weight * self.feature_volume[index[0], index[1], index[2]] weights[...,0] * weights[...,1] * (1 - weights[...,2]),
return interpolated_feature weights[...,0] * weights[...,1] * weights[...,2],
], dim=-1) # (B,8)
# 获取8个顶点的特征 (B,8,D)
indices = indices.unsqueeze(1).expand(-1,8,-1) + torch.tensor([
[0,0,0], [0,0,1], [0,1,0], [0,1,1],
[1,0,0], [1,0,1], [1,1,0], [1,1,1]
], device=indices.device)
indices = torch.clamp(indices, 0, self.resolution-1)
features = self.feature_volume[indices[...,0], indices[...,1], indices[...,2]] # (B,8,D)
# 加权求和 (B,D)
return torch.einsum('bnd,bn->bd', features, weights)

29
brep2sdf/networks/network.py

@ -49,11 +49,13 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.autograd import grad from torch.autograd import grad
from .encoder import Encoder from .encoder import Encoder
from .decoder import Decoder from .decoder import Decoder, CSGCombiner
from brep2sdf.utils.logger import logger
class Net(nn.Module): class Net(nn.Module):
def __init__(self, def __init__(self,
octree, octree,
volume_bboxs,
feature_dim=64, feature_dim=64,
decoder_input_dim=64, decoder_input_dim=64,
decoder_output_dim=1, decoder_output_dim=1,
@ -69,11 +71,18 @@ class Net(nn.Module):
# 初始化 Encoder # 初始化 Encoder
self.encoder = Encoder( self.encoder = Encoder(
feature_dim=feature_dim, feature_dim=feature_dim,
octree=octree volume_bboxs= volume_bboxs
) )
# 初始化 Decoder # 初始化 Decoder
self.decoder = Decoder(input_dim=64, output_dim=1) self.decoder = Decoder(
d_in=decoder_input_dim,
dims_sdf=[decoder_hidden_dim] * decoder_num_layers,
geometric_init=True,
beta=100
)
self.csg_combiner = CSGCombiner(flag_convex=True)
def forward(self, query_points): def forward(self, query_points):
""" """
@ -85,14 +94,16 @@ class Net(nn.Module):
output: 解码后的输出结果 output: 解码后的输出结果
""" """
# 批量查询所有点的索引和bbox # 批量查询所有点的索引和bbox
param_indices,bboxes = self.octree_module.forward(query_points) _,face_indices,csg_trees = self.octree_module.forward(query_points)
print("param_indices requires_grad:", param_indices.requires_grad) # 应该输出False
print("bboxes requires_grad:", bboxes.requires_grad) # 应该输出False
# 编码 # 编码
feature_vector = self.encoder.forward(query_points,param_indices,bboxes) feature_vectors = self.encoder.forward(query_points,face_indices)
print("feature_vector:", feature_vector.requires_grad) #print("feature_vector:", feature_vectors.requires_grad)
# 解码 # 解码
output = self.decoder(feature_vector) logger.gpu_memory_stats("encoder farward后")
f_i = self.decoder(feature_vectors)
logger.gpu_memory_stats("decoder farward后")
output = self.csg_combiner.forward(f_i, csg_trees)
logger.gpu_memory_stats("combine后")
return output return output

264
brep2sdf/networks/octree.py

@ -4,12 +4,14 @@ import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from brep2sdf.data.utils import process_surf_ncs_with_dynamic_padding
from brep2sdf.networks.patch_graph import PatchGraph from brep2sdf.networks.patch_graph import PatchGraph
from brep2sdf.utils.logger import logger
def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor: def bbox_intersect_(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor:
"""判断两个轴对齐包围盒(AABB)是否相交 """判断两个轴对齐包围盒(AABB)是否相交
参数: 参数:
@ -28,8 +30,90 @@ def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor:
# 向量化比较 # 向量化比较
return torch.all((max1 >= min2) & (max2 >= min1)) return torch.all((max1 >= min2) & (max2 >= min1))
def if_points_in_box(points: np.ndarray, bbox: torch.Tensor) -> bool:
"""判断点是否在AABB包围盒内
参数:
points: 形状为 (N, 3) 的数组表示N个点的坐标
bbox: 形状为 (6,) 的张量表示AABB包围盒的坐标
返回:
bool: 如果所有点都在包围盒内返回True否则返回False
"""
# 将 points 转换为 torch.Tensor
points_tensor = torch.tensor(points, dtype=torch.float32, device=bbox.device)
# 提取min和max坐标
min_coords = bbox[:3]
max_coords = bbox[3:]
#logger.debug(f"min_coords: {min_coords}, max_coords: {max_coords}")
# 向量化比较
return torch.any((points_tensor >= min_coords) & (points_tensor <= max_coords)).item()
def bbox_intersect(
surf_bboxes: torch.Tensor,
indices: torch.Tensor,
child_bboxes: torch.Tensor,
surf_points: torch.Tensor = None
) -> torch.Tensor:
'''
args:
surf_bboxes: [B, 6] - 表示多个包围盒的张量每个包围盒由其最小和最大坐标定义
indices: 选择bbox, [N], N <= B - 用于选择特定包围盒的索引张量
child_bboxes: [8, 6] - 一个包围盒被分成八个子包围盒后的结果
surf_points: [B, M_max, 3] - 每个包围盒对应的点云数据可选
return:
result_mask: [8, B] - 布尔掩码表示每个子边界框与所有包围盒是否相交
且是否包含至少一个点如果提供了点云
'''
# 初始化全为 False 的结果掩码 [8, B]
B = surf_bboxes.size(0)
result_mask = torch.zeros((8, B), dtype=torch.bool).to(surf_bboxes.device)
logger.debug(result_mask.shape)
logger.debug(indices.shape)
# 提取选中的边界框
selected_bboxes = surf_bboxes[indices] # 形状为 [N, 6]
min1, max1 = selected_bboxes[:, :3], selected_bboxes[:, 3:] # 形状为 [N, 3]
min2, max2 = child_bboxes[:, :3], child_bboxes[:, 3:] # 形状为 [8, 3]
logger.debug(selected_bboxes.shape)
# 计算子包围盒与选中包围盒的交集
intersect_mask = torch.all(
(max1.unsqueeze(0) >= min2.unsqueeze(1)) & # 形状为 [8, N, 3]
(max2.unsqueeze(1) >= min1.unsqueeze(0)), # 形状为 [8, N, 3]
dim=-1
) # 最终形状为 [8, N]
# 更新结果掩码中选中的部分
result_mask[:, indices] = intersect_mask
# 如果提供了点云,进一步检查点是否在子包围盒内
if surf_points is not None:
# 提取选中的点云
selected_points = surf_points[indices] # 形状为 [N, M_max, 3]
# 将点云广播到子边界框的维度
points_expanded = selected_points.unsqueeze(1) # 形状为 [N, 1, M_max, 3]
min2_expanded = min2.unsqueeze(0).unsqueeze(2) # 形状为 [1, 8, 1, 3]
max2_expanded = max2.unsqueeze(0).unsqueeze(2) # 形状为 [1, 8, 1, 3]
# 判断点是否在子边界框内
point_in_box_mask = (
(points_expanded >= min2_expanded) & # 形状为 [N, 8, M_max, 3]
(points_expanded <= max2_expanded) # 形状为 [N, 8, M_max, 3]
).all(dim=-1) # 最终形状为 [N, 8, M_max]
# 检查每个子边界框是否包含至少一个点
points_in_boxes_mask = point_in_box_mask.any(dim=-1).permute(1, 0) # 形状为 [8, N]
# 合并交集条件和点云条件
result_mask[:, indices] = result_mask[:, indices] & points_in_boxes_mask
logger.debug(result_mask.shape)
return result_mask
class OctreeNode(nn.Module): class OctreeNode(nn.Module):
def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None,device=None): def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None,surf_ncs:np.ndarray = None,device=None):
super().__init__() super().__init__()
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 改为普通张量属性 # 改为普通张量属性
@ -39,52 +123,43 @@ class OctreeNode(nn.Module):
self.child_indices = None self.child_indices = None
self.is_leaf_mask = None self.is_leaf_mask = None
# 面片索引张量 # 面片索引张量
self.face_indices = torch.from_numpy(face_indices).to(self.device) self.all_face_indices = torch.from_numpy(face_indices).to(self.device)
self.surf_bbox = surf_bbox.to(self.device) if surf_bbox is not None else None self.surf_bbox = surf_bbox.to(self.device) if surf_bbox is not None else None
self.surf_ncs = process_surf_ncs_with_dynamic_padding(surf_ncs).to(self.device)
# PatchGraph作为普通属性 # PatchGraph作为普通属性
self.patch_graph = patch_graph.to(self.device) if patch_graph is not None else None self.patch_graph = patch_graph.to(self.device) if patch_graph is not None else None
self.max_depth = max_depth self.max_depth = max_depth
# 参数键改为普通张量
self.param_key = torch.tensor(-1, dtype=torch.long, device=self.device)
self._is_leaf = True self._is_leaf = True
# 删除所有register_buffer调用
@torch.jit.export
def set_param_key(self, k: int) -> None:
"""设置参数键值
参数:
k: 参数索引值
"""
self.param_key.fill_(k)
@torch.jit.export @torch.jit.export
def build_static_tree(self) -> None: def build_static_tree(self) -> None:
"""构建静态八叉树结构""" """构建静态八叉树结构"""
# 预计算所有可能的节点数量,确保结果为整数 # 预计算所有可能的节点数量,确保结果为整数
total_nodes = int(sum(8**i for i in range(self.max_depth + 1))) total_nodes = int(sum(8**i for i in range(self.max_depth + 1)))
num_faces = self.all_face_indices.shape[0]
# 初始化静态张量,使用整数列表作为形状参数 # 初始化静态张量,使用整数列表作为形状参数
self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.device) self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.device)
self.parent_indices = torch.full([int(total_nodes)], -1, dtype=torch.long, device=self.device) self.parent_indices = torch.full([int(total_nodes)], -1, dtype=torch.long, device=self.device)
self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.device) self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.device)
self.is_leaf_mask = torch.zeros([int(total_nodes)], dtype=torch.bool, device=self.device) self.face_indices_mask = torch.zeros([int(total_nodes),num_faces], dtype=torch.bool, device=self.device) # 1 代表有
self.is_leaf_mask = torch.ones([int(total_nodes)], dtype=torch.bool, device=self.device)
# 使用队列进行广度优先遍历 # 使用队列进行广度优先遍历
queue = [(0, self.bbox, self.face_indices)] # (node_idx, bbox, face_indices) queue = [(0, self.bbox, self.all_face_indices)] # (node_idx, bbox, face_indices)
current_idx = 0 current_idx = 0
while queue: while queue:
node_idx, bbox, faces = queue.pop(0) node_idx, bbox, faces = queue.pop(0)
#logger.debug(f"Processing node {node_idx} with {len(faces)} faces.")
self.node_bboxes[node_idx] = bbox self.node_bboxes[node_idx] = bbox
# 判断 要不要继续分裂 # 判断 要不要继续分裂
if not self._should_split_node(current_idx): if not self._should_split_node(current_idx, faces, total_nodes):
self.is_leaf_mask[node_idx] = True
continue continue
self.is_leaf_mask[node_idx] = 0
# 计算子节点边界框 # 计算子节点边界框
min_coords = bbox[:3] min_coords = bbox[:3]
max_coords = bbox[3:] max_coords = bbox[3:]
@ -92,36 +167,35 @@ class OctreeNode(nn.Module):
# 生成8个子节点 # 生成8个子节点
child_bboxes = self._generate_child_bboxes(min_coords, mid_coords, max_coords) child_bboxes = self._generate_child_bboxes(min_coords, mid_coords, max_coords)
intersect_mask = bbox_intersect(self.surf_bbox, faces, child_bboxes)
self.face_indices_mask[current_idx + 1:current_idx + 9, :] = intersect_mask
# 为每个子节点分配面片 # 为每个子节点分配面片
for i, child_bbox in enumerate(child_bboxes): for i, child_bbox in enumerate(child_bboxes):
child_idx = current_idx + 1 child_idx = child_idx = current_idx + i + 1
current_idx += 1
# 找到与子包围盒相交的面
intersecting_faces = []
for face_idx in faces:
face_bbox = self.surf_bbox[face_idx]
if bbox_intersect(child_bbox, face_bbox).item():
intersecting_faces.append(face_idx)
intersecting_faces = intersect_mask[i].nonzero().flatten()
#logger.debug(f"Node {child_idx} has {len(intersecting_faces)} intersecting faces.")
# 更新节点关系 # 更新节点关系
self.parent_indices[child_idx] = node_idx self.parent_indices[child_idx] = node_idx
self.child_indices[node_idx, i] = child_idx self.child_indices[node_idx, i] = child_idx
# 将子节点加入队列 # 将子节点加入队列
if intersecting_faces: if len(intersecting_faces) > 0:
queue.append((child_idx, child_bbox, torch.tensor(intersecting_faces, device=self.device))) queue.append((child_idx, child_bbox, intersecting_faces.clone().detach()))
current_idx += 8
def _should_split_node(self, current_depth: int) -> bool: def _should_split_node(self, current_idx: int,face_indices,max_node:int) -> bool:
"""判断节点是否需要分裂""" """判断节点是否需要分裂"""
# 检查是否达到最大深度 # 检查是否达到最大深度
if current_depth >= self.max_depth: if current_idx + 8 >= max_node:
return False return False
# 检查是否为完全图 # 检查是否为完全图
is_clique = self.patch_graph.is_clique(self.face_indices) #is_clique = self.patch_graph.is_clique(face_indices)
is_clique = face_indices.shape[0] < 2
if is_clique: if is_clique:
#logger.debug(f"Node {current_idx} is a clique. Stopping split.")
return False return False
return True return True
@ -153,9 +227,9 @@ class OctreeNode(nn.Module):
@torch.jit.export @torch.jit.export
def find_leaf(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, bool]: def find_leaf(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, bool]:
""" """
查找包含给定点的叶子节点并返回其信息 修改后的查找叶子节点方法返回face indices
:param query_points: 待查找的点形状为 (3,) :param query_points: 待查找的点形状为 (3,)
:return: 包含叶子节点信息的元组 (bbox, param_key, is_leaf) :return: (bbox, param_key, face_indices, is_leaf)
""" """
# 确保输入是单个点 # 确保输入是单个点
if query_points.dim() != 1 or query_points.shape[0] != 3: if query_points.dim() != 1 or query_points.shape[0] != 3:
@ -168,7 +242,28 @@ class OctreeNode(nn.Module):
while iteration < max_iterations: while iteration < max_iterations:
# 获取当前节点的叶子状态 # 获取当前节点的叶子状态
if self.is_leaf_mask[current_idx].item(): if self.is_leaf_mask[current_idx].item():
return self.node_bboxes[current_idx], self.param_key, True #logger.debug(f"Reached leaf node {current_idx}, with {self.face_indices_mask[current_idx].sum()} faces.")
if self.face_indices_mask[current_idx].sum() == 0:
parent_idx = self.parent_indices[current_idx]
#logger.debug(f"Use parent node {parent_idx}, with {self.face_indices_mask[parent_idx].sum()} faces.")
if parent_idx == -1:
# 根节点没有父节点,返回根节点的信息
#logger.warning(f"Reached root node {current_idx}, with {self.face_indices_mask[current_idx].sum()} faces.")
return (
self.node_bboxes[current_idx],
None, # 新增返回face indices
False
)
return (
self.node_bboxes[parent_idx],
self.face_indices_mask[parent_idx], # 新增返回face indices
False
)
return (
self.node_bboxes[current_idx],
self.face_indices_mask[current_idx], # 新增返回face indices
True
)
# 计算子节点索引 # 计算子节点索引
child_idx = self._get_child_indices(query_points.unsqueeze(0), child_idx = self._get_child_indices(query_points.unsqueeze(0),
@ -185,7 +280,7 @@ class OctreeNode(nn.Module):
iteration += 1 iteration += 1
# 如果达到最大迭代次数,返回当前节点的信息 # 如果达到最大迭代次数,返回当前节点的信息
return self.node_bboxes[current_idx], self.param_key, bool(self.is_leaf_mask[current_idx].item()) return self.node_bboxes[current_idx], None,bool(self.is_leaf_mask[current_idx].item())
@torch.jit.export @torch.jit.export
def _get_child_indices(self, points: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor: def _get_child_indices(self, points: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor:
@ -195,43 +290,70 @@ class OctreeNode(nn.Module):
def forward(self, query_points): def forward(self, query_points):
with torch.no_grad(): with torch.no_grad():
param_indices, bboxes = [], [] bboxes, face_indices_mask, csg_trees = [], [], []
for point in query_points: for point in query_points:
bbox, idx, _ = self.find_leaf(point) bbox, faces_mask, _ = self.find_leaf(point)
param_indices.append(idx)
bboxes.append(bbox) bboxes.append(bbox)
param_indices = torch.stack(param_indices) face_indices_mask.append(faces_mask)
bboxes = torch.stack(bboxes) # 获取当前节点的CSG树结构
# 添加检查代码 csg_tree = self.patch_graph.get_csg_tree(faces_mask) if self.patch_graph else None
return param_indices, bboxes csg_trees.append(csg_tree) # 保持原始列表结构
return (
def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None: torch.stack(bboxes),
torch.stack(face_indices_mask),
csg_trees # 直接返回列表,不转换为张量
)
def print_tree(self, max_print_depth: int = None) -> None:
""" """
递归打印八叉树结构 使用深度优先遍历 (DFS) 打印树结构父子关系通过缩进体现
参数: 参数:
depth: 当前深度 (内部使用) max_print_depth (int): 最大打印深度 (None 表示打印全部)
max_print_depth: 最大打印深度 (None表示打印全部)
""" """
if max_print_depth is not None and depth > max_print_depth: def dfs(node_idx: int, depth: int):
return """
深度优先遍历辅助函数
# 打印当前节点信息
indent = " " * depth 参数:
node_type = "Leaf" if self._is_leaf else "Internal" node_idx (int): 当前节点索引
print(f"{indent}L{depth} [{node_type}] BBox: {self.bbox.cpu().numpy().tolist()}") depth (int): 当前节点的深度
"""
# 打印面片信息(如果有) # 如果超过最大打印深度,跳过当前节点及其子节点
if self.face_indices is not None: if max_print_depth is not None and depth > max_print_depth:
print(f"{indent} Face indices: {self.face_indices.cpu().numpy().tolist()}") return
print(f"{indent} Child indices: {self.child_indices.cpu().numpy().tolist()}")
indent = " " * depth # 根据深度生成缩进
# 打印子节点信息 is_leaf = self.is_leaf_mask[node_idx].item() # 判断是否为叶子节点
if self.child_indices is not None: bbox = self.node_bboxes[node_idx].cpu().numpy().tolist() # 获取边界框信息
for i in range(8):
child_idx = self.child_indices[0, i].item() # 打印当前节点的基本信息
if child_idx != -1: node_type = "Leaf" if is_leaf else "Internal"
print(f"{indent} Child {i}: Node {child_idx}") log_lines.append(f"{indent}L{depth} [{node_type}] NODE_ID-{node_idx}, BBox: {bbox}")
if self.face_indices_mask is not None:
face_indices = self.face_indices_mask[node_idx].nonzero().cpu().numpy().flatten().tolist()
log_lines.append(f"{indent} Face Indices: {face_indices}")
# 如果是叶子节点,打印额外信息
if is_leaf:
child_indices = self.child_indices[node_idx].cpu().numpy().tolist()
log_lines.append(f"{indent} Child Indices: {child_indices}")
# 如果不是叶子节点,递归处理子节点
if not is_leaf:
for i in range(8): # 遍历所有子节点
child_idx = self.child_indices[node_idx, i].item()
if child_idx != -1: # 忽略无效的子节点索引
dfs(child_idx, depth + 1)
# 初始化日志行列表
log_lines = []
# 从根节点开始深度优先遍历
dfs(0, 0)
# 统一输出所有日志
logger.debug("\n".join(log_lines))
def __getstate__(self): def __getstate__(self):
"""支持pickle序列化""" """支持pickle序列化"""
@ -245,7 +367,6 @@ class OctreeNode(nn.Module):
'surf_bbox': self.surf_bbox, 'surf_bbox': self.surf_bbox,
'patch_graph': self.patch_graph, 'patch_graph': self.patch_graph,
'max_depth': self.max_depth, 'max_depth': self.max_depth,
'param_key': self.param_key,
'_is_leaf': self._is_leaf '_is_leaf': self._is_leaf
} }
return state return state
@ -261,7 +382,6 @@ class OctreeNode(nn.Module):
self.surf_bbox = state['surf_bbox'] self.surf_bbox = state['surf_bbox']
self.patch_graph = state['patch_graph'] self.patch_graph = state['patch_graph']
self.max_depth = state['max_depth'] self.max_depth = state['max_depth']
self.param_key = state['param_key']
self._is_leaf = state['_is_leaf'] self._is_leaf = state['_is_leaf']
def to(self, device=None, dtype=None, non_blocking=False): def to(self, device=None, dtype=None, non_blocking=False):

35
brep2sdf/networks/patch_graph.py

@ -2,6 +2,7 @@ from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from brep2sdf.utils.logger import logger
class PatchGraph(nn.Module): class PatchGraph(nn.Module):
def __init__(self, num_patches: int, device: torch.device = None): def __init__(self, num_patches: int, device: torch.device = None):
@ -56,13 +57,45 @@ class PatchGraph(nn.Module):
return subgraph_edges, subgraph_types return subgraph_edges, subgraph_types
def get_csg_tree(self, node_faces_mask: torch.Tensor):
"""生成CSG组合树结构
参数:
node_faces: 要处理的面片索引集合形状为 (N,)
返回:
嵌套列表结构表示CSG组合层次
示例:
[[0, [1,2]], 3] 表示0与(1和2的组合)进行凹组合然后与3进行凸组合
"""
print("node_faces_mask:", node_faces_mask)
if self.edge_index is None:
return []
node_faces = node_faces_mask.nonzero()
node_faces = node_faces.flatten().to('cpu').numpy()
logger.debug(f"node_faces: {node_faces}")
node_set = set(node_faces) # 创建输入面片的集合用于快速查找
visited = set()
csg_tree = []
# 优先处理凹边连接
concave_edges = self.edge_index[:, self.edge_type == 0].cpu().numpy().T
for u, v in concave_edges:
u, v = int(u), int(v)
if u in node_set and v in node_set and u not in visited and v not in visited:
csg_tree.append([u, v])
visited.update({u, v})
# 处理剩余面片(只包含输入的面片)
remaining = [int(f) for f in node_faces if f not in visited]
csg_tree.extend(remaining)
return csg_tree
def is_clique(self, node_faces: torch.Tensor) -> bool: def is_clique(self, node_faces: torch.Tensor) -> bool:
"""检查给定面片集合是否构成完全图 """检查给定面片集合是否构成完全图
参数: 参数:
node_faces: 要检查的面片索引集合 node_faces: 要检查的面片索引集合
face [0,1,2,3,4,]
返回: 返回:
bool: 是否为完全图 bool: 是否为完全图

69
brep2sdf/test.py

@ -1,4 +1,69 @@
import torch import torch
model = torch.jit.load("/home/wch/brep2sdf/data/output_data/00000054.pt") from typing import List, Tuple
print(model)
def bbox_intersect(surf_bboxes: torch.Tensor, indices: torch.Tensor, child_bboxes: torch.Tensor) -> torch.Tensor:
'''
args:
surf_bboxes: [B, 6] - 表示多个包围盒的张量每个包围盒由其最小和最大坐标定义
indices: 选择bbox, [N], N <= B - 用于选择特定包围盒的索引张量
child_bboxes: [8, 6] - 一个包围盒被分成八个子包围盒后的结果
return:
intersect_mask: [8, N] - 布尔掩码表示每个子包围盒与选择的包围盒是否相交
'''
# 提取选中的边界框
selected_bboxes = surf_bboxes[indices] # 形状为 [N, 6]
min1, max1 = selected_bboxes[:, :3], selected_bboxes[:, 3:] # 形状为 [N, 3]
min2, max2 = child_bboxes[:, :3], child_bboxes[:, 3:] # 形状为 [8, 3]
# 确保广播机制正常工作
intersect_mask = torch.all(
(max1.unsqueeze(0) >= min2.unsqueeze(1)) & # 形状为 [8, N, 3]
(max2.unsqueeze(1) >= min1.unsqueeze(0)), # 形状为 [8, N, 3]
dim=-1
) # 最终形状为 [8, N]
return intersect_mask
# 测试程序
if __name__ == "__main__":
# 构造输入数据
surf_bboxes = torch.tensor([
[0, 0, 0, 1, 1, 1], # 立方体 1
[0.5, 0.5, 0.5, 1.5, 1.5, 1.5], # 立方体 2
[2, 2, 2, 3, 3, 3] # 立方体 3
]) # [B=3, 6]
indices = torch.tensor([0, 1]) # 选择前两个立方体
# 假设父边界框为 [0, 0, 0, 2, 2, 2],生成其八个子边界框
parent_bbox = torch.tensor([0, 0, 0, 2, 2, 2])
center = (parent_bbox[:3] + parent_bbox[3:]) / 2
child_bboxes = torch.tensor([
[parent_bbox[0], parent_bbox[1], parent_bbox[2], center[0], center[1], center[2]], # 左下前
[center[0], parent_bbox[1], parent_bbox[2], parent_bbox[3], center[1], center[2]], # 右下前
[parent_bbox[0], center[1], parent_bbox[2], center[0], parent_bbox[4], center[2]], # 左上前
[center[0], center[1], parent_bbox[2], parent_bbox[3], parent_bbox[4], center[2]], # 右上前
[parent_bbox[0], parent_bbox[1], center[2], center[0], center[1], parent_bbox[5]], # 左下后
[center[0], parent_bbox[1], center[2], parent_bbox[3], center[1], parent_bbox[5]], # 右下后
[parent_bbox[0], center[1], center[2], center[0], parent_bbox[4], parent_bbox[5]], # 左上后
[center[0], center[1], center[2], parent_bbox[3], parent_bbox[4], parent_bbox[5]] # 右上后
]) # [8, 6]
# 调用函数
intersect_mask = bbox_intersect(surf_bboxes, indices, child_bboxes)
# 输出结果
print("Intersect Mask:")
print(intersect_mask)
# 将布尔掩码转换为索引列表
child_indices = []
for i in range(8): # 遍历每个子节点
intersecting_faces = indices[intersect_mask[i]] # 获取当前子节点的相交面片索引
child_indices.append(intersecting_faces)
# 打印每个子节点对应的相交索引
print("\nChild Indices:")
for i, indices in enumerate(child_indices):
print(f"Child {i}: {indices}")

209
brep2sdf/train.py

@ -116,11 +116,12 @@ class Trainer:
device=self.device device=self.device
) )
self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=8) self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=6)
logger.gpu_memory_stats("数初始化后") logger.gpu_memory_stats("数初始化后")
self.model = Net( self.model = Net(
octree=self.root, octree=self.root,
volume_bboxs=surf_bbox,
feature_dim=64 feature_dim=64
).to(self.device) ).to(self.device)
logger.gpu_memory_stats("模型初始化后") logger.gpu_memory_stats("模型初始化后")
@ -138,7 +139,7 @@ class Trainer:
logger.info(f"初始化完成,正在处理模型 {self.model_name}") logger.info(f"初始化完成,正在处理模型 {self.model_name}")
def build_tree(self,surf_bbox, graph, max_depth=6): def build_tree(self,surf_bbox, graph, max_depth=9):
num_faces = surf_bbox.shape[0] num_faces = surf_bbox.shape[0]
bbox = self._calculate_global_bbox(surf_bbox) bbox = self._calculate_global_bbox(surf_bbox)
self.root = OctreeNode( self.root = OctreeNode(
@ -147,13 +148,13 @@ class Trainer:
patch_graph=graph, patch_graph=graph,
max_depth=max_depth, max_depth=max_depth,
surf_bbox=surf_bbox, surf_bbox=surf_bbox,
surf_ncs=self.data['surf_ncs']
) )
#print(surf_bbox) #print(surf_bbox)
logger.info("starting octree conduction") logger.info("starting octree conduction")
self.root.build_static_tree() self.root.build_static_tree()
logger.info("complete octree conduction") logger.info("complete octree conduction")
#self.root.print_tree(0) self.root.print_tree()
def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor: def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor:
""" """
@ -190,10 +191,6 @@ class Trainer:
def train_epoch(self, epoch: int) -> float: def train_epoch(self, epoch: int) -> float:
self.model.train()
total_loss = 0.0
step = 0 # 如果你的训练是分批次的,这里应该用批次索引
# --- 1. 检查输入数据 --- # --- 1. 检查输入数据 ---
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) # 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列)
# 并且 SDF 值总是在最后一列 # 并且 SDF 值总是在最后一列
@ -201,114 +198,124 @@ class Trainer:
logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.") logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.")
return float('inf') return float('inf')
points = self.sdf_data[:, 0:3].clone().detach() # 取前3列作为点 self.model.train()
gt_sdf = self.sdf_data[:, -1].clone().detach() # 取最后一列作为SDF真值 total_loss = 0.0
normals = None step = 0 # 如果你的训练是分批次的,这里应该用批次索引
if args.use_normal: batch_size = 10240 # 设置合适的batch大小
if self.sdf_data.shape[1] < 7: # 检查是否有足够的列给法线
logger.error(f"Epoch {epoch}: --use-normal is specified, but sdf_data has only {self.sdf_data.shape[1]} columns.") # 将数据分成多个batch
return float('inf') num_points = self.sdf_data.shape[0]
normals = self.sdf_data[:, 3:6].clone().detach() # 取中间3列作为法线 num_batches = (num_points + batch_size - 1) // batch_size
# 执行检查 for batch_idx in range(num_batches):
if self.debug_mode: start_idx = batch_idx * batch_size
if check_tensor(points, "Input Points", epoch, step): return float('inf') end_idx = min((batch_idx + 1) * batch_size, num_points)
if check_tensor(gt_sdf, "Input GT SDF", epoch, step): return float('inf') points = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点
gt_sdf = self.sdf_data[start_idx:end_idx, -1].clone().detach() # 取最后一列作为SDF真值
normals = None
if args.use_normal: if args.use_normal:
# 只有在请求法线时才检查 normals if self.sdf_data.shape[1] < 7: # 检查是否有足够的列给法线
if check_tensor(normals, "Input Normals", epoch, step): return float('inf') logger.error(f"Epoch {epoch}: --use-normal is specified, but sdf_data has only {self.sdf_data.shape[1]} columns.")
return float('inf')
normals = self.sdf_data[start_idx:end_idx, 3:6].clone().detach() # 取中间3列作为法线
# --- 准备模型输入,启用梯度 --- # 执行检查
points.requires_grad_(True) # 在检查之后启用梯度 if self.debug_mode:
if check_tensor(points, "Input Points", epoch, step): return float('inf')
if check_tensor(gt_sdf, "Input GT SDF", epoch, step): return float('inf')
if args.use_normal:
# 只有在请求法线时才检查 normals
if check_tensor(normals, "Input Normals", epoch, step): return float('inf')
# --- 前向传播 ---
self.optimizer.zero_grad()
pred_sdf = self.model(points)
if self.debug_mode: # --- 准备模型输入,启用梯度 ---
# --- 检查前向传播的输出 --- points.requires_grad_(True) # 在检查之后启用梯度
logger.gpu_memory_stats("前向传播后")
# --- 2. 检查模型输出 ---
#if check_tensor(pred_sdf, "Predicted SDF (Model Output)", epoch, step): return float('inf')
# --- 计算损失 --- # --- 前向传播 ---
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败 self.optimizer.zero_grad()
loss_details = {} pred_sdf = self.model(points)
try:
# --- 3. 检查损失计算前的输入 ---
# (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf)
#if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss")
#if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss")
if args.use_normal:
# 检查法线和带梯度的点
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss")
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss")
logger.gpu_memory_stats("计算损失前")
loss, loss_details = self.loss_manager.compute_loss(
points,
normals, # 传递检查过的 normals
gt_sdf,
pred_sdf
)
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
# --- 4. 检查损失计算结果 ---
if self.debug_mode: if self.debug_mode:
if check_tensor(loss, "Calculated Loss", epoch, step): # --- 检查前向传播的输出 ---
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") logger.gpu_memory_stats("前向传播后")
if loss_details: logger.error(f"Loss Details: {loss_details}") # --- 2. 检查模型输出 ---
return float('inf') # 如果损失无效,停止这个epoch #if check_tensor(pred_sdf, "Predicted SDF (Model Output)", epoch, step): return float('inf')
except Exception as loss_e: # --- 计算损失 ---
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True) loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败
return float('inf') # 如果计算出错,停止这个epoch loss_details = {}
logger.gpu_memory_stats("损失计算后") try:
# --- 3. 检查损失计算前的输入 ---
# (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf)
#if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss")
#if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss")
if args.use_normal:
# 检查法线和带梯度的点
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss")
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss")
logger.gpu_memory_stats("计算损失前")
loss, loss_details = self.loss_manager.compute_loss(
points,
normals, # 传递检查过的 normals
gt_sdf,
pred_sdf
)
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
# --- 4. 检查损失计算结果 ---
if self.debug_mode:
if check_tensor(loss, "Calculated Loss", epoch, step):
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.")
if loss_details: logger.error(f"Loss Details: {loss_details}")
return float('inf') # 如果损失无效,停止这个epoch
except Exception as loss_e:
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True)
return float('inf') # 如果计算出错,停止这个epoch
logger.gpu_memory_stats("损失计算后")
# --- 反向传播和优化 ---
try:
loss.backward()
# --- 5. (可选) 检查梯度 ---
# for name, param in self.model.named_parameters():
# if param.grad is not None:
# if check_tensor(param.grad, f"Gradient/{name}", epoch, step):
# logger.warning(f"Epoch {epoch} Step {step}: Bad gradient for {name}. Consider clipping or zeroing.")
# # 例如:param.grad.data.clamp_(-1, 1) # 值裁剪
# # 或在 optimizer.step() 前进行范数裁剪:
# # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# --- (推荐) 添加梯度裁剪 ---
# 防止梯度爆炸,这可能是导致 inf/nan 的原因之一
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 范数裁剪
self.optimizer.step()
except Exception as backward_e:
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True)
# 如果你想看是哪个操作导致的,可以启用 anomaly detection
# torch.autograd.set_detect_anomaly(True) # 放在训练开始前
return float('inf') # 如果反向传播或优化出错,停止这个epoch
# --- 记录和累加损失 ---
current_loss = loss.item()
if not np.isfinite(current_loss): # 再次确认损失是有效的数值
logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).")
return float('inf')
# --- 反向传播和优化 --- total_loss += current_loss
try: del loss
loss.backward() torch.cuda.empty_cache()
# --- 5. (可选) 检查梯度 ---
# for name, param in self.model.named_parameters():
# if param.grad is not None:
# if check_tensor(param.grad, f"Gradient/{name}", epoch, step):
# logger.warning(f"Epoch {epoch} Step {step}: Bad gradient for {name}. Consider clipping or zeroing.")
# # 例如:param.grad.data.clamp_(-1, 1) # 值裁剪
# # 或在 optimizer.step() 前进行范数裁剪:
# # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# --- (推荐) 添加梯度裁剪 ---
# 防止梯度爆炸,这可能是导致 inf/nan 的原因之一
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 范数裁剪
self.optimizer.step()
except Exception as backward_e:
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True)
# 如果你想看是哪个操作导致的,可以启用 anomaly detection
# torch.autograd.set_detect_anomaly(True) # 放在训练开始前
return float('inf') # 如果反向传播或优化出错,停止这个epoch
# --- 记录和累加损失 ---
current_loss = loss.item()
if not np.isfinite(current_loss): # 再次确认损失是有效的数值
logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).")
return float('inf')
total_loss += current_loss
# 记录训练进度 (只记录有效的损失) # 记录训练进度 (只记录有效的损失)
logger.info(f'Train Epoch: {epoch:4d}]\t' logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {current_loss:.6f}') f'Loss: {current_loss:.6f}')
if loss_details: logger.info(f"Loss Details: {loss_details}") if loss_details: logger.info(f"Loss Details: {loss_details}")
# (如果你的训练分批次,这里应该继续循环下一批次)
# step += 1
del loss
torch.cuda.empty_cache() # 清空缓存
return total_loss # 对于单批次训练,直接返回当前损失 return total_loss # 对于单批次训练,直接返回当前损失
def validate(self, epoch: int) -> float: def validate(self, epoch: int) -> float:

2
brep2sdf/utils/logger.py

@ -227,7 +227,7 @@ class BRepLogger:
stats.append(f" 峰值: {max_allocated:.1f} MB") stats.append(f" 峰值: {max_allocated:.1f} MB")
# 一次性输出所有统计信息 # 一次性输出所有统计信息
self.info("\n".join(stats)) self.debug("\n".join(stats))
# 获取每个张量的内存使用情况 # 获取每个张量的内存使用情况
if include_trace: if include_trace:

Loading…
Cancel
Save