Browse Source

可以训练,但是和torch jit还有一些兼容问题

final
mckay 1 month ago
parent
commit
f98c9ac227
  1. 3
      brep2sdf/config/default_config.py
  2. 1
      brep2sdf/data/data.py
  3. 1
      brep2sdf/data/utils.py
  4. 33
      brep2sdf/networks/decoder.py
  5. 22
      brep2sdf/networks/encoder.py
  6. 52
      brep2sdf/networks/network.py
  7. 89
      brep2sdf/networks/octree.py
  8. 34
      brep2sdf/networks/patch_graph.py
  9. 137
      brep2sdf/train.py

3
brep2sdf/config/default_config.py

@ -9,6 +9,7 @@ class ModelConfig:
embed_dim: int = 768 # 3 的 倍数 embed_dim: int = 768 # 3 的 倍数
latent_dim: int = 32 latent_dim: int = 32
octree_max_depth = 6
# 点云采样配置 # 点云采样配置
num_surf_points: int = 64 # 每个面采样点数 num_surf_points: int = 64 # 每个面采样点数
num_edge_points: int = 8 # 每条边采样点数 num_edge_points: int = 8 # 每条边采样点数
@ -48,7 +49,7 @@ class TrainConfig:
# 基本训练参数 # 基本训练参数
batch_size: int = 8 batch_size: int = 8
num_workers: int = 4 num_workers: int = 4
num_epochs: int = 1000 num_epochs: int = 1
learning_rate: float = 0.001 learning_rate: float = 0.001
min_lr: float = 1e-5 min_lr: float = 1e-5
weight_decay: float = 0.01 weight_decay: float = 0.01

1
brep2sdf/data/data.py

@ -141,7 +141,6 @@ def prepare_sdf_data(surf_data, normals=None, max_points=100000, device='cuda'):
return torch.tensor(sdf_array, dtype=torch.float32, device=device) return torch.tensor(sdf_array, dtype=torch.float32, device=device)
def print_data_distribution(data: torch.Tensor) -> None: def print_data_distribution(data: torch.Tensor) -> None:
"""打印数据分布统计信息 """打印数据分布统计信息

1
brep2sdf/data/utils.py

@ -56,6 +56,7 @@ def process_surf_ncs_with_dynamic_padding(surf_ncs: np.ndarray) -> torch.Tensor:
return padded_tensor return padded_tensor
def normalize(surfs, edges, corners): def normalize(surfs, edges, corners):
""" """
将CAD模型归一化到单位立方体空间 将CAD模型归一化到单位立方体空间

33
brep2sdf/networks/decoder.py

@ -88,7 +88,34 @@ class Decoder(nn.Module):
return f_i return f_i
@torch.jit.export
def forward_training_volumes(self, feature_matrix: torch.Tensor) -> torch.Tensor:
'''
:param feature_matrix: 形状为(S, D) 的特征矩阵
S: 采样数量
D: 特征维度
:return:
f: 各patch的SDF值 (S)
'''
# 直接使用输入的特征矩阵,因为形状已经是 (S, D)
x = feature_matrix
for layer in range(0, self.sdf_layers - 1):
lin = getattr(self, "sdf_" + str(layer))
if layer in self.skip_in:
x = torch.cat([x, x], -1) / np.sqrt(2)
x = lin(x)
if layer < self.sdf_layers - 2:
x = self.activation(x)
output_value = x # 所有 f 的值
# 调整输出形状为 (S)
f = output_value.squeeze(-1)
return f
"""
# 一个基础情形: 输入 fi 形状[P] 和 csg tree,凹凸组合输出h # 一个基础情形: 输入 fi 形状[P] 和 csg tree,凹凸组合输出h
#注意考虑如何批量处理 (B, P) 和 [csg tree] #注意考虑如何批量处理 (B, P) 和 [csg tree]
class CSGCombiner: class CSGCombiner:
@ -96,7 +123,8 @@ class CSGCombiner:
self.flag_convex = flag_convex self.flag_convex = flag_convex
self.rho = rho self.rho = rho
def forward(self, f_i: torch.Tensor, csg_tree) -> torch.Tensor: def forward(self, f_i: torch.Tensor, csg_tree
) -> torch.Tensor:
''' '''
:param f_i: 形状为 (B, P) 的各patch SDF值 :param f_i: 形状为 (B, P) 的各patch SDF值
:param csg_tree: CSG树结构 :param csg_tree: CSG树结构
@ -216,4 +244,5 @@ def test_csg_combiner():
print(f"rho={rho}:", h_soft) print(f"rho={rho}:", h_soft)
if __name__ == "__main__": if __name__ == "__main__":
test_csg_combiner() test_csg_combiner()
"""

22
brep2sdf/networks/encoder.py

@ -75,18 +75,34 @@ class Encoder(nn.Module):
current_indices = volume_indices[:, k] current_indices = volume_indices[:, k]
# 遍历所有存在的volume # 遍历所有存在的volume
for vol_id in range(len(self.feature_volumes)): for vol_id, volume in enumerate(self.feature_volumes):
# 创建掩码 (B,) # 创建掩码 (B,)
mask = (current_indices == vol_id) mask = (current_indices == vol_id)
if mask.any(): if mask.any():
# 获取对应volume的特征 (M, D) # 获取对应volume的特征 (M, D)
features = self.feature_volumes[vol_id](query_points[mask]) features = volume.forward(query_points[mask])
all_features[mask, k] = features all_features[mask, k] = features
return all_features return all_features
@torch.jit.export
def forward_training_volumes(self, surf_points: torch.Tensor, patch_id: int) -> torch.Tensor:
"""
处理表面采样点的特征提取
参数:
surf_points: 表面采样点 (S, 3), S: #sampled point per feature_volumes.
返回:
特征张量 (S, D)
"""
# 使用枚举遍历 feature_volumes,避免直接索引
for idx, volume in enumerate(self.feature_volumes):
if idx == patch_id:
return volume.forward(surf_points)
return torch.zeros(surf_points.shape[0], self.feature_dim, device=surf_points.device)
return features
def _optimized_trilinear(self, points, bboxes, features): def _optimized_trilinear(self, points, bboxes, features)-> torch.Tensor:
"""优化后的向量化三线性插值""" """优化后的向量化三线性插值"""
# 添加显式类型转换确保计算稳定性 # 添加显式类型转换确保计算稳定性
min_coords = bboxes[..., :3].to(torch.float32) min_coords = bboxes[..., :3].to(torch.float32)

52
brep2sdf/networks/network.py

@ -49,7 +49,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.autograd import grad from torch.autograd import grad
from .encoder import Encoder from .encoder import Encoder
from .decoder import Decoder, CSGCombiner from .decoder import Decoder
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
class Net(nn.Module): class Net(nn.Module):
@ -82,8 +82,9 @@ class Net(nn.Module):
beta=100 beta=100
) )
self.csg_combiner = CSGCombiner(flag_convex=True) #self.csg_combiner = CSGCombiner(flag_convex=True)
@torch.jit.export
def forward(self, query_points): def forward(self, query_points):
""" """
前向传播 前向传播
@ -94,18 +95,55 @@ class Net(nn.Module):
output: 解码后的输出结果 output: 解码后的输出结果
""" """
# 批量查询所有点的索引和bbox # 批量查询所有点的索引和bbox
_,face_indices,csg_trees = self.octree_module.forward(query_points) _,face_indices_mask,operator = self.octree_module.forward(query_points)
# 编码 # 编码
feature_vectors = self.encoder.forward(query_points,face_indices) feature_vectors = self.encoder.forward(query_points,face_indices_mask)
#print("feature_vector:", feature_vectors.requires_grad) print("feature_vector:", feature_vectors.shape)
# 解码 # 解码
logger.gpu_memory_stats("encoder farward后") logger.gpu_memory_stats("encoder farward后")
f_i = self.decoder(feature_vectors) f_i = self.decoder(feature_vectors) # (B, P)
logger.gpu_memory_stats("decoder farward后") logger.gpu_memory_stats("decoder farward后")
output = self.csg_combiner.forward(f_i, csg_trees)
output = f_i[:, 0]
# 提取有效值并填充到固定大小 (B, max_patches)
padded_f_i = torch.full((f_i.shape[0], 2), float('inf'), device=f_i.device) # (B, max_patches)
for i in range(f_i.shape[0]):
sample_valid_values = f_i[i][face_indices_mask[i]] # (N,), N <= P
num_valid = min(len(sample_valid_values), 2)
padded_f_i[i, :num_valid] = sample_valid_values[:num_valid]
# 找到需要组合的行
mask_concave = (operator == 0)
mask_convex = (operator == 1)
# 对 operator == 0 的样本取最大值
if mask_concave.any():
output[mask_concave] = torch.max(padded_f_i[mask_concave], dim=1).values
# 对 operator == 1 的样本取最小值
if mask_convex.any():
output[mask_convex] = torch.min(padded_f_i[mask_convex], dim=1).values
logger.gpu_memory_stats("combine后") logger.gpu_memory_stats("combine后")
return output return output
@torch.jit.export
def forward_training_volumes(self, surf_points, patch_id:int):
"""
only surf sampled points
surf_points (P, S):
return (P, S)
"""
feature_mat = self.encoder.forward_training_volumes(surf_points,patch_id)
f_i = self.decoder.forward_training_volumes(feature_mat)
return f_i.squeeze()
def gradient(inputs, outputs): def gradient(inputs, outputs):
d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device) d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)

89
brep2sdf/networks/octree.py

