chore: initial commit

This commit is contained in:
你的用户名
2025-11-01 21:58:31 +08:00
commit 0406b5664f
101 changed files with 20458 additions and 0 deletions

6
src/utils/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
"""工具模块"""
from .logger import Logger, get_logger
from .exceptions import *
from .decorators import *
__all__ = ['Logger', 'get_logger']

233
src/utils/decorators.py Normal file
View File

@@ -0,0 +1,233 @@
"""装饰器工具"""
import functools
import time
import asyncio
from typing import Callable, Any, Optional, Dict
from datetime import datetime, timedelta
from collections import defaultdict
from .logger import get_logger
from .exceptions import RateLimitError, AuthorizationError, ValidationError
logger = get_logger(__name__)
def async_retry(max_attempts: int = 3, delay: float = 1.0, backoff: float = 2.0,
exceptions: tuple = (Exception,)):
"""异步重试装饰器"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
async def wrapper(*args, **kwargs):
attempt = 1
current_delay = delay
while attempt <= max_attempts:
try:
return await func(*args, **kwargs)
except exceptions as e:
if attempt == max_attempts:
logger.error(f"Max retries ({max_attempts}) reached for {func.__name__}")
raise
logger.warning(
f"Attempt {attempt}/{max_attempts} failed for {func.__name__}: {e}. "
f"Retrying in {current_delay:.2f}s..."
)
await asyncio.sleep(current_delay)
current_delay *= backoff
attempt += 1
return wrapper
return decorator
def rate_limit(max_calls: int, period: float):
"""速率限制装饰器"""
def decorator(func: Callable) -> Callable:
calls = defaultdict(list)
@functools.wraps(func)
async def wrapper(*args, **kwargs):
# 获取调用者标识(假设第一个参数是 self第二个是 update
caller_id = None
if len(args) >= 2 and hasattr(args[1], 'effective_user'):
caller_id = args[1].effective_user.id
else:
caller_id = 'global'
now = time.time()
calls[caller_id] = [t for t in calls[caller_id] if now - t < period]
if len(calls[caller_id]) >= max_calls:
raise RateLimitError(
f"Rate limit exceeded: {max_calls} calls per {period} seconds",
details={'caller_id': caller_id, 'limit': max_calls, 'period': period}
)
calls[caller_id].append(now)
return await func(*args, **kwargs)
return wrapper
return decorator
def require_admin(func: Callable) -> Callable:
"""需要管理员权限装饰器"""
@functools.wraps(func)
async def wrapper(self, update, context, *args, **kwargs):
user_id = update.effective_user.id
if user_id != self.config.telegram.admin_id:
if not (hasattr(self, 'is_admin') and await self.is_admin(user_id)):
raise AuthorizationError(
"Admin privileges required",
details={'user_id': user_id}
)
return await func(self, update, context, *args, **kwargs)
return wrapper
def log_action(action_type: str = None):
"""记录操作日志装饰器"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
async def wrapper(*args, **kwargs):
start_time = time.time()
action = action_type or func.__name__
# 提取用户信息
user_info = {}
if len(args) >= 2 and hasattr(args[1], 'effective_user'):
user = args[1].effective_user
user_info = {
'user_id': user.id,
'username': user.username,
'name': user.first_name
}
try:
result = await func(*args, **kwargs)
duration = time.time() - start_time
# 创建额外信息,避免覆盖保留字段
extra_info = {
'action': action,
'duration': duration,
'status': 'success'
}
# 添加用户信息,使用前缀避免冲突
for k, v in user_info.items():
extra_info[f'user_{k}' if k in ['name'] else k] = v
logger.info(
f"Action completed: {action}",
extra=extra_info
)
return result
except Exception as e:
duration = time.time() - start_time
# 创建额外信息,避免覆盖保留字段
extra_info = {
'action': action,
'duration': duration,
'status': 'failed',
'error': str(e)
}
# 添加用户信息,使用前缀避免冲突
for k, v in user_info.items():
extra_info[f'user_{k}' if k in ['name'] else k] = v
logger.error(
f"Action failed: {action}",
extra=extra_info
)
raise
return wrapper
return decorator
def validate_input(**validators):
"""输入验证装饰器"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
async def wrapper(*args, **kwargs):
# 合并位置参数和关键字参数
bound_args = func.__code__.co_varnames[:func.__code__.co_argcount]
all_args = dict(zip(bound_args, args))
all_args.update(kwargs)
# 执行验证
for param_name, validator in validators.items():
if param_name in all_args:
value = all_args[param_name]
if not validator(value):
raise ValidationError(
f"Validation failed for parameter: {param_name}",
details={'parameter': param_name, 'value': value}
)
return await func(*args, **kwargs)
return wrapper
return decorator
def cache_result(ttl: int = 300):
"""结果缓存装饰器"""
def decorator(func: Callable) -> Callable:
cache: Dict[str, tuple[Any, datetime]] = {}
@functools.wraps(func)
async def wrapper(*args, **kwargs):
# 创建缓存键
cache_key = f"{func.__name__}:{str(args)}:{str(kwargs)}"
# 检查缓存
if cache_key in cache:
result, timestamp = cache[cache_key]
if datetime.now() - timestamp < timedelta(seconds=ttl):
logger.debug(f"Cache hit for {func.__name__}")
return result
# 执行函数
result = await func(*args, **kwargs)
# 存储结果
cache[cache_key] = (result, datetime.now())
logger.debug(f"Cache miss for {func.__name__}, cached for {ttl}s")
return result
# 添加清除缓存方法
wrapper.clear_cache = lambda: cache.clear()
return wrapper
return decorator
def measure_performance(func: Callable) -> Callable:
"""性能测量装饰器"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
start_time = time.perf_counter()
start_memory = 0 # 可以添加内存测量
try:
result = await func(*args, **kwargs)
return result
finally:
end_time = time.perf_counter()
duration = end_time - start_time
if duration > 1.0: # 超过1秒的操作记录警告
logger.warning(
f"Slow operation detected: {func.__name__} took {duration:.2f}s",
extra={
'function': func.__name__,
'duration': duration
}
)
else:
logger.debug(f"{func.__name__} completed in {duration:.4f}s")
return wrapper

122
src/utils/exceptions.py Normal file
View File

@@ -0,0 +1,122 @@
"""自定义异常类"""
from typing import Optional, Any, Dict
class BotException(Exception):
"""机器人基础异常"""
def __init__(self, message: str, code: str = None, details: Dict[str, Any] = None):
super().__init__(message)
self.message = message
self.code = code or self.__class__.__name__
self.details = details or {}
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
'error': self.code,
'message': self.message,
'details': self.details
}
class ConfigurationError(BotException):
"""配置错误"""
pass
class DatabaseError(BotException):
"""数据库错误"""
pass
class TelegramError(BotException):
"""Telegram API 错误"""
pass
class AuthenticationError(BotException):
"""认证错误"""
pass
class AuthorizationError(BotException):
"""授权错误"""
pass
class ValidationError(BotException):
"""验证错误"""
pass
class RateLimitError(BotException):
"""速率限制错误"""
pass
class SessionError(BotException):
"""会话错误"""
pass
class MessageRoutingError(BotException):
"""消息路由错误"""
pass
class BusinessLogicError(BotException):
"""业务逻辑错误"""
pass
class ExternalServiceError(BotException):
"""外部服务错误"""
pass
class ErrorHandler:
"""错误处理器"""
@staticmethod
async def handle_error(error: Exception, context: Dict[str, Any] = None) -> Dict[str, Any]:
"""处理错误"""
from ..utils.logger import get_logger
logger = get_logger(__name__)
error_info = {
'type': type(error).__name__,
'message': str(error),
'context': context or {}
}
if isinstance(error, BotException):
# 自定义异常
error_info.update(error.to_dict())
logger.error(f"Bot error: {error.message}", extra={'error_details': error_info})
else:
# 未知异常
logger.exception(f"Unexpected error: {error}", extra={'error_details': error_info})
error_info['message'] = "An unexpected error occurred"
return error_info
@staticmethod
def create_user_message(error: Exception) -> str:
"""创建用户友好的错误消息"""
if isinstance(error, AuthenticationError):
return "❌ 认证失败,请重新登录"
elif isinstance(error, AuthorizationError):
return "❌ 您没有权限执行此操作"
elif isinstance(error, ValidationError):
return f"❌ 输入无效:{error.message}"
elif isinstance(error, RateLimitError):
return "⚠️ 操作太频繁,请稍后再试"
elif isinstance(error, SessionError):
return "❌ 会话已过期,请重新开始"
elif isinstance(error, BusinessLogicError):
return f"❌ 操作失败:{error.message}"
elif isinstance(error, ExternalServiceError):
return "❌ 外部服务暂时不可用,请稍后再试"
else:
return "❌ 系统错误,请稍后再试或联系管理员"

