You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
101 lines
4.4 KiB
101 lines
4.4 KiB
from typing import Tuple, List
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
class PatchFeatureVolume(nn.Module):
|
|
def __init__(self, bbox: torch.Tensor, resolution=64, feature_dim=8, padding_ratio=0.05):
|
|
super(PatchFeatureVolume, self).__init__()
|
|
# 将输入bbox转换为[min, max]格式
|
|
self.resolution = resolution
|
|
min_coords = bbox[:3]
|
|
max_coords = bbox[3:]
|
|
self.original_bbox = torch.stack([min_coords, max_coords])
|
|
expanded_bbox = self._expand_bbox(min_coords, max_coords, padding_ratio)
|
|
# 创建规则的三维网格
|
|
x = torch.linspace(expanded_bbox[0][0], expanded_bbox[1][0], resolution)
|
|
y = torch.linspace(expanded_bbox[0][1], expanded_bbox[1][1], resolution)
|
|
z = torch.linspace(expanded_bbox[0][2], expanded_bbox[1][2], resolution)
|
|
grid_x, grid_y, grid_z = torch.meshgrid(x, y, z)
|
|
self.register_buffer('grid', torch.stack([grid_x, grid_y, grid_z], dim=-1))
|
|
|
|
# 初始化特征向量为很小的值,使用较小的标准差
|
|
self.feature_volume = nn.Parameter(torch.empty(resolution, resolution, resolution, feature_dim))
|
|
torch.nn.init.normal_(self.feature_volume, mean=0.0, std=0.01) # 标准差设置为 0.01,可根据需要调整
|
|
|
|
def _expand_bbox(self, min_coords, max_coords, ratio):
|
|
# 扩展包围盒范围
|
|
center = (min_coords + max_coords) / 2
|
|
expanded_min = center - (center - min_coords) * (1 + ratio)
|
|
expanded_max = center + (max_coords - center) * (1 + ratio)
|
|
return torch.stack([expanded_min, expanded_max])
|
|
|
|
def forward(self, query_points: torch.Tensor) -> torch.Tensor:
|
|
"""批量处理版本的三线性插值
|
|
Args:
|
|
query_points: 形状为 (B, 3) 的查询点坐标
|
|
Returns:
|
|
形状为 (B, D) 的特征向量
|
|
"""
|
|
# 添加类型转换确保计算稳定性
|
|
normalized = ((query_points - self.grid[0,0,0]) /
|
|
(self.grid[-1,-1,-1] - self.grid[0,0,0] + 1e-8)) # (B,3)
|
|
|
|
# 向量化三线性插值
|
|
return self._batched_trilinear(normalized)
|
|
|
|
def _batched_trilinear(self, normalized: torch.Tensor) -> torch.Tensor:
|
|
"""批量处理的三线性插值"""
|
|
# 计算8个顶点的权重
|
|
uvw = normalized * (self.resolution - 1)
|
|
indices = torch.floor(uvw).long() # (B,3)
|
|
weights = uvw - indices.float() # (B,3)
|
|
|
|
# 计算8个顶点的权重组合 (B,8)
|
|
weights = torch.stack([
|
|
(1 - weights[...,0]) * (1 - weights[...,1]) * (1 - weights[...,2]),
|
|
(1 - weights[...,0]) * (1 - weights[...,1]) * weights[...,2],
|
|
(1 - weights[...,0]) * weights[...,1] * (1 - weights[...,2]),
|
|
(1 - weights[...,0]) * weights[...,1] * weights[...,2],
|
|
weights[...,0] * (1 - weights[...,1]) * (1 - weights[...,2]),
|
|
weights[...,0] * (1 - weights[...,1]) * weights[...,2],
|
|
weights[...,0] * weights[...,1] * (1 - weights[...,2]),
|
|
weights[...,0] * weights[...,1] * weights[...,2],
|
|
], dim=-1) # (B,8)
|
|
|
|
# 获取8个顶点的特征 (B,8,D)
|
|
indices = indices.unsqueeze(1).expand(-1,8,-1) + torch.tensor([
|
|
[0,0,0], [0,0,1], [0,1,0], [0,1,1],
|
|
[1,0,0], [1,0,1], [1,1,0], [1,1,1]
|
|
], device=indices.device)
|
|
indices = torch.clamp(indices, 0, self.resolution-1)
|
|
|
|
features = self.feature_volume[indices[...,0], indices[...,1], indices[...,2]] # (B,8,D)
|
|
|
|
# 加权求和 (B,D)
|
|
return torch.einsum('bnd,bn->bd', features, weights)
|
|
|
|
|
|
class SimpleFeatureEncoder(nn.Module):
|
|
def __init__(self, input_dim=3, feature_dim=64):
|
|
super(SimpleFeatureEncoder, self).__init__()
|
|
# 定义一个多层感知机作为编码器
|
|
self.encoder = nn.Sequential(
|
|
nn.Linear(input_dim, 128),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(128, 256),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(256, 512),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(512, feature_dim)
|
|
)
|
|
|
|
def forward(self, query_points: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
query_points: 形状为 (B, 3) 的查询点坐标
|
|
|
|
Returns:
|
|
形状为 (B, feature_dim) 的特征向量
|
|
"""
|
|
return self.encoder(query_points)
|