| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -118,21 +118,23 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_pos = batch['surf_pos'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_pos = batch['edge_pos'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            vertex_pos = batch['vertex_pos'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            sdf = batch['sdf'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_mask = batch['edge_mask'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            points = batch['points'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            gt_sdf = batch['sdf'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 前向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            pred_sdf = self.model( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_ncs, edge_ncs, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_pos, edge_pos, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                vertex_pos, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                sdf[:, :3]  # 只使用点坐标,不包括SDF值 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_ncs=surf_ncs, edge_ncs=edge_ncs, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_pos=surf_pos, edge_pos=edge_pos, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                vertex_pos=vertex_pos, edge_mask=edge_mask, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                query_points=points  # 只使用点坐标,不包括SDF值 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 计算损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            loss = sdf_loss( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                pred_sdf, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                sdf[:, 3],  # 使用SDF值 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                sdf[:, :3],  # 使用点坐标 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                gt_sdf, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                points, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                grad_weight=self.config.train.grad_weight | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -175,21 +177,23 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_pos = batch['surf_pos'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_pos = batch['edge_pos'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                vertex_pos = batch['vertex_pos'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                sdf = batch['sdf'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_mask = batch['edge_mask'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                points = batch['points'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                gt_sdf = batch['sdf'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 前向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                pred_sdf = self.model( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    surf_ncs, edge_ncs, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    surf_pos, edge_pos, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    vertex_pos, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    sdf[:, :3] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    surf_ncs=surf_ncs, edge_ncs=edge_ncs, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    surf_pos=surf_pos, edge_pos=edge_pos, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    vertex_pos=vertex_pos, edge_mask=edge_mask, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    query_points=points | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 计算损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                loss = sdf_loss( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    pred_sdf, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    sdf[:, 3], | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    sdf[:, :3], | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    gt_sdf, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    points, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    grad_weight=self.config.train.grad_weight | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |