| 
						
						
						
					 | 
					@ -1,7 +1,7 @@ | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					import torch | 
					 | 
					 | 
					import torch | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					import torch.nn as nn | 
					 | 
					 | 
					import torch.nn as nn | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					import numpy as np | 
					 | 
					 | 
					import numpy as np | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					
 | 
					 | 
					 | 
					import time | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					from .octree import OctreeNode | 
					 | 
					 | 
					from .octree import OctreeNode | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from .feature_volume import PatchFeatureVolume,SimpleFeatureEncoder | 
					 | 
					 | 
					from .feature_volume import PatchFeatureVolume,SimpleFeatureEncoder | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from brep2sdf.utils.logger import logger | 
					 | 
					 | 
					from brep2sdf.utils.logger import logger | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -95,17 +95,31 @@ class Encoder(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        all_features = torch.zeros(batch_size, num_volumes, self.feature_dim,  | 
					 | 
					 | 
					        all_features = torch.zeros(batch_size, num_volumes, self.feature_dim,  | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                                  device=query_points.device) | 
					 | 
					 | 
					                                  device=query_points.device) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        background_features = self.background.forward(query_points)  # (B, D) | 
					 | 
					 | 
					        background_features = self.background.forward(query_points)  # (B, D) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        start_time = time.time() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 创建 CUDA 流 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        streams = [torch.cuda.Stream() for _ in range(len(self.feature_volumes))] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        features_list = [None] * len(self.feature_volumes) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 遍历每个volume索引 | 
					 | 
					 | 
					        # 并行计算 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        for vol_id, volume in enumerate(self.feature_volumes): | 
					 | 
					 | 
					        for vol_id, volume in enumerate(self.feature_volumes): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            mask = volume_indices_mask[:, vol_id].squeeze() | 
					 | 
					 | 
					            mask = volume_indices_mask[:, vol_id].squeeze() | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            #logger.debug(f"mask:{mask},shape:{mask.shape},mask.any():{mask.any()}") | 
					 | 
					 | 
					            if not mask.any(): | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            if mask.any(): | 
					 | 
					 | 
					                continue | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                # 获取对应volume的特征 (M, D) | 
					 | 
					 | 
					            with torch.cuda.stream(streams[vol_id]): | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                features = volume.forward(query_points[mask]) | 
					 | 
					 | 
					                features = volume(query_points[mask]) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                all_features[mask, vol_id] = 0.9 * background_features[mask] + 0.1 * features | 
					 | 
					 | 
					                features_list[vol_id] = (mask, features) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					     | 
					 | 
					 | 
					         | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        #all_features[:, :] = background_features.unsqueeze(1)  | 
					 | 
					 | 
					        # 同步流 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        torch.cuda.synchronize() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 写入结果 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        for vol_id, item in enumerate(features_list): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            if item is None: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                continue | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            mask, features = item | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            all_features[mask, vol_id] = 0.1 * background_features[mask] + 0.9 * features | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        end_time = time.time() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.debug(f"duration:{end_time-start_time}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        return all_features | 
					 | 
					 | 
					        return all_features | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def forward_background(self, query_points: torch.Tensor) -> torch.Tensor: | 
					 | 
					 | 
					    def forward_background(self, query_points: torch.Tensor) -> torch.Tensor: | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -135,7 +149,7 @@ class Encoder(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        background_features = self.background.forward(surf_points)  # (B, D) | 
					 | 
					 | 
					        background_features = self.background.forward(surf_points)  # (B, D) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        #dot = make_dot(patch_features, params=dict(self.feature_volumes.named_parameters())) | 
					 | 
					 | 
					        #dot = make_dot(patch_features, params=dict(self.feature_volumes.named_parameters())) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        #dot.render("feature_extraction", format="png")  # 将计算图保存为 PDF 文件 | 
					 | 
					 | 
					        #dot.render("feature_extraction", format="png")  # 将计算图保存为 PDF 文件 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        return 0.9 * background_features + 0.1 * patch_features | 
					 | 
					 | 
					        return 0.1 * background_features + 0.9 * patch_features | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def to(self, device): | 
					 | 
					 | 
					    def to(self, device): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        super().to(device) | 
					 | 
					 | 
					        super().to(device) | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					
  |