chore: initial commit

This commit is contained in:
你的用户名
2025-11-01 21:58:31 +08:00
commit 0406b5664f
101 changed files with 20458 additions and 0 deletions

5
src/config/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""配置管理模块"""
from .settings import Settings
from .loader import ConfigLoader
__all__ = ['Settings', 'ConfigLoader']

155
src/config/loader.py Normal file
View File

@@ -0,0 +1,155 @@
"""配置加载器"""
import os
from typing import Any, Dict
from pathlib import Path
from dotenv import load_dotenv
from .settings import (
Settings, TelegramConfig, DatabaseConfig,
LoggingConfig, BusinessConfig, SecurityConfig, FeatureFlags
)
class ConfigLoader:
"""配置加载器"""
@staticmethod
def load_env_file(env_path: str = None) -> None:
"""加载环境变量文件"""
if env_path:
load_dotenv(env_path)
else:
# 查找 .env 文件
current_dir = Path.cwd()
env_file = current_dir / ".env"
if env_file.exists():
load_dotenv(env_file)
else:
# 向上查找
for parent in current_dir.parents:
env_file = parent / ".env"
if env_file.exists():
load_dotenv(env_file)
break
@staticmethod
def get_env(key: str, default: Any = None, cast_type: type = str) -> Any:
"""获取环境变量并转换类型"""
value = os.getenv(key, default)
if value is None:
return None
if cast_type == bool:
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.lower() in ('true', '1', 'yes', 'on')
return bool(value)
elif cast_type == int:
return int(value)
elif cast_type == float:
return float(value)
else:
return value
@classmethod
def load_from_env(cls) -> Settings:
"""从环境变量加载配置"""
cls.load_env_file()
# Telegram 配置
telegram_config = TelegramConfig(
bot_token=cls.get_env('BOT_TOKEN', ''),
admin_id=cls.get_env('ADMIN_ID', 0, int),
admin_username=cls.get_env('ADMIN_USERNAME', ''),
bot_name=cls.get_env('BOT_NAME', 'Customer Service Bot')
)
# 数据库配置
database_config = DatabaseConfig(
type=cls.get_env('DATABASE_TYPE', 'sqlite'),
path=cls.get_env('DATABASE_PATH', './data/bot.db'),
host=cls.get_env('DATABASE_HOST'),
port=cls.get_env('DATABASE_PORT', cast_type=int),
user=cls.get_env('DATABASE_USER'),
password=cls.get_env('DATABASE_PASSWORD'),
database=cls.get_env('DATABASE_NAME')
)
# 日志配置
logging_config = LoggingConfig(
level=cls.get_env('LOG_LEVEL', 'INFO'),
file=cls.get_env('LOG_FILE', './logs/bot.log'),
max_size=cls.get_env('LOG_MAX_SIZE', 10485760, int),
backup_count=cls.get_env('LOG_BACKUP_COUNT', 5, int)
)
# 业务配置
business_config = BusinessConfig(
business_hours_start=cls.get_env('BUSINESS_HOURS_START', '09:00'),
business_hours_end=cls.get_env('BUSINESS_HOURS_END', '18:00'),
timezone=cls.get_env('TIMEZONE', 'Asia/Shanghai'),
auto_reply_delay=cls.get_env('AUTO_REPLY_DELAY', 1, int),
welcome_message=cls.get_env('WELCOME_MESSAGE',
'您好!我是客服助手,正在为您转接人工客服,请稍候...'),
offline_message=cls.get_env('OFFLINE_MESSAGE',
'非常抱歉,现在是非工作时间。我们的工作时间是 {start} - {end}。您的消息已记录,我们会在工作时间尽快回复您。')
)
# 安全配置
security_config = SecurityConfig(
max_messages_per_minute=cls.get_env('MAX_MESSAGES_PER_MINUTE', 30, int),
session_timeout=cls.get_env('SESSION_TIMEOUT', 3600, int),
enable_encryption=cls.get_env('ENABLE_ENCRYPTION', False, bool),
blocked_words=cls.get_env('BLOCKED_WORDS', '').split(',') if cls.get_env('BLOCKED_WORDS') else []
)
# 功能开关
feature_flags = FeatureFlags(
enable_auto_reply=cls.get_env('ENABLE_AUTO_REPLY', True, bool),
enable_statistics=cls.get_env('ENABLE_STATISTICS', True, bool),
enable_customer_history=cls.get_env('ENABLE_CUSTOMER_HISTORY', True, bool),
enable_multi_admin=cls.get_env('ENABLE_MULTI_ADMIN', False, bool),
enable_file_transfer=cls.get_env('ENABLE_FILE_TRANSFER', True, bool),
enable_voice_message=cls.get_env('ENABLE_VOICE_MESSAGE', True, bool),
enable_location_sharing=cls.get_env('ENABLE_LOCATION_SHARING', False, bool)
)
# 创建设置对象
settings = Settings(
telegram=telegram_config,
database=database_config,
logging=logging_config,
business=business_config,
security=security_config,
features=feature_flags,
debug=cls.get_env('DEBUG', False, bool),
testing=cls.get_env('TESTING', False, bool),
version=cls.get_env('VERSION', '1.0.0')
)
# 验证配置
settings.validate()
return settings
@classmethod
def load_from_dict(cls, config_dict: Dict[str, Any]) -> Settings:
"""从字典加载配置"""
telegram_config = TelegramConfig(**config_dict.get('telegram', {}))
database_config = DatabaseConfig(**config_dict.get('database', {}))
logging_config = LoggingConfig(**config_dict.get('logging', {}))
business_config = BusinessConfig(**config_dict.get('business', {}))
security_config = SecurityConfig(**config_dict.get('security', {}))
feature_flags = FeatureFlags(**config_dict.get('features', {}))
settings = Settings(
telegram=telegram_config,
database=database_config,
logging=logging_config,
business=business_config,
security=security_config,
features=feature_flags,
**config_dict.get('runtime', {})
)
settings.validate()
return settings

140
src/config/settings.py Normal file
View File

@@ -0,0 +1,140 @@
"""系统配置定义"""
from dataclasses import dataclass, field
from typing import Optional, List
from datetime import time
import os
@dataclass
class TelegramConfig:
"""Telegram 相关配置"""
bot_token: str
admin_id: int
admin_username: str
bot_name: str = "Customer Service Bot"
def __post_init__(self):
if not self.bot_token:
raise ValueError("Bot token is required")
if not self.admin_id:
raise ValueError("Admin ID is required")
@dataclass
class DatabaseConfig:
"""数据库配置"""
type: str = "sqlite"
path: str = "./data/bot.db"
host: Optional[str] = None
port: Optional[int] = None
user: Optional[str] = None
password: Optional[str] = None
database: Optional[str] = None
def get_connection_string(self) -> str:
"""获取数据库连接字符串"""
if self.type == "sqlite":
return f"sqlite:///{self.path}"
elif self.type == "postgresql":
return f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
elif self.type == "mysql":
return f"mysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
else:
raise ValueError(f"Unsupported database type: {self.type}")
@dataclass
class LoggingConfig:
"""日志配置"""
level: str = "INFO"
file: str = "./logs/bot.log"
max_size: int = 10485760 # 10MB
backup_count: int = 5
format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
def __post_init__(self):
# 确保日志目录存在
log_dir = os.path.dirname(self.file)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
@dataclass
class BusinessConfig:
"""业务配置"""
business_hours_start: str = "09:00"
business_hours_end: str = "18:00"
timezone: str = "Asia/Shanghai"
auto_reply_delay: int = 1 # 秒
welcome_message: str = "您好!我是客服助手,正在为您转接人工客服,请稍候..."
offline_message: str = "非常抱歉,现在是非工作时间。我们的工作时间是 {start} - {end}。您的消息已记录,我们会在工作时间尽快回复您。"
def get_business_hours(self) -> tuple[time, time]:
"""获取营业时间"""
start = time.fromisoformat(self.business_hours_start)
end = time.fromisoformat(self.business_hours_end)
return start, end
@dataclass
class SecurityConfig:
"""安全配置"""
max_messages_per_minute: int = 30
session_timeout: int = 3600 # 秒
enable_encryption: bool = False
blocked_words: List[str] = field(default_factory=list)
allowed_file_types: List[str] = field(default_factory=lambda: [
'.jpg', '.jpeg', '.png', '.gif', '.pdf', '.doc', '.docx'
])
max_file_size: int = 10485760 # 10MB
@dataclass
class FeatureFlags:
"""功能开关"""
enable_auto_reply: bool = True
enable_statistics: bool = True
enable_customer_history: bool = True
enable_multi_admin: bool = False
enable_file_transfer: bool = True
enable_voice_message: bool = True
enable_location_sharing: bool = False
@dataclass
class Settings:
"""主配置类"""
telegram: TelegramConfig
database: DatabaseConfig
logging: LoggingConfig
business: BusinessConfig
security: SecurityConfig
features: FeatureFlags
# 运行时配置
debug: bool = False
testing: bool = False
version: str = "1.0.0"
@classmethod
def from_env(cls) -> 'Settings':
"""从环境变量创建配置"""
from .loader import ConfigLoader
return ConfigLoader.load_from_env()
def validate(self) -> bool:
"""验证配置完整性"""
try:
# 验证必要配置
assert self.telegram.bot_token, "Bot token is required"
assert self.telegram.admin_id, "Admin ID is required"
# 验证路径
if self.database.type == "sqlite":
db_dir = os.path.dirname(self.database.path)
if db_dir and not os.path.exists(db_dir):
os.makedirs(db_dir, exist_ok=True)
return True
except Exception as e:
raise ValueError(f"Configuration validation failed: {e}")

6
src/core/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
"""核心模块"""
from .bot import CustomerServiceBot
from .router import MessageRouter
from .handlers import BaseHandler, HandlerContext
__all__ = ['CustomerServiceBot', 'MessageRouter', 'BaseHandler', 'HandlerContext']

693
src/core/bot.py Normal file
View File

@@ -0,0 +1,693 @@
"""客服机器人主类"""
import asyncio
from typing import Optional, Dict, Any, List
from datetime import datetime
from telegram import Update, Bot, BotCommand, InlineKeyboardButton, InlineKeyboardMarkup
from telegram.ext import Application, CommandHandler, MessageHandler, CallbackQueryHandler, filters
from ..config.settings import Settings
from ..utils.logger import get_logger, Logger
from ..utils.exceptions import BotException, ErrorHandler
from ..utils.decorators import log_action, measure_performance
from .router import MessageRouter, RouteBuilder, MessageContext
from .handlers import BaseHandler, HandlerContext
logger = get_logger(__name__)
class CustomerServiceBot:
"""客服机器人"""
def __init__(self, config: Settings = None):
"""初始化机器人"""
# 加载配置
self.config = config or Settings.from_env()
self.config.validate()
# 初始化日志系统
Logger(self.config)
self.logger = get_logger(self.__class__.__name__, self.config)
# 初始化组件
self.application: Optional[Application] = None
self.router = MessageRouter(self.config)
self.route_builder = RouteBuilder(self.router)
self.handlers: Dict[str, BaseHandler] = {}
self.active_sessions: Dict[str, Dict[str, Any]] = {}
# 当前会话管理
self.current_customer = None # 当前正在对话的客户
# 统计信息
self.stats = {
'messages_received': 0,
'messages_forwarded': 0,
'replies_sent': 0,
'errors': 0,
'start_time': datetime.now()
}
self.logger.info(f"Bot initialized with version {self.config.version}")
async def initialize(self):
"""异步初始化"""
try:
# 创建应用
self.application = Application.builder().token(
self.config.telegram.bot_token
).build()
# 设置命令
await self.setup_commands()
# 注册处理器
self.register_handlers()
# 初始化数据库(如果需要)
if self.config.features.enable_customer_history:
from ..modules.storage import DatabaseManager
self.db_manager = DatabaseManager(self.config)
await self.db_manager.initialize()
self.logger.info("Bot initialization completed")
except Exception as e:
self.logger.error(f"Failed to initialize bot: {e}")
raise
async def setup_commands(self):
"""设置机器人命令"""
commands = [
BotCommand("start", "开始使用机器人"),
BotCommand("help", "获取帮助信息"),
BotCommand("status", "查看机器人状态"),
BotCommand("contact", "联系人工客服"),
]
# 管理员命令
admin_commands = commands + [
BotCommand("stats", "查看统计信息"),
BotCommand("sessions", "查看活跃会话"),
BotCommand("reply", "回复客户消息"),
BotCommand("broadcast", "广播消息"),
BotCommand("settings", "机器人设置"),
]
# 设置命令
await self.application.bot.set_my_commands(commands)
# 为管理员设置特殊命令
await self.application.bot.set_my_commands(
admin_commands,
scope={"type": "chat", "chat_id": self.config.telegram.admin_id}
)
def register_handlers(self):
"""注册消息处理器"""
# 命令处理器
self.application.add_handler(CommandHandler("start", self.handle_start))
self.application.add_handler(CommandHandler("help", self.handle_help))
self.application.add_handler(CommandHandler("status", self.handle_status))
self.application.add_handler(CommandHandler("contact", self.handle_contact))
# 管理员命令
self.application.add_handler(CommandHandler("stats", self.handle_stats))
self.application.add_handler(CommandHandler("sessions", self.handle_sessions))
self.application.add_handler(CommandHandler("reply", self.handle_reply))
self.application.add_handler(CommandHandler("broadcast", self.handle_broadcast))
self.application.add_handler(CommandHandler("settings", self.handle_settings))
# 消息处理器 - 处理所有消息(包括搜索指令)
# 只排除机器人自己处理的命令,其他命令(如搜索指令)也会转发
self.application.add_handler(MessageHandler(
filters.ALL,
self.handle_message
))
# 回调查询处理器
self.application.add_handler(CallbackQueryHandler(self.handle_callback))
# 错误处理器
self.application.add_error_handler(self.handle_error)
@log_action("start_command")
async def handle_start(self, update: Update, context):
"""处理 /start 命令"""
user = update.effective_user
is_admin = user.id == self.config.telegram.admin_id
if is_admin:
text = (
f"👋 欢迎,管理员 {user.first_name}\n\n"
"🤖 客服机器人已就绪\n"
"📊 使用 /stats 查看统计\n"
"💬 使用 /sessions 查看会话\n"
"⚙️ 使用 /settings 进行设置"
)
else:
text = (
f"👋 您好 {user.first_name}\n\n"
"暂时支持的搜索指令:\n\n"
"- 群组目录 /topchat\n"
"- 群组搜索 /search\n"
"- 按消息文本搜索 /text\n"
"- 按名称搜索 /human\n\n"
"您可以使用以上指令进行搜索,或直接发送消息联系客服。"
)
# 通知管理员
await self.notify_admin_new_customer(user)
await update.message.reply_text(text)
self.stats['messages_received'] += 1
async def handle_help(self, update: Update, context):
"""处理 /help 命令"""
user = update.effective_user
is_admin = user.id == self.config.telegram.admin_id
if is_admin:
text = self._get_admin_help()
else:
text = self._get_user_help()
await update.message.reply_text(text, parse_mode='Markdown')
async def handle_status(self, update: Update, context):
"""处理 /status 命令"""
uptime = datetime.now() - self.stats['start_time']
hours = uptime.total_seconds() / 3600
text = (
"✅ 机器人运行正常\n\n"
f"⏱ 运行时间:{hours:.1f} 小时\n"
f"📊 处理消息:{self.stats['messages_received']}\n"
f"👥 活跃会话:{len(self.active_sessions)}"
)
await update.message.reply_text(text)
async def handle_contact(self, update: Update, context):
"""处理 /contact 命令"""
await update.message.reply_text(
"正在为您转接人工客服,请稍候...\n"
"您可以直接发送消息,客服会尽快回复您。"
)
# 修复:传递正确的 context 参数
await self.forward_customer_message(update, context)
@measure_performance
async def handle_message(self, update: Update, context):
"""处理普通消息"""
try:
user = update.effective_user
message = update.effective_message
is_admin = user.id == self.config.telegram.admin_id
self.stats['messages_received'] += 1
if is_admin:
# 管理员消息 - 检查是否是回复
if message.reply_to_message:
await self.handle_admin_reply(update, context)
elif self.current_customer:
# 如果有当前客户,直接发送给当前客户
await self.reply_to_current_customer(update, context)
else:
# 没有当前客户时,提示管理员
await message.reply_text(
"💡 提示:暂无活跃客户\n\n"
"等待客户发送消息,或使用:\n"
"• 直接回复转发的客户消息\n"
"• /sessions 查看所有会话\n"
"• /reply <用户ID> <消息> 回复指定用户"
)
else:
# 客户消息 - 转发给管理员(包括搜索指令)
# 处理所有客户消息,包括 /topchat, /search, /text, /human 等指令
await self.forward_customer_message(update, context)
except Exception as e:
self.logger.error(f"Error handling message: {e}")
await self.send_error_message(update, e)
async def forward_customer_message(self, update: Update, context):
"""转发客户消息给管理员"""
user = update.effective_user
message = update.effective_message
chat = update.effective_chat
# 创建或更新会话
session_id = f"{chat.id}_{user.id}"
if session_id not in self.active_sessions:
self.active_sessions[session_id] = {
'user_id': user.id,
'username': user.username,
'first_name': user.first_name,
'chat_id': chat.id,
'messages': [],
'started_at': datetime.now()
}
# 记录消息
self.active_sessions[session_id]['messages'].append({
'message_id': message.message_id,
'text': message.text or "[非文本消息]",
'timestamp': datetime.now()
})
# 设置为当前客户
self.current_customer = {
'user_id': user.id,
'chat_id': chat.id,
'username': user.username,
'first_name': user.first_name,
'session_id': session_id
}
# 构建用户信息 - 转义特殊字符
def escape_markdown(text):
"""转义 Markdown 特殊字符"""
if text is None:
return ''
# 转义特殊字符
special_chars = ['_', '*', '[', ']', '(', ')', '~', '`', '>', '#', '+', '-', '=', '|', '{', '}', '.', '!']
for char in special_chars:
text = str(text).replace(char, f'\\{char}')
return text
first_name = escape_markdown(user.first_name)
last_name = escape_markdown(user.last_name) if user.last_name else ''
username = escape_markdown(user.username) if user.username else 'N/A'
# 构建用户信息
user_info = (
f"📨 来自客户的消息\n"
f"👤 姓名:{first_name} {last_name}\n"
f"🆔 ID`{user.id}`\n"
f"📱 用户名:@{username}\n"
f"💬 会话:`{session_id}`\n"
f"━━━━━━━━━━━━━━━━"
)
# 发送用户信息
await context.bot.send_message(
chat_id=self.config.telegram.admin_id,
text=user_info,
parse_mode='MarkdownV2'
)
# 转发原始消息
forwarded = await context.bot.forward_message(
chat_id=self.config.telegram.admin_id,
from_chat_id=chat.id,
message_id=message.message_id
)
# 保存转发消息ID映射
context.bot_data.setdefault('message_map', {})[forwarded.message_id] = {
'original_chat': chat.id,
'original_user': user.id,
'session_id': session_id
}
# 提示管理员可以直接输入文字回复
await context.bot.send_message(
chat_id=self.config.telegram.admin_id,
text="💬 现在可以直接输入文字回复此客户,或回复上方转发的消息"
)
# 自动回复(如果启用)
if self.config.features.enable_auto_reply and not is_business_hours(self.config):
await self.send_auto_reply(update, context)
self.stats['messages_forwarded'] += 1
async def handle_admin_reply(self, update: Update, context):
"""处理管理员回复"""
replied_to = update.message.reply_to_message
# 查找原始消息信息
message_map = context.bot_data.get('message_map', {})
if replied_to.message_id not in message_map:
await update.message.reply_text("⚠️ 无法找到原始消息信息")
return
original_info = message_map[replied_to.message_id]
original_chat = original_info['original_chat']
session_id = original_info['session_id']
# 发送回复给客户
try:
if update.message.text:
await context.bot.send_message(
chat_id=original_chat,
text=update.message.text
)
elif update.message.photo:
await context.bot.send_photo(
chat_id=original_chat,
photo=update.message.photo[-1].file_id,
caption=update.message.caption
)
elif update.message.document:
await context.bot.send_document(
chat_id=original_chat,
document=update.message.document.file_id,
caption=update.message.caption
)
# 确认发送
await update.message.reply_text("✅ 消息已发送给客户")
# 更新会话
if session_id in self.active_sessions:
self.active_sessions[session_id]['last_reply'] = datetime.now()
self.stats['replies_sent'] += 1
except Exception as e:
await update.message.reply_text(f"❌ 发送失败:{e}")
self.logger.error(f"Failed to send reply: {e}")
async def handle_stats(self, update: Update, context):
"""处理 /stats 命令(管理员)"""
if update.effective_user.id != self.config.telegram.admin_id:
return
uptime = datetime.now() - self.stats['start_time']
days = uptime.days
hours = uptime.seconds // 3600
text = (
"📊 **统计信息**\n\n"
f"⏱ 运行时间:{days}{hours} 小时\n"
f"📨 接收消息:{self.stats['messages_received']}\n"
f"📤 转发消息:{self.stats['messages_forwarded']}\n"
f"💬 回复消息:{self.stats['replies_sent']}\n"
f"❌ 错误次数:{self.stats['errors']}\n"
f"👥 活跃会话:{len(self.active_sessions)}\n"
f"📅 启动时间:{self.stats['start_time'].strftime('%Y-%m-%d %H:%M:%S')}"
)
await update.message.reply_text(text, parse_mode='Markdown')
async def handle_sessions(self, update: Update, context):
"""处理 /sessions 命令(管理员)"""
if update.effective_user.id != self.config.telegram.admin_id:
return
if not self.active_sessions:
await update.message.reply_text("当前没有活跃会话")
return
text = "👥 **活跃会话**\n\n"
for session_id, session in self.active_sessions.items():
duration = datetime.now() - session['started_at']
text += (
f"会话 `{session_id}`\n"
f"👤 {session['first_name']} (@{session['username'] or 'N/A'})\n"
f"💬 消息数:{len(session['messages'])}\n"
f"⏱ 时长:{duration.seconds // 60} 分钟\n"
f"━━━━━━━━━━━━━━━━\n"
)
await update.message.reply_text(text, parse_mode='Markdown')
async def handle_reply(self, update: Update, context):
"""处理 /reply 命令(管理员)"""
if update.effective_user.id != self.config.telegram.admin_id:
return
if len(context.args) < 2:
await update.message.reply_text(
"用法:/reply <用户ID> <消息>\n"
"示例:/reply 123456789 您好,有什么可以帮助您?"
)
return
try:
user_id = int(context.args[0])
message = ' '.join(context.args[1:])
await context.bot.send_message(chat_id=user_id, text=message)
await update.message.reply_text(f"✅ 消息已发送给用户 {user_id}")
self.stats['replies_sent'] += 1
except Exception as e:
await update.message.reply_text(f"❌ 发送失败:{e}")
async def reply_to_current_customer(self, update: Update, context):
"""回复当前客户"""
if not self.current_customer:
await update.message.reply_text("❌ 没有选中的客户")
return
try:
message = update.effective_message
# 发送消息给当前客户
if message.text:
await context.bot.send_message(
chat_id=self.current_customer['chat_id'],
text=message.text
)
elif message.photo:
await context.bot.send_photo(
chat_id=self.current_customer['chat_id'],
photo=message.photo[-1].file_id,
caption=message.caption
)
elif message.document:
await context.bot.send_document(
chat_id=self.current_customer['chat_id'],
document=message.document.file_id,
caption=message.caption
)
# 简洁确认消息
await update.message.reply_text(f"✅ → {self.current_customer['first_name']}")
self.stats['replies_sent'] += 1
except Exception as e:
await update.message.reply_text(f"❌ 发送失败:{e}")
self.logger.error(f"Failed to send reply: {e}")
async def handle_broadcast(self, update: Update, context):
"""处理 /broadcast 命令(管理员)"""
if update.effective_user.id != self.config.telegram.admin_id:
return
if not context.args:
await update.message.reply_text(
"用法:/broadcast <消息>\n"
"示例:/broadcast 系统维护通知今晚10点进行系统维护"
)
return
message = ' '.join(context.args)
sent = 0
failed = 0
for session_id, session in self.active_sessions.items():
try:
await context.bot.send_message(
chat_id=session['chat_id'],
text=message
)
sent += 1
except Exception as e:
failed += 1
self.logger.error(f"Failed to broadcast to {session['chat_id']}: {e}")
await update.message.reply_text(
f"✅ 广播完成\n"
f"成功:{sent}\n"
f"失败:{failed}"
)
async def handle_settings(self, update: Update, context):
"""处理 /settings 命令(管理员)"""
if update.effective_user.id != self.config.telegram.admin_id:
return
keyboard = [
[
InlineKeyboardButton(
f"{'' if self.config.features.enable_auto_reply else ''} 自动回复",
callback_data="toggle_auto_reply"
)
],
[
InlineKeyboardButton(
f"{'' if self.config.features.enable_statistics else ''} 统计功能",
callback_data="toggle_statistics"
)
],
[
InlineKeyboardButton("📊 查看所有设置", callback_data="view_settings")
]
]
reply_markup = InlineKeyboardMarkup(keyboard)
await update.message.reply_text(
"⚙️ **机器人设置**\n\n点击按钮切换功能:",
reply_markup=reply_markup,
parse_mode='Markdown'
)
async def handle_callback(self, update: Update, context):
"""处理回调查询"""
query = update.callback_query
await query.answer()
data = query.data
if data.startswith("done_"):
session_id = data.replace("done_", "")
await query.edit_message_text(f"✅ 会话 {session_id} 已标记为完成")
elif data.startswith("later_"):
session_id = data.replace("later_", "")
await query.edit_message_text(f"⏸ 会话 {session_id} 已标记为稍后处理")
elif data == "toggle_auto_reply":
self.config.features.enable_auto_reply = not self.config.features.enable_auto_reply
await query.edit_message_text(
f"自动回复已{'启用' if self.config.features.enable_auto_reply else '禁用'}"
)
async def handle_error(self, update: Update, context):
"""处理错误"""
self.stats['errors'] += 1
error_info = await ErrorHandler.handle_error(context.error)
self.logger.error(f"Update {update} caused error {context.error}")
if update and update.effective_message:
user_message = ErrorHandler.create_user_message(context.error)
await update.effective_message.reply_text(user_message)
async def notify_admin_new_customer(self, user):
"""通知管理员有新客户"""
def escape_markdown(text):
"""转义 Markdown 特殊字符"""
if text is None:
return ''
# 转义特殊字符
special_chars = ['_', '*', '[', ']', '(', ')', '~', '`', '>', '#', '+', '-', '=', '|', '{', '}', '.', '!']
for char in special_chars:
text = str(text).replace(char, f'\\{char}')
return text
first_name = escape_markdown(user.first_name)
last_name = escape_markdown(user.last_name) if user.last_name else ''
username = escape_markdown(user.username) if user.username else 'N/A'
text = (
f"🆕 新客户加入\n"
f"👤 姓名:{first_name} {last_name}\n"
f"🆔 ID`{user.id}`\n"
f"📱 用户名:@{username}"
)
try:
await self.application.bot.send_message(
chat_id=self.config.telegram.admin_id,
text=text,
parse_mode='MarkdownV2'
)
except Exception as e:
self.logger.error(f"Failed to notify admin: {e}")
async def send_auto_reply(self, update: Update, context):
"""发送自动回复"""
import pytz
from datetime import time
# 检查营业时间
tz = pytz.timezone(self.config.business.timezone)
now = datetime.now(tz).time()
start_time = time.fromisoformat(self.config.business.business_hours_start)
end_time = time.fromisoformat(self.config.business.business_hours_end)
if not (start_time <= now <= end_time):
message = self.config.business.offline_message.format(
start=self.config.business.business_hours_start,
end=self.config.business.business_hours_end
)
await update.message.reply_text(message)
async def send_error_message(self, update: Update, error: Exception):
"""发送错误消息给用户"""
user_message = ErrorHandler.create_user_message(error)
await update.message.reply_text(user_message)
def _get_admin_help(self) -> str:
"""获取管理员帮助信息"""
return (
"📚 **管理员帮助**\n\n"
"**基础命令**\n"
"/start - 启动机器人\n"
"/help - 显示帮助\n"
"/status - 查看状态\n\n"
"**管理命令**\n"
"/stats - 查看统计信息\n"
"/sessions - 查看活跃会话\n"
"/reply <用户ID> <消息> - 回复指定用户\n"
"/broadcast <消息> - 广播消息给所有用户\n"
"/settings - 机器人设置\n\n"
"**快速回复客户**\n"
"• 直接输入文字 - 自动发送给最近的客户\n"
"• 回复转发消息 - 回复特定客户"
)
def _get_user_help(self) -> str:
"""获取用户帮助信息"""
return (
"📚 **帮助信息**\n\n"
"**使用方法**\n"
"• 直接发送消息,客服会尽快回复您\n"
"• 支持发送文字、图片、文件等\n"
"• /contact - 联系人工客服\n"
"• /status - 查看服务状态\n\n"
f"**工作时间**\n"
f"{self.config.business.business_hours_start} - {self.config.business.business_hours_end}\n\n"
"如有紧急情况,请留言,我们会尽快处理"
)
def run(self):
"""运行机器人"""
try:
# 同步初始化
asyncio.get_event_loop().run_until_complete(self.initialize())
self.logger.info("Starting bot...")
self.application.run_polling(allowed_updates=Update.ALL_TYPES)
except KeyboardInterrupt:
self.logger.info("Bot stopped by user")
except Exception as e:
self.logger.error(f"Bot crashed: {e}")
raise
finally:
self.cleanup()
def cleanup(self):
"""清理资源"""
self.logger.info("Cleaning up resources...")
# 保存统计信息、关闭数据库等
pass
def is_business_hours(config: Settings) -> bool:
"""检查是否在营业时间"""
import pytz
from datetime import time
tz = pytz.timezone(config.business.timezone)
now = datetime.now(tz).time()
start_time = time.fromisoformat(config.business.business_hours_start)
end_time = time.fromisoformat(config.business.business_hours_end)
return start_time <= now <= end_time

157
src/core/handlers.py Normal file
View File

@@ -0,0 +1,157 @@
"""处理器基类和上下文"""
from abc import ABC, abstractmethod
from typing import Any, Optional, Dict, List
from dataclasses import dataclass, field
from telegram import Update, Message
from telegram.ext import ContextTypes
from ..utils.logger import get_logger
from ..config.settings import Settings
logger = get_logger(__name__)
@dataclass
class HandlerContext:
"""处理器上下文"""
update: Update
context: ContextTypes.DEFAULT_TYPE
config: Settings
user_data: Dict[str, Any] = field(default_factory=dict)
chat_data: Dict[str, Any] = field(default_factory=dict)
session_data: Dict[str, Any] = field(default_factory=dict)
@property
def message(self) -> Message:
"""获取消息"""
return self.update.effective_message
@property
def user(self):
"""获取用户"""
return self.update.effective_user
@property
def chat(self):
"""获取聊天"""
return self.update.effective_chat
def get_session_id(self) -> str:
"""获取会话ID"""
return f"{self.chat.id}_{self.user.id}"
class BaseHandler(ABC):
"""处理器基类"""
def __init__(self, config: Settings):
self.config = config
self.logger = get_logger(self.__class__.__name__)
@abstractmethod
async def handle(self, handler_context: HandlerContext) -> Any:
"""处理消息"""
pass
async def __call__(self, update: Update, context: ContextTypes.DEFAULT_TYPE,
message_context: Any = None) -> Any:
"""调用处理器"""
handler_context = HandlerContext(
update=update,
context=context,
config=self.config,
user_data=context.user_data,
chat_data=context.chat_data
)
try:
self.logger.debug(f"Handling message from user {handler_context.user.id}")
result = await self.handle(handler_context)
return result
except Exception as e:
self.logger.error(f"Error in handler: {e}")
raise
async def reply_text(self, context: HandlerContext, text: str, **kwargs) -> Message:
"""回复文本消息"""
return await context.message.reply_text(text, **kwargs)
async def reply_photo(self, context: HandlerContext, photo, caption: str = None, **kwargs) -> Message:
"""回复图片"""
return await context.message.reply_photo(photo, caption=caption, **kwargs)
async def reply_document(self, context: HandlerContext, document, caption: str = None, **kwargs) -> Message:
"""回复文档"""
return await context.message.reply_document(document, caption=caption, **kwargs)
async def forward_to_admin(self, context: HandlerContext) -> Message:
"""转发消息给管理员"""
return await context.context.bot.forward_message(
chat_id=self.config.telegram.admin_id,
from_chat_id=context.chat.id,
message_id=context.message.message_id
)
async def send_to_admin(self, context: HandlerContext, text: str, **kwargs) -> Message:
"""发送消息给管理员"""
return await context.context.bot.send_message(
chat_id=self.config.telegram.admin_id,
text=text,
**kwargs
)
class CompositeHandler(BaseHandler):
"""组合处理器"""
def __init__(self, config: Settings):
super().__init__(config)
self.handlers: List[BaseHandler] = []
def add_handler(self, handler: BaseHandler):
"""添加处理器"""
self.handlers.append(handler)
async def handle(self, handler_context: HandlerContext) -> Any:
"""依次执行所有处理器"""
results = []
for handler in self.handlers:
try:
result = await handler.handle(handler_context)
results.append(result)
except Exception as e:
self.logger.error(f"Error in composite handler {handler.__class__.__name__}: {e}")
# 可以选择继续或中断
raise
return results
class ConditionalHandler(BaseHandler):
"""条件处理器"""
def __init__(self, config: Settings, condition_func):
super().__init__(config)
self.condition_func = condition_func
self.true_handler: Optional[BaseHandler] = None
self.false_handler: Optional[BaseHandler] = None
def set_true_handler(self, handler: BaseHandler):
"""设置条件为真时的处理器"""
self.true_handler = handler
def set_false_handler(self, handler: BaseHandler):
"""设置条件为假时的处理器"""
self.false_handler = handler
async def handle(self, handler_context: HandlerContext) -> Any:
"""根据条件执行处理器"""
if await self.condition_func(handler_context):
if self.true_handler:
return await self.true_handler.handle(handler_context)
else:
if self.false_handler:
return await self.false_handler.handle(handler_context)
return None

321
src/core/router.py Normal file
View File

@@ -0,0 +1,321 @@
"""消息路由系统"""
import asyncio
from typing import Dict, List, Optional, Callable, Any, Type
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from telegram import Update, Message, User, Chat
from telegram.ext import ContextTypes
from ..utils.logger import get_logger
from ..utils.exceptions import MessageRoutingError
from ..utils.decorators import log_action, measure_performance
logger = get_logger(__name__)
class MessageType(Enum):
"""消息类型枚举"""
TEXT = "text"
PHOTO = "photo"
VIDEO = "video"
AUDIO = "audio"
VOICE = "voice"
DOCUMENT = "document"
STICKER = "sticker"
LOCATION = "location"
CONTACT = "contact"
POLL = "poll"
COMMAND = "command"
CALLBACK = "callback"
INLINE = "inline"
class RoutePriority(Enum):
"""路由优先级"""
CRITICAL = 0
HIGH = 1
NORMAL = 2
LOW = 3
@dataclass
class RoutePattern:
"""路由模式"""
pattern: str
type: MessageType
priority: RoutePriority = RoutePriority.NORMAL
conditions: List[Callable] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
def matches(self, message: Message) -> bool:
"""检查消息是否匹配模式"""
# 检查消息类型
if not self._check_message_type(message):
return False
# 检查模式匹配
if self.type == MessageType.TEXT and message.text:
if not self._match_text_pattern(message.text):
return False
# 检查条件
for condition in self.conditions:
if not condition(message):
return False
return True
def _check_message_type(self, message: Message) -> bool:
"""检查消息类型是否匹配"""
type_map = {
MessageType.TEXT: lambda m: m.text is not None,
MessageType.PHOTO: lambda m: m.photo is not None,
MessageType.VIDEO: lambda m: m.video is not None,
MessageType.AUDIO: lambda m: m.audio is not None,
MessageType.VOICE: lambda m: m.voice is not None,
MessageType.DOCUMENT: lambda m: m.document is not None,
MessageType.STICKER: lambda m: m.sticker is not None,
MessageType.LOCATION: lambda m: m.location is not None,
MessageType.CONTACT: lambda m: m.contact is not None,
MessageType.POLL: lambda m: m.poll is not None,
}
check_func = type_map.get(self.type)
return check_func(message) if check_func else False
def _match_text_pattern(self, text: str) -> bool:
"""匹配文本模式"""
import re
if self.pattern.startswith("^") or self.pattern.endswith("$"):
# 正则表达式
return bool(re.match(self.pattern, text))
else:
# 简单包含检查
return self.pattern in text
@dataclass
class MessageContext:
"""消息上下文"""
message_id: str
user_id: int
chat_id: int
username: Optional[str]
first_name: Optional[str]
last_name: Optional[str]
message_type: MessageType
content: Any
timestamp: datetime
is_admin: bool = False
session_id: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
@classmethod
def from_update(cls, update: Update, admin_id: int) -> 'MessageContext':
"""从更新创建上下文"""
message = update.effective_message
user = update.effective_user
chat = update.effective_chat
# 确定消息类型
if message.text and message.text.startswith('/'):
msg_type = MessageType.COMMAND
elif message.text:
msg_type = MessageType.TEXT
elif message.photo:
msg_type = MessageType.PHOTO
elif message.video:
msg_type = MessageType.VIDEO
elif message.voice:
msg_type = MessageType.VOICE
elif message.document:
msg_type = MessageType.DOCUMENT
elif message.location:
msg_type = MessageType.LOCATION
else:
msg_type = MessageType.TEXT
# 提取内容
content = message.text or message.caption or ""
if message.photo:
content = message.photo[-1].file_id
elif message.document:
content = message.document.file_id
elif message.voice:
content = message.voice.file_id
elif message.video:
content = message.video.file_id
return cls(
message_id=str(message.message_id),
user_id=user.id,
chat_id=chat.id,
username=user.username,
first_name=user.first_name,
last_name=user.last_name,
message_type=msg_type,
content=content,
timestamp=datetime.now(),
is_admin=(user.id == admin_id),
session_id=f"{chat.id}_{user.id}"
)
class MessageRouter:
"""消息路由器"""
def __init__(self, config):
self.config = config
self.routes: Dict[RoutePriority, List[tuple[RoutePattern, Callable]]] = {
priority: [] for priority in RoutePriority
}
self.middleware: List[Callable] = []
self.default_handler: Optional[Callable] = None
self.error_handler: Optional[Callable] = None
def add_route(self, pattern: RoutePattern, handler: Callable):
"""添加路由"""
self.routes[pattern.priority].append((pattern, handler))
logger.debug(f"Added route: {pattern.pattern} with priority {pattern.priority}")
def add_middleware(self, middleware: Callable):
"""添加中间件"""
self.middleware.append(middleware)
logger.debug(f"Added middleware: {middleware.__name__}")
def set_default_handler(self, handler: Callable):
"""设置默认处理器"""
self.default_handler = handler
def set_error_handler(self, handler: Callable):
"""设置错误处理器"""
self.error_handler = handler
@measure_performance
@log_action("route_message")
async def route(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> Any:
"""路由消息"""
try:
# 创建消息上下文
msg_context = MessageContext.from_update(
update,
self.config.telegram.admin_id
)
# 应用中间件
for middleware in self.middleware:
result = await middleware(msg_context, context)
if result is False:
logger.debug(f"Middleware {middleware.__name__} blocked message")
return None
# 查找匹配的路由
handler = await self._find_handler(update.effective_message, msg_context)
if handler:
logger.info(
f"Routing message to {handler.__name__}",
extra={'user_id': msg_context.user_id, 'handler': handler.__name__}
)
return await handler(update, context, msg_context)
elif self.default_handler:
logger.info(
f"Using default handler",
extra={'user_id': msg_context.user_id}
)
return await self.default_handler(update, context, msg_context)
else:
logger.warning(f"No handler found for message from user {msg_context.user_id}")
raise MessageRoutingError("No handler found for this message type")
except Exception as e:
if self.error_handler:
return await self.error_handler(update, context, e)
else:
logger.error(f"Error in message routing: {e}")
raise
async def _find_handler(self, message: Message, context: MessageContext) -> Optional[Callable]:
"""查找合适的处理器"""
# 按优先级顺序检查路由
for priority in RoutePriority:
for pattern, handler in self.routes[priority]:
if pattern.matches(message):
return handler
return None
class RouteBuilder:
"""路由构建器"""
def __init__(self, router: MessageRouter):
self.router = router
def text(self, pattern: str = None, priority: RoutePriority = RoutePriority.NORMAL):
"""文本消息路由装饰器"""
def decorator(handler: Callable):
route_pattern = RoutePattern(
pattern=pattern or ".*",
type=MessageType.TEXT,
priority=priority
)
self.router.add_route(route_pattern, handler)
return handler
return decorator
def command(self, command: str, priority: RoutePriority = RoutePriority.HIGH):
"""命令路由装饰器"""
def decorator(handler: Callable):
route_pattern = RoutePattern(
pattern=f"^/{command}",
type=MessageType.TEXT,
priority=priority
)
self.router.add_route(route_pattern, handler)
return handler
return decorator
def photo(self, priority: RoutePriority = RoutePriority.NORMAL):
"""图片消息路由装饰器"""
def decorator(handler: Callable):
route_pattern = RoutePattern(
pattern="",
type=MessageType.PHOTO,
priority=priority
)
self.router.add_route(route_pattern, handler)
return handler
return decorator
def document(self, priority: RoutePriority = RoutePriority.NORMAL):
"""文档消息路由装饰器"""
def decorator(handler: Callable):
route_pattern = RoutePattern(
pattern="",
type=MessageType.DOCUMENT,
priority=priority
)
self.router.add_route(route_pattern, handler)
return handler
return decorator
def voice(self, priority: RoutePriority = RoutePriority.NORMAL):
"""语音消息路由装饰器"""
def decorator(handler: Callable):
route_pattern = RoutePattern(
pattern="",
type=MessageType.VOICE,
priority=priority
)
self.router.add_route(route_pattern, handler)
return handler
return decorator
def middleware(self):
"""中间件装饰器"""
def decorator(handler: Callable):
self.router.add_middleware(handler)
return handler
return decorator

View File

@@ -0,0 +1,279 @@
"""
搜索镜像模块 - 自动转发搜索指令到目标机器人并返回结果
基于 jingxiang 项目的镜像机制
"""
import asyncio
import logging
from typing import Dict, Optional, Any
from pyrogram import Client, filters
from pyrogram.types import Message as PyrogramMessage
from pyrogram.raw.functions.messages import GetBotCallbackAnswer
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
from telegram.ext import ContextTypes
logger = logging.getLogger(__name__)
class MirrorSearchHandler:
"""处理搜索指令的镜像转发"""
def __init__(self, config):
self.config = config
self.enabled = False
# Pyrogram配置需要在.env中配置
self.api_id = None
self.api_hash = None
self.session_name = "search_mirror_session"
self.target_bot = "@openaiw_bot" # 目标搜索机器人
# Pyrogram客户端
self.pyrogram_client: Optional[Client] = None
self.target_bot_id: Optional[int] = None
# 消息映射
self.user_search_requests: Dict[int, Dict[str, Any]] = {} # user_id -> search_info
self.pyrogram_to_user: Dict[int, int] = {} # pyrogram_msg_id -> user_id
self.user_to_telegram: Dict[int, int] = {} # user_id -> telegram_msg_id
# 支持的搜索命令
self.search_commands = ['/topchat', '/search', '/text', '/human']
async def initialize(self, api_id: int, api_hash: str):
"""初始化Pyrogram客户端"""
try:
self.api_id = api_id
self.api_hash = api_hash
self.pyrogram_client = Client(
self.session_name,
api_id=self.api_id,
api_hash=self.api_hash
)
await self.pyrogram_client.start()
logger.info("✅ 搜索镜像客户端已启动")
# 获取目标机器人信息
target = await self.pyrogram_client.get_users(self.target_bot)
self.target_bot_id = target.id
logger.info(f"✅ 连接到搜索机器人: {target.username} (ID: {target.id})")
# 设置消息监听器
await self._setup_listeners()
self.enabled = True
return True
except Exception as e:
logger.error(f"镜像搜索初始化失败: {e}")
self.enabled = False
return False
async def _setup_listeners(self):
"""设置Pyrogram消息监听器"""
if not self.pyrogram_client:
return
@self.pyrogram_client.on_message(filters.user(self.target_bot_id))
async def on_bot_response(_, message: PyrogramMessage):
"""当收到搜索机器人的响应时"""
await self._handle_bot_response(message)
@self.pyrogram_client.on_edited_message(filters.user(self.target_bot_id))
async def on_message_edited(_, message: PyrogramMessage):
"""当搜索机器人编辑消息时(翻页)"""
await self._handle_bot_response(message, is_edit=True)
logger.info("✅ 消息监听器已设置")
def is_search_command(self, text: str) -> bool:
"""检查是否是搜索命令"""
if not text:
return False
command = text.split()[0]
return command in self.search_commands
async def process_search_command(
self,
update: Update,
context: ContextTypes.DEFAULT_TYPE,
user_id: int,
command: str
) -> bool:
"""处理用户的搜索命令"""
if not self.enabled or not self.pyrogram_client:
logger.warning("搜索镜像未启用")
return False
try:
# 记录用户搜索请求
self.user_search_requests[user_id] = {
'command': command,
'chat_id': update.effective_chat.id,
'update': update,
'context': context,
'timestamp': asyncio.get_event_loop().time()
}
# 通过Pyrogram发送命令给目标机器人
sent_message = await self.pyrogram_client.send_message(
self.target_bot,
command
)
# 记录映射关系
if sent_message:
logger.info(f"已发送搜索命令给 {self.target_bot}: {command}")
# 等待响应会通过监听器处理
# 发送等待提示给用户
waiting_msg = await update.message.reply_text(
"🔍 正在搜索,请稍候..."
)
self.user_to_telegram[user_id] = waiting_msg.message_id
return True
except Exception as e:
logger.error(f"发送搜索命令失败: {e}")
await update.message.reply_text(
"❌ 搜索请求失败,请稍后重试或联系管理员"
)
return False
async def _handle_bot_response(self, message: PyrogramMessage, is_edit: bool = False):
"""处理搜索机器人的响应"""
try:
# 查找对应的用户
# 这里需要根据时间戳或其他方式匹配用户请求
user_id = self._find_user_for_response(message)
if not user_id or user_id not in self.user_search_requests:
logger.debug(f"未找到对应的用户请求")
return
user_request = self.user_search_requests[user_id]
# 转换消息格式并发送给用户
await self._forward_to_user(message, user_request, is_edit)
except Exception as e:
logger.error(f"处理机器人响应失败: {e}")
def _find_user_for_response(self, message: PyrogramMessage) -> Optional[int]:
"""查找响应对应的用户"""
# 简单的实现:返回最近的请求用户
# 实际应用中可能需要更复杂的匹配逻辑
if self.user_search_requests:
# 获取最近的请求
recent_user = max(
self.user_search_requests.keys(),
key=lambda k: self.user_search_requests[k].get('timestamp', 0)
)
return recent_user
return None
async def _forward_to_user(
self,
pyrogram_msg: PyrogramMessage,
user_request: Dict[str, Any],
is_edit: bool = False
):
"""转发搜索结果给用户"""
try:
update = user_request['update']
context = user_request['context']
# 提取消息内容
text = self._extract_text(pyrogram_msg)
keyboard = self._convert_keyboard(pyrogram_msg)
if is_edit and user_request['user_id'] in self.user_to_telegram:
# 编辑现有消息
telegram_msg_id = self.user_to_telegram[user_request['user_id']]
await context.bot.edit_message_text(
chat_id=user_request['chat_id'],
message_id=telegram_msg_id,
text=text,
reply_markup=keyboard,
parse_mode='HTML'
)
else:
# 发送新消息
sent = await context.bot.send_message(
chat_id=user_request['chat_id'],
text=text,
reply_markup=keyboard,
parse_mode='HTML'
)
self.user_to_telegram[user_request['user_id']] = sent.message_id
except Exception as e:
logger.error(f"转发消息给用户失败: {e}")
def _extract_text(self, message: PyrogramMessage) -> str:
"""提取消息文本"""
if message.text:
return message.text
elif message.caption:
return message.caption
return "(无文本内容)"
def _convert_keyboard(self, message: PyrogramMessage) -> Optional[InlineKeyboardMarkup]:
"""转换Pyrogram键盘为Telegram键盘"""
if not message.reply_markup:
return None
try:
buttons = []
for row in message.reply_markup.inline_keyboard:
button_row = []
for button in row:
if button.text:
# 创建回调按钮
callback_data = button.callback_data or f"mirror_{button.text}"
if len(callback_data.encode()) > 64:
# Telegram限制callback_data最大64字节
callback_data = callback_data[:60] + "..."
button_row.append(
InlineKeyboardButton(
text=button.text,
callback_data=callback_data
)
)
if button_row:
buttons.append(button_row)
return InlineKeyboardMarkup(buttons) if buttons else None
except Exception as e:
logger.error(f"转换键盘失败: {e}")
return None
async def handle_callback(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
"""处理回调查询(翻页等)"""
query = update.callback_query
if not query.data.startswith("mirror_"):
return False
try:
# 这里需要实现回调处理逻辑
# 将回调转发给Pyrogram客户端
await query.answer("处理中...")
return True
except Exception as e:
logger.error(f"处理回调失败: {e}")
await query.answer("操作失败", show_alert=True)
return False
async def cleanup(self):
"""清理资源"""
if self.pyrogram_client:
await self.pyrogram_client.stop()
logger.info("搜索镜像客户端已停止")

View File

@@ -0,0 +1,5 @@
"""存储模块"""
from .database import DatabaseManager
from .models import Customer, Message, Session
__all__ = ['DatabaseManager', 'Customer', 'Message', 'Session']

View File

@@ -0,0 +1,428 @@
"""数据库管理器"""
import sqlite3
import json
from typing import Optional, List, Dict, Any
from datetime import datetime
from pathlib import Path
import asyncio
from contextlib import asynccontextmanager
from .models import Customer, Message, Session, CustomerStatus, SessionStatus, MessageDirection
from ...utils.logger import get_logger
from ...utils.exceptions import DatabaseError
from ...config.settings import Settings
logger = get_logger(__name__)
class DatabaseManager:
"""数据库管理器"""
def __init__(self, config: Settings):
self.config = config
self.db_path = Path(self.config.database.path)
self.connection: Optional[sqlite3.Connection] = None
self._lock = asyncio.Lock()
# 确保数据库目录存在
self.db_path.parent.mkdir(parents=True, exist_ok=True)
async def initialize(self):
"""初始化数据库"""
async with self._lock:
try:
self.connection = sqlite3.connect(
str(self.db_path),
check_same_thread=False
)
self.connection.row_factory = sqlite3.Row
await self._create_tables()
logger.info(f"Database initialized at {self.db_path}")
except Exception as e:
logger.error(f"Failed to initialize database: {e}")
raise DatabaseError(f"Database initialization failed: {e}")
async def _create_tables(self):
"""创建数据表"""
cursor = self.connection.cursor()
# 客户表
cursor.execute("""
CREATE TABLE IF NOT EXISTS customers (
user_id INTEGER PRIMARY KEY,
username TEXT,
first_name TEXT NOT NULL,
last_name TEXT,
language_code TEXT,
status TEXT NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
metadata TEXT,
tags TEXT,
notes TEXT
)
""")
# 消息表
cursor.execute("""
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
message_id TEXT NOT NULL,
session_id TEXT NOT NULL,
user_id INTEGER NOT NULL,
chat_id INTEGER NOT NULL,
direction TEXT NOT NULL,
content TEXT,
content_type TEXT NOT NULL,
timestamp TEXT NOT NULL,
is_read INTEGER DEFAULT 0,
is_replied INTEGER DEFAULT 0,
reply_to_message_id TEXT,
metadata TEXT,
FOREIGN KEY (user_id) REFERENCES customers (user_id)
)
""")
# 会话表
cursor.execute("""
CREATE TABLE IF NOT EXISTS sessions (
session_id TEXT PRIMARY KEY,
customer_id INTEGER NOT NULL,
chat_id INTEGER NOT NULL,
status TEXT NOT NULL,
started_at TEXT NOT NULL,
ended_at TEXT,
last_message_at TEXT,
message_count INTEGER DEFAULT 0,
assigned_to INTEGER,
tags TEXT,
notes TEXT,
metadata TEXT,
FOREIGN KEY (customer_id) REFERENCES customers (user_id)
)
""")
# 创建索引
cursor.execute("CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_messages_user ON messages(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_sessions_customer ON sessions(customer_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_sessions_status ON sessions(status)")
self.connection.commit()
@asynccontextmanager
async def transaction(self):
"""事务上下文管理器"""
async with self._lock:
cursor = self.connection.cursor()
try:
yield cursor
self.connection.commit()
except Exception as e:
self.connection.rollback()
logger.error(f"Transaction failed: {e}")
raise DatabaseError(f"Transaction failed: {e}")
async def save_customer(self, customer: Customer) -> bool:
"""保存客户"""
async with self.transaction() as cursor:
cursor.execute("""
INSERT OR REPLACE INTO customers (
user_id, username, first_name, last_name, language_code,
status, created_at, updated_at, metadata, tags, notes
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
customer.user_id,
customer.username,
customer.first_name,
customer.last_name,
customer.language_code,
customer.status.value,
customer.created_at.isoformat(),
customer.updated_at.isoformat(),
json.dumps(customer.metadata),
json.dumps(customer.tags),
customer.notes
))
return True
async def get_customer(self, user_id: int) -> Optional[Customer]:
"""获取客户"""
async with self._lock:
cursor = self.connection.cursor()
cursor.execute("SELECT * FROM customers WHERE user_id = ?", (user_id,))
row = cursor.fetchone()
if row:
return Customer(
user_id=row['user_id'],
username=row['username'],
first_name=row['first_name'],
last_name=row['last_name'],
language_code=row['language_code'],
status=CustomerStatus(row['status']),
created_at=datetime.fromisoformat(row['created_at']),
updated_at=datetime.fromisoformat(row['updated_at']),
metadata=json.loads(row['metadata'] or '{}'),
tags=json.loads(row['tags'] or '[]'),
notes=row['notes']
)
return None
async def get_all_customers(self, status: Optional[CustomerStatus] = None) -> List[Customer]:
"""获取所有客户"""
async with self._lock:
cursor = self.connection.cursor()
if status:
cursor.execute("SELECT * FROM customers WHERE status = ?", (status.value,))
else:
cursor.execute("SELECT * FROM customers")
customers = []
for row in cursor.fetchall():
customers.append(Customer(
user_id=row['user_id'],
username=row['username'],
first_name=row['first_name'],
last_name=row['last_name'],
language_code=row['language_code'],
status=CustomerStatus(row['status']),
created_at=datetime.fromisoformat(row['created_at']),
updated_at=datetime.fromisoformat(row['updated_at']),
metadata=json.loads(row['metadata'] or '{}'),
tags=json.loads(row['tags'] or '[]'),
notes=row['notes']
))
return customers
async def save_message(self, message: Message) -> bool:
"""保存消息"""
async with self.transaction() as cursor:
cursor.execute("""
INSERT INTO messages (
message_id, session_id, user_id, chat_id, direction,
content, content_type, timestamp, is_read, is_replied,
reply_to_message_id, metadata
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
message.message_id,
message.session_id,
message.user_id,
message.chat_id,
message.direction.value,
message.content,
message.content_type,
message.timestamp.isoformat(),
message.is_read,
message.is_replied,
message.reply_to_message_id,
json.dumps(message.metadata)
))
# 更新会话的最后消息时间和消息计数
cursor.execute("""
UPDATE sessions
SET last_message_at = ?, message_count = message_count + 1
WHERE session_id = ?
""", (message.timestamp.isoformat(), message.session_id))
return True
async def get_messages(self, session_id: str, limit: int = 100) -> List[Message]:
"""获取会话消息"""
async with self._lock:
cursor = self.connection.cursor()
cursor.execute("""
SELECT * FROM messages
WHERE session_id = ?
ORDER BY timestamp DESC
LIMIT ?
""", (session_id, limit))
messages = []
for row in cursor.fetchall():
messages.append(Message(
message_id=row['message_id'],
session_id=row['session_id'],
user_id=row['user_id'],
chat_id=row['chat_id'],
direction=MessageDirection(row['direction']),
content=row['content'],
content_type=row['content_type'],
timestamp=datetime.fromisoformat(row['timestamp']),
is_read=bool(row['is_read']),
is_replied=bool(row['is_replied']),
reply_to_message_id=row['reply_to_message_id'],
metadata=json.loads(row['metadata'] or '{}')
))
return list(reversed(messages)) # 返回时间顺序
async def save_session(self, session: Session) -> bool:
"""保存会话"""
async with self.transaction() as cursor:
cursor.execute("""
INSERT OR REPLACE INTO sessions (
session_id, customer_id, chat_id, status, started_at,
ended_at, last_message_at, message_count, assigned_to,
tags, notes, metadata
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
session.session_id,
session.customer_id,
session.chat_id,
session.status.value,
session.started_at.isoformat(),
session.ended_at.isoformat() if session.ended_at else None,
session.last_message_at.isoformat() if session.last_message_at else None,
session.message_count,
session.assigned_to,
json.dumps(session.tags),
session.notes,
json.dumps(session.metadata)
))
return True
async def get_session(self, session_id: str) -> Optional[Session]:
"""获取会话"""
async with self._lock:
cursor = self.connection.cursor()
cursor.execute("SELECT * FROM sessions WHERE session_id = ?", (session_id,))
row = cursor.fetchone()
if row:
return Session(
session_id=row['session_id'],
customer_id=row['customer_id'],
chat_id=row['chat_id'],
status=SessionStatus(row['status']),
started_at=datetime.fromisoformat(row['started_at']),
ended_at=datetime.fromisoformat(row['ended_at']) if row['ended_at'] else None,
last_message_at=datetime.fromisoformat(row['last_message_at']) if row['last_message_at'] else None,
message_count=row['message_count'],
assigned_to=row['assigned_to'],
tags=json.loads(row['tags'] or '[]'),
notes=row['notes'],
metadata=json.loads(row['metadata'] or '{}')
)
return None
async def get_active_sessions(self) -> List[Session]:
"""获取活跃会话"""
async with self._lock:
cursor = self.connection.cursor()
cursor.execute("""
SELECT * FROM sessions
WHERE status = ?
ORDER BY last_message_at DESC
""", (SessionStatus.ACTIVE.value,))
sessions = []
for row in cursor.fetchall():
sessions.append(Session(
session_id=row['session_id'],
customer_id=row['customer_id'],
chat_id=row['chat_id'],
status=SessionStatus(row['status']),
started_at=datetime.fromisoformat(row['started_at']),
ended_at=datetime.fromisoformat(row['ended_at']) if row['ended_at'] else None,
last_message_at=datetime.fromisoformat(row['last_message_at']) if row['last_message_at'] else None,
message_count=row['message_count'],
assigned_to=row['assigned_to'],
tags=json.loads(row['tags'] or '[]'),
notes=row['notes'],
metadata=json.loads(row['metadata'] or '{}')
))
return sessions
async def update_session_status(self, session_id: str, status: SessionStatus) -> bool:
"""更新会话状态"""
async with self.transaction() as cursor:
ended_at = None
if status in [SessionStatus.RESOLVED, SessionStatus.CLOSED]:
ended_at = datetime.now().isoformat()
cursor.execute("""
UPDATE sessions
SET status = ?, ended_at = ?
WHERE session_id = ?
""", (status.value, ended_at, session_id))
return cursor.rowcount > 0
async def get_statistics(self) -> Dict[str, Any]:
"""获取统计信息"""
async with self._lock:
cursor = self.connection.cursor()
# 客户统计
cursor.execute("SELECT COUNT(*) as count FROM customers")
total_customers = cursor.fetchone()['count']
cursor.execute("SELECT COUNT(*) as count FROM customers WHERE status = ?",
(CustomerStatus.ACTIVE.value,))
active_customers = cursor.fetchone()['count']
# 会话统计
cursor.execute("SELECT COUNT(*) as count FROM sessions")
total_sessions = cursor.fetchone()['count']
cursor.execute("SELECT COUNT(*) as count FROM sessions WHERE status = ?",
(SessionStatus.ACTIVE.value,))
active_sessions = cursor.fetchone()['count']
# 消息统计
cursor.execute("SELECT COUNT(*) as count FROM messages")
total_messages = cursor.fetchone()['count']
cursor.execute("""
SELECT COUNT(*) as count FROM messages
WHERE direction = ? AND is_replied = 0
""", (MessageDirection.INBOUND.value,))
unreplied_messages = cursor.fetchone()['count']
return {
'customers': {
'total': total_customers,
'active': active_customers
},
'sessions': {
'total': total_sessions,
'active': active_sessions
},
'messages': {
'total': total_messages,
'unreplied': unreplied_messages
}
}
async def cleanup_old_sessions(self, days: int = 30):
"""清理旧会话"""
async with self.transaction() as cursor:
cutoff_date = datetime.now().timestamp() - (days * 24 * 60 * 60)
cutoff_date_str = datetime.fromtimestamp(cutoff_date).isoformat()
cursor.execute("""
DELETE FROM messages
WHERE session_id IN (
SELECT session_id FROM sessions
WHERE ended_at < ? AND status IN (?, ?)
)
""", (cutoff_date_str, SessionStatus.RESOLVED.value, SessionStatus.CLOSED.value))
cursor.execute("""
DELETE FROM sessions
WHERE ended_at < ? AND status IN (?, ?)
""", (cutoff_date_str, SessionStatus.RESOLVED.value, SessionStatus.CLOSED.value))
logger.info(f"Cleaned up sessions older than {days} days")
def close(self):
"""关闭数据库连接"""
if self.connection:
self.connection.close()
logger.info("Database connection closed")

