diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index aedbb90..0874893 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -53,9 +53,9 @@ class Encoder(nn.Module): # 根据归一化后的对角线长度调整分辨率 resolutions = torch.zeros_like(diagonals, dtype=torch.long) - resolutions[diagonals > 1.0] = 16 # 大尺寸 - resolutions[(diagonals > 0.5) & (diagonals <= 1.0)] = 8 # 中等尺寸 - resolutions[diagonals <= 0.5] = 4 # 小尺寸 + resolutions[diagonals > 1.0] = 64 # 大尺寸 + resolutions[(diagonals > 0.5) & (diagonals <= 1.0)] = 32 # 中等尺寸 + resolutions[diagonals <= 0.5] = 16 # 小尺寸 return resolutions @@ -75,20 +75,17 @@ class Encoder(nn.Module): all_features = torch.zeros(batch_size, num_volumes, self.feature_dim, device=query_points.device) background_features = self.background.forward(query_points) # (B, D) + # 遍历每个volume索引 - for p in range(num_volumes): - # 获取当前volume的索引 (B,) - current_indices = volume_indices[:, p] - - # 遍历所有存在的volume - for vol_id, volume in enumerate(self.feature_volumes): - # 创建掩码 (B,) - mask = (current_indices == vol_id) - if mask.any(): - # 获取对应volume的特征 (M, D) - features = volume.forward(query_points[mask]) - all_features[mask, p] = 0.7 * features + 0.3 * background_features[mask] - + for vol_id, volume in enumerate(self.feature_volumes): + current_indices = volume_indices[:, vol_id] + # 创建掩码 (B,) + mask = (current_indices == vol_id) + if mask.any(): + # 获取对应volume的特征 (M, D) + features = volume.forward(query_points[mask]) + all_features[mask, vol_id] = 0.7 * features + 0.3 * background_features[mask] + return all_features @torch.jit.export