Browse Source
			
			
			
			
				
		- Introduced a `Logger` class that implements a singleton pattern for logging. - Added a `ColoredFormatter` to provide colored log output based on log levels. - Implemented methods for logging at different levels (debug, info, warning, error, exception). - Included functionality to capture caller information and log it alongside messages. - Created a `LogConfig` dataclass for easy configuration of logging parameters. - Set up a global logger instance with default configuration. - Added a `timeit` decorator for measuring function execution time.main
				 1 changed files with 225 additions and 0 deletions
			
			
		@ -0,0 +1,225 @@ | 
				
			|||||
 | 
					import os | 
				
			||||
 | 
					import sys | 
				
			||||
 | 
					import logging | 
				
			||||
 | 
					import traceback | 
				
			||||
 | 
					from datetime import datetime | 
				
			||||
 | 
					from functools import wraps | 
				
			||||
 | 
					from dataclasses import dataclass | 
				
			||||
 | 
					import torch | 
				
			||||
 | 
					import time | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					class ColoredFormatter(logging.Formatter): | 
				
			||||
 | 
					    """自定义彩色格式化器""" | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    # ANSI转义序列颜色代码 | 
				
			||||
 | 
					    COLORS = { | 
				
			||||
 | 
					        'DEBUG': '\033[36m',    # 青色 | 
				
			||||
 | 
					        'INFO': '\033[32m',     # 绿色 | 
				
			||||
 | 
					        'WARNING': '\033[33m',   # 黄色 | 
				
			||||
 | 
					        'ERROR': '\033[31m',    # 红色 | 
				
			||||
 | 
					        'CRITICAL': '\033[35m',  # 紫色 | 
				
			||||
 | 
					        'RESET': '\033[0m'      # 重置 | 
				
			||||
 | 
					    } | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    def format(self, record): | 
				
			||||
 | 
					        # 确保record有caller_info属性 | 
				
			||||
 | 
					        if not hasattr(record, 'caller_info'): | 
				
			||||
 | 
					            record.caller_info = self._get_caller_info() | 
				
			||||
 | 
					            print(f"[DEBUG] caller_info set in formatter: {record.caller_info}")  # 调试用 | 
				
			||||
 | 
					             | 
				
			||||
 | 
					        # 添加时间戳颜色 | 
				
			||||
 | 
					        time_color = '\033[34m'  # 蓝色 | 
				
			||||
 | 
					        record.colored_time = f"{time_color}{self.formatTime(record)}\033[0m" | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        # 添加日志级别颜色 | 
				
			||||
 | 
					        level_color = self.COLORS.get(record.levelname, self.COLORS['RESET']) | 
				
			||||
 | 
					        record.colored_levelname = f"{level_color}{record.levelname:8}\033[0m" | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        # 添加文件信息颜色 | 
				
			||||
 | 
					        file_color = '\033[36m'  # 青色 | 
				
			||||
 | 
					        record.colored_file_info = f"{file_color}{record.caller_info}\033[0m" | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        return super().format(record) | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    def _get_caller_info(self): | 
				
			||||
 | 
					        """获取调用者信息""" | 
				
			||||
 | 
					        try: | 
				
			||||
 | 
					            # 获取调用栈 | 
				
			||||
 | 
					            stack = traceback.extract_stack() | 
				
			||||
 | 
					             | 
				
			||||
 | 
					            # 从后往前遍历调用栈,跳过日志模块的调用 | 
				
			||||
 | 
					            for frame in reversed(stack[:-1]): | 
				
			||||
 | 
					                filename = os.path.basename(frame.filename) | 
				
			||||
 | 
					                # 跳过日志模块相关的文件 | 
				
			||||
 | 
					                if not (filename == 'logger.py' or  | 
				
			||||
 | 
					                       filename.startswith('logging') or  | 
				
			||||
 | 
					                       filename == '<string>'): | 
				
			||||
 | 
					                    return f"{filename}:{frame.name}:{frame.lineno}" | 
				
			||||
 | 
					             | 
				
			||||
 | 
					            return "Unknown:Unknown:0" | 
				
			||||
 | 
					        except Exception as e: | 
				
			||||
 | 
					            print(f"[ERROR] 获取caller_info失败: {e}")  # 调试用 | 
				
			||||
 | 
					            return "Unknown:Unknown:0" | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					class Logger: | 
				
			||||
 | 
					    _instance = None | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    def __new__(cls, config=None): | 
				
			||||
 | 
					        if cls._instance is None: | 
				
			||||
 | 
					            cls._instance = super().__new__(cls) | 
				
			||||
 | 
					            cls._instance._initialize_logger(config) | 
				
			||||
 | 
					        return cls._instance | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    def _get_caller_info(self, stack_offset=3): | 
				
			||||
 | 
					        """获取调用者信息 | 
				
			||||
 | 
					        Args: | 
				
			||||
 | 
					            stack_offset: 需要跳过的调用栈层数 | 
				
			||||
 | 
					        """ | 
				
			||||
 | 
					        try: | 
				
			||||
 | 
					            # 获取调用栈 | 
				
			||||
 | 
					            stack = traceback.extract_stack() | 
				
			||||
 | 
					             | 
				
			||||
 | 
					            # 从后往前遍历调用栈,跳过日志模块的调用 | 
				
			||||
 | 
					            for frame in reversed(stack[:-stack_offset]): | 
				
			||||
 | 
					                filename = os.path.basename(frame.filename) | 
				
			||||
 | 
					                # 跳过日志模块相关的文件 | 
				
			||||
 | 
					                if not (filename == 'logger.py' or  | 
				
			||||
 | 
					                       filename.startswith('logging') or  | 
				
			||||
 | 
					                       filename == '<string>'): | 
				
			||||
 | 
					                    return f"{filename}:{frame.name}:{frame.lineno}" | 
				
			||||
 | 
					             | 
				
			||||
 | 
					            return "Unknown:Unknown:0" | 
				
			||||
 | 
					        except Exception as e: | 
				
			||||
 | 
					            print(f"[ERROR] 获取caller_info失败: {e}")  # 调试用 | 
				
			||||
 | 
					            return "Unknown:Unknown:0" | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    def _get_caller_filepath(self, stack_offset=3): | 
				
			||||
 | 
					        """获取调用者文件路径""" | 
				
			||||
 | 
					        try: | 
				
			||||
 | 
					            stack = traceback.extract_stack() | 
				
			||||
 | 
					            frame = stack[-(stack_offset + 1)] | 
				
			||||
 | 
					            return frame.filename | 
				
			||||
 | 
					        except Exception as e: | 
				
			||||
 | 
					            print(f"[ERROR] 获取caller_filepath失败: {e}")  # 调试用 | 
				
			||||
 | 
					            return "unknown_file.py" | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    def _log(self, level, msg, **kwargs): | 
				
			||||
 | 
					        """统一的日志记录方法""" | 
				
			||||
 | 
					        # 创建LogRecord | 
				
			||||
 | 
					        record = logging.LogRecord( | 
				
			||||
 | 
					            name=self.logger.name, | 
				
			||||
 | 
					            level=level, | 
				
			||||
 | 
					            pathname=self._get_caller_filepath(), | 
				
			||||
 | 
					            lineno=0, | 
				
			||||
 | 
					            msg=msg, | 
				
			||||
 | 
					            args=(), | 
				
			||||
 | 
					            exc_info=kwargs.get('exc_info'), | 
				
			||||
 | 
					        ) | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        # 确保设置 caller_info | 
				
			||||
 | 
					        if not hasattr(record, 'caller_info'): | 
				
			||||
 | 
					            record.caller_info = self._get_caller_info(3) | 
				
			||||
 | 
					            #print(f"[DEBUG] caller_info set to: {record.caller_info}")  # 调试用 | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        # 处理日志记录 | 
				
			||||
 | 
					        self.logger.handle(record) | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    def debug(self, msg): | 
				
			||||
 | 
					        """调试信息""" | 
				
			||||
 | 
					        self._log(logging.DEBUG, msg) | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    def info(self, msg): | 
				
			||||
 | 
					        """信息""" | 
				
			||||
 | 
					        self._log(logging.INFO, msg) | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    def warning(self, msg): | 
				
			||||
 | 
					        """警告信息""" | 
				
			||||
 | 
					        self._log(logging.WARNING, msg) | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    def error(self, msg, include_trace=True): | 
				
			||||
 | 
					        """错误信息""" | 
				
			||||
 | 
					        self._log(logging.ERROR, msg, exc_info=include_trace) | 
				
			||||
 | 
					     | 
				
			||||
 | 
					    def exception(self, msg): | 
				
			||||
 | 
					        """异常信息""" | 
				
			||||
 | 
					        self._log(logging.ERROR, msg, exc_info=True) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					    def _initialize_logger(self, config=None): | 
				
			||||
 | 
					        """初始化日志记录器""" | 
				
			||||
 | 
					             | 
				
			||||
 | 
					        # 创建logs目录 | 
				
			||||
 | 
					        log_dir = config.log_dir | 
				
			||||
 | 
					        os.makedirs(log_dir, exist_ok=True) | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        # 创建logger | 
				
			||||
 | 
					        self.logger = logging.getLogger(config.project_name) | 
				
			||||
 | 
					        self.logger.setLevel(getattr(logging, config.log_level.upper())) | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        # 如果logger已经有处理器,则返回 | 
				
			||||
 | 
					        if self.logger.handlers: | 
				
			||||
 | 
					            return | 
				
			||||
 | 
					             | 
				
			||||
 | 
					        # 创建格式化器 | 
				
			||||
 | 
					        console_formatter = ColoredFormatter( | 
				
			||||
 | 
					            '%(colored_time)s | %(colored_levelname)s | %(colored_file_info)s - %(message)s' | 
				
			||||
 | 
					        ) | 
				
			||||
 | 
					        file_formatter = logging.Formatter( | 
				
			||||
 | 
					            '%(asctime)s | %(levelname)-8s | %(caller_info)s - %(message)s' | 
				
			||||
 | 
					        ) | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        # 创建文件处理器 (详细日志) | 
				
			||||
 | 
					        current_time = datetime.now().strftime('%Y%m%d_%H%M%S') | 
				
			||||
 | 
					        log_file = os.path.join(log_dir, f'{config.project_name}_{current_time}.log') | 
				
			||||
 | 
					        file_handler = logging.FileHandler(log_file, encoding='utf-8') | 
				
			||||
 | 
					        file_handler.setLevel(getattr(logging, config.file_level.upper())) | 
				
			||||
 | 
					        file_handler.setFormatter(file_formatter) | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        # 创建控制台处理器 (简略日志) | 
				
			||||
 | 
					        console_handler = logging.StreamHandler(sys.stdout) | 
				
			||||
 | 
					        console_handler.setLevel(getattr(logging, config.console_level.upper())) | 
				
			||||
 | 
					        console_handler.setFormatter(console_formatter) | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        # 添加处理器 | 
				
			||||
 | 
					        self.logger.addHandler(file_handler) | 
				
			||||
 | 
					        self.logger.addHandler(console_handler) | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        # 保存配置 | 
				
			||||
 | 
					        self.include_trace = True  # 默认包含调用栈 | 
				
			||||
 | 
					         | 
				
			||||
 | 
					        # 记录初始信息(确保使用 self.info 和 self.debug) | 
				
			||||
 | 
					        self.info(f"{config.project_name} Logger initialized") | 
				
			||||
 | 
					        self.debug(f"Log file: {log_file}") | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					def timeit(func): | 
				
			||||
 | 
					    """计时装饰器""" | 
				
			||||
 | 
					    @wraps(func) | 
				
			||||
 | 
					    def wrapper(*args, **kwargs): | 
				
			||||
 | 
					        start_time = time.perf_counter() | 
				
			||||
 | 
					        result = func(*args, **kwargs) | 
				
			||||
 | 
					        end_time = time.perf_counter() | 
				
			||||
 | 
					        total_time = end_time - start_time | 
				
			||||
 | 
					        logger.debug(f"{func.__name__} took {total_time:.4f} seconds") | 
				
			||||
 | 
					        return result | 
				
			||||
 | 
					    return wrapper | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					def setup_logger(config=None): | 
				
			||||
 | 
					    """创建logger的便捷函数""" | 
				
			||||
 | 
					    return Logger(config) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					# 使用默认配置创建全局logger实例 | 
				
			||||
 | 
					@dataclass | 
				
			||||
 | 
					class LogConfig: | 
				
			||||
 | 
					    """日志相关配置"""     | 
				
			||||
 | 
					    # 本地日志 | 
				
			||||
 | 
					    project_name: str = 'NH_Rep' | 
				
			||||
 | 
					    log_dir: str = os.path.join(os.path.abspath(os.getcwd()), 'logs')                    # 日志保存目录 | 
				
			||||
 | 
					    log_level: str = 'INFO'                  # 日志级别 | 
				
			||||
 | 
					    console_level: str = 'INFO'              # 控制台日志级别 | 
				
			||||
 | 
					    file_level: str = 'DEBUG'                # 文件日志级别 | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					logger = setup_logger(LogConfig()) | 
				
			||||
					Loading…
					
					
				
		Reference in new issue