Browse Source

cuda加速encoder forward,但是效果不好,做保存

final
mckay 2 weeks ago
parent
commit
60abaf6aeb
  1. 34
      brep2sdf/networks/encoder.py

34
brep2sdf/networks/encoder.py

@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
import time
from .octree import OctreeNode from .octree import OctreeNode
from .feature_volume import PatchFeatureVolume,SimpleFeatureEncoder from .feature_volume import PatchFeatureVolume,SimpleFeatureEncoder
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
@ -95,17 +95,31 @@ class Encoder(nn.Module):
all_features = torch.zeros(batch_size, num_volumes, self.feature_dim, all_features = torch.zeros(batch_size, num_volumes, self.feature_dim,
device=query_points.device) device=query_points.device)
background_features = self.background.forward(query_points) # (B, D) background_features = self.background.forward(query_points) # (B, D)
start_time = time.time()
# 创建 CUDA 流
streams = [torch.cuda.Stream() for _ in range(len(self.feature_volumes))]
features_list = [None] * len(self.feature_volumes)
# 遍历每个volume索引 # 并行计算
for vol_id, volume in enumerate(self.feature_volumes): for vol_id, volume in enumerate(self.feature_volumes):
mask = volume_indices_mask[:, vol_id].squeeze() mask = volume_indices_mask[:, vol_id].squeeze()
#logger.debug(f"mask:{mask},shape:{mask.shape},mask.any():{mask.any()}") if not mask.any():
if mask.any(): continue
# 获取对应volume的特征 (M, D) with torch.cuda.stream(streams[vol_id]):
features = volume.forward(query_points[mask]) features = volume(query_points[mask])
all_features[mask, vol_id] = 0.9 * background_features[mask] + 0.1 * features features_list[vol_id] = (mask, features)
#all_features[:, :] = background_features.unsqueeze(1) # 同步流
torch.cuda.synchronize()
# 写入结果
for vol_id, item in enumerate(features_list):
if item is None:
continue
mask, features = item
all_features[mask, vol_id] = 0.1 * background_features[mask] + 0.9 * features
end_time = time.time()
logger.debug(f"duration:{end_time-start_time}")
return all_features return all_features
def forward_background(self, query_points: torch.Tensor) -> torch.Tensor: def forward_background(self, query_points: torch.Tensor) -> torch.Tensor:
@ -135,7 +149,7 @@ class Encoder(nn.Module):
background_features = self.background.forward(surf_points) # (B, D) background_features = self.background.forward(surf_points) # (B, D)
#dot = make_dot(patch_features, params=dict(self.feature_volumes.named_parameters())) #dot = make_dot(patch_features, params=dict(self.feature_volumes.named_parameters()))
#dot.render("feature_extraction", format="png") # 将计算图保存为 PDF 文件 #dot.render("feature_extraction", format="png") # 将计算图保存为 PDF 文件
return 0.9 * background_features + 0.1 * patch_features return 0.1 * background_features + 0.9 * patch_features
def to(self, device): def to(self, device):
super().to(device) super().to(device)

Loading…
Cancel
Save