Files
funstat-mcp/core/server.py

549 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Funstat BOT MCP ServerPostgreSQL + 远程同步版)
"""
import asyncio
import logging
import os
import time
from collections import deque
from typing import Any, Dict, List, Optional
from mcp.server import Server
from mcp.types import Tool, TextContent
from telethon import TelegramClient
from telethon.tl.types import Message
from config import get_settings
from parsers import extract_entities
from storage import StorageManager
from models import BotResponse, PageRecord
from uploader import RemoteUploader
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger("funstat_mcp")
settings = get_settings()
# 速率限制配置
RATE_LIMIT_PER_SECOND = 18
RATE_LIMIT_WINDOW = 1.0
CACHE_TTL = 3600
# 代理配置(可选)
PROXY_TYPE = os.getenv("FUNSTAT_PROXY_TYPE", "socks5")
PROXY_HOST = os.getenv("FUNSTAT_PROXY_HOST")
PROXY_PORT = os.getenv("FUNSTAT_PROXY_PORT")
PROXY_USERNAME = os.getenv("FUNSTAT_PROXY_USERNAME")
PROXY_PASSWORD = os.getenv("FUNSTAT_PROXY_PASSWORD")
class RateLimiter:
def __init__(self, max_requests: int, time_window: float):
self.max_requests = max_requests
self.time_window = time_window
self.requests = deque()
async def acquire(self):
now = time.time()
while self.requests and self.requests[0] < now - self.time_window:
self.requests.popleft()
if len(self.requests) >= self.max_requests:
sleep_time = self.requests[0] + self.time_window - now
if sleep_time > 0:
logger.info("速率限制: 等待 %.2f", sleep_time)
await asyncio.sleep(sleep_time)
return await self.acquire()
self.requests.append(now)
class ResponseCache:
def __init__(self, ttl: int = CACHE_TTL):
self.cache: Dict[str, tuple[Any, float]] = {}
self.ttl = ttl
def get(self, key: str) -> Optional[Any]:
if key in self.cache:
value, timestamp = self.cache[key]
if time.time() - timestamp < self.ttl:
logger.info("缓存命中: %s", key)
return value
del self.cache[key]
return None
def set(self, key: str, value: Any):
self.cache[key] = (value, time.time())
logger.info("缓存保存: %s", key)
def clear_expired(self):
now = time.time()
expired_keys = [
key for key, (_, timestamp) in self.cache.items()
if now - timestamp >= self.ttl
]
for key in expired_keys:
del self.cache[key]
if expired_keys:
logger.info("清理了 %s 个过期缓存", len(expired_keys))
class FunstatMCPServer:
def __init__(self):
self.server = Server("funstat-mcp")
self.client: Optional[TelegramClient] = None
self.bot_entity = None
self.account_display: Optional[str] = None
self.rate_limiter = RateLimiter(RATE_LIMIT_PER_SECOND, RATE_LIMIT_WINDOW)
self.cache = ResponseCache()
self.storage = StorageManager(settings.database_url)
self.uploader = RemoteUploader(self.storage, settings) if settings.remote_upload_enabled else None
self.server.list_tools()(self.list_tools)
self.server.call_tool()(self.call_tool)
async def initialize(self):
logger.info("初始化 Telegram 客户端...")
session_base = os.path.expanduser(settings.telegram_session_path)
session_file = f"{session_base}.session"
if not os.path.exists(session_file):
raise FileNotFoundError(
f"Session 文件不存在: {session_file}\n"
"请先运行 create_session_safe.py 使用自己的 Telegram 账号创建 session 文件"
)
proxy = None
if PROXY_HOST and PROXY_PORT:
try:
proxy_port = int(PROXY_PORT)
except ValueError:
logger.warning("代理端口无效,忽略代理配置: %s", PROXY_PORT)
else:
if PROXY_USERNAME:
proxy = (
PROXY_TYPE,
PROXY_HOST,
proxy_port,
PROXY_USERNAME,
PROXY_PASSWORD or ""
)
else:
proxy = (PROXY_TYPE, PROXY_HOST, proxy_port)
logger.info("使用代理连接: %s://%s:%s", PROXY_TYPE, PROXY_HOST, proxy_port)
self.client = TelegramClient(
session_base,
settings.telegram_api_id,
settings.telegram_api_hash,
proxy=proxy
)
await self.client.start()
logger.info("连接到 %s ...", settings.telegram_bot_username)
self.bot_entity = await self.client.get_entity(settings.telegram_bot_username)
logger.info("✅ 已连接到: %s", getattr(self.bot_entity, "first_name", settings.telegram_bot_username))
me = await self.client.get_me()
self.account_display = f"@{me.username}" if me.username else f"ID:{me.id}"
logger.info("✅ 当前账号: %s (ID: %s)", self.account_display, me.id)
async def send_command_and_wait(
self,
command: str,
timeout: int = 12,
use_cache: bool = True,
paginate: bool = False,
max_pages: Optional[int] = None
) -> BotResponse:
paginate = paginate and settings.enable_pagination
effective_max_pages = max(1, settings.pagination_max_pages)
if max_pages:
effective_max_pages = min(effective_max_pages, max_pages)
cache_key = f"cmd:{command}:paginate={int(paginate)}:max={effective_max_pages}"
if use_cache:
cached = self.cache.get(cache_key)
if cached:
return cached
if not self.client:
raise RuntimeError("Telegram 客户端尚未初始化")
await self.rate_limiter.acquire()
logger.info("📤 发送命令: %s", command)
last_message_id = await self._get_latest_message_id()
await self.client.send_message(self.bot_entity, command)
await asyncio.sleep(1.5)
start_time = time.time()
last_seen_text = None
while time.time() - start_time < timeout:
async for message in self.client.iter_messages(self.bot_entity, limit=5):
if message.out or not message.text:
continue
is_new = message.id > last_message_id
is_updated = message.id == last_message_id and message.text != last_seen_text
if is_new or is_updated:
last_seen_text = message.text
logger.info("✅ 收到响应 (%s 字符)", len(message.text))
if paginate and message.reply_markup:
pages = await self._collect_paginated_pages(message, effective_max_pages)
response_text = self._format_pages(pages)
else:
pages = [
PageRecord(
page_number=1,
text=message.text,
entities=extract_entities(message.text)
)
]
response_text = message.text
result = BotResponse(text=response_text, pages=pages)
if use_cache:
self.cache.set(cache_key, result)
return result
await asyncio.sleep(0.5)
raise TimeoutError(f"等待 BOT 响应超时 ({timeout}秒)")
async def _get_latest_message_id(self) -> int:
assert self.client is not None
async for message in self.client.iter_messages(self.bot_entity, limit=1):
return message.id
return 0
async def _collect_paginated_pages(self, initial_message: Message, max_pages: int) -> List[PageRecord]:
pages: List[PageRecord] = [
PageRecord(
page_number=1,
text=initial_message.text or "",
entities=extract_entities(initial_message.text or "")
)
]
current_message = initial_message
current_text = initial_message.text or ""
for next_page in range(2, max_pages + 1):
button_info = self._find_next_page_button(current_message)
if not button_info:
break
logger.info("➡️ 点击翻页按钮: %s", button_info["text"])
await self.rate_limiter.acquire()
try:
await current_message.click(button_info["index"])
except Exception as exc:
logger.warning("翻页按钮点击失败: %s", exc)
break
await asyncio.sleep(settings.pagination_delay)
new_message = await self._wait_for_updated_message(current_text)
if not new_message or not new_message.text:
break
current_message = new_message
current_text = new_message.text
pages.append(
PageRecord(
page_number=next_page,
text=current_text,
entities=extract_entities(current_text)
)
)
logger.info("📚 自动翻页完成,共获取 %s", len(pages))
return pages
async def _wait_for_updated_message(self, previous_text: str, timeout: float = None) -> Optional[Message]:
assert self.client is not None
timeout = timeout or settings.pagination_timeout
start_time = time.time()
while time.time() - start_time < timeout:
messages = await self.client.get_messages(self.bot_entity, limit=1)
if not messages:
await asyncio.sleep(0.4)
continue
candidate = messages[0]
if candidate.out or not candidate.text:
await asyncio.sleep(0.4)
continue
if candidate.text.strip() != (previous_text or "").strip():
return candidate
await asyncio.sleep(0.4)
return None
def _find_next_page_button(self, message: Message) -> Optional[Dict[str, Any]]:
if not message.reply_markup or not hasattr(message.reply_markup, "rows"):
return None
button_index = 0
for row in message.reply_markup.rows:
for button in row.buttons:
text = getattr(button, "text", "")
if text and any(keyword in text for keyword in settings.pagination_keywords):
return {"index": button_index, "text": text}
button_index += 1
return None
@staticmethod
def _format_pages(pages: List[PageRecord]) -> str:
if not pages:
return ""
formatted = [
f"【第{page.page_number}页】\n{page.text}"
for page in pages
]
return "\n\n".join(formatted)
async def list_tools(self) -> List[Tool]:
return [
Tool(
name="funstat_search",
description="搜索 Telegram 群组、频道,支持自动翻页",
inputSchema={
"type": "object",
"properties": {
"query": {"type": "string", "description": "搜索关键词"},
"paginate": {
"type": "boolean",
"description": "是否自动翻页(默认开启)"
},
"max_pages": {
"type": "integer",
"minimum": 1,
"maximum": settings.pagination_max_pages,
"description": "最大翻页数量"
}
},
"required": ["query"]
}
),
Tool(
name="funstat_topchat",
description="获取热门群组/频道列表,按成员数或活跃度排序",
inputSchema={
"type": "object",
"properties": {
"category": {"type": "string", "description": "分类筛选(可选)"}
}
}
),
Tool(
name="funstat_text",
description="通过消息文本搜索,查找包含特定文本的消息和来源群组",
inputSchema={
"type": "object",
"properties": {
"text": {"type": "string", "description": "搜索文本"}
},
"required": ["text"]
}
),
Tool(
name="funstat_human",
description="通过姓名搜索用户,查找 Telegram 用户信息",
inputSchema={
"type": "object",
"properties": {
"name": {"type": "string", "description": "姓名或关键词"}
},
"required": ["name"]
}
),
Tool(
name="funstat_user_info",
description="查询用户详细信息支持用户名、用户ID或手机号",
inputSchema={
"type": "object",
"properties": {
"identifier": {"type": "string", "description": "用户标识"}
},
"required": ["identifier"]
}
),
Tool(name="funstat_balance", description="查询当前账号的积分余额和使用统计"),
Tool(name="funstat_menu", description="显示 funstat BOT 的主菜单"),
Tool(name="funstat_start", description="获取 funstat BOT 的欢迎信息"),
]
async def call_tool(self, name: str, arguments: Dict[str, Any]) -> List[TextContent]:
logger.info("🔧 调用工具: %s 参数=%s", name, arguments)
try:
if name == "funstat_search":
query = arguments["query"]
paginate_flag = arguments.get("paginate", True)
max_pages = arguments.get("max_pages")
bot_command = f"/search {query}"
response = await self.send_command_and_wait(
bot_command,
paginate=paginate_flag,
max_pages=max_pages
)
await self._persist_response(
name,
bot_command,
{"query": query, "paginate": paginate_flag, "max_pages": max_pages},
response
)
return [TextContent(type="text", text=response.text)]
if name == "funstat_topchat":
category = arguments.get("category", "").strip()
bot_command = f"/topchat {category}".strip()
response = await self.send_command_and_wait(bot_command)
await self._persist_response(
name,
bot_command,
{"category": category} if category else {},
response
)
return [TextContent(type="text", text=response.text)]
if name == "funstat_text":
text = arguments["text"]
bot_command = f"/text {text}"
response = await self.send_command_and_wait(bot_command)
await self._persist_response(name, bot_command, {"text": text}, response)
return [TextContent(type="text", text=response.text)]
if name == "funstat_human":
name_query = arguments["name"]
bot_command = f"/human {name_query}"
response = await self.send_command_and_wait(bot_command)
await self._persist_response(name, bot_command, {"name": name_query}, response)
return [TextContent(type="text", text=response.text)]
if name == "funstat_user_info":
identifier = arguments["identifier"].strip()
if not identifier:
raise ValueError("用户标识不能为空")
bot_command = f"/user_info {identifier}"
response = await self.send_command_and_wait(bot_command, use_cache=False)
await self._persist_response(name, bot_command, {"identifier": identifier}, response)
return [TextContent(type="text", text=response.text)]
if name == "funstat_balance":
bot_command = "/balance"
response = await self.send_command_and_wait(bot_command, use_cache=False)
await self._persist_response(name, bot_command, {}, response)
return [TextContent(type="text", text=response.text)]
if name == "funstat_menu":
bot_command = "/menu"
response = await self.send_command_and_wait(bot_command)
await self._persist_response(name, bot_command, {}, response)
return [TextContent(type="text", text=response.text)]
if name == "funstat_start":
bot_command = "/start"
response = await self.send_command_and_wait(bot_command)
await self._persist_response(name, bot_command, {}, response)
return [TextContent(type="text", text=response.text)]
raise ValueError(f"未知工具: {name}")
except Exception as exc:
logger.error("❌ 工具调用失败: %s", exc, exc_info=exc)
return [TextContent(type="text", text=f"❌ 错误: {exc}")]
async def _persist_response(
self,
mcp_command: str,
bot_command: str,
arguments: Dict[str, Any],
response: BotResponse
):
if not response.pages:
return
inserted = await self.storage.save_response(
command=mcp_command,
bot_command=bot_command,
arguments=arguments,
pages=response.pages,
source_account=self.account_display
)
if inserted:
logger.info("💾 已写入 %s 条记录到本地数据库", len(inserted))
async def run(self):
await self.storage.initialize()
if self.uploader:
await self.uploader.start()
await self.initialize()
async def cache_cleanup_task():
while True:
await asyncio.sleep(300)
self.cache.clear_expired()
asyncio.create_task(cache_cleanup_task())
logger.info("🚀 Funstat MCP Server 已启动")
from mcp.server.streamable_http import StreamableHTTPServerTransport
from starlette.applications import Starlette
import uvicorn
import uuid
require_session = os.getenv("FUNSTAT_REQUIRE_SESSION", "false").lower() in ("1", "true", "yes")
session_id = str(uuid.uuid4()) if require_session else None
transport = StreamableHTTPServerTransport(
mcp_session_id=session_id,
is_json_response_enabled=True,
)
async def run_mcp_server():
async with transport.connect() as streams:
await self.server.run(
streams[0],
streams[1],
self.server.create_initialization_options(),
)
asyncio.create_task(run_mcp_server())
app = Starlette()
app.mount("/", transport.handle_request)
logger.info("🌐 启动 SSE 服务器: http://%s:%s", settings.host, settings.port)
logger.info("📡 SSE 端点: http://%s:%s/sse", settings.host, settings.port)
logger.info("📨 消息端点: http://%s:%s/messages", settings.host, settings.port)
config = uvicorn.Config(
app,
host=settings.host,
port=settings.port,
log_level="info"
)
server_instance = uvicorn.Server(config)
await server_instance.serve()
async def main():
server = FunstatMCPServer()
try:
await server.run()
finally:
await server.storage.close()
if __name__ == "__main__":
asyncio.run(main())