You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

101 lines
4.4 KiB

from typing import Tuple, List
import torch
import torch.nn as nn
class PatchFeatureVolume(nn.Module):
def __init__(self, bbox: torch.Tensor, resolution=64, feature_dim=8, padding_ratio=0.05):
super(PatchFeatureVolume, self).__init__()
# 将输入bbox转换为[min, max]格式
self.resolution = resolution
min_coords = bbox[:3]
max_coords = bbox[3:]
self.original_bbox = torch.stack([min_coords, max_coords])
expanded_bbox = self._expand_bbox(min_coords, max_coords, padding_ratio)
# 创建规则的三维网格
x = torch.linspace(expanded_bbox[0][0], expanded_bbox[1][0], resolution)
y = torch.linspace(expanded_bbox[0][1], expanded_bbox[1][1], resolution)
z = torch.linspace(expanded_bbox[0][2], expanded_bbox[1][2], resolution)
grid_x, grid_y, grid_z = torch.meshgrid(x, y, z)
self.register_buffer('grid', torch.stack([grid_x, grid_y, grid_z], dim=-1))
# 初始化特征向量为很小的值,使用较小的标准差
self.feature_volume = nn.Parameter(torch.empty(resolution, resolution, resolution, feature_dim))
torch.nn.init.normal_(self.feature_volume, mean=0.0, std=0.01) # 标准差设置为 0.01,可根据需要调整
def _expand_bbox(self, min_coords, max_coords, ratio):
# 扩展包围盒范围
center = (min_coords + max_coords) / 2
expanded_min = center - (center - min_coords) * (1 + ratio)
expanded_max = center + (max_coords - center) * (1 + ratio)
return torch.stack([expanded_min, expanded_max])
def forward(self, query_points: torch.Tensor) -> torch.Tensor:
"""批量处理版本的三线性插值
Args:
query_points: 形状为 (B, 3) 的查询点坐标
Returns:
形状为 (B, D) 的特征向量
"""
# 添加类型转换确保计算稳定性
normalized = ((query_points - self.grid[0,0,0]) /
(self.grid[-1,-1,-1] - self.grid[0,0,0] + 1e-8)) # (B,3)
# 向量化三线性插值
return self._batched_trilinear(normalized)
def _batched_trilinear(self, normalized: torch.Tensor) -> torch.Tensor:
"""批量处理的三线性插值"""
# 计算8个顶点的权重
uvw = normalized * (self.resolution - 1)
indices = torch.floor(uvw).long() # (B,3)
weights = uvw - indices.float() # (B,3)
# 计算8个顶点的权重组合 (B,8)
weights = torch.stack([
(1 - weights[...,0]) * (1 - weights[...,1]) * (1 - weights[...,2]),
(1 - weights[...,0]) * (1 - weights[...,1]) * weights[...,2],
(1 - weights[...,0]) * weights[...,1] * (1 - weights[...,2]),
(1 - weights[...,0]) * weights[...,1] * weights[...,2],
weights[...,0] * (1 - weights[...,1]) * (1 - weights[...,2]),
weights[...,0] * (1 - weights[...,1]) * weights[...,2],
weights[...,0] * weights[...,1] * (1 - weights[...,2]),
weights[...,0] * weights[...,1] * weights[...,2],
], dim=-1) # (B,8)
# 获取8个顶点的特征 (B,8,D)
indices = indices.unsqueeze(1).expand(-1,8,-1) + torch.tensor([
[0,0,0], [0,0,1], [0,1,0], [0,1,1],
[1,0,0], [1,0,1], [1,1,0], [1,1,1]
], device=indices.device)
indices = torch.clamp(indices, 0, self.resolution-1)
features = self.feature_volume[indices[...,0], indices[...,1], indices[...,2]] # (B,8,D)
# 加权求和 (B,D)
return torch.einsum('bnd,bn->bd', features, weights)
class SimpleFeatureEncoder(nn.Module):
def __init__(self, input_dim=3, feature_dim=64):
super(SimpleFeatureEncoder, self).__init__()
# 定义一个多层感知机作为编码器
self.encoder = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 256),
nn.ReLU(inplace=True),
nn.Linear(256, 512),
nn.ReLU(inplace=True),
nn.Linear(512, feature_dim)
)
def forward(self, query_points: torch.Tensor) -> torch.Tensor:
"""
Args:
query_points: 形状为 (B, 3) 的查询点坐标
Returns:
形状为 (B, feature_dim) 的特征向量
"""
return self.encoder(query_points)