Browse Source

优化encoder farward 循环

final
mckay 1 month ago
parent
commit
1cbf7ffffc
  1. 29
      brep2sdf/networks/encoder.py

29
brep2sdf/networks/encoder.py

@ -53,9 +53,9 @@ class Encoder(nn.Module):
# 根据归一化后的对角线长度调整分辨率
resolutions = torch.zeros_like(diagonals, dtype=torch.long)
resolutions[diagonals > 1.0] = 16 # 大尺寸
resolutions[(diagonals > 0.5) & (diagonals <= 1.0)] = 8 # 中等尺寸
resolutions[diagonals <= 0.5] = 4 # 小尺寸
resolutions[diagonals > 1.0] = 64 # 大尺寸
resolutions[(diagonals > 0.5) & (diagonals <= 1.0)] = 32 # 中等尺寸
resolutions[diagonals <= 0.5] = 16 # 小尺寸
return resolutions
@ -75,20 +75,17 @@ class Encoder(nn.Module):
all_features = torch.zeros(batch_size, num_volumes, self.feature_dim,
device=query_points.device)
background_features = self.background.forward(query_points) # (B, D)
# 遍历每个volume索引
for p in range(num_volumes):
# 获取当前volume的索引 (B,)
current_indices = volume_indices[:, p]
# 遍历所有存在的volume
for vol_id, volume in enumerate(self.feature_volumes):
# 创建掩码 (B,)
mask = (current_indices == vol_id)
if mask.any():
# 获取对应volume的特征 (M, D)
features = volume.forward(query_points[mask])
all_features[mask, p] = 0.7 * features + 0.3 * background_features[mask]
for vol_id, volume in enumerate(self.feature_volumes):
current_indices = volume_indices[:, vol_id]
# 创建掩码 (B,)
mask = (current_indices == vol_id)
if mask.any():
# 获取对应volume的特征 (M, D)
features = volume.forward(query_points[mask])
all_features[mask, vol_id] = 0.7 * features + 0.3 * background_features[mask]
return all_features
@torch.jit.export

Loading…
Cancel
Save