View File

@@ -0,0 +1,154 @@
"""数据模型"""
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any
from datetime import datetime
from enum import Enum
class CustomerStatus(Enum):
"""客户状态"""
ACTIVE = "active"
INACTIVE = "inactive"
BLOCKED = "blocked"
class SessionStatus(Enum):
"""会话状态"""
ACTIVE = "active"
PENDING = "pending"
RESOLVED = "resolved"
CLOSED = "closed"
class MessageDirection(Enum):
"""消息方向"""
INBOUND = "inbound" # 客户发送
OUTBOUND = "outbound" # 管理员发送
@dataclass
class Customer:
"""客户模型"""
user_id: int
username: Optional[str]
first_name: str
last_name: Optional[str]
language_code: Optional[str]
status: CustomerStatus = CustomerStatus.ACTIVE
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
metadata: Dict[str, Any] = field(default_factory=dict)
tags: List[str] = field(default_factory=list)
notes: Optional[str] = None
@property
def full_name(self) -> str:
"""获取全名"""
parts = [self.first_name]
if self.last_name:
parts.append(self.last_name)
return " ".join(parts)
@property
def display_name(self) -> str:
"""获取显示名称"""
if self.username:
return f"@{self.username}"
return self.full_name
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
'user_id': self.user_id,
'username': self.username,
'first_name': self.first_name,
'last_name': self.last_name,
'language_code': self.language_code,
'status': self.status.value,
'created_at': self.created_at.isoformat(),
'updated_at': self.updated_at.isoformat(),
'metadata': self.metadata,
'tags': self.tags,
'notes': self.notes
}
@dataclass
class Message:
"""消息模型"""
message_id: str
session_id: str
user_id: int
chat_id: int
direction: MessageDirection
content: str
content_type: str # text, photo, document, voice, etc.
timestamp: datetime = field(default_factory=datetime.now)
is_read: bool = False
is_replied: bool = False
reply_to_message_id: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
'message_id': self.message_id,
'session_id': self.session_id,
'user_id': self.user_id,
'chat_id': self.chat_id,
'direction': self.direction.value,
'content': self.content,
'content_type': self.content_type,
'timestamp': self.timestamp.isoformat(),
'is_read': self.is_read,
'is_replied': self.is_replied,
'reply_to_message_id': self.reply_to_message_id,
'metadata': self.metadata
}
@dataclass
class Session:
"""会话模型"""
session_id: str
customer_id: int
chat_id: int
status: SessionStatus = SessionStatus.ACTIVE
started_at: datetime = field(default_factory=datetime.now)
ended_at: Optional[datetime] = None
last_message_at: Optional[datetime] = None
message_count: int = 0
assigned_to: Optional[int] = None # 分配给哪个管理员
tags: List[str] = field(default_factory=list)
notes: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
@property
def duration(self) -> Optional[float]:
"""获取会话时长(秒)"""
if self.ended_at:
return (self.ended_at - self.started_at).total_seconds()
return (datetime.now() - self.started_at).total_seconds()
@property
def is_active(self) -> bool:
"""是否活跃"""
return self.status == SessionStatus.ACTIVE
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
'session_id': self.session_id,
'customer_id': self.customer_id,
'chat_id': self.chat_id,
'status': self.status.value,
'started_at': self.started_at.isoformat(),
'ended_at': self.ended_at.isoformat() if self.ended_at else None,
'last_message_at': self.last_message_at.isoformat() if self.last_message_at else None,
'message_count': self.message_count,
'assigned_to': self.assigned_to,
'tags': self.tags,
'notes': self.notes,
'metadata': self.metadata,
'duration': self.duration
}

6
src/utils/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
"""工具模块"""
from .logger import Logger, get_logger
from .exceptions import *
from .decorators import *
__all__ = ['Logger', 'get_logger']

233
src/utils/decorators.py Normal file
View File

@@ -0,0 +1,233 @@
"""装饰器工具"""
import functools
import time
import asyncio
from typing import Callable, Any, Optional, Dict
from datetime import datetime, timedelta
from collections import defaultdict
from .logger import get_logger
from .exceptions import RateLimitError, AuthorizationError, ValidationError
logger = get_logger(__name__)
def async_retry(max_attempts: int = 3, delay: float = 1.0, backoff: float = 2.0,
exceptions: tuple = (Exception,)):
"""异步重试装饰器"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
async def wrapper(*args, **kwargs):
attempt = 1
current_delay = delay
while attempt <= max_attempts:
try:
return await func(*args, **kwargs)
except exceptions as e:
if attempt == max_attempts:
logger.error(f"Max retries ({max_attempts}) reached for {func.__name__}")
raise
logger.warning(
f"Attempt {attempt}/{max_attempts} failed for {func.__name__}: {e}. "
f"Retrying in {current_delay:.2f}s..."
)
await asyncio.sleep(current_delay)
current_delay *= backoff
attempt += 1
return wrapper
return decorator
def rate_limit(max_calls: int, period: float):
"""速率限制装饰器"""
def decorator(func: Callable) -> Callable:
calls = defaultdict(list)
@functools.wraps(func)
async def wrapper(*args, **kwargs):
# 获取调用者标识(假设第一个参数是 self第二个是 update
caller_id = None
if len(args) >= 2 and hasattr(args[1], 'effective_user'):
caller_id = args[1].effective_user.id
else:
caller_id = 'global'
now = time.time()
calls[caller_id] = [t for t in calls[caller_id] if now - t < period]
if len(calls[caller_id]) >= max_calls:
raise RateLimitError(
f"Rate limit exceeded: {max_calls} calls per {period} seconds",
details={'caller_id': caller_id, 'limit': max_calls, 'period': period}
)
calls[caller_id].append(now)
return await func(*args, **kwargs)
return wrapper
return decorator
def require_admin(func: Callable) -> Callable:
"""需要管理员权限装饰器"""
@functools.wraps(func)
async def wrapper(self, update, context, *args, **kwargs):
user_id = update.effective_user.id
if user_id != self.config.telegram.admin_id:
if not (hasattr(self, 'is_admin') and await self.is_admin(user_id)):
raise AuthorizationError(
"Admin privileges required",
details={'user_id': user_id}
)
return await func(self, update, context, *args, **kwargs)
return wrapper
def log_action(action_type: str = None):
"""记录操作日志装饰器"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
async def wrapper(*args, **kwargs):
start_time = time.time()
action = action_type or func.__name__
# 提取用户信息
user_info = {}
if len(args) >= 2 and hasattr(args[1], 'effective_user'):
user = args[1].effective_user
user_info = {
'user_id': user.id,
'username': user.username,
'name': user.first_name
}
try:
result = await func(*args, **kwargs)
duration = time.time() - start_time
# 创建额外信息,避免覆盖保留字段
extra_info = {
'action': action,
'duration': duration,
'status': 'success'
}
# 添加用户信息,使用前缀避免冲突
for k, v in user_info.items():
extra_info[f'user_{k}' if k in ['name'] else k] = v
logger.info(
f"Action completed: {action}",
extra=extra_info
)
return result
except Exception as e:
duration = time.time() - start_time
# 创建额外信息,避免覆盖保留字段
extra_info = {
'action': action,
'duration': duration,
'status': 'failed',
'error': str(e)
}
# 添加用户信息,使用前缀避免冲突
for k, v in user_info.items():
extra_info[f'user_{k}' if k in ['name'] else k] = v
logger.error(
f"Action failed: {action}",
extra=extra_info
)
raise
return wrapper
return decorator
def validate_input(**validators):
"""输入验证装饰器"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
async def wrapper(*args, **kwargs):
# 合并位置参数和关键字参数
bound_args = func.__code__.co_varnames[:func.__code__.co_argcount]
all_args = dict(zip(bound_args, args))
all_args.update(kwargs)
# 执行验证
for param_name, validator in validators.items():
if param_name in all_args:
value = all_args[param_name]
if not validator(value):
raise ValidationError(
f"Validation failed for parameter: {param_name}",
details={'parameter': param_name, 'value': value}
)
return await func(*args, **kwargs)
return wrapper
return decorator
def cache_result(ttl: int = 300):
"""结果缓存装饰器"""
def decorator(func: Callable) -> Callable:
cache: Dict[str, tuple[Any, datetime]] = {}
@functools.wraps(func)
async def wrapper(*args, **kwargs):
# 创建缓存键
cache_key = f"{func.__name__}:{str(args)}:{str(kwargs)}"
# 检查缓存
if cache_key in cache:
result, timestamp = cache[cache_key]
if datetime.now() - timestamp < timedelta(seconds=ttl):
logger.debug(f"Cache hit for {func.__name__}")
return result
# 执行函数
result = await func(*args, **kwargs)
# 存储结果
cache[cache_key] = (result, datetime.now())
logger.debug(f"Cache miss for {func.__name__}, cached for {ttl}s")
return result
# 添加清除缓存方法
wrapper.clear_cache = lambda: cache.clear()
return wrapper
return decorator
def measure_performance(func: Callable) -> Callable:
"""性能测量装饰器"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
start_time = time.perf_counter()
start_memory = 0 # 可以添加内存测量
try:
result = await func(*args, **kwargs)
return result
finally:
end_time = time.perf_counter()
duration = end_time - start_time
if duration > 1.0: # 超过1秒的操作记录警告
logger.warning(
f"Slow operation detected: {func.__name__} took {duration:.2f}s",
extra={
'function': func.__name__,
'duration': duration
}
)
else:
logger.debug(f"{func.__name__} completed in {duration:.4f}s")
return wrapper