@ -1,4 +1,4 @@
from typing import Tuple from typing import Tuple,List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -54,7 +54,7 @@ def bbox_intersect(
surf_bboxes: torch.Tensor, surf_bboxes: torch.Tensor,
indices: torch.Tensor, indices: torch.Tensor,
child_bboxes: torch.Tensor, child_bboxes: torch.Tensor,
surf_points: torch.Tensor = None surf_points: Optional[torch.Tensor]=None
) -> torch.Tensor: ) -> torch.Tensor:
''' '''
args: args:
@ -69,15 +69,15 @@ def bbox_intersect(
# 初始化全为 False 的结果掩码 [8, B] # 初始化全为 False 的结果掩码 [8, B]
B = surf_bboxes.size(0) B = surf_bboxes.size(0)
result_mask = torch.zeros((8, B), dtype=torch.bool).to(surf_bboxes.device) result_mask = torch.zeros((8, B), dtype=torch.bool).to(surf_bboxes.device)
logger.debug(result_mask.shape) #logger.debug(result_mask.shape)
logger.debug(indices.shape) #logger.debug(indices.shape)
# 提取选中的边界框 # 提取选中的边界框
selected_bboxes = surf_bboxes[indices] # 形状为 [N, 6] selected_bboxes = surf_bboxes[indices] # 形状为 [N, 6]
min1, max1 = selected_bboxes[:, :3], selected_bboxes[:, 3:] # 形状为 [N, 3] min1, max1 = selected_bboxes[:, :3], selected_bboxes[:, 3:] # 形状为 [N, 3]
min2, max2 = child_bboxes[:, :3], child_bboxes[:, 3:] # 形状为 [8, 3] min2, max2 = child_bboxes[:, :3], child_bboxes[:, 3:] # 形状为 [8, 3]
logger.debug(selected_bboxes.shape) #logger.debug(selected_bboxes.shape)
# 计算子包围盒与选中包围盒的交集 # 计算子包围盒与选中包围盒的交集
intersect_mask = torch.all( intersect_mask = torch.all(
(max1.unsqueeze(0) >= min2.unsqueeze(1)) & # 形状为 [8, N, 3] (max1.unsqueeze(0) >= min2.unsqueeze(1)) & # 形状为 [8, N, 3]
@ -109,11 +109,11 @@ def bbox_intersect(
# 合并交集条件和点云条件 # 合并交集条件和点云条件
result_mask[:, indices] = result_mask[:, indices] & points_in_boxes_mask result_mask[:, indices] = result_mask[:, indices] & points_in_boxes_mask
logger.debug(result_mask.shape) #logger.debug(result_mask.shape)
return result_mask return result_mask
class OctreeNode(nn.Module): class OctreeNode(nn.Module):
def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None,surf_ncs:np.ndarray = None,device=None): def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: Optional[torch.Tensor]=None, patch_graph: Optional[PatchGraph] = None,surf_ncs:Optional[np.ndarray] = None,device:Optional[torch.device]=None):
super().__init__() super().__init__()
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 改为普通张量属性 # 改为普通张量属性
@ -185,7 +185,7 @@ class OctreeNode(nn.Module):
queue.append((child_idx, child_bbox, intersecting_faces.clone().detach())) queue.append((child_idx, child_bbox, intersecting_faces.clone().detach()))
current_idx += 8 current_idx += 8
def _should_split_node(self, current_idx: int,face_indices,max_node:int) -> bool: def _should_split_node(self, current_idx: int,face_indices:torch.Tensor,max_node:int) -> bool:
"""判断节点是否需要分裂""" """判断节点是否需要分裂"""
# 检查是否达到最大深度 # 检查是否达到最大深度
if current_idx + 8 >= max_node: if current_idx + 8 >= max_node:
@ -229,7 +229,7 @@ class OctreeNode(nn.Module):
""" """
修改后的查找叶子节点方法返回face indices 修改后的查找叶子节点方法返回face indices
:param query_points: 待查找的点形状为 (3,) :param query_points: 待查找的点形状为 (3,)
:return: (bbox, param_key, face_indices, is_leaf) :return: (bbox, face_indices, is_leaf)
""" """
# 确保输入是单个点 # 确保输入是单个点
if query_points.dim() != 1 or query_points.shape[0] != 3: if query_points.dim() != 1 or query_points.shape[0] != 3:
@ -248,10 +248,9 @@ class OctreeNode(nn.Module):
#logger.debug(f"Use parent node {parent_idx}, with {self.face_indices_mask[parent_idx].sum()} faces.") #logger.debug(f"Use parent node {parent_idx}, with {self.face_indices_mask[parent_idx].sum()} faces.")
if parent_idx == -1: if parent_idx == -1:
# 根节点没有父节点,返回根节点的信息 # 根节点没有父节点,返回根节点的信息
#logger.warning(f"Reached root node {current_idx}, with {self.face_indices_mask[current_idx].sum()} faces.")
return ( return (
self.node_bboxes[current_idx], self.node_bboxes[current_idx],
None, # 新增返回face indices torch.tensor([], dtype=torch.float32, device=self.node_bboxes.device), # 新增返回face indices
False False
) )
return ( return (
@ -280,7 +279,7 @@ class OctreeNode(nn.Module):
iteration += 1 iteration += 1
# 如果达到最大迭代次数,返回当前节点的信息 # 如果达到最大迭代次数,返回当前节点的信息
return self.node_bboxes[current_idx], None,bool(self.is_leaf_mask[current_idx].item()) return self.node_bboxes[current_idx], torch.tensor([], dtype=torch.float32, device=self.node_bboxes.device),bool(self.is_leaf_mask[current_idx].item())
@torch.jit.export @torch.jit.export
def _get_child_indices(self, points: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor: def _get_child_indices(self, points: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor:
@ -288,23 +287,26 @@ class OctreeNode(nn.Module):
mid_points = (bboxes[:, :3] + bboxes[:, 3:]) / 2 mid_points = (bboxes[:, :3] + bboxes[:, 3:]) / 2
return ((points >= mid_points) << torch.arange(3, device=points.device)).sum(dim=1) return ((points >= mid_points) << torch.arange(3, device=points.device)).sum(dim=1)
def forward(self, query_points): def forward(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
with torch.no_grad(): with torch.no_grad():
bboxes, face_indices_mask, csg_trees = [], [], [] bboxes: List[torch.Tensor] = []
face_indices_mask: List[torch.Tensor] = []
operator: List[int] = []
for point in query_points: for point in query_points:
bbox, faces_mask, _ = self.find_leaf(point) bbox, faces_mask, _ = self.find_leaf(point)
bboxes.append(bbox) bboxes.append(bbox)
face_indices_mask.append(faces_mask) face_indices_mask.append(faces_mask)
# 获取当前节点的CSG树结构 # 获取当前节点的CSG树结构
csg_tree = self.patch_graph.get_csg_tree(faces_mask) if self.patch_graph else None #csg_tree = self.patch_graph.get_csg_tree(faces_mask) if self.patch_graph else None
csg_trees.append(csg_tree) # 保持原始列表结构 #csg_trees.append(csg_tree) # 保持原始列表结构
operator.append(self.patch_graph.get_operator(faces_mask.nonzero()) if self.patch_graph is not None else -1)
return ( return (
torch.stack(bboxes), torch.stack(bboxes),
torch.stack(face_indices_mask), torch.stack(face_indices_mask),
csg_trees # 直接返回列表,不转换为张量 torch.tensor(operator, dtype=torch.int) # 直接返回列表,不转换为张量
) )
def print_tree(self, max_print_depth: int = None) -> None: def print_tree(self, max_print_depth: Optional[int] = None) -> None:
""" """
使用深度优先遍历 (DFS) 打印树结构父子关系通过缩进体现 使用深度优先遍历 (DFS) 打印树结构父子关系通过缩进体现
@ -353,36 +355,45 @@ class OctreeNode(nn.Module):
dfs(0, 0) dfs(0, 0)
# 统一输出所有日志 # 统一输出所有日志
logger.debug("\n".join(log_lines)) #logger.debug("\n".join(log_lines))
def __getstate__(self): def __getstate__(self):
"""支持pickle序列化""" """支持pickle序列化"""
state = { state = {
'bbox': self.bbox, 'bbox': self.bbox.cpu(), # 转换为CPU张量
'node_bboxes': self.node_bboxes, 'node_bboxes': self.node_bboxes.cpu() if self.node_bboxes is not None else None,
'parent_indices': self.parent_indices, 'parent_indices': self.parent_indices.cpu() if self.parent_indices is not None else None,
'child_indices': self.child_indices, 'child_indices': self.child_indices.cpu() if self.child_indices is not None else None,
'is_leaf_mask': self.is_leaf_mask, 'is_leaf_mask': self.is_leaf_mask.cpu() if self.is_leaf_mask is not None else None,
'face_indices': self.face_indices, 'all_face_indices': self.all_face_indices.cpu(),
'surf_bbox': self.surf_bbox, 'face_indices_mask':self.face_indices_mask.cpu() if self.face_indices_mask is not None else None,
'patch_graph': self.patch_graph, 'surf_bbox': self.surf_bbox.cpu() if self.surf_bbox is not None else None,
'surf_ncs': self.surf_ncs.cpu() if self.surf_ncs is not None else None,
'patch_graph': self.patch_graph, # 假设PatchGraph支持序列化
'max_depth': self.max_depth, 'max_depth': self.max_depth,
'_is_leaf': self._is_leaf 'device': str(self.device) # 保存设备信息
} }
return state return state
def __setstate__(self, state): def __setstate__(self, state):
"""支持pickle反序列化""" """支持pickle反序列化"""
self.bbox = state['bbox'] # 手动调用 __init__ 方法
self.node_bboxes = state['node_bboxes'] self.__init__(
self.parent_indices = state['parent_indices'] bbox=state['bbox'],
self.child_indices = state['child_indices'] face_indices=state['all_face_indices'].cpu().numpy(),
self.is_leaf_mask = state['is_leaf_mask'] patch_graph=state['patch_graph'],
self.face_indices = state['face_indices'] max_depth=state['max_depth'],
self.surf_bbox = state['surf_bbox'] surf_bbox=state['surf_bbox'],
self.patch_graph = state['patch_graph'] surf_ncs=state['surf_ncs'],
self.max_depth = state['max_depth'] device=torch.device(state['device'])
self._is_leaf = state['_is_leaf'] )
# 可以在这里设置其他不需要在 __init__ 中处理的属性
self.node_bboxes = state['node_bboxes'].to(self.device) if state['node_bboxes'] is not None else None
self.parent_indices = state['parent_indices'].to(self.device) if state['parent_indices'] is not None else None
self.child_indices = state['child_indices'].to(self.device) if state['child_indices'] is not None else None
self.face_indices_mask = state['face_indices_mask'].to(self.device) if state['face_indices_mask'] is not None else None
self.is_leaf_mask = state['is_leaf_mask'].to(self.device) if state['is_leaf_mask'] is not None else None
def to(self, device=None, dtype=None, non_blocking=False): def to(self, device=None, dtype=None, non_blocking=False):
# 调用父类方法迁移基础参数 # 调用父类方法迁移基础参数

34
brep2sdf/networks/patch_graph.py

@ -71,7 +71,7 @@ class PatchGraph(nn.Module):
return [] return []
node_faces = node_faces_mask.nonzero() node_faces = node_faces_mask.nonzero()
node_faces = node_faces.flatten().to('cpu').numpy() node_faces = node_faces.flatten().to('cpu').numpy()
logger.debug(f"node_faces: {node_faces}") #logger.debug(f"node_faces: {node_faces}")
node_set = set(node_faces) # 创建输入面片的集合用于快速查找 node_set = set(node_faces) # 创建输入面片的集合用于快速查找
visited = set() visited = set()
csg_tree = [] csg_tree = []
@ -89,6 +89,32 @@ class PatchGraph(nn.Module):
csg_tree.extend(remaining) csg_tree.extend(remaining)
return csg_tree return csg_tree
def get_operator(self, node_faces: torch.Tensor):
# node_faces: shape (<=2,)
# 返回 0: 凹边, 1: 凸边,
node_faces = node_faces.flatten().to(self.device)
num_faces = node_faces.numel()
if num_faces == 1:
# 这里设置凸边是因为 后续会补一个 f2 = inf, h = min(f1, f2)
# 因为 f2 = inf, 所以 h = f1
return 1 # 只有一个面
if num_faces > 2:
#logger.warning("get_operator 输入数量为{} > 2,将只取前两个面片进行处理".format(num_faces))
node_faces = node_faces[:2]
#if self.edge_index is None or self.edge_type is None:
#logger.warning("edge_index 或 edge_type 未设置")
# 查找这两个面之间的边
mask = ((self.edge_index[0] == node_faces[0]) & (self.edge_index[1] == node_faces[1])) | \
((self.edge_index[0] == node_faces[1]) & (self.edge_index[1] == node_faces[0]))
if not mask.any():
#logger.warning("没有面可以用")
return 3
edge_types = self.edge_type[mask]
# 如果有多条边,返回第一个
return int(edge_types[0].item())
def is_clique(self, node_faces: torch.Tensor) -> bool: def is_clique(self, node_faces: torch.Tensor) -> bool:
"""检查给定面片集合是否构成完全图 """检查给定面片集合是否构成完全图
@ -150,7 +176,7 @@ class PatchGraph(nn.Module):
@staticmethod @staticmethod
def from_preprocessed_data( def from_preprocessed_data(
surf_wcs: np.ndarray, # 形状为(N,)的对象数组,每个元素是形状为(M, 3)的float32数组 surf_ncs: np.ndarray, # 形状为(N,)的对象数组,每个元素是形状为(M, 3)的float32数组
edgeFace_adj: np.ndarray, # 形状为(num_edges, num_faces)的int32数组 edgeFace_adj: np.ndarray, # 形状为(num_edges, num_faces)的int32数组
edge_types: np.ndarray, # 形状为(num_edges,)的int32数组 edge_types: np.ndarray, # 形状为(num_edges,)的int32数组
device: torch.device = None device: torch.device = None
@ -158,7 +184,7 @@ class PatchGraph(nn.Module):
"""从预处理的数据直接构建面片邻接图 """从预处理的数据直接构建面片邻接图
参数: 参数:
surf_wcs: 世界坐标系下的曲面几何数据形状为(N,)的对象数组每个元素是形状为(M, 3)的float32数组 surf_ncs: 归一化坐标系下的曲面几何数据形状为(N,)的对象数组每个元素是形状为(M, 3)的float32数组
edgeFace_adj: -面邻接矩阵形状为(num_edges, num_faces)的int32数组1表示边与面相邻 edgeFace_adj: -面邻接矩阵形状为(num_edges, num_faces)的int32数组1表示边与面相邻
edge_types: 边的类型数组形状为(num_edges,)的int32数组0表示凹边1表示凸边 edge_types: 边的类型数组形状为(num_edges,)的int32数组0表示凹边1表示凸边
@ -167,7 +193,7 @@ class PatchGraph(nn.Module):
- edge_index: 形状为(2, num_edges*2)的torch.long张量表示双向边的连接关系 - edge_index: 形状为(2, num_edges*2)的torch.long张量表示双向边的连接关系
- edge_type: 形状为(num_edges*2,)的torch.long张量表示每条边的类型 - edge_type: 形状为(num_edges*2,)的torch.long张量表示每条边的类型
""" """
num_faces = len(surf_wcs) num_faces = len(surf_ncs)
graph = PatchGraph(num_faces,device) graph = PatchGraph(num_faces,device)
# 构建边的索引和类型 # 构建边的索引和类型

137
brep2sdf/train.py

@ -104,7 +104,7 @@ class Trainer:
# 构建面片邻接图 # 构建面片邻接图
graph = PatchGraph.from_preprocessed_data( graph = PatchGraph.from_preprocessed_data(
surf_wcs=self.data['surf_wcs'], surf_ncs=self.data['surf_ncs'],
edgeFace_adj=self.data['edgeFace_adj'], edgeFace_adj=self.data['edgeFace_adj'],
edge_types=self.data['edge_types'], edge_types=self.data['edge_types'],
device='cuda' if args.octree_cuda else 'cpu' device='cuda' if args.octree_cuda else 'cpu'
@ -115,9 +115,15 @@ class Trainer:
dtype=torch.float32, dtype=torch.float32,
device=self.device device=self.device
) )
max_depth = config.model.octree_max_depth
self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=6) if not args.force_reprocess:
logger.gpu_memory_stats("数初始化后") if not self._load_octree():
self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=max_depth)
elif self.root.max_depth != max_depth:
self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=max_depth)
else:
self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=max_depth)
logger.gpu_memory_stats("树初始化后")
self.model = Net( self.model = Net(
octree=self.root, octree=self.root,
@ -140,6 +146,7 @@ class Trainer:
def build_tree(self,surf_bbox, graph, max_depth=9): def build_tree(self,surf_bbox, graph, max_depth=9):
logger.info("开始构造八叉树...")
num_faces = surf_bbox.shape[0] num_faces = surf_bbox.shape[0]
bbox = self._calculate_global_bbox(surf_bbox) bbox = self._calculate_global_bbox(surf_bbox)
self.root = OctreeNode( self.root = OctreeNode(
@ -155,6 +162,7 @@ class Trainer:
self.root.build_static_tree() self.root.build_static_tree()
logger.info("complete octree conduction") logger.info("complete octree conduction")
self.root.print_tree() self.root.print_tree()
self._save_octree()
def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor: def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor:
""" """
@ -188,8 +196,80 @@ class Trainer:
# # 返回合并后的边界框 # # 返回合并后的边界框
# return torch.cat([global_min, global_max]) # return torch.cat([global_min, global_max])
# return [-0.5,] # 这个是错误的 # return [-0.5,] # 这个是错误的
def train_epoch_stage1(self, epoch: int):
total_loss = 0.0 # 初始化总损失
for step, surf_points in enumerate(self.data['surf_ncs']): # 定义 step 变量
points = torch.tensor(surf_points, device=self.device)
gt_sdf = torch.zeros(points.shape[0], device=self.device)
normals = None
if args.use_normal:
normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device)
# --- 准备模型输入,启用梯度 ---
points.requires_grad_(True) # 在检查之后启用梯度
# --- 前向传播 ---
self.optimizer.zero_grad()
pred_sdf = self.model.forward_training_volumes(points, step)
if self.debug_mode:
# --- 检查前向传播的输出 ---
logger.gpu_memory_stats("前向传播后")
# --- 计算损失 ---
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败
loss_details = {}
try:
if args.use_normal:
loss, loss_details = self.loss_manager.compute_loss(
points,
normals,
gt_sdf,
pred_sdf
)
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
if self.debug_mode:
if check_tensor(loss, "Calculated Loss", epoch, step):
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.")
if loss_details:
logger.error(f"Loss Details: {loss_details}")
return float('inf') # 如果损失无效,停止这个epoch
except Exception as loss_e:
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True)
return float('inf') # 如果计算出错,停止这个epoch
# --- 反向传播和优化 ---
try:
loss.backward()
# --- (推荐) 添加梯度裁剪 ---
# 防止梯度爆炸,这可能是导致 inf/nan 的原因之一
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 范数裁剪
self.optimizer.step()
except Exception as backward_e:
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True)
return float('inf') # 如果反向传播或优化出错,停止这个epoch
# --- 记录和累加损失 ---
current_loss = loss.item()
if not np.isfinite(current_loss): # 再次确认损失是有效的数值
logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).")
return float('inf')
total_loss += current_loss
del loss
torch.cuda.empty_cache()
# 记录训练进度 (只记录有效的损失)
logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {current_loss:.6f}')
if loss_details: logger.info(f"Loss Details: {loss_details}")
return total_loss
def train_epoch(self, epoch: int) -> float: def train_epoch(self, epoch: int) -> float:
# --- 1. 检查输入数据 --- # --- 1. 检查输入数据 ---
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) # 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列)
@ -347,7 +427,8 @@ class Trainer:
for epoch in range(start_epoch, self.config.train.num_epochs + 1): for epoch in range(start_epoch, self.config.train.num_epochs + 1):
# 训练一个epoch # 训练一个epoch
train_loss = self.train_epoch(epoch) train_loss = self.train_epoch_stage1(epoch)
#train_loss = self.train_epoch(epoch)
# 验证 # 验证
''' '''
@ -445,6 +526,50 @@ class Trainer:
except Exception as e: except Exception as e:
logger.error(f"加载checkpoint失败: {str(e)}") logger.error(f"加载checkpoint失败: {str(e)}")
raise raise
# ... existing code ...
def _save_octree(self):
"""
保存八叉树到文件
八叉树保存路径基于模型名称和配置中的检查点目录
"""
checkpoint_dir = os.path.join(
self.config.train.checkpoint_dir,
self.model_name
)
octree_path = os.path.join(checkpoint_dir, "octree.pth")
try:
# 保存八叉树的根节点
torch.save(self.root, octree_path)
logger.info(f"八叉树已保存到 {octree_path}")
except Exception as e:
logger.error(f"保存八叉树失败: {str(e)}")
def _load_octree(self)->bool:
"""
从文件加载八叉树
尝试从基于模型名称和配置检查点目录的路径加载八叉树
"""
checkpoint_dir = os.path.join(
self.config.train.checkpoint_dir,
self.model_name
)
octree_path = os.path.join(checkpoint_dir, "octree.pth")
try:
if os.path.exists(octree_path):
# 加载八叉树的根节点
self.root = torch.load(octree_path, weights_only=False)
logger.info(f"八叉树已从 {octree_path} 加载")
return True
else:
logger.warning(f"八叉树文件 {octree_path} 不存在,无法加载。")
except Exception as e:
logger.error(f"加载八叉树失败: {str(e)}")
return False
def main(): def main():
# 这里需要初始化配置 # 这里需要初始化配置
config = get_default_config() config = get_default_config()

Loading…
Cancel
Save