| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -83,35 +83,18 @@ class Net(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #self.csg_combiner = CSGCombiner(flag_convex=True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    @torch.jit.export | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def forward(self, query_points): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        前向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        参数: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            query_point: 查询点的位置坐标 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        返回: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            output: 解码后的输出结果 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 批量查询所有点的索引和bbox | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        _,face_indices_mask,operator = self.octree_module.forward(query_points) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 编码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        feature_vectors = self.encoder.forward(query_points,face_indices_mask) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        print("feature_vector:", feature_vectors.shape) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 解码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #logger.gpu_memory_stats("encoder farward后") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        f_i = self.decoder(feature_vectors) # (B, P) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #logger.gpu_memory_stats("decoder farward后") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def process_sdf(self,f_i, face_indices_mask, operator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        output = f_i[:,0] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 提取有效值并填充到固定大小 (B, max_patches) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        padded_f_i = torch.full((f_i.shape[0], 2), float('inf'), device=f_i.device)  # (B, max_patches) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for i in range(f_i.shape[0]): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            sample_valid_values = f_i[i][face_indices_mask[i]]  # (N,), N <= P | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            num_valid = min(len(sample_valid_values), 2) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            padded_f_i[i, :num_valid] = sample_valid_values[:num_valid] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        valid_mask = face_indices_mask.bool()  # 确保是布尔类型 (B, P) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        masked_f_i = torch.where(valid_mask, f_i, torch.tensor(float('inf'), device=f_i.device))  # 将无效值设置为 inf | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 对每个样本取前 max_patches 个有效值 (B, max_patches) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        valid_values, _ = torch.topk(masked_f_i, k=2, dim=1, largest=False)  # 提取前两个有效值 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 填充到固定大小 (B, max_patches) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        padded_f_i[:, :2] = valid_values  # (B, max_patches) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 找到需要组合的行 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        mask_concave = (operator == 0) | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -125,10 +108,36 @@ class Net(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if mask_convex.any(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            output[mask_convex] = torch.min(padded_f_i[mask_convex], dim=1).values | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.debug("step over") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #logger.gpu_memory_stats("combine后") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return output | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    @torch.jit.export | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def forward(self, query_points): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        前向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        参数: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            query_point: 查询点的位置坐标 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        返回: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            output: 解码后的输出结果 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 批量查询所有点的索引和bbox | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #logger.debug("step octree") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        _,face_indices_mask,operator = self.octree_module.forward(query_points) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #logger.debug("step encode") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 编码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        feature_vectors = self.encoder.forward(query_points,face_indices_mask) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #print("feature_vector:", feature_vectors.shape) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 解码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #logger.debug("step decode") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #logger.gpu_memory_stats("encoder farward后") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        f_i = self.decoder(feature_vectors) # (B, P) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #logger.gpu_memory_stats("decoder farward后") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #logger.debug("step combine") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return self.process_sdf(f_i, face_indices_mask, operator) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					       | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    @torch.jit.export | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def forward_training_volumes(self, surf_points, patch_id:int): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |