From f0af2208a909205f34585ac297ba0bbe1fe45f59 Mon Sep 17 00:00:00 2001 From: mckay Date: Mon, 21 Apr 2025 21:18:20 +0800 Subject: [PATCH] =?UTF-8?q?octree=E5=8D=95=E7=8B=AC=E4=BD=9C=E4=B8=BA?= =?UTF-8?q?=E4=B8=80=E4=B8=AA=E6=A8=A1=E5=9D=97=EF=BC=8C=E4=B8=94=E8=A7=A3?= =?UTF-8?q?=E5=86=B3=E6=98=BE=E5=AD=98=E4=B8=8D=E8=B6=B3=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/networks/encoder.py | 23 ++++++----------------- brep2sdf/networks/loss.py | 27 ++++++++++++++++----------- brep2sdf/networks/network.py | 10 +++++++--- brep2sdf/networks/octree.py | 14 +++++++++++--- brep2sdf/train.py | 2 +- 5 files changed, 41 insertions(+), 35 deletions(-) diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index 3e68850..6d3ddcf 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -14,16 +14,15 @@ class Encoder(nn.Module): feature_dim: 特征维度 """ super().__init__() - self.octree = octree self.feature_dim = feature_dim # 初始化叶子节点参数 self.register_buffer('num_parameters', torch.tensor(0, dtype=torch.long)) self._leaf_features = None # 将在_init_parameters中初始化 - self._init_parameters() + self._init_parameters(octree) - def _init_parameters(self): - stack = [(self.octree, 0)] + def _init_parameters(self,octree): + stack = [(octree, 0)] param_count = 0 while stack: @@ -39,7 +38,7 @@ class Encoder(nn.Module): torch.randn(param_count, 8, self.feature_dim)) # 重新遍历设置索引 - stack = [(self.octree, 0)] + stack = [(octree, 0)] index = 0 while stack: node, _ = stack.pop() @@ -51,19 +50,9 @@ class Encoder(nn.Module): if child: stack.append((child, 0)) self.num_parameters.fill_(index) - def forward(self, query_points: torch.Tensor) -> torch.Tensor: + def forward(self, query_points: torch.Tensor,param_indices,bboxes) -> torch.Tensor: batch_size = query_points.shape[0] - - # 批量查询所有点的索引和bbox - with torch.no_grad(): - param_indices, bboxes = [], [] - for point in query_points: - bbox, idx, _ = self.octree.find_leaf(point) - param_indices.append(idx) - bboxes.append(bbox) - param_indices = torch.stack(param_indices) - bboxes = torch.stack(bboxes) - + # 批量获取特征 unique_ids, inverse_ids = torch.unique(param_indices, return_inverse=True) all_features = self._leaf_features[unique_ids] # (U, 8, D) diff --git a/brep2sdf/networks/loss.py b/brep2sdf/networks/loss.py index ebf6777..d6fcdd9 100644 --- a/brep2sdf/networks/loss.py +++ b/brep2sdf/networks/loss.py @@ -1,6 +1,6 @@ import torch from .network import gradient - +from brep2sdf.utils.logger import logger class LossManager: def __init__(self, ablation, **condition_kwargs): @@ -49,20 +49,22 @@ class LossManager: def position_loss(self, pred_sdfs: torch.Tensor, gt_sdfs: torch.Tensor) -> torch.Tensor: """ 计算流型损失的逻辑 - + :param pred_sdfs: 预测的SDF值,形状为 (N, 1) :param gt_sdfs: 真实的SDF值,形状为 (N, 1) :return: 计算得到的流型损失,标量 """ - # 计算预测值与真实值的差 - diff = pred_sdfs - gt_sdfs - - # 计算平方误差 - squared_diff = torch.pow(diff, 2) - - # 计算均值 - manifold_loss = torch.mean(squared_diff) - + with torch.no_grad(): # 当前上下文 + # 显式分离张量 + pred_sdfs = pred_sdfs.detach() + gt_sdfs = gt_sdfs.detach() + squared_diff = torch.pow(pred_sdfs - gt_sdfs, 2) + manifold_loss = torch.mean(squared_diff) + + # 显式释放中间变量 + del squared_diff + torch.cuda.empty_cache() # 立即释放缓存 + return manifold_loss def normals_loss(self, normals: torch.Tensor, mnfld_pnts: torch.Tensor, pred_sdfs) -> torch.Tensor: @@ -135,10 +137,13 @@ class LossManager: :return: 计算得到的流型损失值 """ # 计算流形损失 + #logger.gpu_memory_stats("计算流型损失前") manifold_loss = self.position_loss(pred_sdfs,gt_sdfs) + #logger.gpu_memory_stats("计算流型损失后") # 计算法线损失 normals_loss = self.normals_loss(normals, points, pred_sdfs) + #logger.gpu_memory_stats("计算法线损失后") # 汇总损失 loss_details = { diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index 3a56d23..bd5d15b 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -63,11 +63,13 @@ class Net(nn.Module): decoder_skip_connections=True): super().__init__() + + self.octree_module = octree # 初始化 Encoder self.encoder = Encoder( - octree=octree, - feature_dim=feature_dim + feature_dim=feature_dim, + octree=octree ) # 初始化 Decoder @@ -82,8 +84,10 @@ class Net(nn.Module): 返回: output: 解码后的输出结果 """ + # 批量查询所有点的索引和bbox + param_indices,bboxes = self.octree_module.forward(query_points) # 编码 - feature_vector = self.encoder.forward(query_points) + feature_vector = self.encoder.forward(query_points,param_indices,bboxes) # 解码 output = self.decoder(feature_vector) diff --git a/brep2sdf/networks/octree.py b/brep2sdf/networks/octree.py index 844017b..8e66baa 100644 --- a/brep2sdf/networks/octree.py +++ b/brep2sdf/networks/octree.py @@ -5,7 +5,6 @@ import torch.nn as nn import numpy as np from brep2sdf.networks.patch_graph import PatchGraph -from brep2sdf.utils.logger import logger @@ -63,7 +62,6 @@ class OctreeNode(nn.Module): """构建静态八叉树结构""" # 预计算所有可能的节点数量,确保结果为整数 total_nodes = int(sum(8**i for i in range(self.max_depth + 1))) - logger.info(f"总节点数量: {total_nodes}") # 初始化静态张量,使用整数列表作为形状参数 self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.bbox.device) @@ -71,7 +69,6 @@ class OctreeNode(nn.Module): self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.bbox.device) self.is_leaf_mask = torch.zeros([int(total_nodes)], dtype=torch.bool, device=self.bbox.device) - logger.gpu_memory_stats("树初始化后") # 使用队列进行广度优先遍历 queue = [(0, self.bbox, self.face_indices)] # (node_idx, bbox, face_indices) current_idx = 0 @@ -193,6 +190,17 @@ class OctreeNode(nn.Module): mid_points = (bboxes[:, :3] + bboxes[:, 3:]) / 2 return ((points >= mid_points) << torch.arange(3, device=points.device)).sum(dim=1) + def forward(self, query_points): + with torch.no_grad(): + param_indices, bboxes = [], [] + for point in query_points: + bbox, idx, _ = self.find_leaf(point) + param_indices.append(idx) + bboxes.append(bbox) + param_indices = torch.stack(param_indices) + bboxes = torch.stack(bboxes) + return param_indices, bboxes + def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None: """ 递归打印八叉树结构 diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 2a71781..b624168 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -237,7 +237,7 @@ class Trainer: # 检查法线和带梯度的点 #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") - + logger.gpu_memory_stats("计算损失前") loss, loss_details = self.loss_manager.compute_loss( points, normals, # 传递检查过的 normals