You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
115 lines
4.5 KiB
115 lines
4.5 KiB
import torch
|
|
import torch.nn as nn
|
|
|
|
from .octree import OctreeNode
|
|
from .feature_volume import PatchFeatureVolume
|
|
from brep2sdf.utils.logger import logger
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(self, volume_bboxs:torch.tensor, feature_dim: int = 32):
|
|
"""
|
|
分离后的编码器,接收预构建的八叉树
|
|
|
|
参数:
|
|
volume_bboxs: 所有面片的边界框集合,形状为 (N, 2, 3)
|
|
feature_dim: 特征维度
|
|
"""
|
|
super().__init__()
|
|
self.feature_dim = feature_dim
|
|
|
|
# 批量计算所有bbox的分辨率
|
|
resolutions = self._batch_calculate_resolution(volume_bboxs)
|
|
|
|
# 初始化多个特征体积
|
|
self.feature_volumes = nn.ModuleList([
|
|
PatchFeatureVolume(
|
|
bbox=bbox,
|
|
resolution=int(resolutions[i]),
|
|
feature_dim=feature_dim
|
|
) for i, bbox in enumerate(volume_bboxs)
|
|
])
|
|
print(f"Initialized {len(self.feature_volumes)} feature volumes with resolutions: {resolutions.tolist()}")
|
|
print(f"Model parameters memory usage: {sum(p.numel() * p.element_size() for p in self.parameters()) / 1024**2:.2f} MB")
|
|
|
|
def _batch_calculate_resolution(self, bboxes: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
批量计算归一化bboxes的分辨率
|
|
|
|
参数:
|
|
bboxes: 归一化边界框张量,形状为 (N, 2, 3)
|
|
|
|
返回:
|
|
分辨率张量 (N,)
|
|
"""
|
|
with torch.no_grad():
|
|
# 计算每个bbox的对角线长度(归一化后范围约为0.0-1.732)
|
|
diagonals = torch.norm(bboxes[:,3:6] - bboxes[:,0:3], dim=1)
|
|
|
|
# 根据归一化后的对角线长度调整分辨率
|
|
resolutions = torch.zeros_like(diagonals, dtype=torch.long)
|
|
resolutions[diagonals > 1.0] = 16 # 大尺寸
|
|
resolutions[(diagonals > 0.5) & (diagonals <= 1.0)] = 8 # 中等尺寸
|
|
resolutions[diagonals <= 0.5] = 4 # 小尺寸
|
|
|
|
return resolutions
|
|
|
|
|
|
def forward(self, query_points: torch.Tensor, volume_indices: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
修改后的前向传播,返回所有关联volume的特征矩阵
|
|
|
|
参数:
|
|
query_points: 查询点坐标 (B, 3)
|
|
volume_indices: 关联的volume索引矩阵 (B, K)
|
|
|
|
返回:
|
|
特征张量 (B, K, D)
|
|
"""
|
|
batch_size, num_volumes = volume_indices.shape
|
|
all_features = torch.zeros(batch_size, num_volumes, self.feature_dim,
|
|
device=query_points.device)
|
|
|
|
# 遍历每个volume索引
|
|
for k in range(num_volumes):
|
|
# 获取当前volume的索引 (B,)
|
|
current_indices = volume_indices[:, k]
|
|
|
|
# 遍历所有存在的volume
|
|
for vol_id in range(len(self.feature_volumes)):
|
|
# 创建掩码 (B,)
|
|
mask = (current_indices == vol_id)
|
|
if mask.any():
|
|
# 获取对应volume的特征 (M, D)
|
|
features = self.feature_volumes[vol_id](query_points[mask])
|
|
all_features[mask, k] = features
|
|
|
|
return all_features
|
|
|
|
|
|
def _optimized_trilinear(self, points, bboxes, features):
|
|
"""优化后的向量化三线性插值"""
|
|
# 添加显式类型转换确保计算稳定性
|
|
min_coords = bboxes[..., :3].to(torch.float32)
|
|
max_coords = bboxes[..., 3:].to(torch.float32)
|
|
normalized = (points - min_coords) / (max_coords - min_coords + 1e-8)
|
|
|
|
# 使用爱因斯坦求和代替分步计算
|
|
wx = torch.stack([1 - normalized[...,0], normalized[...,0]], -1) # (B,2)
|
|
wy = torch.stack([1 - normalized[...,1], normalized[...,1]], -1)
|
|
wz = torch.stack([1 - normalized[...,2], normalized[...,2]], -1)
|
|
|
|
# 合并所有计算步骤 (B,8,D) * (B,8,1) -> (B,D)
|
|
return torch.einsum('bcd,bc->bd', features,
|
|
torch.einsum('bi,bj,bk->bijk', wx, wy, wz).view(-1,8))
|
|
|
|
# 原batch_trilinear_interpolation方法可以删除
|
|
|
|
def to(self, device):
|
|
super().to(device)
|
|
def _move_node(node):
|
|
if isinstance(node.bbox, torch.Tensor):
|
|
node.bbox = node.bbox.to(device)
|
|
for child in node.children:
|
|
_move_node(child)
|
|
_move_node(self.octree.root)
|
|
return self
|
|
|