| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -113,22 +113,28 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 清空梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.optimizer.zero_grad() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 获取数据并移动到设备 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_ncs = batch['surf_ncs'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_ncs = batch['edge_ncs'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_pos = batch['surf_pos'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_pos = batch['edge_pos'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            vertex_pos = batch['vertex_pos'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 获取数据并移动到设备,同时设置梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_ncs = batch['surf_ncs'].to(self.device).requires_grad_(True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_ncs = batch['edge_ncs'].to(self.device).requires_grad_(True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_pos = batch['surf_pos'].to(self.device).requires_grad_(True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_pos = batch['edge_pos'].to(self.device).requires_grad_(True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            vertex_pos = batch['vertex_pos'].to(self.device).requires_grad_(True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            points = batch['points'].to(self.device).requires_grad_(True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 这些不需要梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_mask = batch['edge_mask'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            points = batch['points'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            gt_sdf = batch['sdf'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 前向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            pred_sdf = self.model( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_ncs=surf_ncs, edge_ncs=edge_ncs, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_pos=surf_pos, edge_pos=edge_pos, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                vertex_pos=vertex_pos, edge_mask=edge_mask, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                query_points=points  # 只使用点坐标,不包括SDF值 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_ncs=surf_ncs,  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_ncs=edge_ncs, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_pos=surf_pos,  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_pos=edge_pos, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                vertex_pos=vertex_pos,  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_mask=edge_mask, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                query_points=points | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 计算损失 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |