You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

115 lines
4.5 KiB

import torch
import torch.nn as nn
from .octree import OctreeNode
from .feature_volume import PatchFeatureVolume
from brep2sdf.utils.logger import logger
class Encoder(nn.Module):
def __init__(self, volume_bboxs:torch.tensor, feature_dim: int = 32):
"""
分离后的编码器,接收预构建的八叉树
参数:
volume_bboxs: 所有面片的边界框集合,形状为 (N, 2, 3)
feature_dim: 特征维度
"""
super().__init__()
self.feature_dim = feature_dim
# 批量计算所有bbox的分辨率
resolutions = self._batch_calculate_resolution(volume_bboxs)
# 初始化多个特征体积
self.feature_volumes = nn.ModuleList([
PatchFeatureVolume(
bbox=bbox,
resolution=int(resolutions[i]),
feature_dim=feature_dim
) for i, bbox in enumerate(volume_bboxs)
])
print(f"Initialized {len(self.feature_volumes)} feature volumes with resolutions: {resolutions.tolist()}")
print(f"Model parameters memory usage: {sum(p.numel() * p.element_size() for p in self.parameters()) / 1024**2:.2f} MB")
def _batch_calculate_resolution(self, bboxes: torch.Tensor) -> torch.Tensor:
"""
批量计算归一化bboxes的分辨率
参数:
bboxes: 归一化边界框张量,形状为 (N, 2, 3)
返回:
分辨率张量 (N,)
"""
with torch.no_grad():
# 计算每个bbox的对角线长度(归一化后范围约为0.0-1.732)
diagonals = torch.norm(bboxes[:,3:6] - bboxes[:,0:3], dim=1)
# 根据归一化后的对角线长度调整分辨率
resolutions = torch.zeros_like(diagonals, dtype=torch.long)
resolutions[diagonals > 1.0] = 16 # 大尺寸
resolutions[(diagonals > 0.5) & (diagonals <= 1.0)] = 8 # 中等尺寸
resolutions[diagonals <= 0.5] = 4 # 小尺寸
return resolutions
def forward(self, query_points: torch.Tensor, volume_indices: torch.Tensor) -> torch.Tensor:
"""
修改后的前向传播,返回所有关联volume的特征矩阵
参数:
query_points: 查询点坐标 (B, 3)
volume_indices: 关联的volume索引矩阵 (B, K)
返回:
特征张量 (B, K, D)
"""
batch_size, num_volumes = volume_indices.shape
all_features = torch.zeros(batch_size, num_volumes, self.feature_dim,
device=query_points.device)
# 遍历每个volume索引
for k in range(num_volumes):
# 获取当前volume的索引 (B,)
current_indices = volume_indices[:, k]
# 遍历所有存在的volume
for vol_id in range(len(self.feature_volumes)):
# 创建掩码 (B,)
mask = (current_indices == vol_id)
if mask.any():
# 获取对应volume的特征 (M, D)
features = self.feature_volumes[vol_id](query_points[mask])
all_features[mask, k] = features
return all_features
def _optimized_trilinear(self, points, bboxes, features):
"""优化后的向量化三线性插值"""
# 添加显式类型转换确保计算稳定性
min_coords = bboxes[..., :3].to(torch.float32)
max_coords = bboxes[..., 3:].to(torch.float32)
normalized = (points - min_coords) / (max_coords - min_coords + 1e-8)
# 使用爱因斯坦求和代替分步计算
wx = torch.stack([1 - normalized[...,0], normalized[...,0]], -1) # (B,2)
wy = torch.stack([1 - normalized[...,1], normalized[...,1]], -1)
wz = torch.stack([1 - normalized[...,2], normalized[...,2]], -1)
# 合并所有计算步骤 (B,8,D) * (B,8,1) -> (B,D)
return torch.einsum('bcd,bc->bd', features,
torch.einsum('bi,bj,bk->bijk', wx, wy, wz).view(-1,8))
# 原batch_trilinear_interpolation方法可以删除
def to(self, device):
super().to(device)
def _move_node(node):
if isinstance(node.bbox, torch.Tensor):
node.bbox = node.bbox.to(device)
for child in node.children:
_move_node(child)
_move_node(self.octree.root)
return self