| 
						
						
						
					 | 
				
				 | 
				
					@ -1,5 +1,7 @@ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import torch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import torch.optim as optim | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import time | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import os | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from brep2sdf.config.default_config import get_default_config | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from brep2sdf.data.data import load_brep_file,load_sdf_file | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -103,30 +105,69 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def train(self): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        best_val_loss = float('inf') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info("Starting training...") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        start_time = time.time() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for epoch in range(1, self.config.train.num_epochs + 1): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 训练一个epoch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            train_loss = self.train_epoch(epoch) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 验证 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ''' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 定期验证 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if epoch % self.config.train.val_freq == 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                val_loss = self.validate(epoch) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 保存最佳模型 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if val_loss < best_val_loss: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    best_val_loss = val_loss | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    self._save_model(epoch, val_loss) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f'New best model saved at epoch {epoch} with val loss {val_loss:.6f}') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ''' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 定期保存检查点 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 保存检查点 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if epoch % self.config.train.save_freq == 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                self._save_checkpoint(epoch, train_loss) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.info(f'Checkpoint saved at epoch {epoch}') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 训练完成 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        total_time = time.time() - start_time | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info(f'Best validation loss: {best_val_loss:.6f}') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def _save_model(self, epoch: int, val_loss: float): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 保存最佳模型的逻辑 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        pass | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """保存最佳模型""" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        save_path = os.path.join( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.config.train.checkpoint_dir, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.config.train.best_model_name.format( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                model_name=config.train.model_name | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        torch.save({ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'epoch': epoch, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'model_state_dict': self.model.state_dict(), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'optimizer_state_dict': self.optimizer.state_dict(), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'loss': val_loss, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'config': self.config | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        }, save_path) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def _save_checkpoint(self, epoch: int, train_loss: float): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 保存检查点的逻辑 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        pass | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """保存训练检查点""" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        checkpoint_path = os.path.join( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.config.train.checkpoint_dir, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.config.train.checkpoint_format.format( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                model_name=self.config.train.model_name, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                epoch=epoch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        torch.save({ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'epoch': epoch, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'model_state_dict': self.model.state_dict(), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'optimizer_state_dict': self.optimizer.state_dict(), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'loss': train_loss, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            'config': self.config | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        }, checkpoint_path) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def main(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 这里需要初始化配置 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |