| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -45,6 +45,8 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, config): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.config = config | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        use_pin_memory = self.device.type == 'cuda'  # 根据设备类型决定是否使用pin_memory | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info(f"Using device: {self.device}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 初始化数据集 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -69,13 +71,15 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.train_dataset, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            batch_size=config.train.batch_size, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            shuffle=True, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            num_workers=config.train.num_workers | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            num_workers=config.train.num_workers, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            pin_memory=use_pin_memory  # 根据设备类型设置,是否将数据固定在内存中 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.val_loader = DataLoader( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.val_dataset, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            batch_size=config.train.batch_size, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            shuffle=False, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            num_workers=config.train.num_workers | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            num_workers=config.train.num_workers, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            pin_memory=use_pin_memory  # 根据设备类型设置,是否将数据固定在内存中 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 初始化模型 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -105,28 +109,34 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        total_loss = 0 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for batch_idx, batch in enumerate(self.train_loader): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 获取数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_z = batch['surf_z'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_z = batch['edge_z'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_p = batch['surf_p'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_p = batch['edge_p'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            vert_p = batch['vert_p'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            query_points = batch['points'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            gt_sdf = batch['sdf'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 清空梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.optimizer.zero_grad() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 获取数据并移动到设备 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_ncs = batch['surf_ncs'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_ncs = batch['edge_ncs'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_pos = batch['surf_pos'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_pos = batch['edge_pos'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            vertex_pos = batch['vertex_pos'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            sdf = batch['sdf'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 前向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            pred_sdf = self.model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            pred_sdf = self.model( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_ncs, edge_ncs, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_pos, edge_pos, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                vertex_pos, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                sdf[:, :3]  # 只使用点坐标,不包括SDF值 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 计算损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            loss = sdf_loss( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                pred_sdf, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                gt_sdf, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                query_points, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                sdf[:, 3],  # 使用SDF值 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                sdf[:, :3],  # 使用点坐标 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                grad_weight=self.config.train.grad_weight | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 反向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.optimizer.zero_grad() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 反向传播和优化 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            loss.backward() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 梯度裁剪 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -134,22 +144,21 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                self.model.parameters(),  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                self.config.train.max_grad_norm | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.optimizer.step() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            total_loss += loss.item() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 打印训练进度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 记录训练进度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if (batch_idx + 1) % self.config.log.log_interval == 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.info(f'Train Epoch: {epoch} [{batch_idx+1}/{len(self.train_loader)}]\t' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                          f'Loss: {loss.item():.6f}') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 记录到wandb | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if self.config.log.use_wandb and (batch_idx + 1) % self.config.log.log_interval == 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                wandb.log({ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    'batch_loss': loss.item(), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    'batch': batch_idx, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    'epoch': epoch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                }) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if self.config.log.use_wandb: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    wandb.log({ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        'batch_loss': loss.item(), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        'batch': batch_idx, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        'epoch': epoch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    }) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        avg_loss = total_loss / len(self.train_loader) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return avg_loss | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -160,23 +169,27 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        with torch.no_grad(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for batch in self.val_loader: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 获取数据 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_z = batch['surf_z'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_z = batch['edge_z'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_p = batch['surf_p'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_p = batch['edge_p'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                vert_p = batch['vert_p'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                query_points = batch['points'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                gt_sdf = batch['sdf'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 获取数据并移动到设备 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_ncs = batch['surf_ncs'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_ncs = batch['edge_ncs'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_pos = batch['surf_pos'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_pos = batch['edge_pos'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                vertex_pos = batch['vertex_pos'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                sdf = batch['sdf'].to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 前向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                pred_sdf = self.model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                pred_sdf = self.model( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    surf_ncs, edge_ncs, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    surf_pos, edge_pos, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    vertex_pos, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    sdf[:, :3] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 计算损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                loss = sdf_loss( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    pred_sdf, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    gt_sdf, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    query_points, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    sdf[:, 3], | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    sdf[:, :3], | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    grad_weight=self.config.train.grad_weight | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -195,11 +208,12 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def train(self): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        best_val_loss = float('inf') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info("Starting training...") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        val_loss = float('inf')  # 初始化val_loss | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for epoch in range(1, self.config.train.num_epochs + 1): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            train_loss = self.train_epoch(epoch) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 定期验证 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 定期验证和保存 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if epoch % self.config.train.val_freq == 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                val_loss = self.validate(epoch) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |