chore: initial commit
This commit is contained in:
5
src/config/__init__.py
Normal file
5
src/config/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""配置管理模块"""
|
||||
from .settings import Settings
|
||||
from .loader import ConfigLoader
|
||||
|
||||
__all__ = ['Settings', 'ConfigLoader']
|
||||
155
src/config/loader.py
Normal file
155
src/config/loader.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""配置加载器"""
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
from .settings import (
|
||||
Settings, TelegramConfig, DatabaseConfig,
|
||||
LoggingConfig, BusinessConfig, SecurityConfig, FeatureFlags
|
||||
)
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
"""配置加载器"""
|
||||
|
||||
@staticmethod
|
||||
def load_env_file(env_path: str = None) -> None:
|
||||
"""加载环境变量文件"""
|
||||
if env_path:
|
||||
load_dotenv(env_path)
|
||||
else:
|
||||
# 查找 .env 文件
|
||||
current_dir = Path.cwd()
|
||||
env_file = current_dir / ".env"
|
||||
if env_file.exists():
|
||||
load_dotenv(env_file)
|
||||
else:
|
||||
# 向上查找
|
||||
for parent in current_dir.parents:
|
||||
env_file = parent / ".env"
|
||||
if env_file.exists():
|
||||
load_dotenv(env_file)
|
||||
break
|
||||
|
||||
@staticmethod
|
||||
def get_env(key: str, default: Any = None, cast_type: type = str) -> Any:
|
||||
"""获取环境变量并转换类型"""
|
||||
value = os.getenv(key, default)
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if cast_type == bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.lower() in ('true', '1', 'yes', 'on')
|
||||
return bool(value)
|
||||
elif cast_type == int:
|
||||
return int(value)
|
||||
elif cast_type == float:
|
||||
return float(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def load_from_env(cls) -> Settings:
|
||||
"""从环境变量加载配置"""
|
||||
cls.load_env_file()
|
||||
|
||||
# Telegram 配置
|
||||
telegram_config = TelegramConfig(
|
||||
bot_token=cls.get_env('BOT_TOKEN', ''),
|
||||
admin_id=cls.get_env('ADMIN_ID', 0, int),
|
||||
admin_username=cls.get_env('ADMIN_USERNAME', ''),
|
||||
bot_name=cls.get_env('BOT_NAME', 'Customer Service Bot')
|
||||
)
|
||||
|
||||
# 数据库配置
|
||||
database_config = DatabaseConfig(
|
||||
type=cls.get_env('DATABASE_TYPE', 'sqlite'),
|
||||
path=cls.get_env('DATABASE_PATH', './data/bot.db'),
|
||||
host=cls.get_env('DATABASE_HOST'),
|
||||
port=cls.get_env('DATABASE_PORT', cast_type=int),
|
||||
user=cls.get_env('DATABASE_USER'),
|
||||
password=cls.get_env('DATABASE_PASSWORD'),
|
||||
database=cls.get_env('DATABASE_NAME')
|
||||
)
|
||||
|
||||
# 日志配置
|
||||
logging_config = LoggingConfig(
|
||||
level=cls.get_env('LOG_LEVEL', 'INFO'),
|
||||
file=cls.get_env('LOG_FILE', './logs/bot.log'),
|
||||
max_size=cls.get_env('LOG_MAX_SIZE', 10485760, int),
|
||||
backup_count=cls.get_env('LOG_BACKUP_COUNT', 5, int)
|
||||
)
|
||||
|
||||
# 业务配置
|
||||
business_config = BusinessConfig(
|
||||
business_hours_start=cls.get_env('BUSINESS_HOURS_START', '09:00'),
|
||||
business_hours_end=cls.get_env('BUSINESS_HOURS_END', '18:00'),
|
||||
timezone=cls.get_env('TIMEZONE', 'Asia/Shanghai'),
|
||||
auto_reply_delay=cls.get_env('AUTO_REPLY_DELAY', 1, int),
|
||||
welcome_message=cls.get_env('WELCOME_MESSAGE',
|
||||
'您好!我是客服助手,正在为您转接人工客服,请稍候...'),
|
||||
offline_message=cls.get_env('OFFLINE_MESSAGE',
|
||||
'非常抱歉,现在是非工作时间。我们的工作时间是 {start} - {end}。您的消息已记录,我们会在工作时间尽快回复您。')
|
||||
)
|
||||
|
||||
# 安全配置
|
||||
security_config = SecurityConfig(
|
||||
max_messages_per_minute=cls.get_env('MAX_MESSAGES_PER_MINUTE', 30, int),
|
||||
session_timeout=cls.get_env('SESSION_TIMEOUT', 3600, int),
|
||||
enable_encryption=cls.get_env('ENABLE_ENCRYPTION', False, bool),
|
||||
blocked_words=cls.get_env('BLOCKED_WORDS', '').split(',') if cls.get_env('BLOCKED_WORDS') else []
|
||||
)
|
||||
|
||||
# 功能开关
|
||||
feature_flags = FeatureFlags(
|
||||
enable_auto_reply=cls.get_env('ENABLE_AUTO_REPLY', True, bool),
|
||||
enable_statistics=cls.get_env('ENABLE_STATISTICS', True, bool),
|
||||
enable_customer_history=cls.get_env('ENABLE_CUSTOMER_HISTORY', True, bool),
|
||||
enable_multi_admin=cls.get_env('ENABLE_MULTI_ADMIN', False, bool),
|
||||
enable_file_transfer=cls.get_env('ENABLE_FILE_TRANSFER', True, bool),
|
||||
enable_voice_message=cls.get_env('ENABLE_VOICE_MESSAGE', True, bool),
|
||||
enable_location_sharing=cls.get_env('ENABLE_LOCATION_SHARING', False, bool)
|
||||
)
|
||||
|
||||
# 创建设置对象
|
||||
settings = Settings(
|
||||
telegram=telegram_config,
|
||||
database=database_config,
|
||||
logging=logging_config,
|
||||
business=business_config,
|
||||
security=security_config,
|
||||
features=feature_flags,
|
||||
debug=cls.get_env('DEBUG', False, bool),
|
||||
testing=cls.get_env('TESTING', False, bool),
|
||||
version=cls.get_env('VERSION', '1.0.0')
|
||||
)
|
||||
|
||||
# 验证配置
|
||||
settings.validate()
|
||||
return settings
|
||||
|
||||
@classmethod
|
||||
def load_from_dict(cls, config_dict: Dict[str, Any]) -> Settings:
|
||||
"""从字典加载配置"""
|
||||
telegram_config = TelegramConfig(**config_dict.get('telegram', {}))
|
||||
database_config = DatabaseConfig(**config_dict.get('database', {}))
|
||||
logging_config = LoggingConfig(**config_dict.get('logging', {}))
|
||||
business_config = BusinessConfig(**config_dict.get('business', {}))
|
||||
security_config = SecurityConfig(**config_dict.get('security', {}))
|
||||
feature_flags = FeatureFlags(**config_dict.get('features', {}))
|
||||
|
||||
settings = Settings(
|
||||
telegram=telegram_config,
|
||||
database=database_config,
|
||||
logging=logging_config,
|
||||
business=business_config,
|
||||
security=security_config,
|
||||
features=feature_flags,
|
||||
**config_dict.get('runtime', {})
|
||||
)
|
||||
|
||||
settings.validate()
|
||||
return settings
|
||||
140
src/config/settings.py
Normal file
140
src/config/settings.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""系统配置定义"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List
|
||||
from datetime import time
|
||||
import os
|
||||
|
||||
|
||||
@dataclass
|
||||
class TelegramConfig:
|
||||
"""Telegram 相关配置"""
|
||||
bot_token: str
|
||||
admin_id: int
|
||||
admin_username: str
|
||||
bot_name: str = "Customer Service Bot"
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.bot_token:
|
||||
raise ValueError("Bot token is required")
|
||||
if not self.admin_id:
|
||||
raise ValueError("Admin ID is required")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
"""数据库配置"""
|
||||
type: str = "sqlite"
|
||||
path: str = "./data/bot.db"
|
||||
host: Optional[str] = None
|
||||
port: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
database: Optional[str] = None
|
||||
|
||||
def get_connection_string(self) -> str:
|
||||
"""获取数据库连接字符串"""
|
||||
if self.type == "sqlite":
|
||||
return f"sqlite:///{self.path}"
|
||||
elif self.type == "postgresql":
|
||||
return f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
|
||||
elif self.type == "mysql":
|
||||
return f"mysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.type}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoggingConfig:
|
||||
"""日志配置"""
|
||||
level: str = "INFO"
|
||||
file: str = "./logs/bot.log"
|
||||
max_size: int = 10485760 # 10MB
|
||||
backup_count: int = 5
|
||||
format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
|
||||
def __post_init__(self):
|
||||
# 确保日志目录存在
|
||||
log_dir = os.path.dirname(self.file)
|
||||
if log_dir and not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BusinessConfig:
|
||||
"""业务配置"""
|
||||
business_hours_start: str = "09:00"
|
||||
business_hours_end: str = "18:00"
|
||||
timezone: str = "Asia/Shanghai"
|
||||
auto_reply_delay: int = 1 # 秒
|
||||
welcome_message: str = "您好!我是客服助手,正在为您转接人工客服,请稍候..."
|
||||
offline_message: str = "非常抱歉,现在是非工作时间。我们的工作时间是 {start} - {end}。您的消息已记录,我们会在工作时间尽快回复您。"
|
||||
|
||||
def get_business_hours(self) -> tuple[time, time]:
|
||||
"""获取营业时间"""
|
||||
start = time.fromisoformat(self.business_hours_start)
|
||||
end = time.fromisoformat(self.business_hours_end)
|
||||
return start, end
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityConfig:
|
||||
"""安全配置"""
|
||||
max_messages_per_minute: int = 30
|
||||
session_timeout: int = 3600 # 秒
|
||||
enable_encryption: bool = False
|
||||
blocked_words: List[str] = field(default_factory=list)
|
||||
allowed_file_types: List[str] = field(default_factory=lambda: [
|
||||
'.jpg', '.jpeg', '.png', '.gif', '.pdf', '.doc', '.docx'
|
||||
])
|
||||
max_file_size: int = 10485760 # 10MB
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureFlags:
|
||||
"""功能开关"""
|
||||
enable_auto_reply: bool = True
|
||||
enable_statistics: bool = True
|
||||
enable_customer_history: bool = True
|
||||
enable_multi_admin: bool = False
|
||||
enable_file_transfer: bool = True
|
||||
enable_voice_message: bool = True
|
||||
enable_location_sharing: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class Settings:
|
||||
"""主配置类"""
|
||||
telegram: TelegramConfig
|
||||
database: DatabaseConfig
|
||||
logging: LoggingConfig
|
||||
business: BusinessConfig
|
||||
security: SecurityConfig
|
||||
features: FeatureFlags
|
||||
|
||||
# 运行时配置
|
||||
debug: bool = False
|
||||
testing: bool = False
|
||||
version: str = "1.0.0"
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> 'Settings':
|
||||
"""从环境变量创建配置"""
|
||||
from .loader import ConfigLoader
|
||||
return ConfigLoader.load_from_env()
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""验证配置完整性"""
|
||||
try:
|
||||
# 验证必要配置
|
||||
assert self.telegram.bot_token, "Bot token is required"
|
||||
assert self.telegram.admin_id, "Admin ID is required"
|
||||
|
||||
# 验证路径
|
||||
if self.database.type == "sqlite":
|
||||
db_dir = os.path.dirname(self.database.path)
|
||||
if db_dir and not os.path.exists(db_dir):
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
raise ValueError(f"Configuration validation failed: {e}")
|
||||
6
src/core/__init__.py
Normal file
6
src/core/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""核心模块"""
|
||||
from .bot import CustomerServiceBot
|
||||
from .router import MessageRouter
|
||||
from .handlers import BaseHandler, HandlerContext
|
||||
|
||||
__all__ = ['CustomerServiceBot', 'MessageRouter', 'BaseHandler', 'HandlerContext']
|
||||
693
src/core/bot.py
Normal file
693
src/core/bot.py
Normal file
@@ -0,0 +1,693 @@
|
||||
"""客服机器人主类"""
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
from telegram import Update, Bot, BotCommand, InlineKeyboardButton, InlineKeyboardMarkup
|
||||
from telegram.ext import Application, CommandHandler, MessageHandler, CallbackQueryHandler, filters
|
||||
|
||||
from ..config.settings import Settings
|
||||
from ..utils.logger import get_logger, Logger
|
||||
from ..utils.exceptions import BotException, ErrorHandler
|
||||
from ..utils.decorators import log_action, measure_performance
|
||||
from .router import MessageRouter, RouteBuilder, MessageContext
|
||||
from .handlers import BaseHandler, HandlerContext
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CustomerServiceBot:
|
||||
"""客服机器人"""
|
||||
|
||||
def __init__(self, config: Settings = None):
|
||||
"""初始化机器人"""
|
||||
# 加载配置
|
||||
self.config = config or Settings.from_env()
|
||||
self.config.validate()
|
||||
|
||||
# 初始化日志系统
|
||||
Logger(self.config)
|
||||
self.logger = get_logger(self.__class__.__name__, self.config)
|
||||
|
||||
# 初始化组件
|
||||
self.application: Optional[Application] = None
|
||||
self.router = MessageRouter(self.config)
|
||||
self.route_builder = RouteBuilder(self.router)
|
||||
self.handlers: Dict[str, BaseHandler] = {}
|
||||
self.active_sessions: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# 当前会话管理
|
||||
self.current_customer = None # 当前正在对话的客户
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
'messages_received': 0,
|
||||
'messages_forwarded': 0,
|
||||
'replies_sent': 0,
|
||||
'errors': 0,
|
||||
'start_time': datetime.now()
|
||||
}
|
||||
|
||||
self.logger.info(f"Bot initialized with version {self.config.version}")
|
||||
|
||||
async def initialize(self):
|
||||
"""异步初始化"""
|
||||
try:
|
||||
# 创建应用
|
||||
self.application = Application.builder().token(
|
||||
self.config.telegram.bot_token
|
||||
).build()
|
||||
|
||||
# 设置命令
|
||||
await self.setup_commands()
|
||||
|
||||
# 注册处理器
|
||||
self.register_handlers()
|
||||
|
||||
# 初始化数据库(如果需要)
|
||||
if self.config.features.enable_customer_history:
|
||||
from ..modules.storage import DatabaseManager
|
||||
self.db_manager = DatabaseManager(self.config)
|
||||
await self.db_manager.initialize()
|
||||
|
||||
|
||||
self.logger.info("Bot initialization completed")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize bot: {e}")
|
||||
raise
|
||||
|
||||
async def setup_commands(self):
|
||||
"""设置机器人命令"""
|
||||
commands = [
|
||||
BotCommand("start", "开始使用机器人"),
|
||||
BotCommand("help", "获取帮助信息"),
|
||||
BotCommand("status", "查看机器人状态"),
|
||||
BotCommand("contact", "联系人工客服"),
|
||||
]
|
||||
|
||||
# 管理员命令
|
||||
admin_commands = commands + [
|
||||
BotCommand("stats", "查看统计信息"),
|
||||
BotCommand("sessions", "查看活跃会话"),
|
||||
BotCommand("reply", "回复客户消息"),
|
||||
BotCommand("broadcast", "广播消息"),
|
||||
BotCommand("settings", "机器人设置"),
|
||||
]
|
||||
|
||||
# 设置命令
|
||||
await self.application.bot.set_my_commands(commands)
|
||||
|
||||
# 为管理员设置特殊命令
|
||||
await self.application.bot.set_my_commands(
|
||||
admin_commands,
|
||||
scope={"type": "chat", "chat_id": self.config.telegram.admin_id}
|
||||
)
|
||||
|
||||
def register_handlers(self):
|
||||
"""注册消息处理器"""
|
||||
# 命令处理器
|
||||
self.application.add_handler(CommandHandler("start", self.handle_start))
|
||||
self.application.add_handler(CommandHandler("help", self.handle_help))
|
||||
self.application.add_handler(CommandHandler("status", self.handle_status))
|
||||
self.application.add_handler(CommandHandler("contact", self.handle_contact))
|
||||
|
||||
# 管理员命令
|
||||
self.application.add_handler(CommandHandler("stats", self.handle_stats))
|
||||
self.application.add_handler(CommandHandler("sessions", self.handle_sessions))
|
||||
self.application.add_handler(CommandHandler("reply", self.handle_reply))
|
||||
self.application.add_handler(CommandHandler("broadcast", self.handle_broadcast))
|
||||
self.application.add_handler(CommandHandler("settings", self.handle_settings))
|
||||
|
||||
# 消息处理器 - 处理所有消息(包括搜索指令)
|
||||
# 只排除机器人自己处理的命令,其他命令(如搜索指令)也会转发
|
||||
self.application.add_handler(MessageHandler(
|
||||
filters.ALL,
|
||||
self.handle_message
|
||||
))
|
||||
|
||||
# 回调查询处理器
|
||||
self.application.add_handler(CallbackQueryHandler(self.handle_callback))
|
||||
|
||||
# 错误处理器
|
||||
self.application.add_error_handler(self.handle_error)
|
||||
|
||||
@log_action("start_command")
|
||||
async def handle_start(self, update: Update, context):
|
||||
"""处理 /start 命令"""
|
||||
user = update.effective_user
|
||||
is_admin = user.id == self.config.telegram.admin_id
|
||||
|
||||
if is_admin:
|
||||
text = (
|
||||
f"👋 欢迎,管理员 {user.first_name}!\n\n"
|
||||
"🤖 客服机器人已就绪\n"
|
||||
"📊 使用 /stats 查看统计\n"
|
||||
"💬 使用 /sessions 查看会话\n"
|
||||
"⚙️ 使用 /settings 进行设置"
|
||||
)
|
||||
else:
|
||||
text = (
|
||||
f"👋 您好 {user.first_name}!\n\n"
|
||||
"暂时支持的搜索指令:\n\n"
|
||||
"- 群组目录 /topchat\n"
|
||||
"- 群组搜索 /search\n"
|
||||
"- 按消息文本搜索 /text\n"
|
||||
"- 按名称搜索 /human\n\n"
|
||||
"您可以使用以上指令进行搜索,或直接发送消息联系客服。"
|
||||
)
|
||||
|
||||
# 通知管理员
|
||||
await self.notify_admin_new_customer(user)
|
||||
|
||||
await update.message.reply_text(text)
|
||||
self.stats['messages_received'] += 1
|
||||
|
||||
async def handle_help(self, update: Update, context):
|
||||
"""处理 /help 命令"""
|
||||
user = update.effective_user
|
||||
is_admin = user.id == self.config.telegram.admin_id
|
||||
|
||||
if is_admin:
|
||||
text = self._get_admin_help()
|
||||
else:
|
||||
text = self._get_user_help()
|
||||
|
||||
await update.message.reply_text(text, parse_mode='Markdown')
|
||||
|
||||
async def handle_status(self, update: Update, context):
|
||||
"""处理 /status 命令"""
|
||||
uptime = datetime.now() - self.stats['start_time']
|
||||
hours = uptime.total_seconds() / 3600
|
||||
|
||||
text = (
|
||||
"✅ 机器人运行正常\n\n"
|
||||
f"⏱ 运行时间:{hours:.1f} 小时\n"
|
||||
f"📊 处理消息:{self.stats['messages_received']} 条\n"
|
||||
f"👥 活跃会话:{len(self.active_sessions)} 个"
|
||||
)
|
||||
|
||||
await update.message.reply_text(text)
|
||||
|
||||
async def handle_contact(self, update: Update, context):
|
||||
"""处理 /contact 命令"""
|
||||
await update.message.reply_text(
|
||||
"正在为您转接人工客服,请稍候...\n"
|
||||
"您可以直接发送消息,客服会尽快回复您。"
|
||||
)
|
||||
# 修复:传递正确的 context 参数
|
||||
await self.forward_customer_message(update, context)
|
||||
|
||||
@measure_performance
|
||||
async def handle_message(self, update: Update, context):
|
||||
"""处理普通消息"""
|
||||
try:
|
||||
user = update.effective_user
|
||||
message = update.effective_message
|
||||
is_admin = user.id == self.config.telegram.admin_id
|
||||
|
||||
self.stats['messages_received'] += 1
|
||||
|
||||
if is_admin:
|
||||
# 管理员消息 - 检查是否是回复
|
||||
if message.reply_to_message:
|
||||
await self.handle_admin_reply(update, context)
|
||||
elif self.current_customer:
|
||||
# 如果有当前客户,直接发送给当前客户
|
||||
await self.reply_to_current_customer(update, context)
|
||||
else:
|
||||
# 没有当前客户时,提示管理员
|
||||
await message.reply_text(
|
||||
"💡 提示:暂无活跃客户\n\n"
|
||||
"等待客户发送消息,或使用:\n"
|
||||
"• 直接回复转发的客户消息\n"
|
||||
"• /sessions 查看所有会话\n"
|
||||
"• /reply <用户ID> <消息> 回复指定用户"
|
||||
)
|
||||
else:
|
||||
# 客户消息 - 转发给管理员(包括搜索指令)
|
||||
# 处理所有客户消息,包括 /topchat, /search, /text, /human 等指令
|
||||
await self.forward_customer_message(update, context)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error handling message: {e}")
|
||||
await self.send_error_message(update, e)
|
||||
|
||||
async def forward_customer_message(self, update: Update, context):
|
||||
"""转发客户消息给管理员"""
|
||||
user = update.effective_user
|
||||
message = update.effective_message
|
||||
chat = update.effective_chat
|
||||
|
||||
# 创建或更新会话
|
||||
session_id = f"{chat.id}_{user.id}"
|
||||
if session_id not in self.active_sessions:
|
||||
self.active_sessions[session_id] = {
|
||||
'user_id': user.id,
|
||||
'username': user.username,
|
||||
'first_name': user.first_name,
|
||||
'chat_id': chat.id,
|
||||
'messages': [],
|
||||
'started_at': datetime.now()
|
||||
}
|
||||
|
||||
# 记录消息
|
||||
self.active_sessions[session_id]['messages'].append({
|
||||
'message_id': message.message_id,
|
||||
'text': message.text or "[非文本消息]",
|
||||
'timestamp': datetime.now()
|
||||
})
|
||||
|
||||
# 设置为当前客户
|
||||
self.current_customer = {
|
||||
'user_id': user.id,
|
||||
'chat_id': chat.id,
|
||||
'username': user.username,
|
||||
'first_name': user.first_name,
|
||||
'session_id': session_id
|
||||
}
|
||||
|
||||
# 构建用户信息 - 转义特殊字符
|
||||
def escape_markdown(text):
|
||||
"""转义 Markdown 特殊字符"""
|
||||
if text is None:
|
||||
return ''
|
||||
# 转义特殊字符
|
||||
special_chars = ['_', '*', '[', ']', '(', ')', '~', '`', '>', '#', '+', '-', '=', '|', '{', '}', '.', '!']
|
||||
for char in special_chars:
|
||||
text = str(text).replace(char, f'\\{char}')
|
||||
return text
|
||||
|
||||
first_name = escape_markdown(user.first_name)
|
||||
last_name = escape_markdown(user.last_name) if user.last_name else ''
|
||||
username = escape_markdown(user.username) if user.username else 'N/A'
|
||||
|
||||
# 构建用户信息
|
||||
user_info = (
|
||||
f"📨 来自客户的消息\n"
|
||||
f"👤 姓名:{first_name} {last_name}\n"
|
||||
f"🆔 ID:`{user.id}`\n"
|
||||
f"📱 用户名:@{username}\n"
|
||||
f"💬 会话:`{session_id}`\n"
|
||||
f"━━━━━━━━━━━━━━━━"
|
||||
)
|
||||
|
||||
# 发送用户信息
|
||||
await context.bot.send_message(
|
||||
chat_id=self.config.telegram.admin_id,
|
||||
text=user_info,
|
||||
parse_mode='MarkdownV2'
|
||||
)
|
||||
|
||||
# 转发原始消息
|
||||
forwarded = await context.bot.forward_message(
|
||||
chat_id=self.config.telegram.admin_id,
|
||||
from_chat_id=chat.id,
|
||||
message_id=message.message_id
|
||||
)
|
||||
|
||||
# 保存转发消息ID映射
|
||||
context.bot_data.setdefault('message_map', {})[forwarded.message_id] = {
|
||||
'original_chat': chat.id,
|
||||
'original_user': user.id,
|
||||
'session_id': session_id
|
||||
}
|
||||
|
||||
# 提示管理员可以直接输入文字回复
|
||||
await context.bot.send_message(
|
||||
chat_id=self.config.telegram.admin_id,
|
||||
text="💬 现在可以直接输入文字回复此客户,或回复上方转发的消息"
|
||||
)
|
||||
|
||||
# 自动回复(如果启用)
|
||||
if self.config.features.enable_auto_reply and not is_business_hours(self.config):
|
||||
await self.send_auto_reply(update, context)
|
||||
|
||||
self.stats['messages_forwarded'] += 1
|
||||
|
||||
async def handle_admin_reply(self, update: Update, context):
|
||||
"""处理管理员回复"""
|
||||
replied_to = update.message.reply_to_message
|
||||
|
||||
# 查找原始消息信息
|
||||
message_map = context.bot_data.get('message_map', {})
|
||||
if replied_to.message_id not in message_map:
|
||||
await update.message.reply_text("⚠️ 无法找到原始消息信息")
|
||||
return
|
||||
|
||||
original_info = message_map[replied_to.message_id]
|
||||
original_chat = original_info['original_chat']
|
||||
session_id = original_info['session_id']
|
||||
|
||||
# 发送回复给客户
|
||||
try:
|
||||
if update.message.text:
|
||||
await context.bot.send_message(
|
||||
chat_id=original_chat,
|
||||
text=update.message.text
|
||||
)
|
||||
elif update.message.photo:
|
||||
await context.bot.send_photo(
|
||||
chat_id=original_chat,
|
||||
photo=update.message.photo[-1].file_id,
|
||||
caption=update.message.caption
|
||||
)
|
||||
elif update.message.document:
|
||||
await context.bot.send_document(
|
||||
chat_id=original_chat,
|
||||
document=update.message.document.file_id,
|
||||
caption=update.message.caption
|
||||
)
|
||||
|
||||
# 确认发送
|
||||
await update.message.reply_text("✅ 消息已发送给客户")
|
||||
|
||||
# 更新会话
|
||||
if session_id in self.active_sessions:
|
||||
self.active_sessions[session_id]['last_reply'] = datetime.now()
|
||||
|
||||
self.stats['replies_sent'] += 1
|
||||
|
||||
except Exception as e:
|
||||
await update.message.reply_text(f"❌ 发送失败:{e}")
|
||||
self.logger.error(f"Failed to send reply: {e}")
|
||||
|
||||
async def handle_stats(self, update: Update, context):
|
||||
"""处理 /stats 命令(管理员)"""
|
||||
if update.effective_user.id != self.config.telegram.admin_id:
|
||||
return
|
||||
|
||||
uptime = datetime.now() - self.stats['start_time']
|
||||
days = uptime.days
|
||||
hours = uptime.seconds // 3600
|
||||
|
||||
text = (
|
||||
"📊 **统计信息**\n\n"
|
||||
f"⏱ 运行时间:{days} 天 {hours} 小时\n"
|
||||
f"📨 接收消息:{self.stats['messages_received']} 条\n"
|
||||
f"📤 转发消息:{self.stats['messages_forwarded']} 条\n"
|
||||
f"💬 回复消息:{self.stats['replies_sent']} 条\n"
|
||||
f"❌ 错误次数:{self.stats['errors']} 次\n"
|
||||
f"👥 活跃会话:{len(self.active_sessions)} 个\n"
|
||||
f"📅 启动时间:{self.stats['start_time'].strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
)
|
||||
|
||||
await update.message.reply_text(text, parse_mode='Markdown')
|
||||
|
||||
async def handle_sessions(self, update: Update, context):
|
||||
"""处理 /sessions 命令(管理员)"""
|
||||
if update.effective_user.id != self.config.telegram.admin_id:
|
||||
return
|
||||
|
||||
if not self.active_sessions:
|
||||
await update.message.reply_text("当前没有活跃会话")
|
||||
return
|
||||
|
||||
text = "👥 **活跃会话**\n\n"
|
||||
for session_id, session in self.active_sessions.items():
|
||||
duration = datetime.now() - session['started_at']
|
||||
text += (
|
||||
f"会话 `{session_id}`\n"
|
||||
f"👤 {session['first_name']} (@{session['username'] or 'N/A'})\n"
|
||||
f"💬 消息数:{len(session['messages'])}\n"
|
||||
f"⏱ 时长:{duration.seconds // 60} 分钟\n"
|
||||
f"━━━━━━━━━━━━━━━━\n"
|
||||
)
|
||||
|
||||
await update.message.reply_text(text, parse_mode='Markdown')
|
||||
|
||||
async def handle_reply(self, update: Update, context):
|
||||
"""处理 /reply 命令(管理员)"""
|
||||
if update.effective_user.id != self.config.telegram.admin_id:
|
||||
return
|
||||
|
||||
if len(context.args) < 2:
|
||||
await update.message.reply_text(
|
||||
"用法:/reply <用户ID> <消息>\n"
|
||||
"示例:/reply 123456789 您好,有什么可以帮助您?"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
user_id = int(context.args[0])
|
||||
message = ' '.join(context.args[1:])
|
||||
|
||||
await context.bot.send_message(chat_id=user_id, text=message)
|
||||
await update.message.reply_text(f"✅ 消息已发送给用户 {user_id}")
|
||||
self.stats['replies_sent'] += 1
|
||||
|
||||
except Exception as e:
|
||||
await update.message.reply_text(f"❌ 发送失败:{e}")
|
||||
|
||||
|
||||
async def reply_to_current_customer(self, update: Update, context):
|
||||
"""回复当前客户"""
|
||||
if not self.current_customer:
|
||||
await update.message.reply_text("❌ 没有选中的客户")
|
||||
return
|
||||
|
||||
try:
|
||||
message = update.effective_message
|
||||
|
||||
# 发送消息给当前客户
|
||||
if message.text:
|
||||
await context.bot.send_message(
|
||||
chat_id=self.current_customer['chat_id'],
|
||||
text=message.text
|
||||
)
|
||||
elif message.photo:
|
||||
await context.bot.send_photo(
|
||||
chat_id=self.current_customer['chat_id'],
|
||||
photo=message.photo[-1].file_id,
|
||||
caption=message.caption
|
||||
)
|
||||
elif message.document:
|
||||
await context.bot.send_document(
|
||||
chat_id=self.current_customer['chat_id'],
|
||||
document=message.document.file_id,
|
||||
caption=message.caption
|
||||
)
|
||||
|
||||
# 简洁确认消息
|
||||
await update.message.reply_text(f"✅ → {self.current_customer['first_name']}")
|
||||
self.stats['replies_sent'] += 1
|
||||
|
||||
except Exception as e:
|
||||
await update.message.reply_text(f"❌ 发送失败:{e}")
|
||||
self.logger.error(f"Failed to send reply: {e}")
|
||||
|
||||
async def handle_broadcast(self, update: Update, context):
|
||||
"""处理 /broadcast 命令(管理员)"""
|
||||
if update.effective_user.id != self.config.telegram.admin_id:
|
||||
return
|
||||
|
||||
if not context.args:
|
||||
await update.message.reply_text(
|
||||
"用法:/broadcast <消息>\n"
|
||||
"示例:/broadcast 系统维护通知:今晚10点进行系统维护"
|
||||
)
|
||||
return
|
||||
|
||||
message = ' '.join(context.args)
|
||||
sent = 0
|
||||
failed = 0
|
||||
|
||||
for session_id, session in self.active_sessions.items():
|
||||
try:
|
||||
await context.bot.send_message(
|
||||
chat_id=session['chat_id'],
|
||||
text=message
|
||||
)
|
||||
sent += 1
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
self.logger.error(f"Failed to broadcast to {session['chat_id']}: {e}")
|
||||
|
||||
await update.message.reply_text(
|
||||
f"✅ 广播完成\n"
|
||||
f"成功:{sent} 个\n"
|
||||
f"失败:{failed} 个"
|
||||
)
|
||||
|
||||
async def handle_settings(self, update: Update, context):
|
||||
"""处理 /settings 命令(管理员)"""
|
||||
if update.effective_user.id != self.config.telegram.admin_id:
|
||||
return
|
||||
|
||||
keyboard = [
|
||||
[
|
||||
InlineKeyboardButton(
|
||||
f"{'✅' if self.config.features.enable_auto_reply else '❌'} 自动回复",
|
||||
callback_data="toggle_auto_reply"
|
||||
)
|
||||
],
|
||||
[
|
||||
InlineKeyboardButton(
|
||||
f"{'✅' if self.config.features.enable_statistics else '❌'} 统计功能",
|
||||
callback_data="toggle_statistics"
|
||||
)
|
||||
],
|
||||
[
|
||||
InlineKeyboardButton("📊 查看所有设置", callback_data="view_settings")
|
||||
]
|
||||
]
|
||||
reply_markup = InlineKeyboardMarkup(keyboard)
|
||||
|
||||
await update.message.reply_text(
|
||||
"⚙️ **机器人设置**\n\n点击按钮切换功能:",
|
||||
reply_markup=reply_markup,
|
||||
parse_mode='Markdown'
|
||||
)
|
||||
|
||||
async def handle_callback(self, update: Update, context):
|
||||
"""处理回调查询"""
|
||||
query = update.callback_query
|
||||
await query.answer()
|
||||
|
||||
data = query.data
|
||||
if data.startswith("done_"):
|
||||
session_id = data.replace("done_", "")
|
||||
await query.edit_message_text(f"✅ 会话 {session_id} 已标记为完成")
|
||||
elif data.startswith("later_"):
|
||||
session_id = data.replace("later_", "")
|
||||
await query.edit_message_text(f"⏸ 会话 {session_id} 已标记为稍后处理")
|
||||
elif data == "toggle_auto_reply":
|
||||
self.config.features.enable_auto_reply = not self.config.features.enable_auto_reply
|
||||
await query.edit_message_text(
|
||||
f"自动回复已{'启用' if self.config.features.enable_auto_reply else '禁用'}"
|
||||
)
|
||||
|
||||
async def handle_error(self, update: Update, context):
|
||||
"""处理错误"""
|
||||
self.stats['errors'] += 1
|
||||
error_info = await ErrorHandler.handle_error(context.error)
|
||||
self.logger.error(f"Update {update} caused error {context.error}")
|
||||
|
||||
if update and update.effective_message:
|
||||
user_message = ErrorHandler.create_user_message(context.error)
|
||||
await update.effective_message.reply_text(user_message)
|
||||
|
||||
async def notify_admin_new_customer(self, user):
|
||||
"""通知管理员有新客户"""
|
||||
def escape_markdown(text):
|
||||
"""转义 Markdown 特殊字符"""
|
||||
if text is None:
|
||||
return ''
|
||||
# 转义特殊字符
|
||||
special_chars = ['_', '*', '[', ']', '(', ')', '~', '`', '>', '#', '+', '-', '=', '|', '{', '}', '.', '!']
|
||||
for char in special_chars:
|
||||
text = str(text).replace(char, f'\\{char}')
|
||||
return text
|
||||
|
||||
first_name = escape_markdown(user.first_name)
|
||||
last_name = escape_markdown(user.last_name) if user.last_name else ''
|
||||
username = escape_markdown(user.username) if user.username else 'N/A'
|
||||
|
||||
text = (
|
||||
f"🆕 新客户加入\n"
|
||||
f"👤 姓名:{first_name} {last_name}\n"
|
||||
f"🆔 ID:`{user.id}`\n"
|
||||
f"📱 用户名:@{username}"
|
||||
)
|
||||
|
||||
try:
|
||||
await self.application.bot.send_message(
|
||||
chat_id=self.config.telegram.admin_id,
|
||||
text=text,
|
||||
parse_mode='MarkdownV2'
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to notify admin: {e}")
|
||||
|
||||
async def send_auto_reply(self, update: Update, context):
|
||||
"""发送自动回复"""
|
||||
import pytz
|
||||
from datetime import time
|
||||
|
||||
# 检查营业时间
|
||||
tz = pytz.timezone(self.config.business.timezone)
|
||||
now = datetime.now(tz).time()
|
||||
start_time = time.fromisoformat(self.config.business.business_hours_start)
|
||||
end_time = time.fromisoformat(self.config.business.business_hours_end)
|
||||
|
||||
if not (start_time <= now <= end_time):
|
||||
message = self.config.business.offline_message.format(
|
||||
start=self.config.business.business_hours_start,
|
||||
end=self.config.business.business_hours_end
|
||||
)
|
||||
await update.message.reply_text(message)
|
||||
|
||||
async def send_error_message(self, update: Update, error: Exception):
|
||||
"""发送错误消息给用户"""
|
||||
user_message = ErrorHandler.create_user_message(error)
|
||||
await update.message.reply_text(user_message)
|
||||
|
||||
def _get_admin_help(self) -> str:
|
||||
"""获取管理员帮助信息"""
|
||||
return (
|
||||
"📚 **管理员帮助**\n\n"
|
||||
"**基础命令**\n"
|
||||
"/start - 启动机器人\n"
|
||||
"/help - 显示帮助\n"
|
||||
"/status - 查看状态\n\n"
|
||||
"**管理命令**\n"
|
||||
"/stats - 查看统计信息\n"
|
||||
"/sessions - 查看活跃会话\n"
|
||||
"/reply <用户ID> <消息> - 回复指定用户\n"
|
||||
"/broadcast <消息> - 广播消息给所有用户\n"
|
||||
"/settings - 机器人设置\n\n"
|
||||
"**快速回复客户**\n"
|
||||
"• 直接输入文字 - 自动发送给最近的客户\n"
|
||||
"• 回复转发消息 - 回复特定客户"
|
||||
)
|
||||
|
||||
def _get_user_help(self) -> str:
|
||||
"""获取用户帮助信息"""
|
||||
return (
|
||||
"📚 **帮助信息**\n\n"
|
||||
"**使用方法**\n"
|
||||
"• 直接发送消息,客服会尽快回复您\n"
|
||||
"• 支持发送文字、图片、文件等\n"
|
||||
"• /contact - 联系人工客服\n"
|
||||
"• /status - 查看服务状态\n\n"
|
||||
f"**工作时间**\n"
|
||||
f"{self.config.business.business_hours_start} - {self.config.business.business_hours_end}\n\n"
|
||||
"如有紧急情况,请留言,我们会尽快处理"
|
||||
)
|
||||
|
||||
|
||||
def run(self):
|
||||
"""运行机器人"""
|
||||
try:
|
||||
# 同步初始化
|
||||
asyncio.get_event_loop().run_until_complete(self.initialize())
|
||||
|
||||
self.logger.info("Starting bot...")
|
||||
self.application.run_polling(allowed_updates=Update.ALL_TYPES)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
self.logger.info("Bot stopped by user")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Bot crashed: {e}")
|
||||
raise
|
||||
finally:
|
||||
self.cleanup()
|
||||
|
||||
def cleanup(self):
|
||||
"""清理资源"""
|
||||
self.logger.info("Cleaning up resources...")
|
||||
# 保存统计信息、关闭数据库等
|
||||
pass
|
||||
|
||||
|
||||
def is_business_hours(config: Settings) -> bool:
|
||||
"""检查是否在营业时间"""
|
||||
import pytz
|
||||
from datetime import time
|
||||
|
||||
tz = pytz.timezone(config.business.timezone)
|
||||
now = datetime.now(tz).time()
|
||||
start_time = time.fromisoformat(config.business.business_hours_start)
|
||||
end_time = time.fromisoformat(config.business.business_hours_end)
|
||||
|
||||
return start_time <= now <= end_time
|
||||
157
src/core/handlers.py
Normal file
157
src/core/handlers.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""处理器基类和上下文"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Dict, List
|
||||
from dataclasses import dataclass, field
|
||||
from telegram import Update, Message
|
||||
from telegram.ext import ContextTypes
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
from ..config.settings import Settings
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HandlerContext:
|
||||
"""处理器上下文"""
|
||||
update: Update
|
||||
context: ContextTypes.DEFAULT_TYPE
|
||||
config: Settings
|
||||
user_data: Dict[str, Any] = field(default_factory=dict)
|
||||
chat_data: Dict[str, Any] = field(default_factory=dict)
|
||||
session_data: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def message(self) -> Message:
|
||||
"""获取消息"""
|
||||
return self.update.effective_message
|
||||
|
||||
@property
|
||||
def user(self):
|
||||
"""获取用户"""
|
||||
return self.update.effective_user
|
||||
|
||||
@property
|
||||
def chat(self):
|
||||
"""获取聊天"""
|
||||
return self.update.effective_chat
|
||||
|
||||
def get_session_id(self) -> str:
|
||||
"""获取会话ID"""
|
||||
return f"{self.chat.id}_{self.user.id}"
|
||||
|
||||
|
||||
class BaseHandler(ABC):
|
||||
"""处理器基类"""
|
||||
|
||||
def __init__(self, config: Settings):
|
||||
self.config = config
|
||||
self.logger = get_logger(self.__class__.__name__)
|
||||
|
||||
@abstractmethod
|
||||
async def handle(self, handler_context: HandlerContext) -> Any:
|
||||
"""处理消息"""
|
||||
pass
|
||||
|
||||
async def __call__(self, update: Update, context: ContextTypes.DEFAULT_TYPE,
|
||||
message_context: Any = None) -> Any:
|
||||
"""调用处理器"""
|
||||
handler_context = HandlerContext(
|
||||
update=update,
|
||||
context=context,
|
||||
config=self.config,
|
||||
user_data=context.user_data,
|
||||
chat_data=context.chat_data
|
||||
)
|
||||
|
||||
try:
|
||||
self.logger.debug(f"Handling message from user {handler_context.user.id}")
|
||||
result = await self.handle(handler_context)
|
||||
return result
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handler: {e}")
|
||||
raise
|
||||
|
||||
async def reply_text(self, context: HandlerContext, text: str, **kwargs) -> Message:
|
||||
"""回复文本消息"""
|
||||
return await context.message.reply_text(text, **kwargs)
|
||||
|
||||
async def reply_photo(self, context: HandlerContext, photo, caption: str = None, **kwargs) -> Message:
|
||||
"""回复图片"""
|
||||
return await context.message.reply_photo(photo, caption=caption, **kwargs)
|
||||
|
||||
async def reply_document(self, context: HandlerContext, document, caption: str = None, **kwargs) -> Message:
|
||||
"""回复文档"""
|
||||
return await context.message.reply_document(document, caption=caption, **kwargs)
|
||||
|
||||
async def forward_to_admin(self, context: HandlerContext) -> Message:
|
||||
"""转发消息给管理员"""
|
||||
return await context.context.bot.forward_message(
|
||||
chat_id=self.config.telegram.admin_id,
|
||||
from_chat_id=context.chat.id,
|
||||
message_id=context.message.message_id
|
||||
)
|
||||
|
||||
async def send_to_admin(self, context: HandlerContext, text: str, **kwargs) -> Message:
|
||||
"""发送消息给管理员"""
|
||||
return await context.context.bot.send_message(
|
||||
chat_id=self.config.telegram.admin_id,
|
||||
text=text,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
class CompositeHandler(BaseHandler):
|
||||
"""组合处理器"""
|
||||
|
||||
def __init__(self, config: Settings):
|
||||
super().__init__(config)
|
||||
self.handlers: List[BaseHandler] = []
|
||||
|
||||
def add_handler(self, handler: BaseHandler):
|
||||
"""添加处理器"""
|
||||
self.handlers.append(handler)
|
||||
|
||||
async def handle(self, handler_context: HandlerContext) -> Any:
|
||||
"""依次执行所有处理器"""
|
||||
results = []
|
||||
for handler in self.handlers:
|
||||
try:
|
||||
result = await handler.handle(handler_context)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in composite handler {handler.__class__.__name__}: {e}")
|
||||
# 可以选择继续或中断
|
||||
raise
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class ConditionalHandler(BaseHandler):
|
||||
"""条件处理器"""
|
||||
|
||||
def __init__(self, config: Settings, condition_func):
|
||||
super().__init__(config)
|
||||
self.condition_func = condition_func
|
||||
self.true_handler: Optional[BaseHandler] = None
|
||||
self.false_handler: Optional[BaseHandler] = None
|
||||
|
||||
def set_true_handler(self, handler: BaseHandler):
|
||||
"""设置条件为真时的处理器"""
|
||||
self.true_handler = handler
|
||||
|
||||
def set_false_handler(self, handler: BaseHandler):
|
||||
"""设置条件为假时的处理器"""
|
||||
self.false_handler = handler
|
||||
|
||||
async def handle(self, handler_context: HandlerContext) -> Any:
|
||||
"""根据条件执行处理器"""
|
||||
if await self.condition_func(handler_context):
|
||||
if self.true_handler:
|
||||
return await self.true_handler.handle(handler_context)
|
||||
else:
|
||||
if self.false_handler:
|
||||
return await self.false_handler.handle(handler_context)
|
||||
|
||||
return None
|
||||
321
src/core/router.py
Normal file
321
src/core/router.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""消息路由系统"""
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Callable, Any, Type
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from telegram import Update, Message, User, Chat
|
||||
from telegram.ext import ContextTypes
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.exceptions import MessageRoutingError
|
||||
from ..utils.decorators import log_action, measure_performance
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
"""消息类型枚举"""
|
||||
TEXT = "text"
|
||||
PHOTO = "photo"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
VOICE = "voice"
|
||||
DOCUMENT = "document"
|
||||
STICKER = "sticker"
|
||||
LOCATION = "location"
|
||||
CONTACT = "contact"
|
||||
POLL = "poll"
|
||||
COMMAND = "command"
|
||||
CALLBACK = "callback"
|
||||
INLINE = "inline"
|
||||
|
||||
|
||||
class RoutePriority(Enum):
|
||||
"""路由优先级"""
|
||||
CRITICAL = 0
|
||||
HIGH = 1
|
||||
NORMAL = 2
|
||||
LOW = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoutePattern:
|
||||
"""路由模式"""
|
||||
pattern: str
|
||||
type: MessageType
|
||||
priority: RoutePriority = RoutePriority.NORMAL
|
||||
conditions: List[Callable] = field(default_factory=list)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def matches(self, message: Message) -> bool:
|
||||
"""检查消息是否匹配模式"""
|
||||
# 检查消息类型
|
||||
if not self._check_message_type(message):
|
||||
return False
|
||||
|
||||
# 检查模式匹配
|
||||
if self.type == MessageType.TEXT and message.text:
|
||||
if not self._match_text_pattern(message.text):
|
||||
return False
|
||||
|
||||
# 检查条件
|
||||
for condition in self.conditions:
|
||||
if not condition(message):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _check_message_type(self, message: Message) -> bool:
|
||||
"""检查消息类型是否匹配"""
|
||||
type_map = {
|
||||
MessageType.TEXT: lambda m: m.text is not None,
|
||||
MessageType.PHOTO: lambda m: m.photo is not None,
|
||||
MessageType.VIDEO: lambda m: m.video is not None,
|
||||
MessageType.AUDIO: lambda m: m.audio is not None,
|
||||
MessageType.VOICE: lambda m: m.voice is not None,
|
||||
MessageType.DOCUMENT: lambda m: m.document is not None,
|
||||
MessageType.STICKER: lambda m: m.sticker is not None,
|
||||
MessageType.LOCATION: lambda m: m.location is not None,
|
||||
MessageType.CONTACT: lambda m: m.contact is not None,
|
||||
MessageType.POLL: lambda m: m.poll is not None,
|
||||
}
|
||||
|
||||
check_func = type_map.get(self.type)
|
||||
return check_func(message) if check_func else False
|
||||
|
||||
def _match_text_pattern(self, text: str) -> bool:
|
||||
"""匹配文本模式"""
|
||||
import re
|
||||
if self.pattern.startswith("^") or self.pattern.endswith("$"):
|
||||
# 正则表达式
|
||||
return bool(re.match(self.pattern, text))
|
||||
else:
|
||||
# 简单包含检查
|
||||
return self.pattern in text
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageContext:
|
||||
"""消息上下文"""
|
||||
message_id: str
|
||||
user_id: int
|
||||
chat_id: int
|
||||
username: Optional[str]
|
||||
first_name: Optional[str]
|
||||
last_name: Optional[str]
|
||||
message_type: MessageType
|
||||
content: Any
|
||||
timestamp: datetime
|
||||
is_admin: bool = False
|
||||
session_id: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_update(cls, update: Update, admin_id: int) -> 'MessageContext':
|
||||
"""从更新创建上下文"""
|
||||
message = update.effective_message
|
||||
user = update.effective_user
|
||||
chat = update.effective_chat
|
||||
|
||||
# 确定消息类型
|
||||
if message.text and message.text.startswith('/'):
|
||||
msg_type = MessageType.COMMAND
|
||||
elif message.text:
|
||||
msg_type = MessageType.TEXT
|
||||
elif message.photo:
|
||||
msg_type = MessageType.PHOTO
|
||||
elif message.video:
|
||||
msg_type = MessageType.VIDEO
|
||||
elif message.voice:
|
||||
msg_type = MessageType.VOICE
|
||||
elif message.document:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
elif message.location:
|
||||
msg_type = MessageType.LOCATION
|
||||
else:
|
||||
msg_type = MessageType.TEXT
|
||||
|
||||
# 提取内容
|
||||
content = message.text or message.caption or ""
|
||||
if message.photo:
|
||||
content = message.photo[-1].file_id
|
||||
elif message.document:
|
||||
content = message.document.file_id
|
||||
elif message.voice:
|
||||
content = message.voice.file_id
|
||||
elif message.video:
|
||||
content = message.video.file_id
|
||||
|
||||
return cls(
|
||||
message_id=str(message.message_id),
|
||||
user_id=user.id,
|
||||
chat_id=chat.id,
|
||||
username=user.username,
|
||||
first_name=user.first_name,
|
||||
last_name=user.last_name,
|
||||
message_type=msg_type,
|
||||
content=content,
|
||||
timestamp=datetime.now(),
|
||||
is_admin=(user.id == admin_id),
|
||||
session_id=f"{chat.id}_{user.id}"
|
||||
)
|
||||
|
||||
|
||||
class MessageRouter:
|
||||
"""消息路由器"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.routes: Dict[RoutePriority, List[tuple[RoutePattern, Callable]]] = {
|
||||
priority: [] for priority in RoutePriority
|
||||
}
|
||||
self.middleware: List[Callable] = []
|
||||
self.default_handler: Optional[Callable] = None
|
||||
self.error_handler: Optional[Callable] = None
|
||||
|
||||
def add_route(self, pattern: RoutePattern, handler: Callable):
|
||||
"""添加路由"""
|
||||
self.routes[pattern.priority].append((pattern, handler))
|
||||
logger.debug(f"Added route: {pattern.pattern} with priority {pattern.priority}")
|
||||
|
||||
def add_middleware(self, middleware: Callable):
|
||||
"""添加中间件"""
|
||||
self.middleware.append(middleware)
|
||||
logger.debug(f"Added middleware: {middleware.__name__}")
|
||||
|
||||
def set_default_handler(self, handler: Callable):
|
||||
"""设置默认处理器"""
|
||||
self.default_handler = handler
|
||||
|
||||
def set_error_handler(self, handler: Callable):
|
||||
"""设置错误处理器"""
|
||||
self.error_handler = handler
|
||||
|
||||
@measure_performance
|
||||
@log_action("route_message")
|
||||
async def route(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> Any:
|
||||
"""路由消息"""
|
||||
try:
|
||||
# 创建消息上下文
|
||||
msg_context = MessageContext.from_update(
|
||||
update,
|
||||
self.config.telegram.admin_id
|
||||
)
|
||||
|
||||
# 应用中间件
|
||||
for middleware in self.middleware:
|
||||
result = await middleware(msg_context, context)
|
||||
if result is False:
|
||||
logger.debug(f"Middleware {middleware.__name__} blocked message")
|
||||
return None
|
||||
|
||||
# 查找匹配的路由
|
||||
handler = await self._find_handler(update.effective_message, msg_context)
|
||||
|
||||
if handler:
|
||||
logger.info(
|
||||
f"Routing message to {handler.__name__}",
|
||||
extra={'user_id': msg_context.user_id, 'handler': handler.__name__}
|
||||
)
|
||||
return await handler(update, context, msg_context)
|
||||
elif self.default_handler:
|
||||
logger.info(
|
||||
f"Using default handler",
|
||||
extra={'user_id': msg_context.user_id}
|
||||
)
|
||||
return await self.default_handler(update, context, msg_context)
|
||||
else:
|
||||
logger.warning(f"No handler found for message from user {msg_context.user_id}")
|
||||
raise MessageRoutingError("No handler found for this message type")
|
||||
|
||||
except Exception as e:
|
||||
if self.error_handler:
|
||||
return await self.error_handler(update, context, e)
|
||||
else:
|
||||
logger.error(f"Error in message routing: {e}")
|
||||
raise
|
||||
|
||||
async def _find_handler(self, message: Message, context: MessageContext) -> Optional[Callable]:
|
||||
"""查找合适的处理器"""
|
||||
# 按优先级顺序检查路由
|
||||
for priority in RoutePriority:
|
||||
for pattern, handler in self.routes[priority]:
|
||||
if pattern.matches(message):
|
||||
return handler
|
||||
return None
|
||||
|
||||
|
||||
class RouteBuilder:
|
||||
"""路由构建器"""
|
||||
|
||||
def __init__(self, router: MessageRouter):
|
||||
self.router = router
|
||||
|
||||
def text(self, pattern: str = None, priority: RoutePriority = RoutePriority.NORMAL):
|
||||
"""文本消息路由装饰器"""
|
||||
def decorator(handler: Callable):
|
||||
route_pattern = RoutePattern(
|
||||
pattern=pattern or ".*",
|
||||
type=MessageType.TEXT,
|
||||
priority=priority
|
||||
)
|
||||
self.router.add_route(route_pattern, handler)
|
||||
return handler
|
||||
return decorator
|
||||
|
||||
def command(self, command: str, priority: RoutePriority = RoutePriority.HIGH):
|
||||
"""命令路由装饰器"""
|
||||
def decorator(handler: Callable):
|
||||
route_pattern = RoutePattern(
|
||||
pattern=f"^/{command}",
|
||||
type=MessageType.TEXT,
|
||||
priority=priority
|
||||
)
|
||||
self.router.add_route(route_pattern, handler)
|
||||
return handler
|
||||
return decorator
|
||||
|
||||
def photo(self, priority: RoutePriority = RoutePriority.NORMAL):
|
||||
"""图片消息路由装饰器"""
|
||||
def decorator(handler: Callable):
|
||||
route_pattern = RoutePattern(
|
||||
pattern="",
|
||||
type=MessageType.PHOTO,
|
||||
priority=priority
|
||||
)
|
||||
self.router.add_route(route_pattern, handler)
|
||||
return handler
|
||||
return decorator
|
||||
|
||||
def document(self, priority: RoutePriority = RoutePriority.NORMAL):
|
||||
"""文档消息路由装饰器"""
|
||||
def decorator(handler: Callable):
|
||||
route_pattern = RoutePattern(
|
||||
pattern="",
|
||||
type=MessageType.DOCUMENT,
|
||||
priority=priority
|
||||
)
|
||||
self.router.add_route(route_pattern, handler)
|
||||
return handler
|
||||
return decorator
|
||||
|
||||
def voice(self, priority: RoutePriority = RoutePriority.NORMAL):
|
||||
"""语音消息路由装饰器"""
|
||||
def decorator(handler: Callable):
|
||||
route_pattern = RoutePattern(
|
||||
pattern="",
|
||||
type=MessageType.VOICE,
|
||||
priority=priority
|
||||
)
|
||||
self.router.add_route(route_pattern, handler)
|
||||
return handler
|
||||
return decorator
|
||||
|
||||
def middleware(self):
|
||||
"""中间件装饰器"""
|
||||
def decorator(handler: Callable):
|
||||
self.router.add_middleware(handler)
|
||||
return handler
|
||||
return decorator
|
||||
279
src/modules/mirror_search.py
Normal file
279
src/modules/mirror_search.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""
|
||||
搜索镜像模块 - 自动转发搜索指令到目标机器人并返回结果
|
||||
基于 jingxiang 项目的镜像机制
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Optional, Any
|
||||
from pyrogram import Client, filters
|
||||
from pyrogram.types import Message as PyrogramMessage
|
||||
from pyrogram.raw.functions.messages import GetBotCallbackAnswer
|
||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
||||
from telegram.ext import ContextTypes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MirrorSearchHandler:
|
||||
"""处理搜索指令的镜像转发"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.enabled = False
|
||||
|
||||
# Pyrogram配置(需要在.env中配置)
|
||||
self.api_id = None
|
||||
self.api_hash = None
|
||||
self.session_name = "search_mirror_session"
|
||||
self.target_bot = "@openaiw_bot" # 目标搜索机器人
|
||||
|
||||
# Pyrogram客户端
|
||||
self.pyrogram_client: Optional[Client] = None
|
||||
self.target_bot_id: Optional[int] = None
|
||||
|
||||
# 消息映射
|
||||
self.user_search_requests: Dict[int, Dict[str, Any]] = {} # user_id -> search_info
|
||||
self.pyrogram_to_user: Dict[int, int] = {} # pyrogram_msg_id -> user_id
|
||||
self.user_to_telegram: Dict[int, int] = {} # user_id -> telegram_msg_id
|
||||
|
||||
# 支持的搜索命令
|
||||
self.search_commands = ['/topchat', '/search', '/text', '/human']
|
||||
|
||||
async def initialize(self, api_id: int, api_hash: str):
|
||||
"""初始化Pyrogram客户端"""
|
||||
try:
|
||||
self.api_id = api_id
|
||||
self.api_hash = api_hash
|
||||
|
||||
self.pyrogram_client = Client(
|
||||
self.session_name,
|
||||
api_id=self.api_id,
|
||||
api_hash=self.api_hash
|
||||
)
|
||||
|
||||
await self.pyrogram_client.start()
|
||||
logger.info("✅ 搜索镜像客户端已启动")
|
||||
|
||||
# 获取目标机器人信息
|
||||
target = await self.pyrogram_client.get_users(self.target_bot)
|
||||
self.target_bot_id = target.id
|
||||
logger.info(f"✅ 连接到搜索机器人: {target.username} (ID: {target.id})")
|
||||
|
||||
# 设置消息监听器
|
||||
await self._setup_listeners()
|
||||
|
||||
self.enabled = True
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"镜像搜索初始化失败: {e}")
|
||||
self.enabled = False
|
||||
return False
|
||||
|
||||
async def _setup_listeners(self):
|
||||
"""设置Pyrogram消息监听器"""
|
||||
if not self.pyrogram_client:
|
||||
return
|
||||
|
||||
@self.pyrogram_client.on_message(filters.user(self.target_bot_id))
|
||||
async def on_bot_response(_, message: PyrogramMessage):
|
||||
"""当收到搜索机器人的响应时"""
|
||||
await self._handle_bot_response(message)
|
||||
|
||||
@self.pyrogram_client.on_edited_message(filters.user(self.target_bot_id))
|
||||
async def on_message_edited(_, message: PyrogramMessage):
|
||||
"""当搜索机器人编辑消息时(翻页)"""
|
||||
await self._handle_bot_response(message, is_edit=True)
|
||||
|
||||
logger.info("✅ 消息监听器已设置")
|
||||
|
||||
def is_search_command(self, text: str) -> bool:
|
||||
"""检查是否是搜索命令"""
|
||||
if not text:
|
||||
return False
|
||||
command = text.split()[0]
|
||||
return command in self.search_commands
|
||||
|
||||
async def process_search_command(
|
||||
self,
|
||||
update: Update,
|
||||
context: ContextTypes.DEFAULT_TYPE,
|
||||
user_id: int,
|
||||
command: str
|
||||
) -> bool:
|
||||
"""处理用户的搜索命令"""
|
||||
|
||||
if not self.enabled or not self.pyrogram_client:
|
||||
logger.warning("搜索镜像未启用")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 记录用户搜索请求
|
||||
self.user_search_requests[user_id] = {
|
||||
'command': command,
|
||||
'chat_id': update.effective_chat.id,
|
||||
'update': update,
|
||||
'context': context,
|
||||
'timestamp': asyncio.get_event_loop().time()
|
||||
}
|
||||
|
||||
# 通过Pyrogram发送命令给目标机器人
|
||||
sent_message = await self.pyrogram_client.send_message(
|
||||
self.target_bot,
|
||||
command
|
||||
)
|
||||
|
||||
# 记录映射关系
|
||||
if sent_message:
|
||||
logger.info(f"已发送搜索命令给 {self.target_bot}: {command}")
|
||||
# 等待响应会通过监听器处理
|
||||
|
||||
# 发送等待提示给用户
|
||||
waiting_msg = await update.message.reply_text(
|
||||
"🔍 正在搜索,请稍候..."
|
||||
)
|
||||
self.user_to_telegram[user_id] = waiting_msg.message_id
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送搜索命令失败: {e}")
|
||||
await update.message.reply_text(
|
||||
"❌ 搜索请求失败,请稍后重试或联系管理员"
|
||||
)
|
||||
return False
|
||||
|
||||
async def _handle_bot_response(self, message: PyrogramMessage, is_edit: bool = False):
|
||||
"""处理搜索机器人的响应"""
|
||||
try:
|
||||
# 查找对应的用户
|
||||
# 这里需要根据时间戳或其他方式匹配用户请求
|
||||
user_id = self._find_user_for_response(message)
|
||||
|
||||
if not user_id or user_id not in self.user_search_requests:
|
||||
logger.debug(f"未找到对应的用户请求")
|
||||
return
|
||||
|
||||
user_request = self.user_search_requests[user_id]
|
||||
|
||||
# 转换消息格式并发送给用户
|
||||
await self._forward_to_user(message, user_request, is_edit)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理机器人响应失败: {e}")
|
||||
|
||||
def _find_user_for_response(self, message: PyrogramMessage) -> Optional[int]:
|
||||
"""查找响应对应的用户"""
|
||||
# 简单的实现:返回最近的请求用户
|
||||
# 实际应用中可能需要更复杂的匹配逻辑
|
||||
if self.user_search_requests:
|
||||
# 获取最近的请求
|
||||
recent_user = max(
|
||||
self.user_search_requests.keys(),
|
||||
key=lambda k: self.user_search_requests[k].get('timestamp', 0)
|
||||
)
|
||||
return recent_user
|
||||
return None
|
||||
|
||||
async def _forward_to_user(
|
||||
self,
|
||||
pyrogram_msg: PyrogramMessage,
|
||||
user_request: Dict[str, Any],
|
||||
is_edit: bool = False
|
||||
):
|
||||
"""转发搜索结果给用户"""
|
||||
try:
|
||||
update = user_request['update']
|
||||
context = user_request['context']
|
||||
|
||||
# 提取消息内容
|
||||
text = self._extract_text(pyrogram_msg)
|
||||
keyboard = self._convert_keyboard(pyrogram_msg)
|
||||
|
||||
if is_edit and user_request['user_id'] in self.user_to_telegram:
|
||||
# 编辑现有消息
|
||||
telegram_msg_id = self.user_to_telegram[user_request['user_id']]
|
||||
await context.bot.edit_message_text(
|
||||
chat_id=user_request['chat_id'],
|
||||
message_id=telegram_msg_id,
|
||||
text=text,
|
||||
reply_markup=keyboard,
|
||||
parse_mode='HTML'
|
||||
)
|
||||
else:
|
||||
# 发送新消息
|
||||
sent = await context.bot.send_message(
|
||||
chat_id=user_request['chat_id'],
|
||||
text=text,
|
||||
reply_markup=keyboard,
|
||||
parse_mode='HTML'
|
||||
)
|
||||
self.user_to_telegram[user_request['user_id']] = sent.message_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转发消息给用户失败: {e}")
|
||||
|
||||
def _extract_text(self, message: PyrogramMessage) -> str:
|
||||
"""提取消息文本"""
|
||||
if message.text:
|
||||
return message.text
|
||||
elif message.caption:
|
||||
return message.caption
|
||||
return "(无文本内容)"
|
||||
|
||||
def _convert_keyboard(self, message: PyrogramMessage) -> Optional[InlineKeyboardMarkup]:
|
||||
"""转换Pyrogram键盘为Telegram键盘"""
|
||||
if not message.reply_markup:
|
||||
return None
|
||||
|
||||
try:
|
||||
buttons = []
|
||||
for row in message.reply_markup.inline_keyboard:
|
||||
button_row = []
|
||||
for button in row:
|
||||
if button.text:
|
||||
# 创建回调按钮
|
||||
callback_data = button.callback_data or f"mirror_{button.text}"
|
||||
if len(callback_data.encode()) > 64:
|
||||
# Telegram限制callback_data最大64字节
|
||||
callback_data = callback_data[:60] + "..."
|
||||
|
||||
button_row.append(
|
||||
InlineKeyboardButton(
|
||||
text=button.text,
|
||||
callback_data=callback_data
|
||||
)
|
||||
)
|
||||
if button_row:
|
||||
buttons.append(button_row)
|
||||
|
||||
return InlineKeyboardMarkup(buttons) if buttons else None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换键盘失败: {e}")
|
||||
return None
|
||||
|
||||
async def handle_callback(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
"""处理回调查询(翻页等)"""
|
||||
query = update.callback_query
|
||||
|
||||
if not query.data.startswith("mirror_"):
|
||||
return False
|
||||
|
||||
try:
|
||||
# 这里需要实现回调处理逻辑
|
||||
# 将回调转发给Pyrogram客户端
|
||||
await query.answer("处理中...")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理回调失败: {e}")
|
||||
await query.answer("操作失败", show_alert=True)
|
||||
return False
|
||||
|
||||
async def cleanup(self):
|
||||
"""清理资源"""
|
||||
if self.pyrogram_client:
|
||||
await self.pyrogram_client.stop()
|
||||
logger.info("搜索镜像客户端已停止")
|
||||
5
src/modules/storage/__init__.py
Normal file
5
src/modules/storage/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""存储模块"""
|
||||
from .database import DatabaseManager
|
||||
from .models import Customer, Message, Session
|
||||
|
||||
__all__ = ['DatabaseManager', 'Customer', 'Message', 'Session']
|
||||
428
src/modules/storage/database.py
Normal file
428
src/modules/storage/database.py
Normal file
@@ -0,0 +1,428 @@
|
||||
"""数据库管理器"""
|
||||
import sqlite3
|
||||
import json
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from .models import Customer, Message, Session, CustomerStatus, SessionStatus, MessageDirection
|
||||
from ...utils.logger import get_logger
|
||||
from ...utils.exceptions import DatabaseError
|
||||
from ...config.settings import Settings
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""数据库管理器"""
|
||||
|
||||
def __init__(self, config: Settings):
|
||||
self.config = config
|
||||
self.db_path = Path(self.config.database.path)
|
||||
self.connection: Optional[sqlite3.Connection] = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# 确保数据库目录存在
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化数据库"""
|
||||
async with self._lock:
|
||||
try:
|
||||
self.connection = sqlite3.connect(
|
||||
str(self.db_path),
|
||||
check_same_thread=False
|
||||
)
|
||||
self.connection.row_factory = sqlite3.Row
|
||||
await self._create_tables()
|
||||
logger.info(f"Database initialized at {self.db_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize database: {e}")
|
||||
raise DatabaseError(f"Database initialization failed: {e}")
|
||||
|
||||
async def _create_tables(self):
|
||||
"""创建数据表"""
|
||||
cursor = self.connection.cursor()
|
||||
|
||||
# 客户表
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS customers (
|
||||
user_id INTEGER PRIMARY KEY,
|
||||
username TEXT,
|
||||
first_name TEXT NOT NULL,
|
||||
last_name TEXT,
|
||||
language_code TEXT,
|
||||
status TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
metadata TEXT,
|
||||
tags TEXT,
|
||||
notes TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# 消息表
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
message_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
user_id INTEGER NOT NULL,
|
||||
chat_id INTEGER NOT NULL,
|
||||
direction TEXT NOT NULL,
|
||||
content TEXT,
|
||||
content_type TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
is_read INTEGER DEFAULT 0,
|
||||
is_replied INTEGER DEFAULT 0,
|
||||
reply_to_message_id TEXT,
|
||||
metadata TEXT,
|
||||
FOREIGN KEY (user_id) REFERENCES customers (user_id)
|
||||
)
|
||||
""")
|
||||
|
||||
# 会话表
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
session_id TEXT PRIMARY KEY,
|
||||
customer_id INTEGER NOT NULL,
|
||||
chat_id INTEGER NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
started_at TEXT NOT NULL,
|
||||
ended_at TEXT,
|
||||
last_message_at TEXT,
|
||||
message_count INTEGER DEFAULT 0,
|
||||
assigned_to INTEGER,
|
||||
tags TEXT,
|
||||
notes TEXT,
|
||||
metadata TEXT,
|
||||
FOREIGN KEY (customer_id) REFERENCES customers (user_id)
|
||||
)
|
||||
""")
|
||||
|
||||
# 创建索引
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_messages_user ON messages(user_id)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_sessions_customer ON sessions(customer_id)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_sessions_status ON sessions(status)")
|
||||
|
||||
self.connection.commit()
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
"""事务上下文管理器"""
|
||||
async with self._lock:
|
||||
cursor = self.connection.cursor()
|
||||
try:
|
||||
yield cursor
|
||||
self.connection.commit()
|
||||
except Exception as e:
|
||||
self.connection.rollback()
|
||||
logger.error(f"Transaction failed: {e}")
|
||||
raise DatabaseError(f"Transaction failed: {e}")
|
||||
|
||||
async def save_customer(self, customer: Customer) -> bool:
|
||||
"""保存客户"""
|
||||
async with self.transaction() as cursor:
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO customers (
|
||||
user_id, username, first_name, last_name, language_code,
|
||||
status, created_at, updated_at, metadata, tags, notes
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
customer.user_id,
|
||||
customer.username,
|
||||
customer.first_name,
|
||||
customer.last_name,
|
||||
customer.language_code,
|
||||
customer.status.value,
|
||||
customer.created_at.isoformat(),
|
||||
customer.updated_at.isoformat(),
|
||||
json.dumps(customer.metadata),
|
||||
json.dumps(customer.tags),
|
||||
customer.notes
|
||||
))
|
||||
return True
|
||||
|
||||
async def get_customer(self, user_id: int) -> Optional[Customer]:
|
||||
"""获取客户"""
|
||||
async with self._lock:
|
||||
cursor = self.connection.cursor()
|
||||
cursor.execute("SELECT * FROM customers WHERE user_id = ?", (user_id,))
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row:
|
||||
return Customer(
|
||||
user_id=row['user_id'],
|
||||
username=row['username'],
|
||||
first_name=row['first_name'],
|
||||
last_name=row['last_name'],
|
||||
language_code=row['language_code'],
|
||||
status=CustomerStatus(row['status']),
|
||||
created_at=datetime.fromisoformat(row['created_at']),
|
||||
updated_at=datetime.fromisoformat(row['updated_at']),
|
||||
metadata=json.loads(row['metadata'] or '{}'),
|
||||
tags=json.loads(row['tags'] or '[]'),
|
||||
notes=row['notes']
|
||||
)
|
||||
return None
|
||||
|
||||
async def get_all_customers(self, status: Optional[CustomerStatus] = None) -> List[Customer]:
|
||||
"""获取所有客户"""
|
||||
async with self._lock:
|
||||
cursor = self.connection.cursor()
|
||||
|
||||
if status:
|
||||
cursor.execute("SELECT * FROM customers WHERE status = ?", (status.value,))
|
||||
else:
|
||||
cursor.execute("SELECT * FROM customers")
|
||||
|
||||
customers = []
|
||||
for row in cursor.fetchall():
|
||||
customers.append(Customer(
|
||||
user_id=row['user_id'],
|
||||
username=row['username'],
|
||||
first_name=row['first_name'],
|
||||
last_name=row['last_name'],
|
||||
language_code=row['language_code'],
|
||||
status=CustomerStatus(row['status']),
|
||||
created_at=datetime.fromisoformat(row['created_at']),
|
||||
updated_at=datetime.fromisoformat(row['updated_at']),
|
||||
metadata=json.loads(row['metadata'] or '{}'),
|
||||
tags=json.loads(row['tags'] or '[]'),
|
||||
notes=row['notes']
|
||||
))
|
||||
|
||||
return customers
|
||||
|
||||
async def save_message(self, message: Message) -> bool:
|
||||
"""保存消息"""
|
||||
async with self.transaction() as cursor:
|
||||
cursor.execute("""
|
||||
INSERT INTO messages (
|
||||
message_id, session_id, user_id, chat_id, direction,
|
||||
content, content_type, timestamp, is_read, is_replied,
|
||||
reply_to_message_id, metadata
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
message.message_id,
|
||||
message.session_id,
|
||||
message.user_id,
|
||||
message.chat_id,
|
||||
message.direction.value,
|
||||
message.content,
|
||||
message.content_type,
|
||||
message.timestamp.isoformat(),
|
||||
message.is_read,
|
||||
message.is_replied,
|
||||
message.reply_to_message_id,
|
||||
json.dumps(message.metadata)
|
||||
))
|
||||
|
||||
# 更新会话的最后消息时间和消息计数
|
||||
cursor.execute("""
|
||||
UPDATE sessions
|
||||
SET last_message_at = ?, message_count = message_count + 1
|
||||
WHERE session_id = ?
|
||||
""", (message.timestamp.isoformat(), message.session_id))
|
||||
|
||||
return True
|
||||
|
||||
async def get_messages(self, session_id: str, limit: int = 100) -> List[Message]:
|
||||
"""获取会话消息"""
|
||||
async with self._lock:
|
||||
cursor = self.connection.cursor()
|
||||
cursor.execute("""
|
||||
SELECT * FROM messages
|
||||
WHERE session_id = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""", (session_id, limit))
|
||||
|
||||
messages = []
|
||||
for row in cursor.fetchall():
|
||||
messages.append(Message(
|
||||
message_id=row['message_id'],
|
||||
session_id=row['session_id'],
|
||||
user_id=row['user_id'],
|
||||
chat_id=row['chat_id'],
|
||||
direction=MessageDirection(row['direction']),
|
||||
content=row['content'],
|
||||
content_type=row['content_type'],
|
||||
timestamp=datetime.fromisoformat(row['timestamp']),
|
||||
is_read=bool(row['is_read']),
|
||||
is_replied=bool(row['is_replied']),
|
||||
reply_to_message_id=row['reply_to_message_id'],
|
||||
metadata=json.loads(row['metadata'] or '{}')
|
||||
))
|
||||
|
||||
return list(reversed(messages)) # 返回时间顺序
|
||||
|
||||
async def save_session(self, session: Session) -> bool:
|
||||
"""保存会话"""
|
||||
async with self.transaction() as cursor:
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO sessions (
|
||||
session_id, customer_id, chat_id, status, started_at,
|
||||
ended_at, last_message_at, message_count, assigned_to,
|
||||
tags, notes, metadata
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
session.session_id,
|
||||
session.customer_id,
|
||||
session.chat_id,
|
||||
session.status.value,
|
||||
session.started_at.isoformat(),
|
||||
session.ended_at.isoformat() if session.ended_at else None,
|
||||
session.last_message_at.isoformat() if session.last_message_at else None,
|
||||
session.message_count,
|
||||
session.assigned_to,
|
||||
json.dumps(session.tags),
|
||||
session.notes,
|
||||
json.dumps(session.metadata)
|
||||
))
|
||||
return True
|
||||
|
||||
async def get_session(self, session_id: str) -> Optional[Session]:
|
||||
"""获取会话"""
|
||||
async with self._lock:
|
||||
cursor = self.connection.cursor()
|
||||
cursor.execute("SELECT * FROM sessions WHERE session_id = ?", (session_id,))
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row:
|
||||
return Session(
|
||||
session_id=row['session_id'],
|
||||
customer_id=row['customer_id'],
|
||||
chat_id=row['chat_id'],
|
||||
status=SessionStatus(row['status']),
|
||||
started_at=datetime.fromisoformat(row['started_at']),
|
||||
ended_at=datetime.fromisoformat(row['ended_at']) if row['ended_at'] else None,
|
||||
last_message_at=datetime.fromisoformat(row['last_message_at']) if row['last_message_at'] else None,
|
||||
message_count=row['message_count'],
|
||||
assigned_to=row['assigned_to'],
|
||||
tags=json.loads(row['tags'] or '[]'),
|
||||
notes=row['notes'],
|
||||
metadata=json.loads(row['metadata'] or '{}')
|
||||
)
|
||||
return None
|
||||
|
||||
async def get_active_sessions(self) -> List[Session]:
|
||||
"""获取活跃会话"""
|
||||
async with self._lock:
|
||||
cursor = self.connection.cursor()
|
||||
cursor.execute("""
|
||||
SELECT * FROM sessions
|
||||
WHERE status = ?
|
||||
ORDER BY last_message_at DESC
|
||||
""", (SessionStatus.ACTIVE.value,))
|
||||
|
||||
sessions = []
|
||||
for row in cursor.fetchall():
|
||||
sessions.append(Session(
|
||||
session_id=row['session_id'],
|
||||
customer_id=row['customer_id'],
|
||||
chat_id=row['chat_id'],
|
||||
status=SessionStatus(row['status']),
|
||||
started_at=datetime.fromisoformat(row['started_at']),
|
||||
ended_at=datetime.fromisoformat(row['ended_at']) if row['ended_at'] else None,
|
||||
last_message_at=datetime.fromisoformat(row['last_message_at']) if row['last_message_at'] else None,
|
||||
message_count=row['message_count'],
|
||||
assigned_to=row['assigned_to'],
|
||||
tags=json.loads(row['tags'] or '[]'),
|
||||
notes=row['notes'],
|
||||
metadata=json.loads(row['metadata'] or '{}')
|
||||
))
|
||||
|
||||
return sessions
|
||||
|
||||
async def update_session_status(self, session_id: str, status: SessionStatus) -> bool:
|
||||
"""更新会话状态"""
|
||||
async with self.transaction() as cursor:
|
||||
ended_at = None
|
||||
if status in [SessionStatus.RESOLVED, SessionStatus.CLOSED]:
|
||||
ended_at = datetime.now().isoformat()
|
||||
|
||||
cursor.execute("""
|
||||
UPDATE sessions
|
||||
SET status = ?, ended_at = ?
|
||||
WHERE session_id = ?
|
||||
""", (status.value, ended_at, session_id))
|
||||
|
||||
return cursor.rowcount > 0
|
||||
|
||||
async def get_statistics(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
async with self._lock:
|
||||
cursor = self.connection.cursor()
|
||||
|
||||
# 客户统计
|
||||
cursor.execute("SELECT COUNT(*) as count FROM customers")
|
||||
total_customers = cursor.fetchone()['count']
|
||||
|
||||
cursor.execute("SELECT COUNT(*) as count FROM customers WHERE status = ?",
|
||||
(CustomerStatus.ACTIVE.value,))
|
||||
active_customers = cursor.fetchone()['count']
|
||||
|
||||
# 会话统计
|
||||
cursor.execute("SELECT COUNT(*) as count FROM sessions")
|
||||
total_sessions = cursor.fetchone()['count']
|
||||
|
||||
cursor.execute("SELECT COUNT(*) as count FROM sessions WHERE status = ?",
|
||||
(SessionStatus.ACTIVE.value,))
|
||||
active_sessions = cursor.fetchone()['count']
|
||||
|
||||
# 消息统计
|
||||
cursor.execute("SELECT COUNT(*) as count FROM messages")
|
||||
total_messages = cursor.fetchone()['count']
|
||||
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*) as count FROM messages
|
||||
WHERE direction = ? AND is_replied = 0
|
||||
""", (MessageDirection.INBOUND.value,))
|
||||
unreplied_messages = cursor.fetchone()['count']
|
||||
|
||||
return {
|
||||
'customers': {
|
||||
'total': total_customers,
|
||||
'active': active_customers
|
||||
},
|
||||
'sessions': {
|
||||
'total': total_sessions,
|
||||
'active': active_sessions
|
||||
},
|
||||
'messages': {
|
||||
'total': total_messages,
|
||||
'unreplied': unreplied_messages
|
||||
}
|
||||
}
|
||||
|
||||
async def cleanup_old_sessions(self, days: int = 30):
|
||||
"""清理旧会话"""
|
||||
async with self.transaction() as cursor:
|
||||
cutoff_date = datetime.now().timestamp() - (days * 24 * 60 * 60)
|
||||
cutoff_date_str = datetime.fromtimestamp(cutoff_date).isoformat()
|
||||
|
||||
cursor.execute("""
|
||||
DELETE FROM messages
|
||||
WHERE session_id IN (
|
||||
SELECT session_id FROM sessions
|
||||
WHERE ended_at < ? AND status IN (?, ?)
|
||||
)
|
||||
""", (cutoff_date_str, SessionStatus.RESOLVED.value, SessionStatus.CLOSED.value))
|
||||
|
||||
cursor.execute("""
|
||||
DELETE FROM sessions
|
||||
WHERE ended_at < ? AND status IN (?, ?)
|
||||
""", (cutoff_date_str, SessionStatus.RESOLVED.value, SessionStatus.CLOSED.value))
|
||||
|
||||
logger.info(f"Cleaned up sessions older than {days} days")
|
||||
|
||||
def close(self):
|
||||
"""关闭数据库连接"""
|
||||
if self.connection:
|
||||
self.connection.close()
|
||||
logger.info("Database connection closed")
|
||||
154
src/modules/storage/models.py
Normal file
154
src/modules/storage/models.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""数据模型"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class CustomerStatus(Enum):
|
||||
"""客户状态"""
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
BLOCKED = "blocked"
|
||||
|
||||
|
||||
class SessionStatus(Enum):
|
||||
"""会话状态"""
|
||||
ACTIVE = "active"
|
||||
PENDING = "pending"
|
||||
RESOLVED = "resolved"
|
||||
CLOSED = "closed"
|
||||
|
||||
|
||||
class MessageDirection(Enum):
|
||||
"""消息方向"""
|
||||
INBOUND = "inbound" # 客户发送
|
||||
OUTBOUND = "outbound" # 管理员发送
|
||||
|
||||
|
||||
@dataclass
|
||||
class Customer:
|
||||
"""客户模型"""
|
||||
user_id: int
|
||||
username: Optional[str]
|
||||
first_name: str
|
||||
last_name: Optional[str]
|
||||
language_code: Optional[str]
|
||||
status: CustomerStatus = CustomerStatus.ACTIVE
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
tags: List[str] = field(default_factory=list)
|
||||
notes: Optional[str] = None
|
||||
|
||||
@property
|
||||
def full_name(self) -> str:
|
||||
"""获取全名"""
|
||||
parts = [self.first_name]
|
||||
if self.last_name:
|
||||
parts.append(self.last_name)
|
||||
return " ".join(parts)
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""获取显示名称"""
|
||||
if self.username:
|
||||
return f"@{self.username}"
|
||||
return self.full_name
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
'user_id': self.user_id,
|
||||
'username': self.username,
|
||||
'first_name': self.first_name,
|
||||
'last_name': self.last_name,
|
||||
'language_code': self.language_code,
|
||||
'status': self.status.value,
|
||||
'created_at': self.created_at.isoformat(),
|
||||
'updated_at': self.updated_at.isoformat(),
|
||||
'metadata': self.metadata,
|
||||
'tags': self.tags,
|
||||
'notes': self.notes
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""消息模型"""
|
||||
message_id: str
|
||||
session_id: str
|
||||
user_id: int
|
||||
chat_id: int
|
||||
direction: MessageDirection
|
||||
content: str
|
||||
content_type: str # text, photo, document, voice, etc.
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
is_read: bool = False
|
||||
is_replied: bool = False
|
||||
reply_to_message_id: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
'message_id': self.message_id,
|
||||
'session_id': self.session_id,
|
||||
'user_id': self.user_id,
|
||||
'chat_id': self.chat_id,
|
||||
'direction': self.direction.value,
|
||||
'content': self.content,
|
||||
'content_type': self.content_type,
|
||||
'timestamp': self.timestamp.isoformat(),
|
||||
'is_read': self.is_read,
|
||||
'is_replied': self.is_replied,
|
||||
'reply_to_message_id': self.reply_to_message_id,
|
||||
'metadata': self.metadata
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Session:
|
||||
"""会话模型"""
|
||||
session_id: str
|
||||
customer_id: int
|
||||
chat_id: int
|
||||
status: SessionStatus = SessionStatus.ACTIVE
|
||||
started_at: datetime = field(default_factory=datetime.now)
|
||||
ended_at: Optional[datetime] = None
|
||||
last_message_at: Optional[datetime] = None
|
||||
message_count: int = 0
|
||||
assigned_to: Optional[int] = None # 分配给哪个管理员
|
||||
tags: List[str] = field(default_factory=list)
|
||||
notes: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def duration(self) -> Optional[float]:
|
||||
"""获取会话时长(秒)"""
|
||||
if self.ended_at:
|
||||
return (self.ended_at - self.started_at).total_seconds()
|
||||
return (datetime.now() - self.started_at).total_seconds()
|
||||
|
||||
@property
|
||||
def is_active(self) -> bool:
|
||||
"""是否活跃"""
|
||||
return self.status == SessionStatus.ACTIVE
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
'session_id': self.session_id,
|
||||
'customer_id': self.customer_id,
|
||||
'chat_id': self.chat_id,
|
||||
'status': self.status.value,
|
||||
'started_at': self.started_at.isoformat(),
|
||||
'ended_at': self.ended_at.isoformat() if self.ended_at else None,
|
||||
'last_message_at': self.last_message_at.isoformat() if self.last_message_at else None,
|
||||
'message_count': self.message_count,
|
||||
'assigned_to': self.assigned_to,
|
||||
'tags': self.tags,
|
||||
'notes': self.notes,
|
||||
'metadata': self.metadata,
|
||||
'duration': self.duration
|
||||
}
|
||||
6
src/utils/__init__.py
Normal file
6
src/utils/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""工具模块"""
|
||||
from .logger import Logger, get_logger
|
||||
from .exceptions import *
|
||||
from .decorators import *
|
||||
|
||||
__all__ = ['Logger', 'get_logger']
|
||||
233
src/utils/decorators.py
Normal file
233
src/utils/decorators.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""装饰器工具"""
|
||||
import functools
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Callable, Any, Optional, Dict
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict
|
||||
from .logger import get_logger
|
||||
from .exceptions import RateLimitError, AuthorizationError, ValidationError
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def async_retry(max_attempts: int = 3, delay: float = 1.0, backoff: float = 2.0,
|
||||
exceptions: tuple = (Exception,)):
|
||||
"""异步重试装饰器"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
attempt = 1
|
||||
current_delay = delay
|
||||
|
||||
while attempt <= max_attempts:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except exceptions as e:
|
||||
if attempt == max_attempts:
|
||||
logger.error(f"Max retries ({max_attempts}) reached for {func.__name__}")
|
||||
raise
|
||||
|
||||
logger.warning(
|
||||
f"Attempt {attempt}/{max_attempts} failed for {func.__name__}: {e}. "
|
||||
f"Retrying in {current_delay:.2f}s..."
|
||||
)
|
||||
await asyncio.sleep(current_delay)
|
||||
current_delay *= backoff
|
||||
attempt += 1
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def rate_limit(max_calls: int, period: float):
|
||||
"""速率限制装饰器"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
calls = defaultdict(list)
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# 获取调用者标识(假设第一个参数是 self,第二个是 update)
|
||||
caller_id = None
|
||||
if len(args) >= 2 and hasattr(args[1], 'effective_user'):
|
||||
caller_id = args[1].effective_user.id
|
||||
else:
|
||||
caller_id = 'global'
|
||||
|
||||
now = time.time()
|
||||
calls[caller_id] = [t for t in calls[caller_id] if now - t < period]
|
||||
|
||||
if len(calls[caller_id]) >= max_calls:
|
||||
raise RateLimitError(
|
||||
f"Rate limit exceeded: {max_calls} calls per {period} seconds",
|
||||
details={'caller_id': caller_id, 'limit': max_calls, 'period': period}
|
||||
)
|
||||
|
||||
calls[caller_id].append(now)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def require_admin(func: Callable) -> Callable:
|
||||
"""需要管理员权限装饰器"""
|
||||
@functools.wraps(func)
|
||||
async def wrapper(self, update, context, *args, **kwargs):
|
||||
user_id = update.effective_user.id
|
||||
if user_id != self.config.telegram.admin_id:
|
||||
if not (hasattr(self, 'is_admin') and await self.is_admin(user_id)):
|
||||
raise AuthorizationError(
|
||||
"Admin privileges required",
|
||||
details={'user_id': user_id}
|
||||
)
|
||||
return await func(self, update, context, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def log_action(action_type: str = None):
|
||||
"""记录操作日志装饰器"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
action = action_type or func.__name__
|
||||
|
||||
# 提取用户信息
|
||||
user_info = {}
|
||||
if len(args) >= 2 and hasattr(args[1], 'effective_user'):
|
||||
user = args[1].effective_user
|
||||
user_info = {
|
||||
'user_id': user.id,
|
||||
'username': user.username,
|
||||
'name': user.first_name
|
||||
}
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
duration = time.time() - start_time
|
||||
|
||||
# 创建额外信息,避免覆盖保留字段
|
||||
extra_info = {
|
||||
'action': action,
|
||||
'duration': duration,
|
||||
'status': 'success'
|
||||
}
|
||||
# 添加用户信息,使用前缀避免冲突
|
||||
for k, v in user_info.items():
|
||||
extra_info[f'user_{k}' if k in ['name'] else k] = v
|
||||
|
||||
logger.info(
|
||||
f"Action completed: {action}",
|
||||
extra=extra_info
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
duration = time.time() - start_time
|
||||
|
||||
# 创建额外信息,避免覆盖保留字段
|
||||
extra_info = {
|
||||
'action': action,
|
||||
'duration': duration,
|
||||
'status': 'failed',
|
||||
'error': str(e)
|
||||
}
|
||||
# 添加用户信息,使用前缀避免冲突
|
||||
for k, v in user_info.items():
|
||||
extra_info[f'user_{k}' if k in ['name'] else k] = v
|
||||
|
||||
logger.error(
|
||||
f"Action failed: {action}",
|
||||
extra=extra_info
|
||||
)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def validate_input(**validators):
|
||||
"""输入验证装饰器"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# 合并位置参数和关键字参数
|
||||
bound_args = func.__code__.co_varnames[:func.__code__.co_argcount]
|
||||
all_args = dict(zip(bound_args, args))
|
||||
all_args.update(kwargs)
|
||||
|
||||
# 执行验证
|
||||
for param_name, validator in validators.items():
|
||||
if param_name in all_args:
|
||||
value = all_args[param_name]
|
||||
if not validator(value):
|
||||
raise ValidationError(
|
||||
f"Validation failed for parameter: {param_name}",
|
||||
details={'parameter': param_name, 'value': value}
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def cache_result(ttl: int = 300):
|
||||
"""结果缓存装饰器"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
cache: Dict[str, tuple[Any, datetime]] = {}
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# 创建缓存键
|
||||
cache_key = f"{func.__name__}:{str(args)}:{str(kwargs)}"
|
||||
|
||||
# 检查缓存
|
||||
if cache_key in cache:
|
||||
result, timestamp = cache[cache_key]
|
||||
if datetime.now() - timestamp < timedelta(seconds=ttl):
|
||||
logger.debug(f"Cache hit for {func.__name__}")
|
||||
return result
|
||||
|
||||
# 执行函数
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
# 存储结果
|
||||
cache[cache_key] = (result, datetime.now())
|
||||
logger.debug(f"Cache miss for {func.__name__}, cached for {ttl}s")
|
||||
|
||||
return result
|
||||
|
||||
# 添加清除缓存方法
|
||||
wrapper.clear_cache = lambda: cache.clear()
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def measure_performance(func: Callable) -> Callable:
|
||||
"""性能测量装饰器"""
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
start_time = time.perf_counter()
|
||||
start_memory = 0 # 可以添加内存测量
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
if duration > 1.0: # 超过1秒的操作记录警告
|
||||
logger.warning(
|
||||
f"Slow operation detected: {func.__name__} took {duration:.2f}s",
|
||||
extra={
|
||||
'function': func.__name__,
|
||||
'duration': duration
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.debug(f"{func.__name__} completed in {duration:.4f}s")
|
||||
|
||||
return wrapper
|
||||
122
src/utils/exceptions.py
Normal file
122
src/utils/exceptions.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""自定义异常类"""
|
||||
from typing import Optional, Any, Dict
|
||||
|
||||
|
||||
class BotException(Exception):
|
||||
"""机器人基础异常"""
|
||||
|
||||
def __init__(self, message: str, code: str = None, details: Dict[str, Any] = None):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.code = code or self.__class__.__name__
|
||||
self.details = details or {}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
'error': self.code,
|
||||
'message': self.message,
|
||||
'details': self.details
|
||||
}
|
||||
|
||||
|
||||
class ConfigurationError(BotException):
|
||||
"""配置错误"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseError(BotException):
|
||||
"""数据库错误"""
|
||||
pass
|
||||
|
||||
|
||||
class TelegramError(BotException):
|
||||
"""Telegram API 错误"""
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationError(BotException):
|
||||
"""认证错误"""
|
||||
pass
|
||||
|
||||
|
||||
class AuthorizationError(BotException):
|
||||
"""授权错误"""
|
||||
pass
|
||||
|
||||
|
||||
class ValidationError(BotException):
|
||||
"""验证错误"""
|
||||
pass
|
||||
|
||||
|
||||
class RateLimitError(BotException):
|
||||
"""速率限制错误"""
|
||||
pass
|
||||
|
||||
|
||||
class SessionError(BotException):
|
||||
"""会话错误"""
|
||||
pass
|
||||
|
||||
|
||||
class MessageRoutingError(BotException):
|
||||
"""消息路由错误"""
|
||||
pass
|
||||
|
||||
|
||||
class BusinessLogicError(BotException):
|
||||
"""业务逻辑错误"""
|
||||
pass
|
||||
|
||||
|
||||
class ExternalServiceError(BotException):
|
||||
"""外部服务错误"""
|
||||
pass
|
||||
|
||||
|
||||
class ErrorHandler:
|
||||
"""错误处理器"""
|
||||
|
||||
@staticmethod
|
||||
async def handle_error(error: Exception, context: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""处理错误"""
|
||||
from ..utils.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
error_info = {
|
||||
'type': type(error).__name__,
|
||||
'message': str(error),
|
||||
'context': context or {}
|
||||
}
|
||||
|
||||
if isinstance(error, BotException):
|
||||
# 自定义异常
|
||||
error_info.update(error.to_dict())
|
||||
logger.error(f"Bot error: {error.message}", extra={'error_details': error_info})
|
||||
else:
|
||||
# 未知异常
|
||||
logger.exception(f"Unexpected error: {error}", extra={'error_details': error_info})
|
||||
error_info['message'] = "An unexpected error occurred"
|
||||
|
||||
return error_info
|
||||
|
||||
@staticmethod
|
||||
def create_user_message(error: Exception) -> str:
|
||||
"""创建用户友好的错误消息"""
|
||||
if isinstance(error, AuthenticationError):
|
||||
return "❌ 认证失败,请重新登录"
|
||||
elif isinstance(error, AuthorizationError):
|
||||
return "❌ 您没有权限执行此操作"
|
||||
elif isinstance(error, ValidationError):
|
||||
return f"❌ 输入无效:{error.message}"
|
||||
elif isinstance(error, RateLimitError):
|
||||
return "⚠️ 操作太频繁,请稍后再试"
|
||||
elif isinstance(error, SessionError):
|
||||
return "❌ 会话已过期,请重新开始"
|
||||
elif isinstance(error, BusinessLogicError):
|
||||
return f"❌ 操作失败:{error.message}"
|
||||
elif isinstance(error, ExternalServiceError):
|
||||
return "❌ 外部服务暂时不可用,请稍后再试"
|
||||
else:
|
||||
return "❌ 系统错误,请稍后再试或联系管理员"
|
||||
167
src/utils/logger.py
Normal file
167
src/utils/logger.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""日志系统"""
|
||||
import logging
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
"""自定义日志格式化器"""
|
||||
|
||||
# 颜色代码
|
||||
COLORS = {
|
||||
'DEBUG': '\033[36m', # 青色
|
||||
'INFO': '\033[32m', # 绿色
|
||||
'WARNING': '\033[33m', # 黄色
|
||||
'ERROR': '\033[31m', # 红色
|
||||
'CRITICAL': '\033[35m', # 紫色
|
||||
}
|
||||
RESET = '\033[0m'
|
||||
|
||||
def __init__(self, use_color: bool = True):
|
||||
super().__init__()
|
||||
self.use_color = use_color and sys.stderr.isatty()
|
||||
|
||||
def format(self, record):
|
||||
# 基础格式
|
||||
log_format = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
|
||||
|
||||
# 添加额外信息
|
||||
if hasattr(record, 'user_id'):
|
||||
log_format = f"%(asctime)s | %(levelname)-8s | %(name)s | User:{record.user_id} | %(message)s"
|
||||
|
||||
if hasattr(record, 'chat_id'):
|
||||
log_format = f"%(asctime)s | %(levelname)-8s | %(name)s | Chat:{record.chat_id} | %(message)s"
|
||||
|
||||
# 应用颜色
|
||||
if self.use_color:
|
||||
levelname = record.levelname
|
||||
if levelname in self.COLORS:
|
||||
log_format = log_format.replace(
|
||||
'%(levelname)-8s',
|
||||
f"{self.COLORS[levelname]}%(levelname)-8s{self.RESET}"
|
||||
)
|
||||
|
||||
formatter = logging.Formatter(log_format, datefmt='%Y-%m-%d %H:%M:%S')
|
||||
return formatter.format(record)
|
||||
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
"""JSON 格式化器用于结构化日志"""
|
||||
|
||||
def format(self, record):
|
||||
log_data = {
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
'level': record.levelname,
|
||||
'logger': record.name,
|
||||
'message': record.getMessage(),
|
||||
'module': record.module,
|
||||
'function': record.funcName,
|
||||
'line': record.lineno
|
||||
}
|
||||
|
||||
# 添加额外字段
|
||||
for key, value in record.__dict__.items():
|
||||
if key not in ['name', 'msg', 'args', 'created', 'filename',
|
||||
'funcName', 'levelname', 'levelno', 'lineno',
|
||||
'module', 'msecs', 'message', 'pathname', 'process',
|
||||
'processName', 'relativeCreated', 'thread', 'threadName']:
|
||||
log_data[key] = value
|
||||
|
||||
# 添加异常信息
|
||||
if record.exc_info:
|
||||
log_data['exception'] = self.formatException(record.exc_info)
|
||||
|
||||
return json.dumps(log_data, ensure_ascii=False)
|
||||
|
||||
|
||||
class Logger:
|
||||
"""日志管理器"""
|
||||
|
||||
_instance = None
|
||||
_loggers = {}
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if not cls._instance:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config=None):
|
||||
if not hasattr(self, '_initialized'):
|
||||
self._initialized = True
|
||||
self.config = config
|
||||
self.setup_logging()
|
||||
|
||||
def setup_logging(self):
|
||||
"""设置日志系统"""
|
||||
# 根日志配置
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(logging.DEBUG)
|
||||
|
||||
# 移除默认处理器
|
||||
root_logger.handlers = []
|
||||
|
||||
# 控制台处理器
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(
|
||||
getattr(logging, self.config.logging.level if self.config else 'INFO')
|
||||
)
|
||||
console_handler.setFormatter(CustomFormatter(use_color=True))
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# 文件处理器
|
||||
if self.config and self.config.logging.file:
|
||||
file_path = Path(self.config.logging.file)
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_handler = RotatingFileHandler(
|
||||
filename=str(file_path),
|
||||
maxBytes=self.config.logging.max_size,
|
||||
backupCount=self.config.logging.backup_count,
|
||||
encoding='utf-8'
|
||||
)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_handler.setFormatter(CustomFormatter(use_color=False))
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
# JSON 日志文件(用于分析)
|
||||
json_file_path = file_path.with_suffix('.json')
|
||||
json_handler = RotatingFileHandler(
|
||||
filename=str(json_file_path),
|
||||
maxBytes=self.config.logging.max_size,
|
||||
backupCount=self.config.logging.backup_count,
|
||||
encoding='utf-8'
|
||||
)
|
||||
json_handler.setLevel(logging.INFO)
|
||||
json_handler.setFormatter(JsonFormatter())
|
||||
root_logger.addHandler(json_handler)
|
||||
|
||||
@classmethod
|
||||
def get_logger(cls, name: str, config=None) -> logging.Logger:
|
||||
"""获取日志器"""
|
||||
if name not in cls._loggers:
|
||||
if not cls._instance:
|
||||
cls(config)
|
||||
cls._loggers[name] = logging.getLogger(name)
|
||||
return cls._loggers[name]
|
||||
|
||||
|
||||
def get_logger(name: str, config=None) -> logging.Logger:
|
||||
"""获取日志器的便捷方法"""
|
||||
return Logger.get_logger(name, config)
|
||||
|
||||
|
||||
class LoggerContextFilter(logging.Filter):
|
||||
"""日志上下文过滤器"""
|
||||
|
||||
def __init__(self, **context):
|
||||
super().__init__()
|
||||
self.context = context
|
||||
|
||||
def filter(self, record):
|
||||
for key, value in self.context.items():
|
||||
setattr(record, key, value)
|
||||
return True
|
||||
Reference in New Issue
Block a user