| 
						
						
							
								
							
						
						
					 | 
					@ -53,9 +53,9 @@ class Encoder(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 根据归一化后的对角线长度调整分辨率 | 
					 | 
					 | 
					            # 根据归一化后的对角线长度调整分辨率 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            resolutions = torch.zeros_like(diagonals, dtype=torch.long) | 
					 | 
					 | 
					            resolutions = torch.zeros_like(diagonals, dtype=torch.long) | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            resolutions[diagonals > 1.0] = 16    # 大尺寸 | 
					 | 
					 | 
					            resolutions[diagonals > 1.0] = 64    # 大尺寸 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            resolutions[(diagonals > 0.5) & (diagonals <= 1.0)] = 8   # 中等尺寸 | 
					 | 
					 | 
					            resolutions[(diagonals > 0.5) & (diagonals <= 1.0)] = 32   # 中等尺寸 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            resolutions[diagonals <= 0.5] = 4  # 小尺寸 | 
					 | 
					 | 
					            resolutions[diagonals <= 0.5] = 16  # 小尺寸 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            return resolutions | 
					 | 
					 | 
					            return resolutions | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -75,19 +75,16 @@ class Encoder(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        all_features = torch.zeros(batch_size, num_volumes, self.feature_dim,  | 
					 | 
					 | 
					        all_features = torch.zeros(batch_size, num_volumes, self.feature_dim,  | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                                  device=query_points.device) | 
					 | 
					 | 
					                                  device=query_points.device) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        background_features = self.background.forward(query_points)  # (B, D) | 
					 | 
					 | 
					        background_features = self.background.forward(query_points)  # (B, D) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 遍历每个volume索引 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        for p in range(num_volumes): | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 获取当前volume的索引 (B,) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            current_indices = volume_indices[:, p] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            # 遍历所有存在的volume | 
					 | 
					 | 
					        # 遍历每个volume索引 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        for vol_id, volume in enumerate(self.feature_volumes): | 
					 | 
					 | 
					        for vol_id, volume in enumerate(self.feature_volumes): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            current_indices = volume_indices[:, vol_id] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 创建掩码 (B,) | 
					 | 
					 | 
					            # 创建掩码 (B,) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            mask = (current_indices == vol_id) | 
					 | 
					 | 
					            mask = (current_indices == vol_id) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            if mask.any(): | 
					 | 
					 | 
					            if mask.any(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                # 获取对应volume的特征 (M, D) | 
					 | 
					 | 
					                # 获取对应volume的特征 (M, D) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                features = volume.forward(query_points[mask]) | 
					 | 
					 | 
					                features = volume.forward(query_points[mask]) | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    all_features[mask, p] = 0.7 * features + 0.3 * background_features[mask] | 
					 | 
					 | 
					                all_features[mask, vol_id] = 0.7 * features + 0.3 * background_features[mask] | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        return all_features | 
					 | 
					 | 
					        return all_features | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					       | 
					 | 
					 | 
					       | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					
  |