122
src/utils/exceptions.py Normal file
View File

@@ -0,0 +1,122 @@
"""自定义异常类"""
from typing import Optional, Any, Dict
class BotException(Exception):
"""机器人基础异常"""
def __init__(self, message: str, code: str = None, details: Dict[str, Any] = None):
super().__init__(message)
self.message = message
self.code = code or self.__class__.__name__
self.details = details or {}
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
'error': self.code,
'message': self.message,
'details': self.details
}
class ConfigurationError(BotException):
"""配置错误"""
pass
class DatabaseError(BotException):
"""数据库错误"""
pass
class TelegramError(BotException):
"""Telegram API 错误"""
pass
class AuthenticationError(BotException):
"""认证错误"""
pass
class AuthorizationError(BotException):
"""授权错误"""
pass
class ValidationError(BotException):
"""验证错误"""
pass
class RateLimitError(BotException):
"""速率限制错误"""
pass
class SessionError(BotException):
"""会话错误"""
pass
class MessageRoutingError(BotException):
"""消息路由错误"""
pass
class BusinessLogicError(BotException):
"""业务逻辑错误"""
pass
class ExternalServiceError(BotException):
"""外部服务错误"""
pass
class ErrorHandler:
"""错误处理器"""
@staticmethod
async def handle_error(error: Exception, context: Dict[str, Any] = None) -> Dict[str, Any]:
"""处理错误"""
from ..utils.logger import get_logger
logger = get_logger(__name__)
error_info = {
'type': type(error).__name__,
'message': str(error),
'context': context or {}
}
if isinstance(error, BotException):
# 自定义异常
error_info.update(error.to_dict())
logger.error(f"Bot error: {error.message}", extra={'error_details': error_info})
else:
# 未知异常
logger.exception(f"Unexpected error: {error}", extra={'error_details': error_info})
error_info['message'] = "An unexpected error occurred"
return error_info
@staticmethod
def create_user_message(error: Exception) -> str:
"""创建用户友好的错误消息"""
if isinstance(error, AuthenticationError):
return "❌ 认证失败,请重新登录"
elif isinstance(error, AuthorizationError):
return "❌ 您没有权限执行此操作"
elif isinstance(error, ValidationError):
return f"❌ 输入无效:{error.message}"
elif isinstance(error, RateLimitError):
return "⚠️ 操作太频繁,请稍后再试"
elif isinstance(error, SessionError):
return "❌ 会话已过期,请重新开始"
elif isinstance(error, BusinessLogicError):
return f"❌ 操作失败:{error.message}"
elif isinstance(error, ExternalServiceError):
return "❌ 外部服务暂时不可用,请稍后再试"
else:
return "❌ 系统错误,请稍后再试或联系管理员"

