Browse Source

device相关优化,但是显存峰值问题还在

final
mckay 2 months ago
parent
commit
b89989bbc1
  1. 6
      brep2sdf/networks/network.py
  2. 63
      brep2sdf/networks/octree.py
  3. 87
      brep2sdf/networks/patch_graph.py
  4. 20
      brep2sdf/train.py
  5. 15
      brep2sdf/utils/logger.py

6
brep2sdf/networks/network.py

@ -64,7 +64,7 @@ class Net(nn.Module):
super().__init__() super().__init__()
self.octree_module = octree self.octree_module = octree.to("cpu")
# 初始化 Encoder # 初始化 Encoder
self.encoder = Encoder( self.encoder = Encoder(
@ -86,9 +86,11 @@ class Net(nn.Module):
""" """
# 批量查询所有点的索引和bbox # 批量查询所有点的索引和bbox
param_indices,bboxes = self.octree_module.forward(query_points) param_indices,bboxes = self.octree_module.forward(query_points)
print("param_indices requires_grad:", param_indices.requires_grad) # 应该输出False
print("bboxes requires_grad:", bboxes.requires_grad) # 应该输出False
# 编码 # 编码
feature_vector = self.encoder.forward(query_points,param_indices,bboxes) feature_vector = self.encoder.forward(query_points,param_indices,bboxes)
print("feature_vector:", feature_vector.requires_grad)
# 解码 # 解码
output = self.decoder(feature_vector) output = self.decoder(feature_vector)
return output return output

63
brep2sdf/networks/octree.py

@ -29,25 +29,29 @@ def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor:
return torch.all((max1 >= min2) & (max2 >= min1)) return torch.all((max1 >= min2) & (max2 >= min1))
class OctreeNode(nn.Module): class OctreeNode(nn.Module):
def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None): def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None,device=None):
super().__init__() super().__init__()
# 静态张量存储节点信息 self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.register_buffer('bbox', bbox) # 当前节点的边界框 # 改为普通张量属性
self.register_buffer('node_bboxes', None) # 所有节点的边界框 self.bbox = bbox.to(self.device) # 显式设备管理
self.register_buffer('parent_indices', None) # 父节点索引 self.node_bboxes = None
self.register_buffer('child_indices', None) # 子节点索引 self.parent_indices = None
self.register_buffer('is_leaf_mask', None) # 叶子节点标记 self.child_indices = None
self.register_buffer('face_indices', torch.from_numpy(face_indices).to(bbox.device)) # 面片索引张量 self.is_leaf_mask = None
self.register_buffer('surf_bbox', surf_bbox) # 面片边界框 # 面片索引张量
self.face_indices = torch.from_numpy(face_indices).to(self.device)
self.surf_bbox = surf_bbox.to(self.device) if surf_bbox is not None else None
# PatchGraph作为普通属性 # PatchGraph作为普通属性
self.patch_graph = patch_graph # 不再使用register_buffer self.patch_graph = patch_graph.to(self.device) if patch_graph is not None else None
self.max_depth = max_depth self.max_depth = max_depth
# 将param_key改为张量 # 参数键改为普通张量
self.register_buffer('param_key', torch.tensor(-1, dtype=torch.long)) self.param_key = torch.tensor(-1, dtype=torch.long, device=self.device)
self._is_leaf = True self._is_leaf = True
# 删除所有register_buffer调用
@torch.jit.export @torch.jit.export
def set_param_key(self, k: int) -> None: def set_param_key(self, k: int) -> None:
"""设置参数键值 """设置参数键值
@ -64,11 +68,10 @@ class OctreeNode(nn.Module):
total_nodes = int(sum(8**i for i in range(self.max_depth + 1))) total_nodes = int(sum(8**i for i in range(self.max_depth + 1)))
# 初始化静态张量,使用整数列表作为形状参数 # 初始化静态张量,使用整数列表作为形状参数
self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.bbox.device) self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.device)
self.parent_indices = torch.full([int(total_nodes)], -1, dtype=torch.long, device=self.bbox.device) self.parent_indices = torch.full([int(total_nodes)], -1, dtype=torch.long, device=self.device)
self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.bbox.device) self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.device)
self.is_leaf_mask = torch.zeros([int(total_nodes)], dtype=torch.bool, device=self.bbox.device) self.is_leaf_mask = torch.zeros([int(total_nodes)], dtype=torch.bool, device=self.device)
# 使用队列进行广度优先遍历 # 使用队列进行广度优先遍历
queue = [(0, self.bbox, self.face_indices)] # (node_idx, bbox, face_indices) queue = [(0, self.bbox, self.face_indices)] # (node_idx, bbox, face_indices)
current_idx = 0 current_idx = 0
@ -108,7 +111,7 @@ class OctreeNode(nn.Module):
# 将子节点加入队列 # 将子节点加入队列
if intersecting_faces: if intersecting_faces:
queue.append((child_idx, child_bbox, torch.tensor(intersecting_faces, device=self.bbox.device))) queue.append((child_idx, child_bbox, torch.tensor(intersecting_faces, device=self.device)))
def _should_split_node(self, current_depth: int) -> bool: def _should_split_node(self, current_depth: int) -> bool:
"""判断节点是否需要分裂""" """判断节点是否需要分裂"""
@ -127,7 +130,7 @@ class OctreeNode(nn.Module):
def _generate_child_bboxes(self, min_coords: torch.Tensor, mid_coords: torch.Tensor, max_coords: torch.Tensor) -> torch.Tensor: def _generate_child_bboxes(self, min_coords: torch.Tensor, mid_coords: torch.Tensor, max_coords: torch.Tensor) -> torch.Tensor:
# 使用 with torch.no_grad() 减少梯度计算的内存占用 # 使用 with torch.no_grad() 减少梯度计算的内存占用
with torch.no_grad(): with torch.no_grad():
child_bboxes = torch.zeros([8, 6], device=self.bbox.device) child_bboxes = torch.zeros([8, 6], device=self.device)
# 使用向量化操作生成所有子节点边界框 # 使用向量化操作生成所有子节点边界框
child_bboxes[0] = torch.cat([min_coords, mid_coords]) # 前下左 child_bboxes[0] = torch.cat([min_coords, mid_coords]) # 前下左
@ -199,6 +202,7 @@ class OctreeNode(nn.Module):
bboxes.append(bbox) bboxes.append(bbox)
param_indices = torch.stack(param_indices) param_indices = torch.stack(param_indices)
bboxes = torch.stack(bboxes) bboxes = torch.stack(bboxes)
# 添加检查代码
return param_indices, bboxes return param_indices, bboxes
def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None: def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None:
@ -258,4 +262,21 @@ class OctreeNode(nn.Module):
self.patch_graph = state['patch_graph'] self.patch_graph = state['patch_graph']
self.max_depth = state['max_depth'] self.max_depth = state['max_depth']
self.param_key = state['param_key'] self.param_key = state['param_key']
self._is_leaf = state['_is_leaf'] self._is_leaf = state['_is_leaf']
def to(self, device=None, dtype=None, non_blocking=False):
# 调用父类方法迁移基础参数
super().to(device, dtype, non_blocking)
# 迁移自定义属性
if self.patch_graph is not None:
if hasattr(self.patch_graph, 'to'):
self.patch_graph = self.patch_graph.to(device=device, dtype=dtype)
else:
# 手动移动非Module属性
for attr in ['edge_index', 'edge_type', 'patch_features']:
tensor = getattr(self.patch_graph, attr, None)
if tensor is not None:
setattr(self.patch_graph, attr, tensor.to(device=device, dtype=dtype))
return self

