|
|
@ -53,9 +53,9 @@ class Encoder(nn.Module): |
|
|
|
|
|
|
|
# 根据归一化后的对角线长度调整分辨率 |
|
|
|
resolutions = torch.zeros_like(diagonals, dtype=torch.long) |
|
|
|
resolutions[diagonals > 1.0] = 16 # 大尺寸 |
|
|
|
resolutions[(diagonals > 0.5) & (diagonals <= 1.0)] = 8 # 中等尺寸 |
|
|
|
resolutions[diagonals <= 0.5] = 4 # 小尺寸 |
|
|
|
resolutions[diagonals > 1.0] = 64 # 大尺寸 |
|
|
|
resolutions[(diagonals > 0.5) & (diagonals <= 1.0)] = 32 # 中等尺寸 |
|
|
|
resolutions[diagonals <= 0.5] = 16 # 小尺寸 |
|
|
|
|
|
|
|
return resolutions |
|
|
|
|
|
|
@ -75,19 +75,16 @@ class Encoder(nn.Module): |
|
|
|
all_features = torch.zeros(batch_size, num_volumes, self.feature_dim, |
|
|
|
device=query_points.device) |
|
|
|
background_features = self.background.forward(query_points) # (B, D) |
|
|
|
# 遍历每个volume索引 |
|
|
|
for p in range(num_volumes): |
|
|
|
# 获取当前volume的索引 (B,) |
|
|
|
current_indices = volume_indices[:, p] |
|
|
|
|
|
|
|
# 遍历所有存在的volume |
|
|
|
# 遍历每个volume索引 |
|
|
|
for vol_id, volume in enumerate(self.feature_volumes): |
|
|
|
current_indices = volume_indices[:, vol_id] |
|
|
|
# 创建掩码 (B,) |
|
|
|
mask = (current_indices == vol_id) |
|
|
|
if mask.any(): |
|
|
|
# 获取对应volume的特征 (M, D) |
|
|
|
features = volume.forward(query_points[mask]) |
|
|
|
all_features[mask, p] = 0.7 * features + 0.3 * background_features[mask] |
|
|
|
all_features[mask, vol_id] = 0.7 * features + 0.3 * background_features[mask] |
|
|
|
|
|
|
|
return all_features |
|
|
|
|
|
|
|