Browse Source

octree单独作为一个模块,且解决显存不足问题

final
mckay 2 months ago
parent
commit
f0af2208a9
  1. 23
      brep2sdf/networks/encoder.py
  2. 27
      brep2sdf/networks/loss.py
  3. 10
      brep2sdf/networks/network.py
  4. 14
      brep2sdf/networks/octree.py
  5. 2
      brep2sdf/train.py

23
brep2sdf/networks/encoder.py

@ -14,16 +14,15 @@ class Encoder(nn.Module):
feature_dim: 特征维度 feature_dim: 特征维度
""" """
super().__init__() super().__init__()
self.octree = octree
self.feature_dim = feature_dim self.feature_dim = feature_dim
# 初始化叶子节点参数 # 初始化叶子节点参数
self.register_buffer('num_parameters', torch.tensor(0, dtype=torch.long)) self.register_buffer('num_parameters', torch.tensor(0, dtype=torch.long))
self._leaf_features = None # 将在_init_parameters中初始化 self._leaf_features = None # 将在_init_parameters中初始化
self._init_parameters() self._init_parameters(octree)
def _init_parameters(self): def _init_parameters(self,octree):
stack = [(self.octree, 0)] stack = [(octree, 0)]
param_count = 0 param_count = 0
while stack: while stack:
@ -39,7 +38,7 @@ class Encoder(nn.Module):
torch.randn(param_count, 8, self.feature_dim)) torch.randn(param_count, 8, self.feature_dim))
# 重新遍历设置索引 # 重新遍历设置索引
stack = [(self.octree, 0)] stack = [(octree, 0)]
index = 0 index = 0
while stack: while stack:
node, _ = stack.pop() node, _ = stack.pop()
@ -51,19 +50,9 @@ class Encoder(nn.Module):
if child: stack.append((child, 0)) if child: stack.append((child, 0))
self.num_parameters.fill_(index) self.num_parameters.fill_(index)
def forward(self, query_points: torch.Tensor) -> torch.Tensor: def forward(self, query_points: torch.Tensor,param_indices,bboxes) -> torch.Tensor:
batch_size = query_points.shape[0] batch_size = query_points.shape[0]
# 批量查询所有点的索引和bbox
with torch.no_grad():
param_indices, bboxes = [], []
for point in query_points:
bbox, idx, _ = self.octree.find_leaf(point)
param_indices.append(idx)
bboxes.append(bbox)
param_indices = torch.stack(param_indices)
bboxes = torch.stack(bboxes)
# 批量获取特征 # 批量获取特征
unique_ids, inverse_ids = torch.unique(param_indices, return_inverse=True) unique_ids, inverse_ids = torch.unique(param_indices, return_inverse=True)
all_features = self._leaf_features[unique_ids] # (U, 8, D) all_features = self._leaf_features[unique_ids] # (U, 8, D)

27
brep2sdf/networks/loss.py

@ -1,6 +1,6 @@
import torch import torch
from .network import gradient from .network import gradient
from brep2sdf.utils.logger import logger
class LossManager: class LossManager:
def __init__(self, ablation, **condition_kwargs): def __init__(self, ablation, **condition_kwargs):
@ -49,20 +49,22 @@ class LossManager:
def position_loss(self, pred_sdfs: torch.Tensor, gt_sdfs: torch.Tensor) -> torch.Tensor: def position_loss(self, pred_sdfs: torch.Tensor, gt_sdfs: torch.Tensor) -> torch.Tensor:
""" """
计算流型损失的逻辑 计算流型损失的逻辑
:param pred_sdfs: 预测的SDF值形状为 (N, 1) :param pred_sdfs: 预测的SDF值形状为 (N, 1)
:param gt_sdfs: 真实的SDF值形状为 (N, 1) :param gt_sdfs: 真实的SDF值形状为 (N, 1)
:return: 计算得到的流型损失标量 :return: 计算得到的流型损失标量
""" """
# 计算预测值与真实值的差 with torch.no_grad(): # 当前上下文
diff = pred_sdfs - gt_sdfs # 显式分离张量
pred_sdfs = pred_sdfs.detach()
# 计算平方误差 gt_sdfs = gt_sdfs.detach()
squared_diff = torch.pow(diff, 2) squared_diff = torch.pow(pred_sdfs - gt_sdfs, 2)
manifold_loss = torch.mean(squared_diff)
# 计算均值
manifold_loss = torch.mean(squared_diff) # 显式释放中间变量
del squared_diff
torch.cuda.empty_cache() # 立即释放缓存
return manifold_loss return manifold_loss
def normals_loss(self, normals: torch.Tensor, mnfld_pnts: torch.Tensor, pred_sdfs) -> torch.Tensor: def normals_loss(self, normals: torch.Tensor, mnfld_pnts: torch.Tensor, pred_sdfs) -> torch.Tensor:
@ -135,10 +137,13 @@ class LossManager:
:return: 计算得到的流型损失值 :return: 计算得到的流型损失值
""" """
# 计算流形损失 # 计算流形损失
#logger.gpu_memory_stats("计算流型损失前")
manifold_loss = self.position_loss(pred_sdfs,gt_sdfs) manifold_loss = self.position_loss(pred_sdfs,gt_sdfs)
#logger.gpu_memory_stats("计算流型损失后")
# 计算法线损失 # 计算法线损失
normals_loss = self.normals_loss(normals, points, pred_sdfs) normals_loss = self.normals_loss(normals, points, pred_sdfs)
#logger.gpu_memory_stats("计算法线损失后")
# 汇总损失 # 汇总损失
loss_details = { loss_details = {

10
brep2sdf/networks/network.py

@ -63,11 +63,13 @@ class Net(nn.Module):
decoder_skip_connections=True): decoder_skip_connections=True):
super().__init__() super().__init__()
self.octree_module = octree
# 初始化 Encoder # 初始化 Encoder
self.encoder = Encoder( self.encoder = Encoder(
octree=octree, feature_dim=feature_dim,
feature_dim=feature_dim octree=octree
) )
# 初始化 Decoder # 初始化 Decoder
@ -82,8 +84,10 @@ class Net(nn.Module):
返回: 返回:
output: 解码后的输出结果 output: 解码后的输出结果
""" """
# 批量查询所有点的索引和bbox
param_indices,bboxes = self.octree_module.forward(query_points)
# 编码 # 编码
feature_vector = self.encoder.forward(query_points) feature_vector = self.encoder.forward(query_points,param_indices,bboxes)
# 解码 # 解码
output = self.decoder(feature_vector) output = self.decoder(feature_vector)