167
src/utils/logger.py Normal file
View File

@@ -0,0 +1,167 @@
"""日志系统"""
import logging
import sys
from logging.handlers import RotatingFileHandler
from typing import Optional
from pathlib import Path
import json
from datetime import datetime
class CustomFormatter(logging.Formatter):
"""自定义日志格式化器"""
# 颜色代码
COLORS = {
'DEBUG': '\033[36m', # 青色
'INFO': '\033[32m', # 绿色
'WARNING': '\033[33m', # 黄色
'ERROR': '\033[31m', # 红色
'CRITICAL': '\033[35m', # 紫色
}
RESET = '\033[0m'
def __init__(self, use_color: bool = True):
super().__init__()
self.use_color = use_color and sys.stderr.isatty()
def format(self, record):
# 基础格式
log_format = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
# 添加额外信息
if hasattr(record, 'user_id'):
log_format = f"%(asctime)s | %(levelname)-8s | %(name)s | User:{record.user_id} | %(message)s"
if hasattr(record, 'chat_id'):
log_format = f"%(asctime)s | %(levelname)-8s | %(name)s | Chat:{record.chat_id} | %(message)s"
# 应用颜色
if self.use_color:
levelname = record.levelname
if levelname in self.COLORS:
log_format = log_format.replace(
'%(levelname)-8s',
f"{self.COLORS[levelname]}%(levelname)-8s{self.RESET}"
)
formatter = logging.Formatter(log_format, datefmt='%Y-%m-%d %H:%M:%S')
return formatter.format(record)
class JsonFormatter(logging.Formatter):
"""JSON 格式化器用于结构化日志"""
def format(self, record):
log_data = {
'timestamp': datetime.utcnow().isoformat(),
'level': record.levelname,
'logger': record.name,
'message': record.getMessage(),
'module': record.module,
'function': record.funcName,
'line': record.lineno
}
# 添加额外字段
for key, value in record.__dict__.items():
if key not in ['name', 'msg', 'args', 'created', 'filename',
'funcName', 'levelname', 'levelno', 'lineno',
'module', 'msecs', 'message', 'pathname', 'process',
'processName', 'relativeCreated', 'thread', 'threadName']:
log_data[key] = value
# 添加异常信息
if record.exc_info:
log_data['exception'] = self.formatException(record.exc_info)
return json.dumps(log_data, ensure_ascii=False)
class Logger:
"""日志管理器"""
_instance = None
_loggers = {}
def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, config=None):
if not hasattr(self, '_initialized'):
self._initialized = True
self.config = config
self.setup_logging()
def setup_logging(self):
"""设置日志系统"""
# 根日志配置
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG)
# 移除默认处理器
root_logger.handlers = []
# 控制台处理器
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(
getattr(logging, self.config.logging.level if self.config else 'INFO')
)
console_handler.setFormatter(CustomFormatter(use_color=True))
root_logger.addHandler(console_handler)
# 文件处理器
if self.config and self.config.logging.file:
file_path = Path(self.config.logging.file)
file_path.parent.mkdir(parents=True, exist_ok=True)
file_handler = RotatingFileHandler(
filename=str(file_path),
maxBytes=self.config.logging.max_size,
backupCount=self.config.logging.backup_count,
encoding='utf-8'
)
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(CustomFormatter(use_color=False))
root_logger.addHandler(file_handler)
# JSON 日志文件(用于分析)
json_file_path = file_path.with_suffix('.json')
json_handler = RotatingFileHandler(
filename=str(json_file_path),
maxBytes=self.config.logging.max_size,
backupCount=self.config.logging.backup_count,
encoding='utf-8'
)
json_handler.setLevel(logging.INFO)
json_handler.setFormatter(JsonFormatter())
root_logger.addHandler(json_handler)
@classmethod
def get_logger(cls, name: str, config=None) -> logging.Logger:
"""获取日志器"""
if name not in cls._loggers:
if not cls._instance:
cls(config)
cls._loggers[name] = logging.getLogger(name)
return cls._loggers[name]
def get_logger(name: str, config=None) -> logging.Logger:
"""获取日志器的便捷方法"""
return Logger.get_logger(name, config)
class LoggerContextFilter(logging.Filter):
"""日志上下文过滤器"""
def __init__(self, **context):
super().__init__()
self.context = context
def filter(self, record):
for key, value in self.context.items():
setattr(record, key, value)
return True