167
src/utils/logger.py Normal file
View File

@@ -0,0 +1,167 @@
"""日志系统"""
import logging
import sys
from logging.handlers import RotatingFileHandler
from typing import Optional
from pathlib import Path
import json
from datetime import datetime
class CustomFormatter(logging.Formatter):
"""自定义日志格式化器"""
# 颜色代码
COLORS = {
'DEBUG': '\033[36m', # 青色
'INFO': '\033[32m', # 绿色
'WARNING': '\033[33m', # 黄色
'ERROR': '\033[31m', # 红色
'CRITICAL': '\033[35m', # 紫色
}
RESET = '\033[0m'
def __init__(self, use_color: bool = True):
super().__init__()
self.use_color = use_color and sys.stderr.isatty()
def format(self, record):
# 基础格式
log_format = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
# 添加额外信息
if hasattr(record, 'user_id'):
log_format = f"%(asctime)s | %(levelname)-8s | %(name)s | User:{record.user_id} | %(message)s"
if hasattr(record, 'chat_id'):
log_format = f"%(asctime)s | %(levelname)-8s | %(name)s | Chat:{record.chat_id} | %(message)s"
# 应用颜色
if self.use_color:
levelname = record.levelname
if levelname in self.COLORS:
log_format = log_format.replace(
'%(levelname)-8s',
f"{self.COLORS[levelname]}%(levelname)-8s{self.RESET}"
)
formatter = logging.Formatter(log_format, datefmt='%Y-%m-%d %H:%M:%S')
return formatter.format(record)
class JsonFormatter(logging.Formatter):
"""JSON 格式化器用于结构化日志"""
def format(self, record):
log_data = {
'timestamp': datetime.utcnow().isoformat(),
'level': record.levelname,
'logger': record.name,
'message': record.getMessage(),
'module': record.module,
'function': record.funcName,
'line': record.lineno
}
# 添加额外字段
for key, value in record.__dict__.items():
if key not in ['name', 'msg', 'args', 'created', 'filename',
'funcName', 'levelname', 'levelno', 'lineno',
'module', 'msecs', 'message', 'pathname', 'process',
'processName', 'relativeCreated', 'thread', 'threadName']:
log_data[key] = value
# 添加异常信息
if record.exc_info:
log_data['exception'] = self.formatException(record.exc_info)
return json.dumps(log_data, ensure_ascii=False)
class Logger:
"""日志管理器"""
_instance = None
_loggers = {}
def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, config=None):
if not hasattr(self, '_initialized'):
self._initialized = True
self.config = config
self.setup_logging()
def setup_logging(self):
"""设置日志系统"""
# 根日志配置
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG)
# 移除默认处理器
root_logger.handlers = []
# 控制台处理器
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(
getattr(logging, self.config.logging.level if self.config else 'INFO')
)
console_handler.setFormatter(CustomFormatter(use_color=True))
root_logger.addHandler(console_handler)
# 文件处理器
if self.config and self.config.logging.file:
file_path = Path(self.config.logging.file)
file_path.parent.mkdir(parents=True, exist_ok=True)
file_handler = RotatingFileHandler(
filename=str(file_path),
maxBytes=self.config.logging.max_size,
backupCount=self.config.logging.backup_count,
encoding='utf-8'
)
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(CustomFormatter(use_color=False))
root_logger.addHandler(file_handler)
# JSON 日志文件(用于分析)
json_file_path = file_path.with_suffix('.json')
json_handler = RotatingFileHandler(
filename=str(json_file_path),
maxBytes=self.config.logging.max_size,
backupCount=self.config.logging.backup_count,
encoding='utf-8'
)
json_handler.setLevel(logging.INFO)
json_handler.setFormatter(JsonFormatter())
root_logger.addHandler(json_handler)
@classmethod
def get_logger(cls, name: str, config=None) -> logging.Logger:
"""获取日志器"""
if name not in cls._loggers:
if not cls._instance:
cls(config)
cls._loggers[name] = logging.getLogger(name)
return cls._loggers[name]
def get_logger(name: str, config=None) -> logging.Logger:
"""获取日志器的便捷方法"""
return Logger.get_logger(name, config)
class LoggerContextFilter(logging.Filter):
"""日志上下文过滤器"""
def __init__(self, **context):
super().__init__()
self.context = context
def filter(self, record):
for key, value in self.context.items():
setattr(record, key, value)
return True