14
brep2sdf/networks/octree.py

@ -5,7 +5,6 @@ import torch.nn as nn
import numpy as np import numpy as np
from brep2sdf.networks.patch_graph import PatchGraph from brep2sdf.networks.patch_graph import PatchGraph
from brep2sdf.utils.logger import logger
@ -63,7 +62,6 @@ class OctreeNode(nn.Module):
"""构建静态八叉树结构""" """构建静态八叉树结构"""
# 预计算所有可能的节点数量,确保结果为整数 # 预计算所有可能的节点数量,确保结果为整数
total_nodes = int(sum(8**i for i in range(self.max_depth + 1))) total_nodes = int(sum(8**i for i in range(self.max_depth + 1)))
logger.info(f"总节点数量: {total_nodes}")
# 初始化静态张量,使用整数列表作为形状参数 # 初始化静态张量,使用整数列表作为形状参数
self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.bbox.device) self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.bbox.device)
@ -71,7 +69,6 @@ class OctreeNode(nn.Module):
self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.bbox.device) self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.bbox.device)
self.is_leaf_mask = torch.zeros([int(total_nodes)], dtype=torch.bool, device=self.bbox.device) self.is_leaf_mask = torch.zeros([int(total_nodes)], dtype=torch.bool, device=self.bbox.device)
logger.gpu_memory_stats("树初始化后")
# 使用队列进行广度优先遍历 # 使用队列进行广度优先遍历
queue = [(0, self.bbox, self.face_indices)] # (node_idx, bbox, face_indices) queue = [(0, self.bbox, self.face_indices)] # (node_idx, bbox, face_indices)
current_idx = 0 current_idx = 0
@ -193,6 +190,17 @@ class OctreeNode(nn.Module):
mid_points = (bboxes[:, :3] + bboxes[:, 3:]) / 2 mid_points = (bboxes[:, :3] + bboxes[:, 3:]) / 2
return ((points >= mid_points) << torch.arange(3, device=points.device)).sum(dim=1) return ((points >= mid_points) << torch.arange(3, device=points.device)).sum(dim=1)
def forward(self, query_points):
with torch.no_grad():
param_indices, bboxes = [], []
for point in query_points:
bbox, idx, _ = self.find_leaf(point)
param_indices.append(idx)
bboxes.append(bbox)
param_indices = torch.stack(param_indices)
bboxes = torch.stack(bboxes)
return param_indices, bboxes
def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None: def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None:
""" """
递归打印八叉树结构 递归打印八叉树结构

2
brep2sdf/train.py

@ -237,7 +237,7 @@ class Trainer:
# 检查法线和带梯度的点 # 检查法线和带梯度的点
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss")
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss")
logger.gpu_memory_stats("计算损失前")
loss, loss_details = self.loss_manager.compute_loss( loss, loss_details = self.loss_manager.compute_loss(
points, points,
normals, # 传递检查过的 normals normals, # 传递检查过的 normals

Loading…
Cancel
Save