549 lines
20 KiB
Python
549 lines
20 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Funstat BOT MCP Server(PostgreSQL + 远程同步版)
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
import os
|
||
import time
|
||
from collections import deque
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from mcp.server import Server
|
||
from mcp.types import Tool, TextContent
|
||
from telethon import TelegramClient
|
||
from telethon.tl.types import Message
|
||
|
||
from config import get_settings
|
||
from parsers import extract_entities
|
||
from storage import StorageManager
|
||
from models import BotResponse, PageRecord
|
||
from uploader import RemoteUploader
|
||
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||
)
|
||
logger = logging.getLogger("funstat_mcp")
|
||
|
||
settings = get_settings()
|
||
|
||
# 速率限制配置
|
||
RATE_LIMIT_PER_SECOND = 18
|
||
RATE_LIMIT_WINDOW = 1.0
|
||
CACHE_TTL = 3600
|
||
|
||
# 代理配置(可选)
|
||
PROXY_TYPE = os.getenv("FUNSTAT_PROXY_TYPE", "socks5")
|
||
PROXY_HOST = os.getenv("FUNSTAT_PROXY_HOST")
|
||
PROXY_PORT = os.getenv("FUNSTAT_PROXY_PORT")
|
||
PROXY_USERNAME = os.getenv("FUNSTAT_PROXY_USERNAME")
|
||
PROXY_PASSWORD = os.getenv("FUNSTAT_PROXY_PASSWORD")
|
||
|
||
|
||
class RateLimiter:
|
||
def __init__(self, max_requests: int, time_window: float):
|
||
self.max_requests = max_requests
|
||
self.time_window = time_window
|
||
self.requests = deque()
|
||
|
||
async def acquire(self):
|
||
now = time.time()
|
||
while self.requests and self.requests[0] < now - self.time_window:
|
||
self.requests.popleft()
|
||
|
||
if len(self.requests) >= self.max_requests:
|
||
sleep_time = self.requests[0] + self.time_window - now
|
||
if sleep_time > 0:
|
||
logger.info("速率限制: 等待 %.2f 秒", sleep_time)
|
||
await asyncio.sleep(sleep_time)
|
||
return await self.acquire()
|
||
|
||
self.requests.append(now)
|
||
|
||
|
||
class ResponseCache:
|
||
def __init__(self, ttl: int = CACHE_TTL):
|
||
self.cache: Dict[str, tuple[Any, float]] = {}
|
||
self.ttl = ttl
|
||
|
||
def get(self, key: str) -> Optional[Any]:
|
||
if key in self.cache:
|
||
value, timestamp = self.cache[key]
|
||
if time.time() - timestamp < self.ttl:
|
||
logger.info("缓存命中: %s", key)
|
||
return value
|
||
del self.cache[key]
|
||
return None
|
||
|
||
def set(self, key: str, value: Any):
|
||
self.cache[key] = (value, time.time())
|
||
logger.info("缓存保存: %s", key)
|
||
|
||
def clear_expired(self):
|
||
now = time.time()
|
||
expired_keys = [
|
||
key for key, (_, timestamp) in self.cache.items()
|
||
if now - timestamp >= self.ttl
|
||
]
|
||
for key in expired_keys:
|
||
del self.cache[key]
|
||
if expired_keys:
|
||
logger.info("清理了 %s 个过期缓存", len(expired_keys))
|
||
|
||
|
||
class FunstatMCPServer:
|
||
def __init__(self):
|
||
self.server = Server("funstat-mcp")
|
||
self.client: Optional[TelegramClient] = None
|
||
self.bot_entity = None
|
||
self.account_display: Optional[str] = None
|
||
self.rate_limiter = RateLimiter(RATE_LIMIT_PER_SECOND, RATE_LIMIT_WINDOW)
|
||
self.cache = ResponseCache()
|
||
self.storage = StorageManager(settings.database_url)
|
||
self.uploader = RemoteUploader(self.storage, settings) if settings.remote_upload_enabled else None
|
||
|
||
self.server.list_tools()(self.list_tools)
|
||
self.server.call_tool()(self.call_tool)
|
||
|
||
async def initialize(self):
|
||
logger.info("初始化 Telegram 客户端...")
|
||
session_base = os.path.expanduser(settings.telegram_session_path)
|
||
session_file = f"{session_base}.session"
|
||
if not os.path.exists(session_file):
|
||
raise FileNotFoundError(
|
||
f"Session 文件不存在: {session_file}\n"
|
||
"请先运行 create_session_safe.py 使用自己的 Telegram 账号创建 session 文件"
|
||
)
|
||
|
||
proxy = None
|
||
if PROXY_HOST and PROXY_PORT:
|
||
try:
|
||
proxy_port = int(PROXY_PORT)
|
||
except ValueError:
|
||
logger.warning("代理端口无效,忽略代理配置: %s", PROXY_PORT)
|
||
else:
|
||
if PROXY_USERNAME:
|
||
proxy = (
|
||
PROXY_TYPE,
|
||
PROXY_HOST,
|
||
proxy_port,
|
||
PROXY_USERNAME,
|
||
PROXY_PASSWORD or ""
|
||
)
|
||
else:
|
||
proxy = (PROXY_TYPE, PROXY_HOST, proxy_port)
|
||
logger.info("使用代理连接: %s://%s:%s", PROXY_TYPE, PROXY_HOST, proxy_port)
|
||
|
||
self.client = TelegramClient(
|
||
session_base,
|
||
settings.telegram_api_id,
|
||
settings.telegram_api_hash,
|
||
proxy=proxy
|
||
)
|
||
await self.client.start()
|
||
|
||
logger.info("连接到 %s ...", settings.telegram_bot_username)
|
||
self.bot_entity = await self.client.get_entity(settings.telegram_bot_username)
|
||
logger.info("✅ 已连接到: %s", getattr(self.bot_entity, "first_name", settings.telegram_bot_username))
|
||
|
||
me = await self.client.get_me()
|
||
self.account_display = f"@{me.username}" if me.username else f"ID:{me.id}"
|
||
logger.info("✅ 当前账号: %s (ID: %s)", self.account_display, me.id)
|
||
|
||
async def send_command_and_wait(
|
||
self,
|
||
command: str,
|
||
timeout: int = 12,
|
||
use_cache: bool = True,
|
||
paginate: bool = False,
|
||
max_pages: Optional[int] = None
|
||
) -> BotResponse:
|
||
paginate = paginate and settings.enable_pagination
|
||
effective_max_pages = max(1, settings.pagination_max_pages)
|
||
if max_pages:
|
||
effective_max_pages = min(effective_max_pages, max_pages)
|
||
|
||
cache_key = f"cmd:{command}:paginate={int(paginate)}:max={effective_max_pages}"
|
||
if use_cache:
|
||
cached = self.cache.get(cache_key)
|
||
if cached:
|
||
return cached
|
||
|
||
if not self.client:
|
||
raise RuntimeError("Telegram 客户端尚未初始化")
|
||
|
||
await self.rate_limiter.acquire()
|
||
logger.info("📤 发送命令: %s", command)
|
||
|
||
last_message_id = await self._get_latest_message_id()
|
||
await self.client.send_message(self.bot_entity, command)
|
||
await asyncio.sleep(1.5)
|
||
|
||
start_time = time.time()
|
||
last_seen_text = None
|
||
while time.time() - start_time < timeout:
|
||
async for message in self.client.iter_messages(self.bot_entity, limit=5):
|
||
if message.out or not message.text:
|
||
continue
|
||
|
||
is_new = message.id > last_message_id
|
||
is_updated = message.id == last_message_id and message.text != last_seen_text
|
||
|
||
if is_new or is_updated:
|
||
last_seen_text = message.text
|
||
logger.info("✅ 收到响应 (%s 字符)", len(message.text))
|
||
|
||
if paginate and message.reply_markup:
|
||
pages = await self._collect_paginated_pages(message, effective_max_pages)
|
||
response_text = self._format_pages(pages)
|
||
else:
|
||
pages = [
|
||
PageRecord(
|
||
page_number=1,
|
||
text=message.text,
|
||
entities=extract_entities(message.text)
|
||
)
|
||
]
|
||
response_text = message.text
|
||
|
||
result = BotResponse(text=response_text, pages=pages)
|
||
if use_cache:
|
||
self.cache.set(cache_key, result)
|
||
return result
|
||
|
||
await asyncio.sleep(0.5)
|
||
|
||
raise TimeoutError(f"等待 BOT 响应超时 ({timeout}秒)")
|
||
|
||
async def _get_latest_message_id(self) -> int:
|
||
assert self.client is not None
|
||
async for message in self.client.iter_messages(self.bot_entity, limit=1):
|
||
return message.id
|
||
return 0
|
||
|
||
async def _collect_paginated_pages(self, initial_message: Message, max_pages: int) -> List[PageRecord]:
|
||
pages: List[PageRecord] = [
|
||
PageRecord(
|
||
page_number=1,
|
||
text=initial_message.text or "",
|
||
entities=extract_entities(initial_message.text or "")
|
||
)
|
||
]
|
||
current_message = initial_message
|
||
current_text = initial_message.text or ""
|
||
|
||
for next_page in range(2, max_pages + 1):
|
||
button_info = self._find_next_page_button(current_message)
|
||
if not button_info:
|
||
break
|
||
|
||
logger.info("➡️ 点击翻页按钮: %s", button_info["text"])
|
||
await self.rate_limiter.acquire()
|
||
try:
|
||
await current_message.click(button_info["index"])
|
||
except Exception as exc:
|
||
logger.warning("翻页按钮点击失败: %s", exc)
|
||
break
|
||
|
||
await asyncio.sleep(settings.pagination_delay)
|
||
new_message = await self._wait_for_updated_message(current_text)
|
||
if not new_message or not new_message.text:
|
||
break
|
||
|
||
current_message = new_message
|
||
current_text = new_message.text
|
||
pages.append(
|
||
PageRecord(
|
||
page_number=next_page,
|
||
text=current_text,
|
||
entities=extract_entities(current_text)
|
||
)
|
||
)
|
||
|
||
logger.info("📚 自动翻页完成,共获取 %s 页", len(pages))
|
||
return pages
|
||
|
||
async def _wait_for_updated_message(self, previous_text: str, timeout: float = None) -> Optional[Message]:
|
||
assert self.client is not None
|
||
timeout = timeout or settings.pagination_timeout
|
||
start_time = time.time()
|
||
|
||
while time.time() - start_time < timeout:
|
||
messages = await self.client.get_messages(self.bot_entity, limit=1)
|
||
if not messages:
|
||
await asyncio.sleep(0.4)
|
||
continue
|
||
|
||
candidate = messages[0]
|
||
if candidate.out or not candidate.text:
|
||
await asyncio.sleep(0.4)
|
||
continue
|
||
|
||
if candidate.text.strip() != (previous_text or "").strip():
|
||
return candidate
|
||
|
||
await asyncio.sleep(0.4)
|
||
|
||
return None
|
||
|
||
def _find_next_page_button(self, message: Message) -> Optional[Dict[str, Any]]:
|
||
if not message.reply_markup or not hasattr(message.reply_markup, "rows"):
|
||
return None
|
||
|
||
button_index = 0
|
||
for row in message.reply_markup.rows:
|
||
for button in row.buttons:
|
||
text = getattr(button, "text", "")
|
||
if text and any(keyword in text for keyword in settings.pagination_keywords):
|
||
return {"index": button_index, "text": text}
|
||
button_index += 1
|
||
return None
|
||
|
||
@staticmethod
|
||
def _format_pages(pages: List[PageRecord]) -> str:
|
||
if not pages:
|
||
return ""
|
||
formatted = [
|
||
f"【第{page.page_number}页】\n{page.text}"
|
||
for page in pages
|
||
]
|
||
return "\n\n".join(formatted)
|
||
|
||
async def list_tools(self) -> List[Tool]:
|
||
return [
|
||
Tool(
|
||
name="funstat_search",
|
||
description="搜索 Telegram 群组、频道,支持自动翻页",
|
||
inputSchema={
|
||
"type": "object",
|
||
"properties": {
|
||
"query": {"type": "string", "description": "搜索关键词"},
|
||
"paginate": {
|
||
"type": "boolean",
|
||
"description": "是否自动翻页(默认开启)"
|
||
},
|
||
"max_pages": {
|
||
"type": "integer",
|
||
"minimum": 1,
|
||
"maximum": settings.pagination_max_pages,
|
||
"description": "最大翻页数量"
|
||
}
|
||
},
|
||
"required": ["query"]
|
||
}
|
||
),
|
||
Tool(
|
||
name="funstat_topchat",
|
||
description="获取热门群组/频道列表,按成员数或活跃度排序",
|
||
inputSchema={
|
||
"type": "object",
|
||
"properties": {
|
||
"category": {"type": "string", "description": "分类筛选(可选)"}
|
||
}
|
||
}
|
||
),
|
||
Tool(
|
||
name="funstat_text",
|
||
description="通过消息文本搜索,查找包含特定文本的消息和来源群组",
|
||
inputSchema={
|
||
"type": "object",
|
||
"properties": {
|
||
"text": {"type": "string", "description": "搜索文本"}
|
||
},
|
||
"required": ["text"]
|
||
}
|
||
),
|
||
Tool(
|
||
name="funstat_human",
|
||
description="通过姓名搜索用户,查找 Telegram 用户信息",
|
||
inputSchema={
|
||
"type": "object",
|
||
"properties": {
|
||
"name": {"type": "string", "description": "姓名或关键词"}
|
||
},
|
||
"required": ["name"]
|
||
}
|
||
),
|
||
Tool(
|
||
name="funstat_user_info",
|
||
description="查询用户详细信息,支持用户名、用户ID或手机号",
|
||
inputSchema={
|
||
"type": "object",
|
||
"properties": {
|
||
"identifier": {"type": "string", "description": "用户标识"}
|
||
},
|
||
"required": ["identifier"]
|
||
}
|
||
),
|
||
Tool(name="funstat_balance", description="查询当前账号的积分余额和使用统计"),
|
||
Tool(name="funstat_menu", description="显示 funstat BOT 的主菜单"),
|
||
Tool(name="funstat_start", description="获取 funstat BOT 的欢迎信息"),
|
||
]
|
||
|
||
async def call_tool(self, name: str, arguments: Dict[str, Any]) -> List[TextContent]:
|
||
logger.info("🔧 调用工具: %s 参数=%s", name, arguments)
|
||
try:
|
||
if name == "funstat_search":
|
||
query = arguments["query"]
|
||
paginate_flag = arguments.get("paginate", True)
|
||
max_pages = arguments.get("max_pages")
|
||
bot_command = f"/search {query}"
|
||
response = await self.send_command_and_wait(
|
||
bot_command,
|
||
paginate=paginate_flag,
|
||
max_pages=max_pages
|
||
)
|
||
await self._persist_response(
|
||
name,
|
||
bot_command,
|
||
{"query": query, "paginate": paginate_flag, "max_pages": max_pages},
|
||
response
|
||
)
|
||
return [TextContent(type="text", text=response.text)]
|
||
|
||
if name == "funstat_topchat":
|
||
category = arguments.get("category", "").strip()
|
||
bot_command = f"/topchat {category}".strip()
|
||
response = await self.send_command_and_wait(bot_command)
|
||
await self._persist_response(
|
||
name,
|
||
bot_command,
|
||
{"category": category} if category else {},
|
||
response
|
||
)
|
||
return [TextContent(type="text", text=response.text)]
|
||
|
||
if name == "funstat_text":
|
||
text = arguments["text"]
|
||
bot_command = f"/text {text}"
|
||
response = await self.send_command_and_wait(bot_command)
|
||
await self._persist_response(name, bot_command, {"text": text}, response)
|
||
return [TextContent(type="text", text=response.text)]
|
||
|
||
if name == "funstat_human":
|
||
name_query = arguments["name"]
|
||
bot_command = f"/human {name_query}"
|
||
response = await self.send_command_and_wait(bot_command)
|
||
await self._persist_response(name, bot_command, {"name": name_query}, response)
|
||
return [TextContent(type="text", text=response.text)]
|
||
|
||
if name == "funstat_user_info":
|
||
identifier = arguments["identifier"].strip()
|
||
if not identifier:
|
||
raise ValueError("用户标识不能为空")
|
||
bot_command = f"/user_info {identifier}"
|
||
response = await self.send_command_and_wait(bot_command, use_cache=False)
|
||
await self._persist_response(name, bot_command, {"identifier": identifier}, response)
|
||
return [TextContent(type="text", text=response.text)]
|
||
|
||
if name == "funstat_balance":
|
||
bot_command = "/balance"
|
||
response = await self.send_command_and_wait(bot_command, use_cache=False)
|
||
await self._persist_response(name, bot_command, {}, response)
|
||
return [TextContent(type="text", text=response.text)]
|
||
|
||
if name == "funstat_menu":
|
||
bot_command = "/menu"
|
||
response = await self.send_command_and_wait(bot_command)
|
||
await self._persist_response(name, bot_command, {}, response)
|
||
return [TextContent(type="text", text=response.text)]
|
||
|
||
if name == "funstat_start":
|
||
bot_command = "/start"
|
||
response = await self.send_command_and_wait(bot_command)
|
||
await self._persist_response(name, bot_command, {}, response)
|
||
return [TextContent(type="text", text=response.text)]
|
||
|
||
raise ValueError(f"未知工具: {name}")
|
||
|
||
except Exception as exc:
|
||
logger.error("❌ 工具调用失败: %s", exc, exc_info=exc)
|
||
return [TextContent(type="text", text=f"❌ 错误: {exc}")]
|
||
|
||
async def _persist_response(
|
||
self,
|
||
mcp_command: str,
|
||
bot_command: str,
|
||
arguments: Dict[str, Any],
|
||
response: BotResponse
|
||
):
|
||
if not response.pages:
|
||
return
|
||
inserted = await self.storage.save_response(
|
||
command=mcp_command,
|
||
bot_command=bot_command,
|
||
arguments=arguments,
|
||
pages=response.pages,
|
||
source_account=self.account_display
|
||
)
|
||
if inserted:
|
||
logger.info("💾 已写入 %s 条记录到本地数据库", len(inserted))
|
||
|
||
async def run(self):
|
||
await self.storage.initialize()
|
||
if self.uploader:
|
||
await self.uploader.start()
|
||
|
||
await self.initialize()
|
||
|
||
async def cache_cleanup_task():
|
||
while True:
|
||
await asyncio.sleep(300)
|
||
self.cache.clear_expired()
|
||
|
||
asyncio.create_task(cache_cleanup_task())
|
||
logger.info("🚀 Funstat MCP Server 已启动")
|
||
|
||
from mcp.server.streamable_http import StreamableHTTPServerTransport
|
||
from starlette.applications import Starlette
|
||
import uvicorn
|
||
import uuid
|
||
|
||
require_session = os.getenv("FUNSTAT_REQUIRE_SESSION", "false").lower() in ("1", "true", "yes")
|
||
session_id = str(uuid.uuid4()) if require_session else None
|
||
transport = StreamableHTTPServerTransport(
|
||
mcp_session_id=session_id,
|
||
is_json_response_enabled=True,
|
||
)
|
||
|
||
async def run_mcp_server():
|
||
async with transport.connect() as streams:
|
||
await self.server.run(
|
||
streams[0],
|
||
streams[1],
|
||
self.server.create_initialization_options(),
|
||
)
|
||
|
||
asyncio.create_task(run_mcp_server())
|
||
|
||
app = Starlette()
|
||
app.mount("/", transport.handle_request)
|
||
|
||
logger.info("🌐 启动 SSE 服务器: http://%s:%s", settings.host, settings.port)
|
||
logger.info("📡 SSE 端点: http://%s:%s/sse", settings.host, settings.port)
|
||
logger.info("📨 消息端点: http://%s:%s/messages", settings.host, settings.port)
|
||
|
||
config = uvicorn.Config(
|
||
app,
|
||
host=settings.host,
|
||
port=settings.port,
|
||
log_level="info"
|
||
)
|
||
server_instance = uvicorn.Server(config)
|
||
await server_instance.serve()
|
||
|
||
|
||
async def main():
|
||
server = FunstatMCPServer()
|
||
try:
|
||
await server.run()
|
||
finally:
|
||
await server.storage.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|