You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
225 lines
7.7 KiB
225 lines
7.7 KiB
2 months ago
|
import os
|
||
|
import sys
|
||
|
import logging
|
||
|
import traceback
|
||
|
from datetime import datetime
|
||
|
from functools import wraps
|
||
|
from dataclasses import dataclass
|
||
|
import torch
|
||
|
import time
|
||
|
|
||
|
|
||
|
class ColoredFormatter(logging.Formatter):
|
||
|
"""自定义彩色格式化器"""
|
||
|
|
||
|
# ANSI转义序列颜色代码
|
||
|
COLORS = {
|
||
|
'DEBUG': '\033[36m', # 青色
|
||
|
'INFO': '\033[32m', # 绿色
|
||
|
'WARNING': '\033[33m', # 黄色
|
||
|
'ERROR': '\033[31m', # 红色
|
||
|
'CRITICAL': '\033[35m', # 紫色
|
||
|
'RESET': '\033[0m' # 重置
|
||
|
}
|
||
|
|
||
|
def format(self, record):
|
||
|
# 确保record有caller_info属性
|
||
|
if not hasattr(record, 'caller_info'):
|
||
|
record.caller_info = self._get_caller_info()
|
||
|
print(f"[DEBUG] caller_info set in formatter: {record.caller_info}") # 调试用
|
||
|
|
||
|
# 添加时间戳颜色
|
||
|
time_color = '\033[34m' # 蓝色
|
||
|
record.colored_time = f"{time_color}{self.formatTime(record)}\033[0m"
|
||
|
|
||
|
# 添加日志级别颜色
|
||
|
level_color = self.COLORS.get(record.levelname, self.COLORS['RESET'])
|
||
|
record.colored_levelname = f"{level_color}{record.levelname:8}\033[0m"
|
||
|
|
||
|
# 添加文件信息颜色
|
||
|
file_color = '\033[36m' # 青色
|
||
|
record.colored_file_info = f"{file_color}{record.caller_info}\033[0m"
|
||
|
|
||
|
return super().format(record)
|
||
|
|
||
|
def _get_caller_info(self):
|
||
|
"""获取调用者信息"""
|
||
|
try:
|
||
|
# 获取调用栈
|
||
|
stack = traceback.extract_stack()
|
||
|
|
||
|
# 从后往前遍历调用栈,跳过日志模块的调用
|
||
|
for frame in reversed(stack[:-1]):
|
||
|
filename = os.path.basename(frame.filename)
|
||
|
# 跳过日志模块相关的文件
|
||
|
if not (filename == 'logger.py' or
|
||
|
filename.startswith('logging') or
|
||
|
filename == '<string>'):
|
||
|
return f"{filename}:{frame.name}:{frame.lineno}"
|
||
|
|
||
|
return "Unknown:Unknown:0"
|
||
|
except Exception as e:
|
||
|
print(f"[ERROR] 获取caller_info失败: {e}") # 调试用
|
||
|
return "Unknown:Unknown:0"
|
||
|
|
||
|
|
||
|
class Logger:
|
||
|
_instance = None
|
||
|
|
||
|
def __new__(cls, config=None):
|
||
|
if cls._instance is None:
|
||
|
cls._instance = super().__new__(cls)
|
||
|
cls._instance._initialize_logger(config)
|
||
|
return cls._instance
|
||
|
|
||
|
def _get_caller_info(self, stack_offset=3):
|
||
|
"""获取调用者信息
|
||
|
Args:
|
||
|
stack_offset: 需要跳过的调用栈层数
|
||
|
"""
|
||
|
try:
|
||
|
# 获取调用栈
|
||
|
stack = traceback.extract_stack()
|
||
|
|
||
|
# 从后往前遍历调用栈,跳过日志模块的调用
|
||
|
for frame in reversed(stack[:-stack_offset]):
|
||
|
filename = os.path.basename(frame.filename)
|
||
|
# 跳过日志模块相关的文件
|
||
|
if not (filename == 'logger.py' or
|
||
|
filename.startswith('logging') or
|
||
|
filename == '<string>'):
|
||
|
return f"{filename}:{frame.name}:{frame.lineno}"
|
||
|
|
||
|
return "Unknown:Unknown:0"
|
||
|
except Exception as e:
|
||
|
print(f"[ERROR] 获取caller_info失败: {e}") # 调试用
|
||
|
return "Unknown:Unknown:0"
|
||
|
|
||
|
def _get_caller_filepath(self, stack_offset=3):
|
||
|
"""获取调用者文件路径"""
|
||
|
try:
|
||
|
stack = traceback.extract_stack()
|
||
|
frame = stack[-(stack_offset + 1)]
|
||
|
return frame.filename
|
||
|
except Exception as e:
|
||
|
print(f"[ERROR] 获取caller_filepath失败: {e}") # 调试用
|
||
|
return "unknown_file.py"
|
||
|
|
||
|
def _log(self, level, msg, **kwargs):
|
||
|
"""统一的日志记录方法"""
|
||
|
# 创建LogRecord
|
||
|
record = logging.LogRecord(
|
||
|
name=self.logger.name,
|
||
|
level=level,
|
||
|
pathname=self._get_caller_filepath(),
|
||
|
lineno=0,
|
||
|
msg=msg,
|
||
|
args=(),
|
||
|
exc_info=kwargs.get('exc_info'),
|
||
|
)
|
||
|
|
||
|
# 确保设置 caller_info
|
||
|
if not hasattr(record, 'caller_info'):
|
||
|
record.caller_info = self._get_caller_info(3)
|
||
|
#print(f"[DEBUG] caller_info set to: {record.caller_info}") # 调试用
|
||
|
|
||
|
# 处理日志记录
|
||
|
self.logger.handle(record)
|
||
|
|
||
|
def debug(self, msg):
|
||
|
"""调试信息"""
|
||
|
self._log(logging.DEBUG, msg)
|
||
|
|
||
|
def info(self, msg):
|
||
|
"""信息"""
|
||
|
self._log(logging.INFO, msg)
|
||
|
|
||
|
def warning(self, msg):
|
||
|
"""警告信息"""
|
||
|
self._log(logging.WARNING, msg)
|
||
|
|
||
|
def error(self, msg, include_trace=True):
|
||
|
"""错误信息"""
|
||
|
self._log(logging.ERROR, msg, exc_info=include_trace)
|
||
|
|
||
|
def exception(self, msg):
|
||
|
"""异常信息"""
|
||
|
self._log(logging.ERROR, msg, exc_info=True)
|
||
|
|
||
|
def _initialize_logger(self, config=None):
|
||
|
"""初始化日志记录器"""
|
||
|
|
||
|
# 创建logs目录
|
||
|
log_dir = config.log_dir
|
||
|
os.makedirs(log_dir, exist_ok=True)
|
||
|
|
||
|
# 创建logger
|
||
|
self.logger = logging.getLogger(config.project_name)
|
||
|
self.logger.setLevel(getattr(logging, config.log_level.upper()))
|
||
|
|
||
|
# 如果logger已经有处理器,则返回
|
||
|
if self.logger.handlers:
|
||
|
return
|
||
|
|
||
|
# 创建格式化器
|
||
|
console_formatter = ColoredFormatter(
|
||
|
'%(colored_time)s | %(colored_levelname)s | %(colored_file_info)s - %(message)s'
|
||
|
)
|
||
|
file_formatter = logging.Formatter(
|
||
|
'%(asctime)s | %(levelname)-8s | %(caller_info)s - %(message)s'
|
||
|
)
|
||
|
|
||
|
# 创建文件处理器 (详细日志)
|
||
|
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
|
log_file = os.path.join(log_dir, f'{config.project_name}_{current_time}.log')
|
||
|
file_handler = logging.FileHandler(log_file, encoding='utf-8')
|
||
|
file_handler.setLevel(getattr(logging, config.file_level.upper()))
|
||
|
file_handler.setFormatter(file_formatter)
|
||
|
|
||
|
# 创建控制台处理器 (简略日志)
|
||
|
console_handler = logging.StreamHandler(sys.stdout)
|
||
|
console_handler.setLevel(getattr(logging, config.console_level.upper()))
|
||
|
console_handler.setFormatter(console_formatter)
|
||
|
|
||
|
# 添加处理器
|
||
|
self.logger.addHandler(file_handler)
|
||
|
self.logger.addHandler(console_handler)
|
||
|
|
||
|
# 保存配置
|
||
|
self.include_trace = True # 默认包含调用栈
|
||
|
|
||
|
# 记录初始信息(确保使用 self.info 和 self.debug)
|
||
|
self.info(f"{config.project_name} Logger initialized")
|
||
|
self.debug(f"Log file: {log_file}")
|
||
|
|
||
|
|
||
|
def timeit(func):
|
||
|
"""计时装饰器"""
|
||
|
@wraps(func)
|
||
|
def wrapper(*args, **kwargs):
|
||
|
start_time = time.perf_counter()
|
||
|
result = func(*args, **kwargs)
|
||
|
end_time = time.perf_counter()
|
||
|
total_time = end_time - start_time
|
||
|
logger.debug(f"{func.__name__} took {total_time:.4f} seconds")
|
||
|
return result
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
def setup_logger(config=None):
|
||
|
"""创建logger的便捷函数"""
|
||
|
return Logger(config)
|
||
|
|
||
|
|
||
|
# 使用默认配置创建全局logger实例
|
||
|
@dataclass
|
||
|
class LogConfig:
|
||
|
"""日志相关配置"""
|
||
|
# 本地日志
|
||
|
project_name: str = 'NH_Rep'
|
||
|
log_dir: str = os.path.join(os.path.abspath(os.getcwd()), 'logs') # 日志保存目录
|
||
|
log_level: str = 'INFO' # 日志级别
|
||
|
console_level: str = 'INFO' # 控制台日志级别
|
||
|
file_level: str = 'DEBUG' # 文件日志级别
|
||
|
|
||
|
logger = setup_logger(LogConfig())
|