Browse Source

优化了八叉树并行,确保一个叶节点不超过两个面

final
mckay 12 months ago
parent
commit
fa17441396
  1. 4
      brep2sdf/config/default_config.py
  2. 20
      brep2sdf/data/utils.py
  3. 245
      brep2sdf/networks/decoder.py
  4. 115
      brep2sdf/networks/encoder.py
  5. 96
      brep2sdf/networks/feature_volume.py
  6. 29
      brep2sdf/networks/network.py
  7. 262
      brep2sdf/networks/octree.py
  8. 35
      brep2sdf/networks/patch_graph.py
  9. 69
      brep2sdf/test.py
  10. 213
      brep2sdf/train.py
  11. 2
      brep2sdf/utils/logger.py

4
brep2sdf/config/default_config.py

@ -48,7 +48,7 @@ class TrainConfig:
# 基本训练参数
batch_size: int = 8
num_workers: int = 4
num_epochs: int = 100
num_epochs: int = 1000
learning_rate: float = 0.001
min_lr: float = 1e-5
weight_decay: float = 0.01
@ -90,7 +90,7 @@ class LogConfig:
# 本地日志
log_dir: str = '/home/wch/brep2sdf/logs' # 日志保存目录
log_level: str = 'INFO' # 日志级别
console_level: str = 'DEBUG' # 控制台日志级别
console_level: str = 'INFO' # 控制台日志级别
file_level: str = 'DEBUG' # 文件日志级别
@dataclass

20
brep2sdf/data/utils.py

@ -10,6 +10,9 @@ from OCC.Core.TopoDS import topods, TopoDS_Vertex # 拓扑数据结构
import numpy as np
from scipy.spatial import cKDTree
from torch.nn.utils.rnn import pad_sequence
import torch
from brep2sdf.utils.logger import logger
def load_step(step_path):
@ -35,7 +38,22 @@ def get_bbox(shape, subshape):
xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get()
return np.array([xmin, ymin, zmin, xmax, ymax, zmax])
def process_surf_ncs_with_dynamic_padding(surf_ncs: np.ndarray) -> torch.Tensor:
"""
使用 pad_sequence 动态填充 surf_ncs
参数:
surf_ncs: 形状为 (N,) np.ndarray(dtype=object)每个元素是形状为 (M, 3) float32 数组
返回:
padded_tensor: 形状为 (N, M_max, 3) 的张量其中 M_max 是最长子数组的长度
"""
# 转换为张量列表
tensor_list = [torch.tensor(arr, dtype=torch.float32) for arr in surf_ncs]
# 动态填充
padded_tensor = pad_sequence(tensor_list, batch_first=True, padding_value=float('inf'))
return padded_tensor
def normalize(surfs, edges, corners):

245
brep2sdf/networks/decoder.py

@ -1,42 +1,219 @@
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import torch
import numpy as np
from typing import Tuple, List, Union
from brep2sdf.utils.logger import logger
class Sine(nn.Module):
def __init(self):
super().__init__()
def forward(self, input):
# See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
return torch.sin(30 * input)
class Decoder(nn.Module):
def __init__(self,
input_dim: int,
output_dim: int,
hidden_dim: int = 256) :
"""
最简单的Decoder实现
参数:
input_dim: 输入维度
output_dim: 输出维度
hidden_dim: 隐藏层维度 (默认: 256)
"""
def __init__(
self,
d_in: int,
dims_sdf: List[int],
skip_in: Tuple[int, ...] = (),
flag_convex: bool = True,
geometric_init: bool = True,
radius_init: float = 1,
beta: float = 100,
) -> None:
super().__init__()
# 三层全连接网络
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, output_dim)
self.flag_convex = flag_convex
self.skip_in = skip_in
dims_sdf = [d_in] + dims_sdf + [1] #[self.n_branch]
self.sdf_layers = len(dims_sdf)
for layer in range(0, len(dims_sdf) - 1):
if layer + 1 in skip_in:
out_dim = dims_sdf[layer + 1] - d_in
else:
out_dim = dims_sdf[layer + 1]
lin = nn.Linear(dims_sdf[layer], out_dim)
if geometric_init:
if layer == self.sdf_layers - 2:
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims_sdf[layer]), std=0.00001)
torch.nn.init.constant_(lin.bias, -radius_init)
else:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
setattr(self, "sdf_"+str(layer), lin)
if geometric_init:
if beta > 0:
self.activation = nn.Softplus(beta=beta)
# vanilla relu
else:
self.activation = nn.ReLU()
else:
#siren
self.activation = Sine()
self.final_activation = nn.ReLU()
# composite f_i to h
def forward(self, x: Tensor) -> Tensor:
"""
前向传播
def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor:
'''
:param feature_matrix: 形状为 (B, P, D) 的特征矩阵
B: 批大小
P: patch volume数量
D: 特征维度
:return:
f_i: 各patch的SDF值 (B, P)
'''
B, P, D = feature_matrix.shape
# 展平处理 (B*P, D)
x = feature_matrix.view(-1, D)
for layer in range(0, self.sdf_layers - 1):
lin = getattr(self, "sdf_" + str(layer))
if layer in self.skip_in:
x = torch.cat([x, x], -1) / np.sqrt(2) # Fix undefined 'input'
x = lin(x)
if layer < self.sdf_layers - 2:
x = self.activation(x)
output_value = x #all f_i
# 恢复维度 (B, P)
f_i = output_value.view(B, P)
return f_i
# 一个基础情形: 输入 fi 形状[P] 和 csg tree,凹凸组合输出h
#注意考虑如何批量处理 (B, P) 和 [csg tree]
class CSGCombiner:
def __init__(self, flag_convex: bool = True, rho: float = 0.05):
self.flag_convex = flag_convex
self.rho = rho
def forward(self, f_i: torch.Tensor, csg_tree) -> torch.Tensor:
'''
:param f_i: 形状为 (B, P) 的各patch SDF值
:param csg_tree: CSG树结构
:return: 组合后的整体SDF (B,)
'''
logger.info("\n".join(f"{i}个csg: {t}" for i,t in enumerate(csg_tree)))
B = f_i.shape[0]
results = []
for i in range(B):
# 处理每个样本的CSG组合
h = self.nested_cvx_output_soft_blend(
f_i[i].unsqueeze(0),
csg_tree,
self.flag_convex
)
results.append(h)
return torch.cat(results, dim=0).squeeze(1) # 从(B,1)变为(B,)
def nested_cvx_output_soft_blend(
self,
value_matrix: torch.Tensor,
list_operation: List[Union[int, List]],
cvx_flag: bool = True
) -> torch.Tensor:
list_value = []
for v in list_operation:
if not isinstance(v, list):
list_value.append(v)
op_mat = torch.zeros(value_matrix.shape[1], len(list_value),
device=value_matrix.device)
for i in range(len(list_value)):
op_mat[list_value[i]][i] = 1.0
mat_mul = torch.matmul(value_matrix, op_mat)
if len(list_operation) == len(list_value):
return self.max_soft_blend(mat_mul, self.rho) if cvx_flag \
else self.min_soft_blend(mat_mul, self.rho)
参数:
x: 输入张量
返回:
输出张量
"""
# 第一层
h = F.relu(self.fc1(x))
# 第二层
h = F.relu(self.fc2(h))
# 输出层
out = self.fc3(h)
list_output = [mat_mul]
for v in list_operation:
if isinstance(v, list):
list_output.append(
self.nested_cvx_output_soft_blend(
value_matrix, v, not cvx_flag
)
)
return out
return self.max_soft_blend(torch.cat(list_output, 1), self.rho) if cvx_flag \
else self.min_soft_blend(torch.cat(list_output, 1), self.rho)
def min_soft_blend(self, mat, rho):
res = mat[:,0]
for i in range(1, mat.shape[1]):
srho = res * res + mat[:,i] * mat[:,i] - rho * rho
res = res + mat[:,i] - torch.sqrt(res * res + mat[:,i] * mat[:,i] + 1.0/(8 * rho * rho) * srho * (srho - srho.abs()))
return res.unsqueeze(1)
def max_soft_blend(self, mat, rho):
res = mat[:,0]
for i in range(1, mat.shape[1]):
srho = res * res + mat[:,i] * mat[:,i] - rho * rho
res = res + mat[:,i] + torch.sqrt(res * res + mat[:,i] * mat[:,i] + 1.0/(8 * rho * rho) * srho * (srho - srho.abs()))
return res.unsqueeze(1)
def test_csg_combiner():
# 测试数据 (B=3, P=5)
f_i = torch.tensor([
[1.0, 2.0, 3.0, 4.0, 5.0],
[0.5, 1.5, 2.5, 3.5, 4.5],
[-1.0, 0.0, 1.0, 2.0, 3.0]
])
# 每个样本使用不同的CSG树结构
csg_trees = [
[0, [1, 2]], # 使用索引0,1,2
[[0, 1], 3], # 使用索引0,1,3
[0, 1, [2, 4]] # 使用索引0,1,2,4
]
# 验证所有索引都有效
P = f_i.shape[1]
for i, tree in enumerate(csg_trees):
def check_indices(node):
if isinstance(node, list):
for n in node:
check_indices(n)
else:
assert node < P, f"样本{i}的树包含无效索引{node},P={P}"
check_indices(tree)
print("Input SDF values:")
print(f_i)
print("\nCSG Trees:")
for i, tree in enumerate(csg_trees):
print(f"Sample {i}: {tree}")
# 测试凸组合
print("\nTesting convex combination:")
combiner_convex = CSGCombiner(flag_convex=True)
h_convex = combiner_convex.forward(f_i, csg_trees)
print("Results:", h_convex)
# 测试凹组合
print("\nTesting concave combination:")
combiner_concave = CSGCombiner(flag_convex=False)
h_concave = combiner_concave.forward(f_i, csg_trees)
print("Results:", h_concave)
# 测试不同rho值的软混合
print("\nTesting soft blends:")
for rho in [0.01, 0.1, 0.5]:
combiner_soft = CSGCombiner(flag_convex=True, rho=rho)
h_soft = combiner_soft.forward(f_i, csg_trees)
print(f"rho={rho}:", h_soft)
if __name__ == "__main__":
test_csg_combiner()

115
brep2sdf/networks/encoder.py

@ -2,72 +2,89 @@ import torch
import torch.nn as nn
from .octree import OctreeNode
from .feature_volume import PatchFeatureVolume
from brep2sdf.utils.logger import logger
class Encoder(nn.Module):
def __init__(self, octree: OctreeNode, feature_dim: int = 32):
def __init__(self, volume_bboxs:torch.tensor, feature_dim: int = 32):
"""
分离后的编码器接收预构建的八叉树
参数:
octree: 预构建的八叉树结构
volume_bboxs: 所有面片的边界框集合形状为 (N, 2, 3)
feature_dim: 特征维度
"""
super().__init__()
self.feature_dim = feature_dim
# 初始化叶子节点参数
self.register_buffer('num_parameters', torch.tensor(0, dtype=torch.long))
self._leaf_features = None # 将在_init_parameters中初始化
self._init_parameters(octree)
def _init_parameters(self,octree):
stack = [(octree, 0)]
param_count = 0
while stack:
node, _ = stack.pop()
if node._is_leaf:
param_count += 1
else:
for child in node.child_nodes:
if child: stack.append((child, 0))
# 批量计算所有bbox的分辨率
resolutions = self._batch_calculate_resolution(volume_bboxs)
# 初始化多个特征体积
self.feature_volumes = nn.ModuleList([
PatchFeatureVolume(
bbox=bbox,
resolution=int(resolutions[i]),
feature_dim=feature_dim
) for i, bbox in enumerate(volume_bboxs)
])
print(f"Initialized {len(self.feature_volumes)} feature volumes with resolutions: {resolutions.tolist()}")
print(f"Model parameters memory usage: {sum(p.numel() * p.element_size() for p in self.parameters()) / 1024**2:.2f} MB")
# 初始化连续参数张量
self._leaf_features = nn.Parameter(
torch.randn(param_count, 8, self.feature_dim))
def _batch_calculate_resolution(self, bboxes: torch.Tensor) -> torch.Tensor:
"""
批量计算归一化bboxes的分辨率
# 重新遍历设置索引
stack = [(octree, 0)]
index = 0
while stack:
node, _ = stack.pop()
if node._is_leaf:
node.set_param_key(index)
index += 1
else:
for child in node.child_nodes:
if child: stack.append((child, 0))
self.num_parameters.fill_(index)
参数:
bboxes: 归一化边界框张量形状为 (N, 2, 3)
返回:
分辨率张量 (N,)
"""
with torch.no_grad():
# 计算每个bbox的对角线长度(归一化后范围约为0.0-1.732)
diagonals = torch.norm(bboxes[:,3:6] - bboxes[:,0:3], dim=1)
# 根据归一化后的对角线长度调整分辨率
resolutions = torch.zeros_like(diagonals, dtype=torch.long)
resolutions[diagonals > 1.0] = 16 # 大尺寸
resolutions[(diagonals > 0.5) & (diagonals <= 1.0)] = 8 # 中等尺寸
resolutions[diagonals <= 0.5] = 4 # 小尺寸
return resolutions
def forward(self, query_points: torch.Tensor,param_indices,bboxes) -> torch.Tensor:
batch_size = query_points.shape[0]
# 批量获取特征
unique_ids, inverse_ids = torch.unique(param_indices, return_inverse=True)
all_features = self._leaf_features[unique_ids] # (U, 8, D)
node_features = all_features[inverse_ids] # (B, 8, D)
# 启用混合精度和优化后的插值
with torch.cuda.amp.autocast():
features = self._optimized_trilinear(
query_points,
bboxes.detach(),
node_features
)
def forward(self, query_points: torch.Tensor, volume_indices: torch.Tensor) -> torch.Tensor:
"""
修改后的前向传播返回所有关联volume的特征矩阵
参数:
query_points: 查询点坐标 (B, 3)
volume_indices: 关联的volume索引矩阵 (B, K)
# 添加类型转换确保输出为float32
return features.to(torch.float32) # 添加这行
返回:
特征张量 (B, K, D)
"""
batch_size, num_volumes = volume_indices.shape
all_features = torch.zeros(batch_size, num_volumes, self.feature_dim,
device=query_points.device)
# 遍历每个volume索引
for k in range(num_volumes):
# 获取当前volume的索引 (B,)
current_indices = volume_indices[:, k]
# 遍历所有存在的volume
for vol_id in range(len(self.feature_volumes)):
# 创建掩码 (B,)
mask = (current_indices == vol_id)
if mask.any():
# 获取对应volume的特征 (M, D)
features = self.feature_volumes[vol_id](query_points[mask])
all_features[mask, k] = features
return all_features
def _optimized_trilinear(self, points, bboxes, features):
"""优化后的向量化三线性插值"""

96
brep2sdf/networks/feature_volume.py

@ -4,50 +4,72 @@ import torch
import torch.nn as nn
class PatchFeatureVolume(nn.Module):
def __init__(self, bbox:np, resolution=64, feature_dim=64):
def __init__(self, bbox: torch.Tensor, resolution=64, feature_dim=64, padding_ratio=0.05):
super(PatchFeatureVolume, self).__init__()
self.bbox = bbox # 补丁的边界框
self.resolution = resolution # 网格分辨率
self.feature_dim = feature_dim # 特征向量维度
# 将输入bbox转换为[min, max]格式
self.resolution = resolution
min_coords = bbox[:3]
max_coords = bbox[3:]
self.original_bbox = torch.stack([min_coords, max_coords])
expanded_bbox = self._expand_bbox(min_coords, max_coords, padding_ratio)
# 创建规则的三维网格
x = torch.linspace(bbox[0][0], bbox[1][0], resolution)
y = torch.linspace(bbox[0][1], bbox[1][1], resolution)
z = torch.linspace(bbox[0][2], bbox[1][2], resolution)
x = torch.linspace(expanded_bbox[0][0], expanded_bbox[1][0], resolution)
y = torch.linspace(expanded_bbox[0][1], expanded_bbox[1][1], resolution)
z = torch.linspace(expanded_bbox[0][2], expanded_bbox[1][2], resolution)
grid_x, grid_y, grid_z = torch.meshgrid(x, y, z)
self.register_buffer('grid', torch.stack([grid_x, grid_y, grid_z], dim=-1))
# 初始化特征向量,作为可训练参数
# 初始化特征向量
self.feature_volume = nn.Parameter(torch.randn(resolution, resolution, resolution, feature_dim))
def forward(self, query_points: List[Tuple[float, float, float]]):
def _expand_bbox(self, min_coords, max_coords, ratio):
# 扩展包围盒范围
center = (min_coords + max_coords) / 2
expanded_min = center - (center - min_coords) * (1 + ratio)
expanded_max = center + (max_coords - center) * (1 + ratio)
return torch.stack([expanded_min, expanded_max])
def forward(self, query_points: torch.Tensor) -> torch.Tensor:
"""批量处理版本的三线性插值
Args:
query_points: 形状为 (B, 3) 的查询点坐标
Returns:
形状为 (B, D) 的特征向量
"""
根据查询点的位置从补丁特征体积中获取插值后的特征向量
# 添加类型转换确保计算稳定性
normalized = ((query_points - self.grid[0,0,0]) /
(self.grid[-1,-1,-1] - self.grid[0,0,0] + 1e-8)) # (B,3)
:param query_points: 查询点的位置坐标形状为 (N, 3)
:return: 插值后的特征向量形状为 (N, feature_dim)
"""
interpolated_features = torch.zeros(query_points.shape[0], self.feature_dim).to(self.feature_volume.device)
for i, point in enumerate(query_points):
interpolated_feature = self.trilinear_interpolation(point)
interpolated_features[i] = interpolated_feature
return interpolated_features
def trilinear_interpolation(self, query_point):
"""三线性插值"""
normalized_coords = ((query_point - torch.tensor(self.bbox[0]).to(self.grid.device)) /
(torch.tensor(self.bbox[1]).to(self.grid.device) - torch.tensor(self.bbox[0]).to(self.grid.device))) * (self.resolution - 1)
indices = torch.floor(normalized_coords).long()
weights = normalized_coords - indices.float()
# 向量化三线性插值
return self._batched_trilinear(normalized)
interpolated_feature = torch.zeros(self.feature_dim).to(self.feature_volume.device)
for di in range(2):
for dj in range(2):
for dk in range(2):
weight = (weights[0] if di == 1 else 1 - weights[0]) * \
(weights[1] if dj == 1 else 1 - weights[1]) * \
(weights[2] if dk == 1 else 1 - weights[2])
index = indices + torch.tensor([di, dj, dk]).to(indices.device)
index = torch.clamp(index, 0, self.resolution - 1)
interpolated_feature += weight * self.feature_volume[index[0], index[1], index[2]]
return interpolated_feature
def _batched_trilinear(self, normalized: torch.Tensor) -> torch.Tensor:
"""批量处理的三线性插值"""
# 计算8个顶点的权重
uvw = normalized * (self.resolution - 1)
indices = torch.floor(uvw).long() # (B,3)
weights = uvw - indices.float() # (B,3)
# 计算8个顶点的权重组合 (B,8)
weights = torch.stack([
(1 - weights[...,0]) * (1 - weights[...,1]) * (1 - weights[...,2]),
(1 - weights[...,0]) * (1 - weights[...,1]) * weights[...,2],
(1 - weights[...,0]) * weights[...,1] * (1 - weights[...,2]),
(1 - weights[...,0]) * weights[...,1] * weights[...,2],
weights[...,0] * (1 - weights[...,1]) * (1 - weights[...,2]),
weights[...,0] * (1 - weights[...,1]) * weights[...,2],
weights[...,0] * weights[...,1] * (1 - weights[...,2]),
weights[...,0] * weights[...,1] * weights[...,2],
], dim=-1) # (B,8)
# 获取8个顶点的特征 (B,8,D)
indices = indices.unsqueeze(1).expand(-1,8,-1) + torch.tensor([
[0,0,0], [0,0,1], [0,1,0], [0,1,1],
[1,0,0], [1,0,1], [1,1,0], [1,1,1]
], device=indices.device)
indices = torch.clamp(indices, 0, self.resolution-1)
features = self.feature_volume[indices[...,0], indices[...,1], indices[...,2]] # (B,8,D)
# 加权求和 (B,D)
return torch.einsum('bnd,bn->bd', features, weights)

29
brep2sdf/networks/network.py

@ -49,11 +49,13 @@ import torch
import torch.nn as nn
from torch.autograd import grad
from .encoder import Encoder
from .decoder import Decoder
from .decoder import Decoder, CSGCombiner
from brep2sdf.utils.logger import logger
class Net(nn.Module):
def __init__(self,
octree,
volume_bboxs,
feature_dim=64,
decoder_input_dim=64,
decoder_output_dim=1,
@ -69,11 +71,18 @@ class Net(nn.Module):
# 初始化 Encoder
self.encoder = Encoder(
feature_dim=feature_dim,
octree=octree
volume_bboxs= volume_bboxs
)
# 初始化 Decoder
self.decoder = Decoder(input_dim=64, output_dim=1)
self.decoder = Decoder(
d_in=decoder_input_dim,
dims_sdf=[decoder_hidden_dim] * decoder_num_layers,
geometric_init=True,
beta=100
)
self.csg_combiner = CSGCombiner(flag_convex=True)
def forward(self, query_points):
"""
@ -85,14 +94,16 @@ class Net(nn.Module):
output: 解码后的输出结果
"""
# 批量查询所有点的索引和bbox
param_indices,bboxes = self.octree_module.forward(query_points)
print("param_indices requires_grad:", param_indices.requires_grad) # 应该输出False
print("bboxes requires_grad:", bboxes.requires_grad) # 应该输出False
_,face_indices,csg_trees = self.octree_module.forward(query_points)
# 编码
feature_vector = self.encoder.forward(query_points,param_indices,bboxes)
print("feature_vector:", feature_vector.requires_grad)
feature_vectors = self.encoder.forward(query_points,face_indices)
#print("feature_vector:", feature_vectors.requires_grad)
# 解码
output = self.decoder(feature_vector)
logger.gpu_memory_stats("encoder farward后")
f_i = self.decoder(feature_vectors)
logger.gpu_memory_stats("decoder farward后")
output = self.csg_combiner.forward(f_i, csg_trees)
logger.gpu_memory_stats("combine后")
return output

262
brep2sdf/networks/octree.py

@ -4,12 +4,14 @@ import torch
import torch.nn as nn
import numpy as np
from brep2sdf.data.utils import process_surf_ncs_with_dynamic_padding
from brep2sdf.networks.patch_graph import PatchGraph
from brep2sdf.utils.logger import logger
def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor:
def bbox_intersect_(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor:
"""判断两个轴对齐包围盒(AABB)是否相交
参数:
@ -28,8 +30,90 @@ def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor:
# 向量化比较
return torch.all((max1 >= min2) & (max2 >= min1))
def if_points_in_box(points: np.ndarray, bbox: torch.Tensor) -> bool:
"""判断点是否在AABB包围盒内
参数:
points: 形状为 (N, 3) 的数组表示N个点的坐标
bbox: 形状为 (6,) 的张量表示AABB包围盒的坐标
返回:
bool: 如果所有点都在包围盒内返回True否则返回False
"""
# 将 points 转换为 torch.Tensor
points_tensor = torch.tensor(points, dtype=torch.float32, device=bbox.device)
# 提取min和max坐标
min_coords = bbox[:3]
max_coords = bbox[3:]
#logger.debug(f"min_coords: {min_coords}, max_coords: {max_coords}")
# 向量化比较
return torch.any((points_tensor >= min_coords) & (points_tensor <= max_coords)).item()
def bbox_intersect(
surf_bboxes: torch.Tensor,
indices: torch.Tensor,
child_bboxes: torch.Tensor,
surf_points: torch.Tensor = None
) -> torch.Tensor:
'''
args:
surf_bboxes: [B, 6] - 表示多个包围盒的张量每个包围盒由其最小和最大坐标定义
indices: 选择bbox, [N], N <= B - 用于选择特定包围盒的索引张量
child_bboxes: [8, 6] - 一个包围盒被分成八个子包围盒后的结果
surf_points: [B, M_max, 3] - 每个包围盒对应的点云数据可选
return:
result_mask: [8, B] - 布尔掩码表示每个子边界框与所有包围盒是否相交
且是否包含至少一个点如果提供了点云
'''
# 初始化全为 False 的结果掩码 [8, B]
B = surf_bboxes.size(0)
result_mask = torch.zeros((8, B), dtype=torch.bool).to(surf_bboxes.device)
logger.debug(result_mask.shape)
logger.debug(indices.shape)
# 提取选中的边界框
selected_bboxes = surf_bboxes[indices] # 形状为 [N, 6]
min1, max1 = selected_bboxes[:, :3], selected_bboxes[:, 3:] # 形状为 [N, 3]
min2, max2 = child_bboxes[:, :3], child_bboxes[:, 3:] # 形状为 [8, 3]
logger.debug(selected_bboxes.shape)
# 计算子包围盒与选中包围盒的交集
intersect_mask = torch.all(
(max1.unsqueeze(0) >= min2.unsqueeze(1)) & # 形状为 [8, N, 3]
(max2.unsqueeze(1) >= min1.unsqueeze(0)), # 形状为 [8, N, 3]
dim=-1
) # 最终形状为 [8, N]
# 更新结果掩码中选中的部分
result_mask[:, indices] = intersect_mask
# 如果提供了点云,进一步检查点是否在子包围盒内
if surf_points is not None:
# 提取选中的点云
selected_points = surf_points[indices] # 形状为 [N, M_max, 3]
# 将点云广播到子边界框的维度
points_expanded = selected_points.unsqueeze(1) # 形状为 [N, 1, M_max, 3]
min2_expanded = min2.unsqueeze(0).unsqueeze(2) # 形状为 [1, 8, 1, 3]
max2_expanded = max2.unsqueeze(0).unsqueeze(2) # 形状为 [1, 8, 1, 3]
# 判断点是否在子边界框内
point_in_box_mask = (
(points_expanded >= min2_expanded) & # 形状为 [N, 8, M_max, 3]
(points_expanded <= max2_expanded) # 形状为 [N, 8, M_max, 3]
).all(dim=-1) # 最终形状为 [N, 8, M_max]
# 检查每个子边界框是否包含至少一个点
points_in_boxes_mask = point_in_box_mask.any(dim=-1).permute(1, 0) # 形状为 [8, N]
# 合并交集条件和点云条件
result_mask[:, indices] = result_mask[:, indices] & points_in_boxes_mask
logger.debug(result_mask.shape)
return result_mask
class OctreeNode(nn.Module):
def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None,device=None):
def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None,surf_ncs:np.ndarray = None,device=None):
super().__init__()
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 改为普通张量属性
@ -39,52 +123,43 @@ class OctreeNode(nn.Module):
self.child_indices = None
self.is_leaf_mask = None
# 面片索引张量
self.face_indices = torch.from_numpy(face_indices).to(self.device)
self.all_face_indices = torch.from_numpy(face_indices).to(self.device)
self.surf_bbox = surf_bbox.to(self.device) if surf_bbox is not None else None
self.surf_ncs = process_surf_ncs_with_dynamic_padding(surf_ncs).to(self.device)
# PatchGraph作为普通属性
self.patch_graph = patch_graph.to(self.device) if patch_graph is not None else None
self.max_depth = max_depth
# 参数键改为普通张量
self.param_key = torch.tensor(-1, dtype=torch.long, device=self.device)
self._is_leaf = True
# 删除所有register_buffer调用
@torch.jit.export
def set_param_key(self, k: int) -> None:
"""设置参数键值
参数:
k: 参数索引值
"""
self.param_key.fill_(k)
@torch.jit.export
def build_static_tree(self) -> None:
"""构建静态八叉树结构"""
# 预计算所有可能的节点数量,确保结果为整数
total_nodes = int(sum(8**i for i in range(self.max_depth + 1)))
num_faces = self.all_face_indices.shape[0]
# 初始化静态张量,使用整数列表作为形状参数
self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.device)
self.parent_indices = torch.full([int(total_nodes)], -1, dtype=torch.long, device=self.device)
self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.device)
self.is_leaf_mask = torch.zeros([int(total_nodes)], dtype=torch.bool, device=self.device)
self.face_indices_mask = torch.zeros([int(total_nodes),num_faces], dtype=torch.bool, device=self.device) # 1 代表有
self.is_leaf_mask = torch.ones([int(total_nodes)], dtype=torch.bool, device=self.device)
# 使用队列进行广度优先遍历
queue = [(0, self.bbox, self.face_indices)] # (node_idx, bbox, face_indices)
queue = [(0, self.bbox, self.all_face_indices)] # (node_idx, bbox, face_indices)
current_idx = 0
while queue:
node_idx, bbox, faces = queue.pop(0)
#logger.debug(f"Processing node {node_idx} with {len(faces)} faces.")
self.node_bboxes[node_idx] = bbox
# 判断 要不要继续分裂
if not self._should_split_node(current_idx):
self.is_leaf_mask[node_idx] = True
if not self._should_split_node(current_idx, faces, total_nodes):
continue
self.is_leaf_mask[node_idx] = 0
# 计算子节点边界框
min_coords = bbox[:3]
max_coords = bbox[3:]
@ -92,36 +167,35 @@ class OctreeNode(nn.Module):
# 生成8个子节点
child_bboxes = self._generate_child_bboxes(min_coords, mid_coords, max_coords)
intersect_mask = bbox_intersect(self.surf_bbox, faces, child_bboxes)
self.face_indices_mask[current_idx + 1:current_idx + 9, :] = intersect_mask
# 为每个子节点分配面片
for i, child_bbox in enumerate(child_bboxes):
child_idx = current_idx + 1
current_idx += 1
# 找到与子包围盒相交的面
intersecting_faces = []
for face_idx in faces:
face_bbox = self.surf_bbox[face_idx]
if bbox_intersect(child_bbox, face_bbox).item():
intersecting_faces.append(face_idx)
child_idx = child_idx = current_idx + i + 1
intersecting_faces = intersect_mask[i].nonzero().flatten()
#logger.debug(f"Node {child_idx} has {len(intersecting_faces)} intersecting faces.")
# 更新节点关系
self.parent_indices[child_idx] = node_idx
self.child_indices[node_idx, i] = child_idx
# 将子节点加入队列
if intersecting_faces:
queue.append((child_idx, child_bbox, torch.tensor(intersecting_faces, device=self.device)))
if len(intersecting_faces) > 0:
queue.append((child_idx, child_bbox, intersecting_faces.clone().detach()))
current_idx += 8
def _should_split_node(self, current_depth: int) -> bool:
def _should_split_node(self, current_idx: int,face_indices,max_node:int) -> bool:
"""判断节点是否需要分裂"""
# 检查是否达到最大深度
if current_depth >= self.max_depth:
if current_idx + 8 >= max_node:
return False
# 检查是否为完全图
is_clique = self.patch_graph.is_clique(self.face_indices)
#is_clique = self.patch_graph.is_clique(face_indices)
is_clique = face_indices.shape[0] < 2
if is_clique:
#logger.debug(f"Node {current_idx} is a clique. Stopping split.")
return False
return True
@ -153,9 +227,9 @@ class OctreeNode(nn.Module):
@torch.jit.export
def find_leaf(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, bool]:
"""
查找包含给定点的叶子节点并返回其信息
修改后的查找叶子节点方法返回face indices
:param query_points: 待查找的点形状为 (3,)
:return: 包含叶子节点信息的元组 (bbox, param_key, is_leaf)
:return: (bbox, param_key, face_indices, is_leaf)
"""
# 确保输入是单个点
if query_points.dim() != 1 or query_points.shape[0] != 3:
@ -168,7 +242,28 @@ class OctreeNode(nn.Module):
while iteration < max_iterations:
# 获取当前节点的叶子状态
if self.is_leaf_mask[current_idx].item():
return self.node_bboxes[current_idx], self.param_key, True
#logger.debug(f"Reached leaf node {current_idx}, with {self.face_indices_mask[current_idx].sum()} faces.")
if self.face_indices_mask[current_idx].sum() == 0:
parent_idx = self.parent_indices[current_idx]
#logger.debug(f"Use parent node {parent_idx}, with {self.face_indices_mask[parent_idx].sum()} faces.")
if parent_idx == -1:
# 根节点没有父节点,返回根节点的信息
#logger.warning(f"Reached root node {current_idx}, with {self.face_indices_mask[current_idx].sum()} faces.")
return (
self.node_bboxes[current_idx],
None, # 新增返回face indices
False
)
return (
self.node_bboxes[parent_idx],
self.face_indices_mask[parent_idx], # 新增返回face indices
False
)
return (
self.node_bboxes[current_idx],
self.face_indices_mask[current_idx], # 新增返回face indices
True
)
# 计算子节点索引
child_idx = self._get_child_indices(query_points.unsqueeze(0),
@ -185,7 +280,7 @@ class OctreeNode(nn.Module):
iteration += 1
# 如果达到最大迭代次数,返回当前节点的信息
return self.node_bboxes[current_idx], self.param_key, bool(self.is_leaf_mask[current_idx].item())
return self.node_bboxes[current_idx], None,bool(self.is_leaf_mask[current_idx].item())
@torch.jit.export
def _get_child_indices(self, points: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor:
@ -195,43 +290,70 @@ class OctreeNode(nn.Module):
def forward(self, query_points):
with torch.no_grad():
param_indices, bboxes = [], []
bboxes, face_indices_mask, csg_trees = [], [], []
for point in query_points:
bbox, idx, _ = self.find_leaf(point)
param_indices.append(idx)
bbox, faces_mask, _ = self.find_leaf(point)
bboxes.append(bbox)
param_indices = torch.stack(param_indices)
bboxes = torch.stack(bboxes)
# 添加检查代码
return param_indices, bboxes
face_indices_mask.append(faces_mask)
# 获取当前节点的CSG树结构
csg_tree = self.patch_graph.get_csg_tree(faces_mask) if self.patch_graph else None
csg_trees.append(csg_tree) # 保持原始列表结构
return (
torch.stack(bboxes),
torch.stack(face_indices_mask),
csg_trees # 直接返回列表,不转换为张量
)
def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None:
def print_tree(self, max_print_depth: int = None) -> None:
"""
递归打印八叉树结构
使用深度优先遍历 (DFS) 打印树结构父子关系通过缩进体现
参数:
depth: 当前深度 (内部使用)
max_print_depth: 最大打印深度 (None表示打印全部)
max_print_depth (int): 最大打印深度 (None 表示打印全部)
"""
if max_print_depth is not None and depth > max_print_depth:
return
def dfs(node_idx: int, depth: int):
"""
深度优先遍历辅助函数
# 打印当前节点信息
indent = " " * depth
node_type = "Leaf" if self._is_leaf else "Internal"
print(f"{indent}L{depth} [{node_type}] BBox: {self.bbox.cpu().numpy().tolist()}")
# 打印面片信息(如果有)
if self.face_indices is not None:
print(f"{indent} Face indices: {self.face_indices.cpu().numpy().tolist()}")
print(f"{indent} Child indices: {self.child_indices.cpu().numpy().tolist()}")
# 打印子节点信息
if self.child_indices is not None:
for i in range(8):
child_idx = self.child_indices[0, i].item()
if child_idx != -1:
print(f"{indent} Child {i}: Node {child_idx}")
参数:
node_idx (int): 当前节点索引
depth (int): 当前节点的深度
"""
# 如果超过最大打印深度,跳过当前节点及其子节点
if max_print_depth is not None and depth > max_print_depth:
return
indent = " " * depth # 根据深度生成缩进
is_leaf = self.is_leaf_mask[node_idx].item() # 判断是否为叶子节点
bbox = self.node_bboxes[node_idx].cpu().numpy().tolist() # 获取边界框信息
# 打印当前节点的基本信息
node_type = "Leaf" if is_leaf else "Internal"
log_lines.append(f"{indent}L{depth} [{node_type}] NODE_ID-{node_idx}, BBox: {bbox}")
if self.face_indices_mask is not None:
face_indices = self.face_indices_mask[node_idx].nonzero().cpu().numpy().flatten().tolist()
log_lines.append(f"{indent} Face Indices: {face_indices}")
# 如果是叶子节点,打印额外信息
if is_leaf:
child_indices = self.child_indices[node_idx].cpu().numpy().tolist()
log_lines.append(f"{indent} Child Indices: {child_indices}")
# 如果不是叶子节点,递归处理子节点
if not is_leaf:
for i in range(8): # 遍历所有子节点
child_idx = self.child_indices[node_idx, i].item()
if child_idx != -1: # 忽略无效的子节点索引
dfs(child_idx, depth + 1)
# 初始化日志行列表
log_lines = []
# 从根节点开始深度优先遍历
dfs(0, 0)
# 统一输出所有日志
logger.debug("\n".join(log_lines))
def __getstate__(self):
"""支持pickle序列化"""
@ -245,7 +367,6 @@ class OctreeNode(nn.Module):
'surf_bbox': self.surf_bbox,
'patch_graph': self.patch_graph,
'max_depth': self.max_depth,
'param_key': self.param_key,
'_is_leaf': self._is_leaf
}
return state
@ -261,7 +382,6 @@ class OctreeNode(nn.Module):
self.surf_bbox = state['surf_bbox']
self.patch_graph = state['patch_graph']
self.max_depth = state['max_depth']
self.param_key = state['param_key']
self._is_leaf = state['_is_leaf']
def to(self, device=None, dtype=None, non_blocking=False):

35
brep2sdf/networks/patch_graph.py

@ -2,6 +2,7 @@ from typing import Tuple
import torch
import torch.nn as nn
import numpy as np
from brep2sdf.utils.logger import logger
class PatchGraph(nn.Module):
def __init__(self, num_patches: int, device: torch.device = None):
@ -56,13 +57,45 @@ class PatchGraph(nn.Module):
return subgraph_edges, subgraph_types
def get_csg_tree(self, node_faces_mask: torch.Tensor):
"""生成CSG组合树结构
参数:
node_faces: 要处理的面片索引集合形状为 (N,)
返回:
嵌套列表结构表示CSG组合层次
示例:
[[0, [1,2]], 3] 表示0与(1和2的组合)进行凹组合然后与3进行凸组合
"""
print("node_faces_mask:", node_faces_mask)
if self.edge_index is None:
return []
node_faces = node_faces_mask.nonzero()
node_faces = node_faces.flatten().to('cpu').numpy()
logger.debug(f"node_faces: {node_faces}")
node_set = set(node_faces) # 创建输入面片的集合用于快速查找
visited = set()
csg_tree = []
# 优先处理凹边连接
concave_edges = self.edge_index[:, self.edge_type == 0].cpu().numpy().T
for u, v in concave_edges:
u, v = int(u), int(v)
if u in node_set and v in node_set and u not in visited and v not in visited:
csg_tree.append([u, v])
visited.update({u, v})
# 处理剩余面片(只包含输入的面片)
remaining = [int(f) for f in node_faces if f not in visited]
csg_tree.extend(remaining)
return csg_tree
def is_clique(self, node_faces: torch.Tensor) -> bool:
"""检查给定面片集合是否构成完全图
参数:
node_faces: 要检查的面片索引集合
face [0,1,2,3,4,]
返回:
bool: 是否为完全图

69
brep2sdf/test.py

@ -1,4 +1,69 @@
import torch
model = torch.jit.load("/home/wch/brep2sdf/data/output_data/00000054.pt")
print(model)
from typing import List, Tuple
def bbox_intersect(surf_bboxes: torch.Tensor, indices: torch.Tensor, child_bboxes: torch.Tensor) -> torch.Tensor:
'''
args:
surf_bboxes: [B, 6] - 表示多个包围盒的张量每个包围盒由其最小和最大坐标定义
indices: 选择bbox, [N], N <= B - 用于选择特定包围盒的索引张量
child_bboxes: [8, 6] - 一个包围盒被分成八个子包围盒后的结果
return:
intersect_mask: [8, N] - 布尔掩码表示每个子包围盒与选择的包围盒是否相交
'''
# 提取选中的边界框
selected_bboxes = surf_bboxes[indices] # 形状为 [N, 6]
min1, max1 = selected_bboxes[:, :3], selected_bboxes[:, 3:] # 形状为 [N, 3]
min2, max2 = child_bboxes[:, :3], child_bboxes[:, 3:] # 形状为 [8, 3]
# 确保广播机制正常工作
intersect_mask = torch.all(
(max1.unsqueeze(0) >= min2.unsqueeze(1)) & # 形状为 [8, N, 3]
(max2.unsqueeze(1) >= min1.unsqueeze(0)), # 形状为 [8, N, 3]
dim=-1
) # 最终形状为 [8, N]
return intersect_mask
# 测试程序
if __name__ == "__main__":
# 构造输入数据
surf_bboxes = torch.tensor([
[0, 0, 0, 1, 1, 1], # 立方体 1
[0.5, 0.5, 0.5, 1.5, 1.5, 1.5], # 立方体 2
[2, 2, 2, 3, 3, 3] # 立方体 3
]) # [B=3, 6]
indices = torch.tensor([0, 1]) # 选择前两个立方体
# 假设父边界框为 [0, 0, 0, 2, 2, 2],生成其八个子边界框
parent_bbox = torch.tensor([0, 0, 0, 2, 2, 2])
center = (parent_bbox[:3] + parent_bbox[3:]) / 2
child_bboxes = torch.tensor([
[parent_bbox[0], parent_bbox[1], parent_bbox[2], center[0], center[1], center[2]], # 左下前
[center[0], parent_bbox[1], parent_bbox[2], parent_bbox[3], center[1], center[2]], # 右下前
[parent_bbox[0], center[1], parent_bbox[2], center[0], parent_bbox[4], center[2]], # 左上前
[center[0], center[1], parent_bbox[2], parent_bbox[3], parent_bbox[4], center[2]], # 右上前
[parent_bbox[0], parent_bbox[1], center[2], center[0], center[1], parent_bbox[5]], # 左下后
[center[0], parent_bbox[1], center[2], parent_bbox[3], center[1], parent_bbox[5]], # 右下后
[parent_bbox[0], center[1], center[2], center[0], parent_bbox[4], parent_bbox[5]], # 左上后
[center[0], center[1], center[2], parent_bbox[3], parent_bbox[4], parent_bbox[5]] # 右上后
]) # [8, 6]
# 调用函数
intersect_mask = bbox_intersect(surf_bboxes, indices, child_bboxes)
# 输出结果
print("Intersect Mask:")
print(intersect_mask)
# 将布尔掩码转换为索引列表
child_indices = []
for i in range(8): # 遍历每个子节点
intersecting_faces = indices[intersect_mask[i]] # 获取当前子节点的相交面片索引
child_indices.append(intersecting_faces)
# 打印每个子节点对应的相交索引
print("\nChild Indices:")
for i, indices in enumerate(child_indices):
print(f"Child {i}: {indices}")

213
brep2sdf/train.py

@ -116,11 +116,12 @@ class Trainer:
device=self.device
)
self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=8)
self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=6)
logger.gpu_memory_stats("数初始化后")
self.model = Net(
octree=self.root,
volume_bboxs=surf_bbox,
feature_dim=64
).to(self.device)
logger.gpu_memory_stats("模型初始化后")
@ -138,7 +139,7 @@ class Trainer:
logger.info(f"初始化完成,正在处理模型 {self.model_name}")
def build_tree(self,surf_bbox, graph, max_depth=6):
def build_tree(self,surf_bbox, graph, max_depth=9):
num_faces = surf_bbox.shape[0]
bbox = self._calculate_global_bbox(surf_bbox)
self.root = OctreeNode(
@ -147,13 +148,13 @@ class Trainer:
patch_graph=graph,
max_depth=max_depth,
surf_bbox=surf_bbox,
surf_ncs=self.data['surf_ncs']
)
#print(surf_bbox)
logger.info("starting octree conduction")
self.root.build_static_tree()
logger.info("complete octree conduction")
#self.root.print_tree(0)
self.root.print_tree()
def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor:
"""
@ -190,10 +191,6 @@ class Trainer:
def train_epoch(self, epoch: int) -> float:
self.model.train()
total_loss = 0.0
step = 0 # 如果你的训练是分批次的,这里应该用批次索引
# --- 1. 检查输入数据 ---
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列)
# 并且 SDF 值总是在最后一列
@ -201,114 +198,124 @@ class Trainer:
logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.")
return float('inf')
points = self.sdf_data[:, 0:3].clone().detach() # 取前3列作为点
gt_sdf = self.sdf_data[:, -1].clone().detach() # 取最后一列作为SDF真值
normals = None
if args.use_normal:
if self.sdf_data.shape[1] < 7: # 检查是否有足够的列给法线
logger.error(f"Epoch {epoch}: --use-normal is specified, but sdf_data has only {self.sdf_data.shape[1]} columns.")
return float('inf')
normals = self.sdf_data[:, 3:6].clone().detach() # 取中间3列作为法线
# 执行检查
if self.debug_mode:
if check_tensor(points, "Input Points", epoch, step): return float('inf')
if check_tensor(gt_sdf, "Input GT SDF", epoch, step): return float('inf')
self.model.train()
total_loss = 0.0
step = 0 # 如果你的训练是分批次的,这里应该用批次索引
batch_size = 10240 # 设置合适的batch大小
# 将数据分成多个batch
num_points = self.sdf_data.shape[0]
num_batches = (num_points + batch_size - 1) // batch_size
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, num_points)
points = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点
gt_sdf = self.sdf_data[start_idx:end_idx, -1].clone().detach() # 取最后一列作为SDF真值
normals = None
if args.use_normal:
# 只有在请求法线时才检查 normals
if check_tensor(normals, "Input Normals", epoch, step): return float('inf')
if self.sdf_data.shape[1] < 7: # 检查是否有足够的列给法线
logger.error(f"Epoch {epoch}: --use-normal is specified, but sdf_data has only {self.sdf_data.shape[1]} columns.")
return float('inf')
normals = self.sdf_data[start_idx:end_idx, 3:6].clone().detach() # 取中间3列作为法线
# 执行检查
if self.debug_mode:
if check_tensor(points, "Input Points", epoch, step): return float('inf')
if check_tensor(gt_sdf, "Input GT SDF", epoch, step): return float('inf')
if args.use_normal:
# 只有在请求法线时才检查 normals
if check_tensor(normals, "Input Normals", epoch, step): return float('inf')
# --- 准备模型输入,启用梯度 ---
points.requires_grad_(True) # 在检查之后启用梯度
# --- 前向传播 ---
self.optimizer.zero_grad()
pred_sdf = self.model(points)
# --- 准备模型输入,启用梯度 ---
points.requires_grad_(True) # 在检查之后启用梯度
if self.debug_mode:
# --- 检查前向传播的输出 ---
logger.gpu_memory_stats("前向传播后")
# --- 2. 检查模型输出 ---
#if check_tensor(pred_sdf, "Predicted SDF (Model Output)", epoch, step): return float('inf')
# --- 前向传播 ---
self.optimizer.zero_grad()
pred_sdf = self.model(points)
# --- 计算损失 ---
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败
loss_details = {}
try:
# --- 3. 检查损失计算前的输入 ---
# (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf)
#if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss")
#if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss")
if args.use_normal:
# 检查法线和带梯度的点
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss")
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss")
logger.gpu_memory_stats("计算损失前")
loss, loss_details = self.loss_manager.compute_loss(
points,
normals, # 传递检查过的 normals
gt_sdf,
pred_sdf
)
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
# --- 4. 检查损失计算结果 ---
if self.debug_mode:
if check_tensor(loss, "Calculated Loss", epoch, step):
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.")
if loss_details: logger.error(f"Loss Details: {loss_details}")
return float('inf') # 如果损失无效,停止这个epoch
except Exception as loss_e:
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True)
return float('inf') # 如果计算出错,停止这个epoch
logger.gpu_memory_stats("损失计算后")
# --- 反向传播和优化 ---
try:
loss.backward()
# --- 5. (可选) 检查梯度 ---
# for name, param in self.model.named_parameters():
# if param.grad is not None:
# if check_tensor(param.grad, f"Gradient/{name}", epoch, step):
# logger.warning(f"Epoch {epoch} Step {step}: Bad gradient for {name}. Consider clipping or zeroing.")
# # 例如:param.grad.data.clamp_(-1, 1) # 值裁剪
# # 或在 optimizer.step() 前进行范数裁剪:
# # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# --- (推荐) 添加梯度裁剪 ---
# 防止梯度爆炸,这可能是导致 inf/nan 的原因之一
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 范数裁剪
self.optimizer.step()
except Exception as backward_e:
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True)
# 如果你想看是哪个操作导致的,可以启用 anomaly detection
# torch.autograd.set_detect_anomaly(True) # 放在训练开始前
return float('inf') # 如果反向传播或优化出错,停止这个epoch
# --- 记录和累加损失 ---
current_loss = loss.item()
if not np.isfinite(current_loss): # 再次确认损失是有效的数值
logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).")
return float('inf')
# --- 检查前向传播的输出 ---
logger.gpu_memory_stats("前向传播后")
# --- 2. 检查模型输出 ---
#if check_tensor(pred_sdf, "Predicted SDF (Model Output)", epoch, step): return float('inf')
# --- 计算损失 ---
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败
loss_details = {}
try:
# --- 3. 检查损失计算前的输入 ---
# (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf)
#if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss")
#if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss")
if args.use_normal:
# 检查法线和带梯度的点
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss")
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss")
logger.gpu_memory_stats("计算损失前")
loss, loss_details = self.loss_manager.compute_loss(
points,
normals, # 传递检查过的 normals
gt_sdf,
pred_sdf
)
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
# --- 4. 检查损失计算结果 ---
if self.debug_mode:
if check_tensor(loss, "Calculated Loss", epoch, step):
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.")
if loss_details: logger.error(f"Loss Details: {loss_details}")
return float('inf') # 如果损失无效,停止这个epoch
except Exception as loss_e:
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True)
return float('inf') # 如果计算出错,停止这个epoch
logger.gpu_memory_stats("损失计算后")
# --- 反向传播和优化 ---
try:
loss.backward()
# --- 5. (可选) 检查梯度 ---
# for name, param in self.model.named_parameters():
# if param.grad is not None:
# if check_tensor(param.grad, f"Gradient/{name}", epoch, step):
# logger.warning(f"Epoch {epoch} Step {step}: Bad gradient for {name}. Consider clipping or zeroing.")
# # 例如:param.grad.data.clamp_(-1, 1) # 值裁剪
# # 或在 optimizer.step() 前进行范数裁剪:
# # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# --- (推荐) 添加梯度裁剪 ---
# 防止梯度爆炸,这可能是导致 inf/nan 的原因之一
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 范数裁剪
self.optimizer.step()
except Exception as backward_e:
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True)
# 如果你想看是哪个操作导致的,可以启用 anomaly detection
# torch.autograd.set_detect_anomaly(True) # 放在训练开始前
return float('inf') # 如果反向传播或优化出错,停止这个epoch
# --- 记录和累加损失 ---
current_loss = loss.item()
if not np.isfinite(current_loss): # 再次确认损失是有效的数值
logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).")
return float('inf')
total_loss += current_loss
total_loss += current_loss
del loss
torch.cuda.empty_cache()
# 记录训练进度 (只记录有效的损失)
logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {current_loss:.6f}')
if loss_details: logger.info(f"Loss Details: {loss_details}")
# (如果你的训练分批次,这里应该继续循环下一批次)
# step += 1
del loss
torch.cuda.empty_cache() # 清空缓存
return total_loss # 对于单批次训练,直接返回当前损失
def validate(self, epoch: int) -> float:

2
brep2sdf/utils/logger.py

@ -227,7 +227,7 @@ class BRepLogger:
stats.append(f" 峰值: {max_allocated:.1f} MB")
# 一次性输出所有统计信息
self.info("\n".join(stats))
self.debug("\n".join(stats))
# 获取每个张量的内存使用情况
if include_trace:

Loading…
Cancel
Save