| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -101,11 +101,11 @@ class Net(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 对 operator == 0 的样本取最大值 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if mask_concave.any(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            output[mask_concave] = torch.max(padded_f_i[mask_concave], dim=1).values | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            output[mask_concave] = torch.min(padded_f_i[mask_concave], dim=1).values | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 对 operator == 1 的样本取最小值 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if mask_convex.any(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            output[mask_convex] = torch.min(padded_f_i[mask_convex], dim=1).values | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            output[mask_convex] = torch.max(padded_f_i[mask_convex], dim=1).values | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #logger.gpu_memory_stats("combine后") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return output | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -187,6 +187,19 @@ class Net(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return f_i.squeeze() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def freeze_stage1(self): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.encoder.freeze_stage1() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def freeze_stage2(self): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.encoder.freeze_stage2() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for param in self.decoder.parameters(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            param.requires_grad = False | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def unfreeze(self): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.encoder.unfreeze() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for param in self.decoder.parameters(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            param.requires_grad = True | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |