Browse Source

cuda加速encoder forward,但是效果不好,做保存

final
mckay 2 weeks ago
parent
commit
60abaf6aeb
  1. 34
      brep2sdf/networks/encoder.py

34
brep2sdf/networks/encoder.py

@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import numpy as np
import time
from .octree import OctreeNode
from .feature_volume import PatchFeatureVolume,SimpleFeatureEncoder
from brep2sdf.utils.logger import logger
@ -95,17 +95,31 @@ class Encoder(nn.Module):
all_features = torch.zeros(batch_size, num_volumes, self.feature_dim,
device=query_points.device)
background_features = self.background.forward(query_points) # (B, D)
start_time = time.time()
# 创建 CUDA 流
streams = [torch.cuda.Stream() for _ in range(len(self.feature_volumes))]
features_list = [None] * len(self.feature_volumes)
# 遍历每个volume索引
# 并行计算
for vol_id, volume in enumerate(self.feature_volumes):
mask = volume_indices_mask[:, vol_id].squeeze()
#logger.debug(f"mask:{mask},shape:{mask.shape},mask.any():{mask.any()}")
if mask.any():
# 获取对应volume的特征 (M, D)
features = volume.forward(query_points[mask])
all_features[mask, vol_id] = 0.9 * background_features[mask] + 0.1 * features
#all_features[:, :] = background_features.unsqueeze(1)
if not mask.any():
continue
with torch.cuda.stream(streams[vol_id]):
features = volume(query_points[mask])
features_list[vol_id] = (mask, features)
# 同步流
torch.cuda.synchronize()
# 写入结果
for vol_id, item in enumerate(features_list):
if item is None:
continue
mask, features = item
all_features[mask, vol_id] = 0.1 * background_features[mask] + 0.9 * features
end_time = time.time()
logger.debug(f"duration:{end_time-start_time}")
return all_features
def forward_background(self, query_points: torch.Tensor) -> torch.Tensor:
@ -135,7 +149,7 @@ class Encoder(nn.Module):
background_features = self.background.forward(surf_points) # (B, D)
#dot = make_dot(patch_features, params=dict(self.feature_volumes.named_parameters()))
#dot.render("feature_extraction", format="png") # 将计算图保存为 PDF 文件
return 0.9 * background_features + 0.1 * patch_features
return 0.1 * background_features + 0.9 * patch_features
def to(self, device):
super().to(device)

Loading…
Cancel
Save