chore: initial commit
This commit is contained in:
6
src/utils/__init__.py
Normal file
6
src/utils/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""工具模块"""
|
||||
from .logger import Logger, get_logger
|
||||
from .exceptions import *
|
||||
from .decorators import *
|
||||
|
||||
__all__ = ['Logger', 'get_logger']
|
||||
233
src/utils/decorators.py
Normal file
233
src/utils/decorators.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""装饰器工具"""
|
||||
import functools
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Callable, Any, Optional, Dict
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict
|
||||
from .logger import get_logger
|
||||
from .exceptions import RateLimitError, AuthorizationError, ValidationError
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def async_retry(max_attempts: int = 3, delay: float = 1.0, backoff: float = 2.0,
|
||||
exceptions: tuple = (Exception,)):
|
||||
"""异步重试装饰器"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
attempt = 1
|
||||
current_delay = delay
|
||||
|
||||
while attempt <= max_attempts:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except exceptions as e:
|
||||
if attempt == max_attempts:
|
||||
logger.error(f"Max retries ({max_attempts}) reached for {func.__name__}")
|
||||
raise
|
||||
|
||||
logger.warning(
|
||||
f"Attempt {attempt}/{max_attempts} failed for {func.__name__}: {e}. "
|
||||
f"Retrying in {current_delay:.2f}s..."
|
||||
)
|
||||
await asyncio.sleep(current_delay)
|
||||
current_delay *= backoff
|
||||
attempt += 1
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def rate_limit(max_calls: int, period: float):
|
||||
"""速率限制装饰器"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
calls = defaultdict(list)
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# 获取调用者标识(假设第一个参数是 self,第二个是 update)
|
||||
caller_id = None
|
||||
if len(args) >= 2 and hasattr(args[1], 'effective_user'):
|
||||
caller_id = args[1].effective_user.id
|
||||
else:
|
||||
caller_id = 'global'
|
||||
|
||||
now = time.time()
|
||||
calls[caller_id] = [t for t in calls[caller_id] if now - t < period]
|
||||
|
||||
if len(calls[caller_id]) >= max_calls:
|
||||
raise RateLimitError(
|
||||
f"Rate limit exceeded: {max_calls} calls per {period} seconds",
|
||||
details={'caller_id': caller_id, 'limit': max_calls, 'period': period}
|
||||
)
|
||||
|
||||
calls[caller_id].append(now)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def require_admin(func: Callable) -> Callable:
|
||||
"""需要管理员权限装饰器"""
|
||||
@functools.wraps(func)
|
||||
async def wrapper(self, update, context, *args, **kwargs):
|
||||
user_id = update.effective_user.id
|
||||
if user_id != self.config.telegram.admin_id:
|
||||
if not (hasattr(self, 'is_admin') and await self.is_admin(user_id)):
|
||||
raise AuthorizationError(
|
||||
"Admin privileges required",
|
||||
details={'user_id': user_id}
|
||||
)
|
||||
return await func(self, update, context, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def log_action(action_type: str = None):
|
||||
"""记录操作日志装饰器"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
action = action_type or func.__name__
|
||||
|
||||
# 提取用户信息
|
||||
user_info = {}
|
||||
if len(args) >= 2 and hasattr(args[1], 'effective_user'):
|
||||
user = args[1].effective_user
|
||||
user_info = {
|
||||
'user_id': user.id,
|
||||
'username': user.username,
|
||||
'name': user.first_name
|
||||
}
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
duration = time.time() - start_time
|
||||
|
||||
# 创建额外信息,避免覆盖保留字段
|
||||
extra_info = {
|
||||
'action': action,
|
||||
'duration': duration,
|
||||
'status': 'success'
|
||||
}
|
||||
# 添加用户信息,使用前缀避免冲突
|
||||
for k, v in user_info.items():
|
||||
extra_info[f'user_{k}' if k in ['name'] else k] = v
|
||||
|
||||
logger.info(
|
||||
f"Action completed: {action}",
|
||||
extra=extra_info
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
duration = time.time() - start_time
|
||||
|
||||
# 创建额外信息,避免覆盖保留字段
|
||||
extra_info = {
|
||||
'action': action,
|
||||
'duration': duration,
|
||||
'status': 'failed',
|
||||
'error': str(e)
|
||||
}
|
||||
# 添加用户信息,使用前缀避免冲突
|
||||
for k, v in user_info.items():
|
||||
extra_info[f'user_{k}' if k in ['name'] else k] = v
|
||||
|
||||
logger.error(
|
||||
f"Action failed: {action}",
|
||||
extra=extra_info
|
||||
)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def validate_input(**validators):
|
||||
"""输入验证装饰器"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# 合并位置参数和关键字参数
|
||||
bound_args = func.__code__.co_varnames[:func.__code__.co_argcount]
|
||||
all_args = dict(zip(bound_args, args))
|
||||
all_args.update(kwargs)
|
||||
|
||||
# 执行验证
|
||||
for param_name, validator in validators.items():
|
||||
if param_name in all_args:
|
||||
value = all_args[param_name]
|
||||
if not validator(value):
|
||||
raise ValidationError(
|
||||
f"Validation failed for parameter: {param_name}",
|
||||
details={'parameter': param_name, 'value': value}
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def cache_result(ttl: int = 300):
|
||||
"""结果缓存装饰器"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
cache: Dict[str, tuple[Any, datetime]] = {}
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# 创建缓存键
|
||||
cache_key = f"{func.__name__}:{str(args)}:{str(kwargs)}"
|
||||
|
||||
# 检查缓存
|
||||
if cache_key in cache:
|
||||
result, timestamp = cache[cache_key]
|
||||
if datetime.now() - timestamp < timedelta(seconds=ttl):
|
||||
logger.debug(f"Cache hit for {func.__name__}")
|
||||
return result
|
||||
|
||||
# 执行函数
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
# 存储结果
|
||||
cache[cache_key] = (result, datetime.now())
|
||||
logger.debug(f"Cache miss for {func.__name__}, cached for {ttl}s")
|
||||
|
||||
return result
|
||||
|
||||
# 添加清除缓存方法
|
||||
wrapper.clear_cache = lambda: cache.clear()
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def measure_performance(func: Callable) -> Callable:
|
||||
"""性能测量装饰器"""
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
start_time = time.perf_counter()
|
||||
start_memory = 0 # 可以添加内存测量
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
if duration > 1.0: # 超过1秒的操作记录警告
|
||||
logger.warning(
|
||||
f"Slow operation detected: {func.__name__} took {duration:.2f}s",
|
||||
extra={
|
||||
'function': func.__name__,
|
||||
'duration': duration
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.debug(f"{func.__name__} completed in {duration:.4f}s")
|
||||
|
||||
return wrapper
|
||||
122
src/utils/exceptions.py
Normal file
122
src/utils/exceptions.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""自定义异常类"""
|
||||
from typing import Optional, Any, Dict
|
||||
|
||||
|
||||
class BotException(Exception):
|
||||
"""机器人基础异常"""
|
||||
|
||||
def __init__(self, message: str, code: str = None, details: Dict[str, Any] = None):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.code = code or self.__class__.__name__
|
||||
self.details = details or {}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
'error': self.code,
|
||||
'message': self.message,
|
||||
'details': self.details
|
||||
}
|
||||
|
||||
|
||||
class ConfigurationError(BotException):
|
||||
"""配置错误"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseError(BotException):
|
||||
"""数据库错误"""
|
||||
pass
|
||||
|
||||
|
||||
class TelegramError(BotException):
|
||||
"""Telegram API 错误"""
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationError(BotException):
|
||||
"""认证错误"""
|
||||
pass
|
||||
|
||||
|
||||
class AuthorizationError(BotException):
|
||||
"""授权错误"""
|
||||
pass
|
||||
|
||||
|
||||
class ValidationError(BotException):
|
||||
"""验证错误"""
|
||||
pass
|
||||
|
||||
|
||||
class RateLimitError(BotException):
|
||||
"""速率限制错误"""
|
||||
pass
|
||||
|
||||
|
||||
class SessionError(BotException):
|
||||
"""会话错误"""
|
||||
pass
|
||||
|
||||
|
||||
class MessageRoutingError(BotException):
|
||||
"""消息路由错误"""
|
||||
pass
|
||||
|
||||
|
||||
class BusinessLogicError(BotException):
|
||||
"""业务逻辑错误"""
|
||||
pass
|
||||
|
||||
|
||||
class ExternalServiceError(BotException):
|
||||
"""外部服务错误"""
|
||||
pass
|
||||
|
||||
|
||||
class ErrorHandler:
|
||||
"""错误处理器"""
|
||||
|
||||
@staticmethod
|
||||
async def handle_error(error: Exception, context: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""处理错误"""
|
||||
from ..utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
error_info = {
|
||||
'type': type(error).__name__,
|
||||
'message': str(error),
|
||||
'context': context or {}
|
||||
}
|
||||
|
||||
if isinstance(error, BotException):
|
||||
# 自定义异常
|
||||
error_info.update(error.to_dict())
|
||||
logger.error(f"Bot error: {error.message}", extra={'error_details': error_info})
|
||||
else:
|
||||
# 未知异常
|
||||
logger.exception(f"Unexpected error: {error}", extra={'error_details': error_info})
|
||||
error_info['message'] = "An unexpected error occurred"
|
||||
|
||||
return error_info
|
||||
|
||||
@staticmethod
|
||||
def create_user_message(error: Exception) -> str:
|
||||
"""创建用户友好的错误消息"""
|
||||
if isinstance(error, AuthenticationError):
|
||||
return "❌ 认证失败,请重新登录"
|
||||
elif isinstance(error, AuthorizationError):
|
||||
return "❌ 您没有权限执行此操作"
|
||||
elif isinstance(error, ValidationError):
|
||||
return f"❌ 输入无效:{error.message}"
|
||||
elif isinstance(error, RateLimitError):
|
||||
return "⚠️ 操作太频繁,请稍后再试"
|
||||
elif isinstance(error, SessionError):
|
||||
return "❌ 会话已过期,请重新开始"
|
||||
elif isinstance(error, BusinessLogicError):
|
||||
return f"❌ 操作失败:{error.message}"
|
||||
elif isinstance(error, ExternalServiceError):
|
||||
return "❌ 外部服务暂时不可用,请稍后再试"
|
||||
else:
|
||||
return "❌ 系统错误,请稍后再试或联系管理员"
|
||||
167
src/utils/logger.py
Normal file
167
src/utils/logger.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""日志系统"""
|
||||
import logging
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
"""自定义日志格式化器"""
|
||||
|
||||
# 颜色代码
|
||||
COLORS = {
|
||||
'DEBUG': '\033[36m', # 青色
|
||||
'INFO': '\033[32m', # 绿色
|
||||
'WARNING': '\033[33m', # 黄色
|
||||
'ERROR': '\033[31m', # 红色
|
||||
'CRITICAL': '\033[35m', # 紫色
|
||||
}
|
||||
RESET = '\033[0m'
|
||||
|
||||
def __init__(self, use_color: bool = True):
|
||||
super().__init__()
|
||||
self.use_color = use_color and sys.stderr.isatty()
|
||||
|
||||
def format(self, record):
|
||||
# 基础格式
|
||||
log_format = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
|
||||
|
||||
# 添加额外信息
|
||||
if hasattr(record, 'user_id'):
|
||||
log_format = f"%(asctime)s | %(levelname)-8s | %(name)s | User:{record.user_id} | %(message)s"
|
||||
|
||||
if hasattr(record, 'chat_id'):
|
||||
log_format = f"%(asctime)s | %(levelname)-8s | %(name)s | Chat:{record.chat_id} | %(message)s"
|
||||
|
||||
# 应用颜色
|
||||
if self.use_color:
|
||||
levelname = record.levelname
|
||||
if levelname in self.COLORS:
|
||||
log_format = log_format.replace(
|
||||
'%(levelname)-8s',
|
||||
f"{self.COLORS[levelname]}%(levelname)-8s{self.RESET}"
|
||||
)
|
||||
|
||||
formatter = logging.Formatter(log_format, datefmt='%Y-%m-%d %H:%M:%S')
|
||||
return formatter.format(record)
|
||||
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
"""JSON 格式化器用于结构化日志"""
|
||||
|
||||
def format(self, record):
|
||||
log_data = {
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
'level': record.levelname,
|
||||
'logger': record.name,
|
||||
'message': record.getMessage(),
|
||||
'module': record.module,
|
||||
'function': record.funcName,
|
||||
'line': record.lineno
|
||||
}
|
||||
|
||||
# 添加额外字段
|
||||
for key, value in record.__dict__.items():
|
||||
if key not in ['name', 'msg', 'args', 'created', 'filename',
|
||||
'funcName', 'levelname', 'levelno', 'lineno',
|
||||
'module', 'msecs', 'message', 'pathname', 'process',
|
||||
'processName', 'relativeCreated', 'thread', 'threadName']:
|
||||
log_data[key] = value
|
||||
|
||||
# 添加异常信息
|
||||
if record.exc_info:
|
||||
log_data['exception'] = self.formatException(record.exc_info)
|
||||
|
||||
return json.dumps(log_data, ensure_ascii=False)
|
||||
|
||||
|
||||
class Logger:
|
||||
"""日志管理器"""
|
||||
|
||||
_instance = None
|
||||
_loggers = {}
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if not cls._instance:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config=None):
|
||||
if not hasattr(self, '_initialized'):
|
||||
self._initialized = True
|
||||
self.config = config
|
||||
self.setup_logging()
|
||||
|
||||
def setup_logging(self):
|
||||
"""设置日志系统"""
|
||||
# 根日志配置
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(logging.DEBUG)
|
||||
|
||||
# 移除默认处理器
|
||||
root_logger.handlers = []
|
||||
|
||||
# 控制台处理器
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(
|
||||
getattr(logging, self.config.logging.level if self.config else 'INFO')
|
||||
)
|
||||
console_handler.setFormatter(CustomFormatter(use_color=True))
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# 文件处理器
|
||||
if self.config and self.config.logging.file:
|
||||
file_path = Path(self.config.logging.file)
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_handler = RotatingFileHandler(
|
||||
filename=str(file_path),
|
||||
maxBytes=self.config.logging.max_size,
|
||||
backupCount=self.config.logging.backup_count,
|
||||
encoding='utf-8'
|
||||
)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_handler.setFormatter(CustomFormatter(use_color=False))
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
# JSON 日志文件(用于分析)
|
||||
json_file_path = file_path.with_suffix('.json')
|
||||
json_handler = RotatingFileHandler(
|
||||
filename=str(json_file_path),
|
||||
maxBytes=self.config.logging.max_size,
|
||||
backupCount=self.config.logging.backup_count,
|
||||
encoding='utf-8'
|
||||
)
|
||||
json_handler.setLevel(logging.INFO)
|
||||
json_handler.setFormatter(JsonFormatter())
|
||||
root_logger.addHandler(json_handler)
|
||||
|
||||
@classmethod
|
||||
def get_logger(cls, name: str, config=None) -> logging.Logger:
|
||||
"""获取日志器"""
|
||||
if name not in cls._loggers:
|
||||
if not cls._instance:
|
||||
cls(config)
|
||||
cls._loggers[name] = logging.getLogger(name)
|
||||
return cls._loggers[name]
|
||||
|
||||
|
||||
def get_logger(name: str, config=None) -> logging.Logger:
|
||||
"""获取日志器的便捷方法"""
|
||||
return Logger.get_logger(name, config)
|
||||
|
||||
|
||||
class LoggerContextFilter(logging.Filter):
|
||||
"""日志上下文过滤器"""
|
||||
|
||||
def __init__(self, **context):
|
||||
super().__init__()
|
||||
self.context = context
|
||||
|
||||
def filter(self, record):
|
||||
for key, value in self.context.items():
|
||||
setattr(record, key, value)
|
||||
return True
|
||||
Reference in New Issue
Block a user