|
|
@ -1,7 +1,7 @@ |
|
|
|
import torch |
|
|
|
import torch.nn as nn |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
import time |
|
|
|
from .octree import OctreeNode |
|
|
|
from .feature_volume import PatchFeatureVolume,SimpleFeatureEncoder |
|
|
|
from brep2sdf.utils.logger import logger |
|
|
@ -95,17 +95,31 @@ class Encoder(nn.Module): |
|
|
|
all_features = torch.zeros(batch_size, num_volumes, self.feature_dim, |
|
|
|
device=query_points.device) |
|
|
|
background_features = self.background.forward(query_points) # (B, D) |
|
|
|
start_time = time.time() |
|
|
|
# 创建 CUDA 流 |
|
|
|
streams = [torch.cuda.Stream() for _ in range(len(self.feature_volumes))] |
|
|
|
features_list = [None] * len(self.feature_volumes) |
|
|
|
|
|
|
|
# 遍历每个volume索引 |
|
|
|
# 并行计算 |
|
|
|
for vol_id, volume in enumerate(self.feature_volumes): |
|
|
|
mask = volume_indices_mask[:, vol_id].squeeze() |
|
|
|
#logger.debug(f"mask:{mask},shape:{mask.shape},mask.any():{mask.any()}") |
|
|
|
if mask.any(): |
|
|
|
# 获取对应volume的特征 (M, D) |
|
|
|
features = volume.forward(query_points[mask]) |
|
|
|
all_features[mask, vol_id] = 0.9 * background_features[mask] + 0.1 * features |
|
|
|
|
|
|
|
#all_features[:, :] = background_features.unsqueeze(1) |
|
|
|
if not mask.any(): |
|
|
|
continue |
|
|
|
with torch.cuda.stream(streams[vol_id]): |
|
|
|
features = volume(query_points[mask]) |
|
|
|
features_list[vol_id] = (mask, features) |
|
|
|
|
|
|
|
# 同步流 |
|
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
# 写入结果 |
|
|
|
for vol_id, item in enumerate(features_list): |
|
|
|
if item is None: |
|
|
|
continue |
|
|
|
mask, features = item |
|
|
|
all_features[mask, vol_id] = 0.1 * background_features[mask] + 0.9 * features |
|
|
|
end_time = time.time() |
|
|
|
logger.debug(f"duration:{end_time-start_time}") |
|
|
|
return all_features |
|
|
|
|
|
|
|
def forward_background(self, query_points: torch.Tensor) -> torch.Tensor: |
|
|
@ -135,7 +149,7 @@ class Encoder(nn.Module): |
|
|
|
background_features = self.background.forward(surf_points) # (B, D) |
|
|
|
#dot = make_dot(patch_features, params=dict(self.feature_volumes.named_parameters())) |
|
|
|
#dot.render("feature_extraction", format="png") # 将计算图保存为 PDF 文件 |
|
|
|
return 0.9 * background_features + 0.1 * patch_features |
|
|
|
return 0.1 * background_features + 0.9 * patch_features |
|
|
|
|
|
|
|
def to(self, device): |
|
|
|
super().to(device) |
|
|
|