From 60abaf6aebd5dd52daa11663be71bdf0a114631e Mon Sep 17 00:00:00 2001 From: mckay Date: Fri, 23 May 2025 15:32:31 +0800 Subject: [PATCH] =?UTF-8?q?cuda=E5=8A=A0=E9=80=9Fencoder=20forward?= =?UTF-8?q?=EF=BC=8C=E4=BD=86=E6=98=AF=E6=95=88=E6=9E=9C=E4=B8=8D=E5=A5=BD?= =?UTF-8?q?=EF=BC=8C=E5=81=9A=E4=BF=9D=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/networks/encoder.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index 61c5168..49e4803 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import numpy as np - +import time from .octree import OctreeNode from .feature_volume import PatchFeatureVolume,SimpleFeatureEncoder from brep2sdf.utils.logger import logger @@ -95,17 +95,31 @@ class Encoder(nn.Module): all_features = torch.zeros(batch_size, num_volumes, self.feature_dim, device=query_points.device) background_features = self.background.forward(query_points) # (B, D) + start_time = time.time() + # 创建 CUDA 流 + streams = [torch.cuda.Stream() for _ in range(len(self.feature_volumes))] + features_list = [None] * len(self.feature_volumes) - # 遍历每个volume索引 + # 并行计算 for vol_id, volume in enumerate(self.feature_volumes): mask = volume_indices_mask[:, vol_id].squeeze() - #logger.debug(f"mask:{mask},shape:{mask.shape},mask.any():{mask.any()}") - if mask.any(): - # 获取对应volume的特征 (M, D) - features = volume.forward(query_points[mask]) - all_features[mask, vol_id] = 0.9 * background_features[mask] + 0.1 * features - - #all_features[:, :] = background_features.unsqueeze(1) + if not mask.any(): + continue + with torch.cuda.stream(streams[vol_id]): + features = volume(query_points[mask]) + features_list[vol_id] = (mask, features) + + # 同步流 + torch.cuda.synchronize() + + # 写入结果 + for vol_id, item in enumerate(features_list): + if item is None: + continue + mask, features = item + all_features[mask, vol_id] = 0.1 * background_features[mask] + 0.9 * features + end_time = time.time() + logger.debug(f"duration:{end_time-start_time}") return all_features def forward_background(self, query_points: torch.Tensor) -> torch.Tensor: @@ -135,7 +149,7 @@ class Encoder(nn.Module): background_features = self.background.forward(surf_points) # (B, D) #dot = make_dot(patch_features, params=dict(self.feature_volumes.named_parameters())) #dot.render("feature_extraction", format="png") # 将计算图保存为 PDF 文件 - return 0.9 * background_features + 0.1 * patch_features + return 0.1 * background_features + 0.9 * patch_features def to(self, device): super().to(device)