| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -111,8 +111,7 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        total_loss = 0 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for batch_idx, batch in enumerate(self.train_loader): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 清空梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.optimizer.zero_grad() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 获取数据并移动到设备,同时设置梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 获取数据并移动到设备,同时保留计算图 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -130,6 +129,10 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_mask = batch['edge_mask'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            gt_sdf = batch['sdf'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 前向传播前清空梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.model.zero_grad()  # 清空模型梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.optimizer.zero_grad()  # 清空优化器梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 前向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            pred_sdf = self.model( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_ncs=surf_ncs,  | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |