You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

225 lines
7.7 KiB

import os
import sys
import logging
import traceback
from datetime import datetime
from functools import wraps
from dataclasses import dataclass
import torch
import time
class ColoredFormatter(logging.Formatter):
"""自定义彩色格式化器"""
# ANSI转义序列颜色代码
COLORS = {
'DEBUG': '\033[36m', # 青色
'INFO': '\033[32m', # 绿色
'WARNING': '\033[33m', # 黄色
'ERROR': '\033[31m', # 红色
'CRITICAL': '\033[35m', # 紫色
'RESET': '\033[0m' # 重置
}
def format(self, record):
# 确保record有caller_info属性
if not hasattr(record, 'caller_info'):
record.caller_info = self._get_caller_info()
print(f"[DEBUG] caller_info set in formatter: {record.caller_info}") # 调试用
# 添加时间戳颜色
time_color = '\033[34m' # 蓝色
record.colored_time = f"{time_color}{self.formatTime(record)}\033[0m"
# 添加日志级别颜色
level_color = self.COLORS.get(record.levelname, self.COLORS['RESET'])
record.colored_levelname = f"{level_color}{record.levelname:8}\033[0m"
# 添加文件信息颜色
file_color = '\033[36m' # 青色
record.colored_file_info = f"{file_color}{record.caller_info}\033[0m"
return super().format(record)
def _get_caller_info(self):
"""获取调用者信息"""
try:
# 获取调用栈
stack = traceback.extract_stack()
# 从后往前遍历调用栈,跳过日志模块的调用
for frame in reversed(stack[:-1]):
filename = os.path.basename(frame.filename)
# 跳过日志模块相关的文件
if not (filename == 'logger.py' or
filename.startswith('logging') or
filename == '<string>'):
return f"{filename}:{frame.name}:{frame.lineno}"
return "Unknown:Unknown:0"
except Exception as e:
print(f"[ERROR] 获取caller_info失败: {e}") # 调试用
return "Unknown:Unknown:0"
class Logger:
_instance = None
def __new__(cls, config=None):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialize_logger(config)
return cls._instance
def _get_caller_info(self, stack_offset=3):
"""获取调用者信息
Args:
stack_offset: 需要跳过的调用栈层数
"""
try:
# 获取调用栈
stack = traceback.extract_stack()
# 从后往前遍历调用栈,跳过日志模块的调用
for frame in reversed(stack[:-stack_offset]):
filename = os.path.basename(frame.filename)
# 跳过日志模块相关的文件
if not (filename == 'logger.py' or
filename.startswith('logging') or
filename == '<string>'):
return f"{filename}:{frame.name}:{frame.lineno}"
return "Unknown:Unknown:0"
except Exception as e:
print(f"[ERROR] 获取caller_info失败: {e}") # 调试用
return "Unknown:Unknown:0"
def _get_caller_filepath(self, stack_offset=3):
"""获取调用者文件路径"""
try:
stack = traceback.extract_stack()
frame = stack[-(stack_offset + 1)]
return frame.filename
except Exception as e:
print(f"[ERROR] 获取caller_filepath失败: {e}") # 调试用
return "unknown_file.py"
def _log(self, level, msg, **kwargs):
"""统一的日志记录方法"""
# 创建LogRecord
record = logging.LogRecord(
name=self.logger.name,
level=level,
pathname=self._get_caller_filepath(),
lineno=0,
msg=msg,
args=(),
exc_info=kwargs.get('exc_info'),
)
# 确保设置 caller_info
if not hasattr(record, 'caller_info'):
record.caller_info = self._get_caller_info(3)
#print(f"[DEBUG] caller_info set to: {record.caller_info}") # 调试用
# 处理日志记录
self.logger.handle(record)
def debug(self, msg):
"""调试信息"""
self._log(logging.DEBUG, msg)
def info(self, msg):
"""信息"""
self._log(logging.INFO, msg)
def warning(self, msg):
"""警告信息"""
self._log(logging.WARNING, msg)
def error(self, msg, include_trace=True):
"""错误信息"""
self._log(logging.ERROR, msg, exc_info=include_trace)
def exception(self, msg):
"""异常信息"""
self._log(logging.ERROR, msg, exc_info=True)
def _initialize_logger(self, config=None):
"""初始化日志记录器"""
# 创建logs目录
log_dir = config.log_dir
os.makedirs(log_dir, exist_ok=True)
# 创建logger
self.logger = logging.getLogger(config.project_name)
self.logger.setLevel(getattr(logging, config.log_level.upper()))
# 如果logger已经有处理器,则返回
if self.logger.handlers:
return
# 创建格式化器
console_formatter = ColoredFormatter(
'%(colored_time)s | %(colored_levelname)s | %(colored_file_info)s - %(message)s'
)
file_formatter = logging.Formatter(
'%(asctime)s | %(levelname)-8s | %(caller_info)s - %(message)s'
)
# 创建文件处理器 (详细日志)
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = os.path.join(log_dir, f'{config.project_name}_{current_time}.log')
file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_handler.setLevel(getattr(logging, config.file_level.upper()))
file_handler.setFormatter(file_formatter)
# 创建控制台处理器 (简略日志)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(getattr(logging, config.console_level.upper()))
console_handler.setFormatter(console_formatter)
# 添加处理器
self.logger.addHandler(file_handler)
self.logger.addHandler(console_handler)
# 保存配置
self.include_trace = True # 默认包含调用栈
# 记录初始信息(确保使用 self.info 和 self.debug)
self.info(f"{config.project_name} Logger initialized")
self.debug(f"Log file: {log_file}")
def timeit(func):
"""计时装饰器"""
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.perf_counter()
result = func(*args, **kwargs)
end_time = time.perf_counter()
total_time = end_time - start_time
logger.debug(f"{func.__name__} took {total_time:.4f} seconds")
return result
return wrapper
def setup_logger(config=None):
"""创建logger的便捷函数"""
return Logger(config)
# 使用默认配置创建全局logger实例
@dataclass
class LogConfig:
"""日志相关配置"""
# 本地日志
project_name: str = 'NH_Rep'
log_dir: str = os.path.join(os.path.abspath(os.getcwd()), 'logs') # 日志保存目录
log_level: str = 'INFO' # 日志级别
console_level: str = 'INFO' # 控制台日志级别
file_level: str = 'DEBUG' # 文件日志级别
logger = setup_logger(LogConfig())