diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index bd5d15b..10c8f25 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -64,7 +64,7 @@ class Net(nn.Module): super().__init__() - self.octree_module = octree + self.octree_module = octree.to("cpu") # 初始化 Encoder self.encoder = Encoder( @@ -86,9 +86,11 @@ class Net(nn.Module): """ # 批量查询所有点的索引和bbox param_indices,bboxes = self.octree_module.forward(query_points) + print("param_indices requires_grad:", param_indices.requires_grad) # 应该输出False + print("bboxes requires_grad:", bboxes.requires_grad) # 应该输出False # 编码 feature_vector = self.encoder.forward(query_points,param_indices,bboxes) - + print("feature_vector:", feature_vector.requires_grad) # 解码 output = self.decoder(feature_vector) return output diff --git a/brep2sdf/networks/octree.py b/brep2sdf/networks/octree.py index 8e66baa..4c8f3c9 100644 --- a/brep2sdf/networks/octree.py +++ b/brep2sdf/networks/octree.py @@ -29,25 +29,29 @@ def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor: return torch.all((max1 >= min2) & (max2 >= min1)) class OctreeNode(nn.Module): - def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None): + def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None,device=None): super().__init__() - # 静态张量存储节点信息 - self.register_buffer('bbox', bbox) # 当前节点的边界框 - self.register_buffer('node_bboxes', None) # 所有节点的边界框 - self.register_buffer('parent_indices', None) # 父节点索引 - self.register_buffer('child_indices', None) # 子节点索引 - self.register_buffer('is_leaf_mask', None) # 叶子节点标记 - self.register_buffer('face_indices', torch.from_numpy(face_indices).to(bbox.device)) # 面片索引张量 - self.register_buffer('surf_bbox', surf_bbox) # 面片边界框 - + self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # 改为普通张量属性 + self.bbox = bbox.to(self.device) # 显式设备管理 + self.node_bboxes = None + self.parent_indices = None + self.child_indices = None + self.is_leaf_mask = None + # 面片索引张量 + self.face_indices = torch.from_numpy(face_indices).to(self.device) + self.surf_bbox = surf_bbox.to(self.device) if surf_bbox is not None else None + # PatchGraph作为普通属性 - self.patch_graph = patch_graph # 不再使用register_buffer + self.patch_graph = patch_graph.to(self.device) if patch_graph is not None else None self.max_depth = max_depth - # 将param_key改为张量 - self.register_buffer('param_key', torch.tensor(-1, dtype=torch.long)) + # 参数键改为普通张量 + self.param_key = torch.tensor(-1, dtype=torch.long, device=self.device) self._is_leaf = True + # 删除所有register_buffer调用 + @torch.jit.export def set_param_key(self, k: int) -> None: """设置参数键值 @@ -64,11 +68,10 @@ class OctreeNode(nn.Module): total_nodes = int(sum(8**i for i in range(self.max_depth + 1))) # 初始化静态张量,使用整数列表作为形状参数 - self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.bbox.device) - self.parent_indices = torch.full([int(total_nodes)], -1, dtype=torch.long, device=self.bbox.device) - self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.bbox.device) - self.is_leaf_mask = torch.zeros([int(total_nodes)], dtype=torch.bool, device=self.bbox.device) - + self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.device) + self.parent_indices = torch.full([int(total_nodes)], -1, dtype=torch.long, device=self.device) + self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.device) + self.is_leaf_mask = torch.zeros([int(total_nodes)], dtype=torch.bool, device=self.device) # 使用队列进行广度优先遍历 queue = [(0, self.bbox, self.face_indices)] # (node_idx, bbox, face_indices) current_idx = 0 @@ -108,7 +111,7 @@ class OctreeNode(nn.Module): # 将子节点加入队列 if intersecting_faces: - queue.append((child_idx, child_bbox, torch.tensor(intersecting_faces, device=self.bbox.device))) + queue.append((child_idx, child_bbox, torch.tensor(intersecting_faces, device=self.device))) def _should_split_node(self, current_depth: int) -> bool: """判断节点是否需要分裂""" @@ -127,7 +130,7 @@ class OctreeNode(nn.Module): def _generate_child_bboxes(self, min_coords: torch.Tensor, mid_coords: torch.Tensor, max_coords: torch.Tensor) -> torch.Tensor: # 使用 with torch.no_grad() 减少梯度计算的内存占用 with torch.no_grad(): - child_bboxes = torch.zeros([8, 6], device=self.bbox.device) + child_bboxes = torch.zeros([8, 6], device=self.device) # 使用向量化操作生成所有子节点边界框 child_bboxes[0] = torch.cat([min_coords, mid_coords]) # 前下左 @@ -199,6 +202,7 @@ class OctreeNode(nn.Module): bboxes.append(bbox) param_indices = torch.stack(param_indices) bboxes = torch.stack(bboxes) + # 添加检查代码 return param_indices, bboxes def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None: @@ -258,4 +262,21 @@ class OctreeNode(nn.Module): self.patch_graph = state['patch_graph'] self.max_depth = state['max_depth'] self.param_key = state['param_key'] - self._is_leaf = state['_is_leaf'] \ No newline at end of file + self._is_leaf = state['_is_leaf'] + + def to(self, device=None, dtype=None, non_blocking=False): + # 调用父类方法迁移基础参数 + super().to(device, dtype, non_blocking) + + # 迁移自定义属性 + if self.patch_graph is not None: + if hasattr(self.patch_graph, 'to'): + self.patch_graph = self.patch_graph.to(device=device, dtype=dtype) + else: + # 手动移动非Module属性 + for attr in ['edge_index', 'edge_type', 'patch_features']: + tensor = getattr(self.patch_graph, attr, None) + if tensor is not None: + setattr(self.patch_graph, attr, tensor.to(device=device, dtype=dtype)) + + return self \ No newline at end of file diff --git a/brep2sdf/networks/patch_graph.py b/brep2sdf/networks/patch_graph.py index b30b828..914193e 100644 --- a/brep2sdf/networks/patch_graph.py +++ b/brep2sdf/networks/patch_graph.py @@ -9,10 +9,10 @@ class PatchGraph(nn.Module): self.num_patches = num_patches self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') - # 注册缓冲区 - self.register_buffer('edge_index', None) # 边的连接关系 (2, E) - self.register_buffer('edge_type', None) # 边的类型 (E,) 0:凹边 1:凸边 - self.register_buffer('patch_features', None) # 面片特征 (N, F) + # 删除register_buffer调用,改为普通属性 + self.edge_index = None # 形状为 (2, E) 的张量 + self.edge_type = None # 形状为 (E,) 的张量 + self.patch_features = None # 形状为 (N, F) 的张量 def set_edges(self, edge_index: torch.Tensor, edge_type: torch.Tensor) -> None: """设置边的信息 @@ -25,9 +25,24 @@ class PatchGraph(nn.Module): raise ValueError(f"edge_index 必须是形状为 (2, E) 的张量,但得到 {edge_index.shape}") if edge_index.shape[1] != edge_type.shape[0]: raise ValueError("edge_index 和 edge_type 的边数量不匹配") - - self.edge_index = edge_index.to(self.device) - self.edge_type = edge_type.to(self.device) + + # 添加梯度隔离 + with torch.no_grad(): + self.edge_index = edge_index.detach().to(self.device).requires_grad_(False) + self.edge_type = edge_type.detach().to(self.device).requires_grad_(False) + + def set_features(self, features: torch.Tensor) -> None: + """设置面片特征 + + 参数: + features: 形状为 (N, F) 的张量,表示面片的特征向量 + """ + if features.shape[0] != self.num_patches: + raise ValueError(f"特征数量 {features.shape[0]} 与面片数量 {self.num_patches} 不匹配") + + # 添加梯度隔离 + with torch.no_grad(): + self.patch_features = features.detach().to(self.device).requires_grad_(False) def get_subgraph(self, node_faces: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """获取子图的边和类型""" @@ -41,38 +56,7 @@ class PatchGraph(nn.Module): return subgraph_edges, subgraph_types - @staticmethod - def from_preprocessed_data(surf_wcs: np.ndarray, edgeFace_adj: np.ndarray, edge_types: np.ndarray, device: torch.device = None) -> 'PatchGraph': - num_faces = len(surf_wcs) - graph = PatchGraph(num_faces, device) - - edge_pairs = [] - edge_types_list = [] - - for edge_idx in range(len(edgeFace_adj)): - connected_faces = np.where(edgeFace_adj[edge_idx] == 1)[0] - if len(connected_faces) == 2: - face1, face2 = connected_faces - edge_pairs.extend([[face1, face2], [face2, face1]]) - edge_type = edge_types[edge_idx] - edge_types_list.extend([edge_type, edge_type]) - - if edge_pairs: - edge_index = torch.tensor(edge_pairs, dtype=torch.long, device=graph.device).t() - edge_type = torch.tensor(edge_types_list, dtype=torch.long, device=graph.device) - graph.set_edges(edge_index, edge_type) - - return graph - - def set_features(self, features: torch.Tensor) -> None: - """设置面片特征 - - 参数: - features: 形状为 (N, F) 的张量,表示面片的特征向量 - """ - if features.shape[0] != self.num_patches: - raise ValueError(f"特征数量 {features.shape[0]} 与面片数量 {self.num_patches} 不匹配") - self.patch_features = features + def is_clique(self, node_faces: torch.Tensor) -> bool: """检查给定面片集合是否构成完全图 @@ -96,7 +80,6 @@ class PatchGraph(nn.Module): # 计算实际的边数(考虑无向图) actual_edges = len(subgraph_edges[0]) // 2 - return actual_edges == expected_edges def combine_sdf(self, sdf_values: torch.Tensor) -> torch.Tensor: @@ -136,7 +119,8 @@ class PatchGraph(nn.Module): def from_preprocessed_data( surf_wcs: np.ndarray, # 形状为(N,)的对象数组,每个元素是形状为(M, 3)的float32数组 edgeFace_adj: np.ndarray, # 形状为(num_edges, num_faces)的int32数组 - edge_types: np.ndarray # 形状为(num_edges,)的int32数组 + edge_types: np.ndarray, # 形状为(num_edges,)的int32数组 + device: torch.device = None ) -> 'PatchGraph': """从预处理的数据直接构建面片邻接图 @@ -151,7 +135,7 @@ class PatchGraph(nn.Module): - edge_type: 形状为(num_edges*2,)的torch.long张量,表示每条边的类型 """ num_faces = len(surf_wcs) - graph = PatchGraph(num_faces) + graph = PatchGraph(num_faces,device) # 构建边的索引和类型 edge_pairs = [] @@ -174,3 +158,22 @@ class PatchGraph(nn.Module): graph.set_edges(edge_index, edge_type) return graph + + def to(self, device=None, dtype=None, non_blocking=False): + # 调用父类方法迁移基础参数 + super().to(device, dtype, non_blocking) + + # 更新设备信息 + if device is not None: + self.device = device + + # 迁移自定义张量属性 + tensor_attrs = ['edge_index', 'edge_type', 'patch_features'] + for attr in tensor_attrs: + tensor = getattr(self, attr) + if tensor is not None: + setattr(self, attr, tensor.to(device=self.device, + dtype=dtype, + non_blocking=non_blocking)) + + return self diff --git a/brep2sdf/train.py b/brep2sdf/train.py index b624168..64fe51e 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -41,6 +41,12 @@ parser.add_argument( help='从指定的checkpoint文件继续训练' ) +parser.add_argument( + '--octree-cuda', + action='store_true', # 默认为 False,如果用户指定该参数,则为 True + help='使用CUDA加速Octree构建' +) + args = parser.parse_args() @@ -100,7 +106,8 @@ class Trainer: graph = PatchGraph.from_preprocessed_data( surf_wcs=self.data['surf_wcs'], edgeFace_adj=self.data['edgeFace_adj'], - edge_types=self.data['edge_types'] + edge_types=self.data['edge_types'], + device='cuda' if args.octree_cuda else 'cpu' ) # 初始化网络 surf_bbox=torch.tensor( @@ -109,7 +116,7 @@ class Trainer: device=self.device ) - self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=4) + self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=8) logger.gpu_memory_stats("数初始化后") self.model = Net( @@ -139,7 +146,8 @@ class Trainer: face_indices=np.arange(num_faces), # 初始包含所有面 patch_graph=graph, max_depth=max_depth, - surf_bbox=surf_bbox + surf_bbox=surf_bbox, + ) #print(surf_bbox) logger.info("starting octree conduction") @@ -160,8 +168,7 @@ class Trainer: # 直接定义固定的单位立方体边界框 # 注意:确保张量在正确的设备上创建 fixed_bbox = torch.tensor([-0.5, -0.5, -0.5, 0.5, 0.5, 0.5], - dtype=torch.float32, - device=self.device) # 假设 self.device 存储了目标设备 + dtype=torch.float32) # 假设 self.device 存储了目标设备 logger.debug(f"使用固定的全局边界框: {fixed_bbox.cpu().numpy()}") return fixed_bbox @@ -277,7 +284,6 @@ class Trainer: torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 范数裁剪 self.optimizer.step() - except Exception as backward_e: logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) # 如果你想看是哪个操作导致的,可以启用 anomaly detection @@ -300,6 +306,8 @@ class Trainer: # (如果你的训练分批次,这里应该继续循环下一批次) # step += 1 + del loss + torch.cuda.empty_cache() # 清空缓存 return total_loss # 对于单批次训练,直接返回当前损失 diff --git a/brep2sdf/utils/logger.py b/brep2sdf/utils/logger.py index 0f6f728..33bb409 100644 --- a/brep2sdf/utils/logger.py +++ b/brep2sdf/utils/logger.py @@ -205,17 +205,25 @@ class BRepLogger: if not torch.cuda.is_available(): return - torch.cuda.synchronize() # 同步所有CUDA操作 + torch.cuda.synchronize() + # 新增类变量记录上次内存状态 + if not hasattr(self, '_last_allocated'): + self._last_allocated = 0 + allocated = torch.cuda.memory_allocated() / 1024**2 reserved = torch.cuda.memory_reserved() / 1024**2 + delta_allocated = allocated - self._last_allocated # 计算增量 max_allocated = torch.cuda.max_memory_allocated() / 1024**2 + # 更新最后记录值 + self._last_allocated = allocated tag_str = f" [{tag}]" if tag else "" stats = [] stats.append(f"\n=== GPU内存状态{tag_str} ===") - stats.append(f" 已分配: {allocated:.1f} MB") - stats.append(f" 已缓存: {reserved:.1f} MB") + stats.append(f" 当前分配: {allocated:.1f} MB") + stats.append(f" 增量分配: {delta_allocated:.1f} MB") + stats.append(f" 缓存保留: {reserved:.1f} MB") stats.append(f" 峰值: {max_allocated:.1f} MB") # 一次性输出所有统计信息 @@ -252,6 +260,7 @@ class BRepLogger: if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() torch.cuda.empty_cache() + self._last_allocated = 0 # 重置基准值 self.info("已重置GPU内存统计") def timeit(func):