87
brep2sdf/networks/patch_graph.py

@ -9,10 +9,10 @@ class PatchGraph(nn.Module):
self.num_patches = num_patches self.num_patches = num_patches
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 注册缓冲区 # 删除register_buffer调用,改为普通属性
self.register_buffer('edge_index', None) # 边的连接关系 (2, E) self.edge_index = None # 形状为 (2, E) 的张量
self.register_buffer('edge_type', None) # 边的类型 (E,) 0:凹边 1:凸边 self.edge_type = None # 形状为 (E,) 的张量
self.register_buffer('patch_features', None) # 面片特征 (N, F) self.patch_features = None # 形状为 (N, F) 的张量
def set_edges(self, edge_index: torch.Tensor, edge_type: torch.Tensor) -> None: def set_edges(self, edge_index: torch.Tensor, edge_type: torch.Tensor) -> None:
"""设置边的信息 """设置边的信息
@ -25,9 +25,24 @@ class PatchGraph(nn.Module):
raise ValueError(f"edge_index 必须是形状为 (2, E) 的张量,但得到 {edge_index.shape}") raise ValueError(f"edge_index 必须是形状为 (2, E) 的张量,但得到 {edge_index.shape}")
if edge_index.shape[1] != edge_type.shape[0]: if edge_index.shape[1] != edge_type.shape[0]:
raise ValueError("edge_index 和 edge_type 的边数量不匹配") raise ValueError("edge_index 和 edge_type 的边数量不匹配")
self.edge_index = edge_index.to(self.device) # 添加梯度隔离
self.edge_type = edge_type.to(self.device) with torch.no_grad():
self.edge_index = edge_index.detach().to(self.device).requires_grad_(False)
self.edge_type = edge_type.detach().to(self.device).requires_grad_(False)
def set_features(self, features: torch.Tensor) -> None:
"""设置面片特征
参数:
features: 形状为 (N, F) 的张量表示面片的特征向量
"""
if features.shape[0] != self.num_patches:
raise ValueError(f"特征数量 {features.shape[0]} 与面片数量 {self.num_patches} 不匹配")
# 添加梯度隔离
with torch.no_grad():
self.patch_features = features.detach().to(self.device).requires_grad_(False)
def get_subgraph(self, node_faces: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def get_subgraph(self, node_faces: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""获取子图的边和类型""" """获取子图的边和类型"""
@ -41,38 +56,7 @@ class PatchGraph(nn.Module):
return subgraph_edges, subgraph_types return subgraph_edges, subgraph_types
@staticmethod
def from_preprocessed_data(surf_wcs: np.ndarray, edgeFace_adj: np.ndarray, edge_types: np.ndarray, device: torch.device = None) -> 'PatchGraph':
num_faces = len(surf_wcs)
graph = PatchGraph(num_faces, device)
edge_pairs = []
edge_types_list = []
for edge_idx in range(len(edgeFace_adj)):
connected_faces = np.where(edgeFace_adj[edge_idx] == 1)[0]
if len(connected_faces) == 2:
face1, face2 = connected_faces
edge_pairs.extend([[face1, face2], [face2, face1]])
edge_type = edge_types[edge_idx]
edge_types_list.extend([edge_type, edge_type])
if edge_pairs:
edge_index = torch.tensor(edge_pairs, dtype=torch.long, device=graph.device).t()
edge_type = torch.tensor(edge_types_list, dtype=torch.long, device=graph.device)
graph.set_edges(edge_index, edge_type)
return graph
def set_features(self, features: torch.Tensor) -> None:
"""设置面片特征
参数:
features: 形状为 (N, F) 的张量表示面片的特征向量
"""
if features.shape[0] != self.num_patches:
raise ValueError(f"特征数量 {features.shape[0]} 与面片数量 {self.num_patches} 不匹配")
self.patch_features = features
def is_clique(self, node_faces: torch.Tensor) -> bool: def is_clique(self, node_faces: torch.Tensor) -> bool:
"""检查给定面片集合是否构成完全图 """检查给定面片集合是否构成完全图
@ -96,7 +80,6 @@ class PatchGraph(nn.Module):
# 计算实际的边数(考虑无向图) # 计算实际的边数(考虑无向图)
actual_edges = len(subgraph_edges[0]) // 2 actual_edges = len(subgraph_edges[0]) // 2
return actual_edges == expected_edges return actual_edges == expected_edges
def combine_sdf(self, sdf_values: torch.Tensor) -> torch.Tensor: def combine_sdf(self, sdf_values: torch.Tensor) -> torch.Tensor:
@ -136,7 +119,8 @@ class PatchGraph(nn.Module):
def from_preprocessed_data( def from_preprocessed_data(
surf_wcs: np.ndarray, # 形状为(N,)的对象数组,每个元素是形状为(M, 3)的float32数组 surf_wcs: np.ndarray, # 形状为(N,)的对象数组,每个元素是形状为(M, 3)的float32数组
edgeFace_adj: np.ndarray, # 形状为(num_edges, num_faces)的int32数组 edgeFace_adj: np.ndarray, # 形状为(num_edges, num_faces)的int32数组
edge_types: np.ndarray # 形状为(num_edges,)的int32数组 edge_types: np.ndarray, # 形状为(num_edges,)的int32数组
device: torch.device = None
) -> 'PatchGraph': ) -> 'PatchGraph':
"""从预处理的数据直接构建面片邻接图 """从预处理的数据直接构建面片邻接图
@ -151,7 +135,7 @@ class PatchGraph(nn.Module):
- edge_type: 形状为(num_edges*2,)的torch.long张量表示每条边的类型 - edge_type: 形状为(num_edges*2,)的torch.long张量表示每条边的类型
""" """
num_faces = len(surf_wcs) num_faces = len(surf_wcs)
graph = PatchGraph(num_faces) graph = PatchGraph(num_faces,device)
# 构建边的索引和类型 # 构建边的索引和类型
edge_pairs = [] edge_pairs = []
@ -174,3 +158,22 @@ class PatchGraph(nn.Module):
graph.set_edges(edge_index, edge_type) graph.set_edges(edge_index, edge_type)
return graph return graph
def to(self, device=None, dtype=None, non_blocking=False):
# 调用父类方法迁移基础参数
super().to(device, dtype, non_blocking)
# 更新设备信息
if device is not None:
self.device = device
# 迁移自定义张量属性
tensor_attrs = ['edge_index', 'edge_type', 'patch_features']
for attr in tensor_attrs:
tensor = getattr(self, attr)
if tensor is not None:
setattr(self, attr, tensor.to(device=self.device,
dtype=dtype,
non_blocking=non_blocking))
return self

20
brep2sdf/train.py

@ -41,6 +41,12 @@ parser.add_argument(
help='从指定的checkpoint文件继续训练' help='从指定的checkpoint文件继续训练'
) )
parser.add_argument(
'--octree-cuda',
action='store_true', # 默认为 False,如果用户指定该参数,则为 True
help='使用CUDA加速Octree构建'
)
args = parser.parse_args() args = parser.parse_args()
@ -100,7 +106,8 @@ class Trainer:
graph = PatchGraph.from_preprocessed_data( graph = PatchGraph.from_preprocessed_data(
surf_wcs=self.data['surf_wcs'], surf_wcs=self.data['surf_wcs'],
edgeFace_adj=self.data['edgeFace_adj'], edgeFace_adj=self.data['edgeFace_adj'],
edge_types=self.data['edge_types'] edge_types=self.data['edge_types'],
device='cuda' if args.octree_cuda else 'cpu'
) )
# 初始化网络 # 初始化网络
surf_bbox=torch.tensor( surf_bbox=torch.tensor(
@ -109,7 +116,7 @@ class Trainer:
device=self.device device=self.device
) )
self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=4) self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=8)
logger.gpu_memory_stats("数初始化后") logger.gpu_memory_stats("数初始化后")
self.model = Net( self.model = Net(
@ -139,7 +146,8 @@ class Trainer:
face_indices=np.arange(num_faces), # 初始包含所有面 face_indices=np.arange(num_faces), # 初始包含所有面
patch_graph=graph, patch_graph=graph,
max_depth=max_depth, max_depth=max_depth,
surf_bbox=surf_bbox surf_bbox=surf_bbox,
) )
#print(surf_bbox) #print(surf_bbox)
logger.info("starting octree conduction") logger.info("starting octree conduction")
@ -160,8 +168,7 @@ class Trainer:
# 直接定义固定的单位立方体边界框 # 直接定义固定的单位立方体边界框
# 注意:确保张量在正确的设备上创建 # 注意:确保张量在正确的设备上创建
fixed_bbox = torch.tensor([-0.5, -0.5, -0.5, 0.5, 0.5, 0.5], fixed_bbox = torch.tensor([-0.5, -0.5, -0.5, 0.5, 0.5, 0.5],
dtype=torch.float32, dtype=torch.float32) # 假设 self.device 存储了目标设备
device=self.device) # 假设 self.device 存储了目标设备
logger.debug(f"使用固定的全局边界框: {fixed_bbox.cpu().numpy()}") logger.debug(f"使用固定的全局边界框: {fixed_bbox.cpu().numpy()}")
return fixed_bbox return fixed_bbox
@ -277,7 +284,6 @@ class Trainer:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 范数裁剪 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 范数裁剪
self.optimizer.step() self.optimizer.step()
except Exception as backward_e: except Exception as backward_e:
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True)
# 如果你想看是哪个操作导致的,可以启用 anomaly detection # 如果你想看是哪个操作导致的,可以启用 anomaly detection
@ -300,6 +306,8 @@ class Trainer:
# (如果你的训练分批次,这里应该继续循环下一批次) # (如果你的训练分批次,这里应该继续循环下一批次)
# step += 1 # step += 1
del loss
torch.cuda.empty_cache() # 清空缓存
return total_loss # 对于单批次训练,直接返回当前损失 return total_loss # 对于单批次训练,直接返回当前损失

15
brep2sdf/utils/logger.py

@ -205,17 +205,25 @@ class BRepLogger:
if not torch.cuda.is_available(): if not torch.cuda.is_available():
return return
torch.cuda.synchronize() # 同步所有CUDA操作 torch.cuda.synchronize()
# 新增类变量记录上次内存状态
if not hasattr(self, '_last_allocated'):
self._last_allocated = 0
allocated = torch.cuda.memory_allocated() / 1024**2 allocated = torch.cuda.memory_allocated() / 1024**2
reserved = torch.cuda.memory_reserved() / 1024**2 reserved = torch.cuda.memory_reserved() / 1024**2
delta_allocated = allocated - self._last_allocated # 计算增量
max_allocated = torch.cuda.max_memory_allocated() / 1024**2 max_allocated = torch.cuda.max_memory_allocated() / 1024**2
# 更新最后记录值
self._last_allocated = allocated
tag_str = f" [{tag}]" if tag else "" tag_str = f" [{tag}]" if tag else ""
stats = [] stats = []
stats.append(f"\n=== GPU内存状态{tag_str} ===") stats.append(f"\n=== GPU内存状态{tag_str} ===")
stats.append(f" 已分配: {allocated:.1f} MB") stats.append(f" 当前分配: {allocated:.1f} MB")
stats.append(f" 已缓存: {reserved:.1f} MB") stats.append(f" 增量分配: {delta_allocated:.1f} MB")
stats.append(f" 缓存保留: {reserved:.1f} MB")
stats.append(f" 峰值: {max_allocated:.1f} MB") stats.append(f" 峰值: {max_allocated:.1f} MB")
# 一次性输出所有统计信息 # 一次性输出所有统计信息
@ -252,6 +260,7 @@ class BRepLogger:
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache() torch.cuda.empty_cache()
self._last_allocated = 0 # 重置基准值
self.info("已重置GPU内存统计") self.info("已重置GPU内存统计")
def timeit(func): def timeit(func):

Loading…
Cancel
Save