Initial commit: FunStat MCP Server Go implementation
- Telegram integration for customer statistics - MCP server implementation with rate limiting - Cache system for performance optimization - Multi-language support - RESTful API endpoints 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
153
internal/app/app.go
Normal file
153
internal/app/app.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
tgclient "funstatmcp/internal/telegram"
|
||||
)
|
||||
|
||||
type App struct {
|
||||
cfg Config
|
||||
client *tgclient.Client
|
||||
}
|
||||
|
||||
func New(cfg Config) (*App, error) {
|
||||
client, err := tgclient.New(cfg.Telegram)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &App{
|
||||
cfg: cfg,
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *App) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *App) SendCommand(ctx context.Context, command string, useCache bool) (string, error) {
|
||||
return a.client.SendCommand(ctx, command, useCache)
|
||||
}
|
||||
|
||||
func (a *App) CallTool(ctx context.Context, name string, args map[string]any) (string, error) {
|
||||
switch name {
|
||||
case "funstat_search":
|
||||
query, err := requireString(args, "query")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return a.client.SendCommand(ctx, fmt.Sprintf("/search %s", query), true)
|
||||
|
||||
case "funstat_topchat":
|
||||
category := optionalString(args, "category")
|
||||
if category != "" {
|
||||
return a.client.SendCommand(ctx, fmt.Sprintf("/topchat %s", category), true)
|
||||
}
|
||||
return a.client.SendCommand(ctx, "/topchat", true)
|
||||
|
||||
case "funstat_text":
|
||||
text, err := requireString(args, "text")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return a.client.SendCommand(ctx, fmt.Sprintf("/text %s", text), true)
|
||||
|
||||
case "funstat_human":
|
||||
nameArg, err := requireString(args, "name")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return a.client.SendCommand(ctx, fmt.Sprintf("/human %s", nameArg), true)
|
||||
|
||||
case "funstat_user_info":
|
||||
identifier, err := requireString(args, "identifier")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
identifier = strings.TrimSpace(identifier)
|
||||
if identifier == "" {
|
||||
return "", fmt.Errorf("identifier cannot be empty")
|
||||
}
|
||||
return a.client.SendCommand(ctx, fmt.Sprintf("/user_info %s", identifier), true)
|
||||
|
||||
case "funstat_user_messages":
|
||||
identifier, err := requireString(args, "identifier")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var maxPagesPtr *int
|
||||
if value, ok := args["max_pages"]; ok {
|
||||
v, err := toInt(value)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("max_pages must be an integer: %w", err)
|
||||
}
|
||||
maxPagesPtr = &v
|
||||
}
|
||||
return a.client.FetchUserMessages(ctx, identifier, maxPagesPtr)
|
||||
|
||||
case "funstat_balance":
|
||||
return a.client.SendCommand(ctx, "/balance", true)
|
||||
|
||||
case "funstat_menu":
|
||||
return a.client.SendCommand(ctx, "/menu", true)
|
||||
|
||||
case "funstat_start":
|
||||
return a.client.SendCommand(ctx, "/start", true)
|
||||
|
||||
default:
|
||||
return "", fmt.Errorf("unknown tool: %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
func requireString(args map[string]any, key string) (string, error) {
|
||||
value, ok := args[key]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("missing required argument: %s", key)
|
||||
}
|
||||
str, ok := value.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("argument %s must be a string", key)
|
||||
}
|
||||
str = strings.TrimSpace(str)
|
||||
if str == "" {
|
||||
return "", fmt.Errorf("argument %s cannot be empty", key)
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
|
||||
func optionalString(args map[string]any, key string) string {
|
||||
if value, ok := args[key]; ok {
|
||||
if str, ok := value.(string); ok {
|
||||
return strings.TrimSpace(str)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func toInt(value any) (int, error) {
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return int(v), nil
|
||||
case float32:
|
||||
return int(v), nil
|
||||
case int:
|
||||
return v, nil
|
||||
case int32:
|
||||
return int(v), nil
|
||||
case int64:
|
||||
return int(v), nil
|
||||
case string:
|
||||
parsed, err := strconv.Atoi(strings.TrimSpace(v))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return parsed, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported type %T", value)
|
||||
}
|
||||
}
|
||||
158
internal/app/config.go
Normal file
158
internal/app/config.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tgconfig "funstatmcp/internal/telegram"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Telegram tgconfig.Config
|
||||
Host string
|
||||
Port int
|
||||
RequireSession bool
|
||||
}
|
||||
|
||||
func FromEnv() (Config, error) {
|
||||
var cfg Config
|
||||
|
||||
telegramCfg, err := loadTelegramConfig()
|
||||
if err != nil {
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
cfg.Telegram = telegramCfg
|
||||
cfg.Host = getEnvDefault("FUNSTAT_HOST", "127.0.0.1")
|
||||
|
||||
portStr := getEnvDefault("FUNSTAT_PORT", "8091")
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return cfg, fmt.Errorf("invalid FUNSTAT_PORT: %w", err)
|
||||
}
|
||||
cfg.Port = port
|
||||
|
||||
cfg.RequireSession = parseBool(getEnvDefault("FUNSTAT_REQUIRE_SESSION", "false"))
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func loadTelegramConfig() (tgconfig.Config, error) {
|
||||
var cfg tgconfig.Config
|
||||
|
||||
apiIDStr := os.Getenv("TELEGRAM_API_ID")
|
||||
if apiIDStr == "" {
|
||||
return cfg, fmt.Errorf("TELEGRAM_API_ID is required")
|
||||
}
|
||||
apiID, err := strconv.Atoi(apiIDStr)
|
||||
if err != nil {
|
||||
return cfg, fmt.Errorf("invalid TELEGRAM_API_ID: %w", err)
|
||||
}
|
||||
cfg.APIID = apiID
|
||||
|
||||
cfg.APIHash = strings.TrimSpace(os.Getenv("TELEGRAM_API_HASH"))
|
||||
if cfg.APIHash == "" {
|
||||
return cfg, fmt.Errorf("TELEGRAM_API_HASH is required")
|
||||
}
|
||||
|
||||
cfg.BotUsername = getEnvDefault("FUNSTAT_BOT_USERNAME", "@openaiw_bot")
|
||||
cfg.SessionString = strings.TrimSpace(os.Getenv("TELEGRAM_SESSION_STRING"))
|
||||
sessionStringFile := strings.TrimSpace(os.Getenv("TELEGRAM_SESSION_STRING_FILE"))
|
||||
if cfg.SessionString == "" && sessionStringFile != "" {
|
||||
data, err := os.ReadFile(expandPath(sessionStringFile))
|
||||
if err != nil {
|
||||
return cfg, fmt.Errorf("read TELEGRAM_SESSION_STRING_FILE: %w", err)
|
||||
}
|
||||
cfg.SessionString = strings.TrimSpace(string(data))
|
||||
}
|
||||
|
||||
sessionPath := strings.TrimSpace(os.Getenv("TELEGRAM_SESSION_PATH"))
|
||||
if sessionPath != "" {
|
||||
if !strings.HasSuffix(sessionPath, ".session") {
|
||||
sessionPath = sessionPath + ".session"
|
||||
}
|
||||
cfg.SessionStorage = expandPath(sessionPath)
|
||||
} else {
|
||||
cfg.SessionStorage = defaultSessionPath()
|
||||
}
|
||||
|
||||
if value := strings.TrimSpace(os.Getenv("FUNSTAT_RATE_LIMIT_PER_SECOND")); value != "" {
|
||||
parsed, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return cfg, fmt.Errorf("invalid FUNSTAT_RATE_LIMIT_PER_SECOND: %w", err)
|
||||
}
|
||||
cfg.RateLimit = parsed
|
||||
}
|
||||
|
||||
if value := strings.TrimSpace(os.Getenv("FUNSTAT_RATE_LIMIT_WINDOW")); value != "" {
|
||||
duration, err := time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return cfg, fmt.Errorf("invalid FUNSTAT_RATE_LIMIT_WINDOW: %w", err)
|
||||
}
|
||||
cfg.RateLimitWindow = duration
|
||||
}
|
||||
|
||||
if value := strings.TrimSpace(os.Getenv("FUNSTAT_CACHE_TTL")); value != "" {
|
||||
duration, err := time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return cfg, fmt.Errorf("invalid FUNSTAT_CACHE_TTL: %w", err)
|
||||
}
|
||||
cfg.CacheTTL = duration
|
||||
}
|
||||
|
||||
proxyHost := strings.TrimSpace(os.Getenv("FUNSTAT_PROXY_HOST"))
|
||||
proxyPort := strings.TrimSpace(os.Getenv("FUNSTAT_PROXY_PORT"))
|
||||
if proxyHost != "" && proxyPort != "" {
|
||||
port, err := strconv.Atoi(proxyPort)
|
||||
if err != nil {
|
||||
return cfg, fmt.Errorf("invalid FUNSTAT_PROXY_PORT: %w", err)
|
||||
}
|
||||
cfg.Proxy = &tgconfig.ProxyConfig{
|
||||
Type: getEnvDefault("FUNSTAT_PROXY_TYPE", "socks5"),
|
||||
Host: proxyHost,
|
||||
Port: port,
|
||||
Username: strings.TrimSpace(os.Getenv("FUNSTAT_PROXY_USERNAME")),
|
||||
Password: strings.TrimSpace(os.Getenv("FUNSTAT_PROXY_PASSWORD")),
|
||||
}
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func getEnvDefault(key, fallback string) string {
|
||||
if value := strings.TrimSpace(os.Getenv(key)); value != "" {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func parseBool(value string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func defaultSessionPath() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return filepath.Join(os.TempDir(), "funstatmcp", "session.json")
|
||||
}
|
||||
return filepath.Join(home, ".funstatmcp", "session.json")
|
||||
}
|
||||
|
||||
func expandPath(path string) string {
|
||||
if strings.HasPrefix(path, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err == nil {
|
||||
return filepath.Join(home, strings.TrimPrefix(path, "~"))
|
||||
}
|
||||
}
|
||||
return path
|
||||
}
|
||||
124
internal/app/tools.go
Normal file
124
internal/app/tools.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package app
|
||||
|
||||
type ToolDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
InputSchema map[string]any `json:"inputSchema"`
|
||||
}
|
||||
|
||||
func ToolDefinitions() []ToolDefinition {
|
||||
return []ToolDefinition{
|
||||
{
|
||||
Name: "funstat_search",
|
||||
Description: "搜索 Telegram 群组、频道。支持关键词搜索,返回相关的群组列表",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{
|
||||
"type": "string",
|
||||
"description": "搜索关键词,例如: 'python', '区块链', 'AI'",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "funstat_topchat",
|
||||
Description: "获取热门群组/频道列表,按成员数或活跃度排序",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"category": map[string]any{
|
||||
"type": "string",
|
||||
"description": "分类筛选(可选),例如: 'tech', 'crypto', 'news'",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "funstat_text",
|
||||
Description: "通过消息文本搜索,查找包含特定文本的消息和来源群组",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "要搜索的文本内容",
|
||||
},
|
||||
},
|
||||
"required": []string{"text"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "funstat_human",
|
||||
Description: "通过姓名搜索,查找包含特定用户的群组和消息",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"name": map[string]any{
|
||||
"type": "string",
|
||||
"description": "用户姓名",
|
||||
},
|
||||
},
|
||||
"required": []string{"name"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "funstat_user_info",
|
||||
Description: "查询用户详细信息,支持通过用户名、用户ID、联系人等方式查询",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"identifier": map[string]any{
|
||||
"type": "string",
|
||||
"description": "用户标识: 用户名(@username)、用户ID、或手机号",
|
||||
},
|
||||
},
|
||||
"required": []string{"identifier"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "funstat_user_messages",
|
||||
Description: "获取指定用户的历史消息列表,并自动翻页汇总",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"identifier": map[string]any{
|
||||
"type": "string",
|
||||
"description": "用户标识: 用户名(@username) 或用户ID",
|
||||
},
|
||||
"max_pages": map[string]any{
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"description": "可选,限制抓取的最大页数",
|
||||
},
|
||||
},
|
||||
"required": []string{"identifier"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "funstat_balance",
|
||||
Description: "查询当前账号的积分余额和使用统计",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "funstat_menu",
|
||||
Description: "显示 funstat BOT 的主菜单和所有可用功能",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "funstat_start",
|
||||
Description: "获取 funstat BOT 的欢迎信息和使用说明",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
70
internal/cache/cache.go
vendored
Normal file
70
internal/cache/cache.go
vendored
Normal file
@@ -0,0 +1,70 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type entry struct {
|
||||
value string
|
||||
expires time.Time
|
||||
}
|
||||
|
||||
type Cache struct {
|
||||
ttl time.Duration
|
||||
mu sync.RWMutex
|
||||
values map[string]entry
|
||||
}
|
||||
|
||||
func New(ttl time.Duration) *Cache {
|
||||
if ttl <= 0 {
|
||||
ttl = time.Hour
|
||||
}
|
||||
return &Cache{
|
||||
ttl: ttl,
|
||||
values: make(map[string]entry),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache) Get(key string) (string, bool) {
|
||||
c.mu.RLock()
|
||||
e, ok := c.values[key]
|
||||
c.mu.RUnlock()
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if time.Now().After(e.expires) {
|
||||
c.mu.Lock()
|
||||
delete(c.values, key)
|
||||
c.mu.Unlock()
|
||||
return "", false
|
||||
}
|
||||
return e.value, true
|
||||
}
|
||||
|
||||
func (c *Cache) Set(key string, value string) {
|
||||
c.mu.Lock()
|
||||
c.values[key] = entry{
|
||||
value: value,
|
||||
expires: time.Now().Add(c.ttl),
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *Cache) ClearExpired() int {
|
||||
now := time.Now()
|
||||
c.mu.Lock()
|
||||
removed := 0
|
||||
for k, v := range c.values {
|
||||
if now.After(v.expires) {
|
||||
delete(c.values, k)
|
||||
removed++
|
||||
}
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return removed
|
||||
}
|
||||
|
||||
func (c *Cache) TTL() time.Duration {
|
||||
return c.ttl
|
||||
}
|
||||
71
internal/ratelimit/limiter.go
Normal file
71
internal/ratelimit/limiter.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Limiter struct {
|
||||
maxRequests int
|
||||
window time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
timestamps []time.Time
|
||||
}
|
||||
|
||||
func New(maxRequests int, window time.Duration) *Limiter {
|
||||
if maxRequests <= 0 {
|
||||
maxRequests = 1
|
||||
}
|
||||
if window <= 0 {
|
||||
window = time.Second
|
||||
}
|
||||
|
||||
return &Limiter{
|
||||
maxRequests: maxRequests,
|
||||
window: window,
|
||||
timestamps: make([]time.Time, 0, maxRequests),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Limiter) Wait(ctx context.Context) error {
|
||||
for {
|
||||
l.mu.Lock()
|
||||
now := time.Now()
|
||||
|
||||
cutoff := now.Add(-l.window)
|
||||
idx := 0
|
||||
for ; idx < len(l.timestamps); idx++ {
|
||||
if l.timestamps[idx].After(cutoff) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if idx > 0 {
|
||||
l.timestamps = append([]time.Time(nil), l.timestamps[idx:]...)
|
||||
}
|
||||
|
||||
if len(l.timestamps) < l.maxRequests {
|
||||
l.timestamps = append(l.timestamps, now)
|
||||
l.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
waitUntil := l.timestamps[0].Add(l.window)
|
||||
waitDuration := time.Until(waitUntil)
|
||||
l.mu.Unlock()
|
||||
|
||||
if waitDuration <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
timer := time.NewTimer(waitDuration)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
71
internal/telegram/buttons.go
Normal file
71
internal/telegram/buttons.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/gotd/td/tg"
|
||||
)
|
||||
|
||||
func findCallbackButton(message *tg.Message, keyword string) (*tg.KeyboardButtonCallback, error) {
|
||||
markup, ok := message.ReplyMarkup.(*tg.ReplyInlineMarkup)
|
||||
if !ok || len(markup.Rows) == 0 {
|
||||
return nil, fmt.Errorf("message has no interactive buttons")
|
||||
}
|
||||
|
||||
normalizedKeyword := strings.ToLower(normalizeButtonText(keyword))
|
||||
available := make([]string, 0)
|
||||
|
||||
for _, row := range markup.Rows {
|
||||
for _, button := range row.Buttons {
|
||||
callback, ok := button.(*tg.KeyboardButtonCallback)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
text := callback.Text
|
||||
normalized := strings.ToLower(normalizeButtonText(text))
|
||||
available = append(available, normalizeButtonText(text))
|
||||
|
||||
if strings.Contains(normalized, normalizedKeyword) {
|
||||
return callback, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("button containing '%s' not found (available: %s)", keyword, strings.Join(available, ", "))
|
||||
}
|
||||
|
||||
func extractTotalPages(message *tg.Message) int {
|
||||
markup, ok := message.ReplyMarkup.(*tg.ReplyInlineMarkup)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
|
||||
for _, row := range markup.Rows {
|
||||
for _, button := range row.Buttons {
|
||||
callback, ok := button.(*tg.KeyboardButtonCallback)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.Contains(callback.Text, "⏭") {
|
||||
digits := strings.Builder{}
|
||||
for _, r := range normalizeButtonText(callback.Text) {
|
||||
if unicode.IsDigit(r) {
|
||||
digits.WriteRune(r)
|
||||
}
|
||||
}
|
||||
if digits.Len() > 0 {
|
||||
if value, err := strconv.Atoi(digits.String()); err == nil {
|
||||
return value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
497
internal/telegram/client.go
Normal file
497
internal/telegram/client.go
Normal file
@@ -0,0 +1,497 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
|
||||
"github.com/gotd/td/session"
|
||||
"github.com/gotd/td/telegram"
|
||||
"github.com/gotd/td/telegram/dcs"
|
||||
"github.com/gotd/td/telegram/message"
|
||||
"github.com/gotd/td/tg"
|
||||
|
||||
"funstatmcp/internal/cache"
|
||||
"funstatmcp/internal/ratelimit"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
cfg Config
|
||||
limiter *ratelimit.Limiter
|
||||
cache *cache.Cache
|
||||
sessionPath string
|
||||
sessionOnce sync.Once
|
||||
}
|
||||
|
||||
func New(cfg Config) (*Client, error) {
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sessionPath := cfg.SessionStorage
|
||||
if sessionPath == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get home dir: %w", err)
|
||||
}
|
||||
sessionPath = filepath.Join(home, ".funstatmcp", "session.json")
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(sessionPath), 0o700); err != nil {
|
||||
return nil, fmt.Errorf("create session directory: %w", err)
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
cfg: cfg,
|
||||
limiter: ratelimit.New(cfg.RateLimitPerSecond(), cfg.RateLimitDuration()),
|
||||
cache: cache.New(cfg.CacheDuration()),
|
||||
sessionPath: sessionPath,
|
||||
}
|
||||
|
||||
if strings.TrimSpace(cfg.SessionString) != "" {
|
||||
if err := client.writeStringSession(strings.TrimSpace(cfg.SessionString)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) writeStringSession(sessionStr string) error {
|
||||
var result error
|
||||
c.sessionOnce.Do(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
data, err := session.TelethonSession(sessionStr)
|
||||
if err != nil {
|
||||
result = fmt.Errorf("decode telethon session: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
loader := session.Loader{Storage: &session.FileStorage{Path: c.sessionPath}}
|
||||
if err := loader.Save(ctx, data); err != nil {
|
||||
result = fmt.Errorf("save session: %w", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *Client) createOptions() (telegram.Options, error) {
|
||||
opts := telegram.Options{
|
||||
SessionStorage: &session.FileStorage{Path: c.sessionPath},
|
||||
NoUpdates: true,
|
||||
}
|
||||
|
||||
proxyCfg := c.cfg.Proxy
|
||||
if proxyCfg != nil && proxyCfg.Host != "" && proxyCfg.Port > 0 {
|
||||
address := fmt.Sprintf("%s:%d", proxyCfg.Host, proxyCfg.Port)
|
||||
switch strings.ToLower(proxyCfg.Type) {
|
||||
case "", "socks5", "socks":
|
||||
var auth *proxy.Auth
|
||||
if proxyCfg.Username != "" {
|
||||
auth = &proxy.Auth{User: proxyCfg.Username, Password: proxyCfg.Password}
|
||||
}
|
||||
dialer, err := proxy.SOCKS5("tcp", address, auth, proxy.Direct)
|
||||
if err != nil {
|
||||
return opts, fmt.Errorf("create SOCKS5 proxy: %w", err)
|
||||
}
|
||||
contextDialer, ok := dialer.(proxy.ContextDialer)
|
||||
if !ok {
|
||||
contextDialer = &contextDialerAdapter{Dialer: dialer}
|
||||
}
|
||||
opts.Resolver = dcs.Plain(dcs.PlainOptions{Dial: contextDialer.DialContext})
|
||||
default:
|
||||
return opts, fmt.Errorf("unsupported proxy type %q", proxyCfg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
type contextDialerAdapter struct {
|
||||
Dialer proxy.Dialer
|
||||
}
|
||||
|
||||
func (a *contextDialerAdapter) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
type dialResult struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
result := make(chan dialResult, 1)
|
||||
go func() {
|
||||
conn, err := a.Dialer.Dial(network, addr)
|
||||
result <- dialResult{conn: conn, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case res := <-result:
|
||||
return res.conn, res.err
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) withClient(ctx context.Context, fn func(ctx context.Context, api *tg.Client, sender *message.Sender) error) error {
|
||||
opts, err := c.createOptions()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := telegram.NewClient(c.cfg.APIID, c.cfg.APIHash, opts)
|
||||
|
||||
return client.Run(ctx, func(runCtx context.Context) error {
|
||||
if err := c.ensureAuthorized(runCtx, client); err != nil {
|
||||
return err
|
||||
}
|
||||
raw := tg.NewClient(client)
|
||||
sender := message.NewSender(raw)
|
||||
return fn(runCtx, raw, sender)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) withPeer(ctx context.Context, fn func(ctx context.Context, api *tg.Client, sender *message.Sender, peer tg.InputPeerClass, botID int64) error) error {
|
||||
return c.withClient(ctx, func(runCtx context.Context, api *tg.Client, sender *message.Sender) error {
|
||||
peer, botID, err := c.resolvePeer(runCtx, sender)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fn(runCtx, api, sender, peer, botID)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) ensureAuthorized(ctx context.Context, client *telegram.Client) error {
|
||||
status, err := client.Auth().Status(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check auth status: %w", err)
|
||||
}
|
||||
if !status.Authorized {
|
||||
return errors.New("telegram session is not authorized; provide TELEGRAM_SESSION_STRING")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) resolvePeer(ctx context.Context, sender *message.Sender) (tg.InputPeerClass, int64, error) {
|
||||
builder := sender.Resolve(c.cfg.BotUsername)
|
||||
peer, err := builder.AsInputPeer(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("resolve bot peer: %w", err)
|
||||
}
|
||||
|
||||
inputUser, err := builder.AsInputUser(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("resolve bot user: %w", err)
|
||||
}
|
||||
|
||||
return peer, inputUser.UserID, nil
|
||||
}
|
||||
|
||||
func (c *Client) latestIncomingMessageID(ctx context.Context, api *tg.Client, peer tg.InputPeerClass, botID int64) (int, error) {
|
||||
resp, err := api.MessagesGetHistory(ctx, &tg.MessagesGetHistoryRequest{
|
||||
Peer: peer,
|
||||
Limit: 5,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
messages, err := extractMessages(resp)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
last := 0
|
||||
for _, msg := range messages {
|
||||
if isFromBot(msg, botID) && msg.ID > last {
|
||||
last = msg.ID
|
||||
}
|
||||
}
|
||||
return last, nil
|
||||
}
|
||||
|
||||
func (c *Client) waitForMessage(ctx context.Context, api *tg.Client, peer tg.InputPeerClass, botID int64, lastID int, timeout time.Duration) (*tg.Message, error) {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := api.MessagesGetHistory(ctx, &tg.MessagesGetHistoryRequest{
|
||||
Peer: peer,
|
||||
Limit: 5,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
messages, err := extractMessages(resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg.ID > lastID && isFromBot(msg, botID) {
|
||||
return msg, nil
|
||||
}
|
||||
}
|
||||
|
||||
if time.Now().After(deadline) {
|
||||
return nil, fmt.Errorf("timeout waiting for bot response")
|
||||
}
|
||||
|
||||
if err := sleepWithContext(ctx, 500*time.Millisecond); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) SendCommand(ctx context.Context, command string, useCache bool) (string, error) {
|
||||
cacheKey := fmt.Sprintf("cmd:%s", command)
|
||||
if useCache {
|
||||
if cached, ok := c.cache.Get(cacheKey); ok {
|
||||
return cached, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.limiter.Wait(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var response string
|
||||
err := c.withPeer(ctx, func(runCtx context.Context, api *tg.Client, sender *message.Sender, peer tg.InputPeerClass, botID int64) error {
|
||||
lastID, err := c.latestIncomingMessageID(runCtx, api, peer, botID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := sender.Resolve(c.cfg.BotUsername).Text(runCtx, command); err != nil {
|
||||
return fmt.Errorf("send command: %w", err)
|
||||
}
|
||||
|
||||
msg, err := c.waitForMessage(runCtx, api, peer, botID, lastID, 15*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response = strings.TrimSpace(msg.Message)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if useCache {
|
||||
c.cache.Set(cacheKey, response)
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (c *Client) SendCommandMessage(ctx context.Context, command string, timeout time.Duration) (*tg.Message, error) {
|
||||
if err := c.limiter.Wait(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result *tg.Message
|
||||
err := c.withPeer(ctx, func(runCtx context.Context, api *tg.Client, sender *message.Sender, peer tg.InputPeerClass, botID int64) error {
|
||||
lastID, err := c.latestIncomingMessageID(runCtx, api, peer, botID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := sender.Resolve(c.cfg.BotUsername).Text(runCtx, command); err != nil {
|
||||
return fmt.Errorf("send command: %w", err)
|
||||
}
|
||||
|
||||
msg, err := c.waitForMessage(runCtx, api, peer, botID, lastID, timeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result = msg
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *Client) PressButton(ctx context.Context, msg *tg.Message, keyword string) (*tg.Message, error) {
|
||||
if msg == nil {
|
||||
return nil, errors.New("message cannot be nil")
|
||||
}
|
||||
|
||||
if err := c.limiter.Wait(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
normalizedKeyword := strings.ToLower(keyword)
|
||||
|
||||
var updated *tg.Message
|
||||
err := c.withPeer(ctx, func(runCtx context.Context, api *tg.Client, sender *message.Sender, peer tg.InputPeerClass, botID int64) error {
|
||||
button, err := findCallbackButton(msg, normalizedKeyword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := &tg.MessagesGetBotCallbackAnswerRequest{
|
||||
Peer: peer,
|
||||
MsgID: msg.ID,
|
||||
}
|
||||
if len(button.Data) > 0 {
|
||||
req.SetData(button.Data)
|
||||
}
|
||||
|
||||
invoke := func() error {
|
||||
_, err := api.MessagesGetBotCallbackAnswer(runCtx, req)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := invoke(); err != nil {
|
||||
if wait, ok := telegram.AsFloodWait(err); ok {
|
||||
if err := sleepWithContext(runCtx, wait+time.Second); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := invoke(); err != nil {
|
||||
return fmt.Errorf("callback retry failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("press callback button: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := sleepWithContext(runCtx, 1200*time.Millisecond); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := api.MessagesGetMessages(runCtx, []tg.InputMessageClass{
|
||||
&tg.InputMessageID{ID: msg.ID},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch updated message: %w", err)
|
||||
}
|
||||
|
||||
refreshed, err := extractMessageByID(resp, msg.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
updated = refreshed
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func (c *Client) FetchUserMessages(ctx context.Context, identifier string, maxPages *int) (string, error) {
|
||||
id := strings.TrimSpace(identifier)
|
||||
if id == "" {
|
||||
return "", errors.New("identifier cannot be empty")
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(id, "/") {
|
||||
if !strings.HasPrefix(id, "@") && !isNumericIdentifier(id) {
|
||||
id = "@" + id
|
||||
}
|
||||
}
|
||||
|
||||
command := id
|
||||
if !strings.HasPrefix(command, "/") {
|
||||
command = fmt.Sprintf("/user_info %s", id)
|
||||
}
|
||||
|
||||
base, err := c.SendCommandMessage(ctx, command, 20*time.Second)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
stage, err := c.PressButton(ctx, base, "messages")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
current, err := c.PressButton(ctx, stage, "all")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
pages := make([]string, 0)
|
||||
seen := map[string]struct{}{}
|
||||
currentPage := 1
|
||||
totalPages := extractTotalPages(current)
|
||||
|
||||
var limit int
|
||||
if maxPages != nil {
|
||||
if *maxPages <= 0 {
|
||||
return "", errors.New("maxPages must be greater than zero")
|
||||
}
|
||||
limit = *maxPages
|
||||
}
|
||||
|
||||
for {
|
||||
text := strings.TrimSpace(current.Message)
|
||||
if text != "" {
|
||||
if _, ok := seen[text]; !ok {
|
||||
header := fmt.Sprintf("第 %d 页", currentPage)
|
||||
if totalPages > 0 {
|
||||
header = fmt.Sprintf("%s/%d", header, totalPages)
|
||||
}
|
||||
entry := strings.Join([]string{header, "", text}, "\n")
|
||||
pages = append(pages, entry)
|
||||
seen[text] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if limit > 0 && currentPage >= limit {
|
||||
break
|
||||
}
|
||||
|
||||
updated, err := c.PressButton(ctx, current, "➡")
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
newText := strings.TrimSpace(updated.Message)
|
||||
if _, ok := seen[newText]; ok {
|
||||
break
|
||||
}
|
||||
|
||||
current = updated
|
||||
currentPage++
|
||||
}
|
||||
|
||||
if len(pages) == 0 {
|
||||
return fmt.Sprintf("未找到 %s 的消息记录。", identifier), nil
|
||||
}
|
||||
|
||||
summary := fmt.Sprintf("共收集 %d 页消息", len(pages))
|
||||
if totalPages > 0 {
|
||||
summary = fmt.Sprintf("%s(存在 %d 页)", summary, totalPages)
|
||||
}
|
||||
|
||||
result := append([]string{summary, ""}, pages...)
|
||||
return strings.Join(result, "\n\n"), nil
|
||||
}
|
||||
|
||||
func isNumericIdentifier(value string) bool {
|
||||
for _, r := range value {
|
||||
if r != '+' && (r < '0' || r > '9') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return value != ""
|
||||
}
|
||||
60
internal/telegram/config.go
Normal file
60
internal/telegram/config.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ProxyConfig struct {
|
||||
Type string
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
APIID int
|
||||
APIHash string
|
||||
BotUsername string
|
||||
SessionString string
|
||||
SessionStorage string
|
||||
RateLimit int
|
||||
RateLimitWindow time.Duration
|
||||
CacheTTL time.Duration
|
||||
Proxy *ProxyConfig
|
||||
}
|
||||
|
||||
func (c Config) Validate() error {
|
||||
if c.APIID == 0 {
|
||||
return fmt.Errorf("APIID must be provided")
|
||||
}
|
||||
if c.APIHash == "" {
|
||||
return fmt.Errorf("APIHash must be provided")
|
||||
}
|
||||
if c.BotUsername == "" {
|
||||
return fmt.Errorf("BotUsername must be provided")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Config) RateLimitPerSecond() int {
|
||||
if c.RateLimit <= 0 {
|
||||
return 18
|
||||
}
|
||||
return c.RateLimit
|
||||
}
|
||||
|
||||
func (c Config) RateLimitDuration() time.Duration {
|
||||
if c.RateLimitWindow <= 0 {
|
||||
return time.Second
|
||||
}
|
||||
return c.RateLimitWindow
|
||||
}
|
||||
|
||||
func (c Config) CacheDuration() time.Duration {
|
||||
if c.CacheTTL <= 0 {
|
||||
return time.Hour
|
||||
}
|
||||
return c.CacheTTL
|
||||
}
|
||||
63
internal/telegram/messages.go
Normal file
63
internal/telegram/messages.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gotd/td/tg"
|
||||
)
|
||||
|
||||
func extractMessages(resp tg.MessagesMessagesClass) ([]*tg.Message, error) {
|
||||
switch v := resp.(type) {
|
||||
case *tg.MessagesMessages:
|
||||
return filterMessages(v.Messages), nil
|
||||
case *tg.MessagesMessagesSlice:
|
||||
return filterMessages(v.Messages), nil
|
||||
case *tg.MessagesChannelMessages:
|
||||
return filterMessages(v.Messages), nil
|
||||
case *tg.MessagesMessagesNotModified:
|
||||
return nil, fmt.Errorf("messages not modified")
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported response type %T", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func filterMessages(values []tg.MessageClass) []*tg.Message {
|
||||
result := make([]*tg.Message, 0, len(values))
|
||||
for _, m := range values {
|
||||
if msg, ok := m.(*tg.Message); ok {
|
||||
result = append(result, msg)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func extractMessageByID(resp tg.MessagesMessagesClass, id int) (*tg.Message, error) {
|
||||
messages, err := extractMessages(resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, msg := range messages {
|
||||
if msg.ID == id {
|
||||
return msg, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("message %d not found", id)
|
||||
}
|
||||
|
||||
func isFromBot(msg *tg.Message, botID int64) bool {
|
||||
if msg == nil || msg.Out {
|
||||
return false
|
||||
}
|
||||
|
||||
if peer, ok := msg.GetPeerID().(*tg.PeerUser); ok && peer.UserID == botID {
|
||||
return true
|
||||
}
|
||||
|
||||
if fromClass, ok := msg.GetFromID(); ok {
|
||||
if from, ok := fromClass.(*tg.PeerUser); ok && from.UserID == botID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
45
internal/telegram/normalize.go
Normal file
45
internal/telegram/normalize.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package telegram
|
||||
|
||||
import "strings"
|
||||
|
||||
var buttonTextTranslations = map[rune]rune{
|
||||
'ƒ': 'f',
|
||||
'Μ': 'M',
|
||||
'τ': 't',
|
||||
'ѕ': 's',
|
||||
'η': 'n',
|
||||
'Ғ': 'F',
|
||||
'α': 'a',
|
||||
'ο': 'o',
|
||||
'ᴜ': 'u',
|
||||
'о': 'o',
|
||||
'е': 'e',
|
||||
'с': 'c',
|
||||
'℮': 'e',
|
||||
'Τ': 'T',
|
||||
'ρ': 'p',
|
||||
'Δ': 'D',
|
||||
'χ': 'x',
|
||||
'β': 'b',
|
||||
'λ': 'l',
|
||||
'γ': 'y',
|
||||
'Ν': 'N',
|
||||
'μ': 'm',
|
||||
'ψ': 'y',
|
||||
'Α': 'A',
|
||||
'Ρ': 'P',
|
||||
'С': 'C',
|
||||
'ё': 'e',
|
||||
'ł': 'l',
|
||||
'Ł': 'L',
|
||||
'ց': 'g',
|
||||
}
|
||||
|
||||
func normalizeButtonText(text string) string {
|
||||
return strings.Map(func(r rune) rune {
|
||||
if mapped, ok := buttonTextTranslations[r]; ok {
|
||||
return mapped
|
||||
}
|
||||
return r
|
||||
}, text)
|
||||
}
|
||||
22
internal/telegram/util.go
Normal file
22
internal/telegram/util.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
func sleepWithContext(ctx context.Context, d time.Duration) error {
|
||||
if d <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
timer := time.NewTimer(d)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
277
internal/transport/server.go
Normal file
277
internal/transport/server.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"funstatmcp/internal/app"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
app *app.App
|
||||
config app.Config
|
||||
|
||||
subscribersMu sync.Mutex
|
||||
subscribers map[int]chan []byte
|
||||
nextID int
|
||||
}
|
||||
|
||||
func NewServer(appInstance *app.App, cfg app.Config) *Server {
|
||||
return &Server{
|
||||
app: appInstance,
|
||||
config: cfg,
|
||||
subscribers: make(map[int]chan []byte),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Run(ctx context.Context) error {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/sse", s.handleSSE)
|
||||
mux.HandleFunc("/messages", s.handleMessages)
|
||||
mux.HandleFunc("/health", s.handleHealth)
|
||||
|
||||
server := &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", s.config.Host, s.config.Port),
|
||||
Handler: s.corsMiddleware(mux),
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
log.Printf("HTTP server shutdown error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Printf("Funstat MCP Go server listening on http://%s:%d", s.config.Host, s.config.Port)
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) corsMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, X-MCP-Session-ID")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
|
||||
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "ok",
|
||||
"server": "funstat-mcp-go",
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleSSE(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming unsupported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
subscriber := make(chan []byte, 16)
|
||||
id := s.addSubscriber(subscriber)
|
||||
defer s.removeSubscriber(id)
|
||||
|
||||
heartbeatTicker := time.NewTicker(15 * time.Second)
|
||||
defer heartbeatTicker.Stop()
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-heartbeatTicker.C:
|
||||
fmt.Fprint(w, ": ping\n\n")
|
||||
flusher.Flush()
|
||||
case data := <-subscriber:
|
||||
fmt.Fprintf(w, "event: message\n")
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) addSubscriber(ch chan []byte) int {
|
||||
s.subscribersMu.Lock()
|
||||
defer s.subscribersMu.Unlock()
|
||||
id := s.nextID
|
||||
s.nextID++
|
||||
s.subscribers[id] = ch
|
||||
return id
|
||||
}
|
||||
|
||||
func (s *Server) removeSubscriber(id int) {
|
||||
s.subscribersMu.Lock()
|
||||
defer s.subscribersMu.Unlock()
|
||||
delete(s.subscribers, id)
|
||||
}
|
||||
|
||||
func (s *Server) broadcast(payload []byte) {
|
||||
s.subscribersMu.Lock()
|
||||
defer s.subscribersMu.Unlock()
|
||||
for _, ch := range s.subscribers {
|
||||
select {
|
||||
case ch <- payload:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type jsonRPCRequest struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID any `json:"id"`
|
||||
Method string `json:"method"`
|
||||
Params json.RawMessage `json:"params"`
|
||||
}
|
||||
|
||||
type jsonRPCResponse struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID any `json:"id"`
|
||||
Result any `json:"result,omitempty"`
|
||||
Error *jsonRPCError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type jsonRPCError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func (s *Server) handleMessages(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var request jsonRPCRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
||||
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid JSON: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
response := s.handleRequest(r.Context(), request)
|
||||
|
||||
var buffer bytes.Buffer
|
||||
if err := json.NewEncoder(&buffer).Encode(response); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
payload := bytes.TrimSpace(buffer.Bytes())
|
||||
s.broadcast(payload)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(payload)
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(ctx context.Context, req jsonRPCRequest) jsonRPCResponse {
|
||||
response := jsonRPCResponse{
|
||||
JSONRPC: "2.0",
|
||||
ID: req.ID,
|
||||
}
|
||||
|
||||
if strings.TrimSpace(req.JSONRPC) != "2.0" {
|
||||
response.Error = &jsonRPCError{Code: -32600, Message: "invalid jsonrpc version"}
|
||||
return response
|
||||
}
|
||||
|
||||
switch req.Method {
|
||||
case "initialize":
|
||||
response.Result = s.initializeResult()
|
||||
case "list_tools":
|
||||
response.Result = map[string]any{
|
||||
"tools": app.ToolDefinitions(),
|
||||
}
|
||||
case "call_tool":
|
||||
var payload struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments"`
|
||||
}
|
||||
if err := json.Unmarshal(req.Params, &payload); err != nil {
|
||||
response.Error = &jsonRPCError{Code: -32602, Message: "invalid params"}
|
||||
return response
|
||||
}
|
||||
|
||||
if payload.Arguments == nil {
|
||||
payload.Arguments = make(map[string]any)
|
||||
}
|
||||
|
||||
result, err := s.app.CallTool(ctx, payload.Name, payload.Arguments)
|
||||
if err != nil {
|
||||
response.Result = map[string]any{
|
||||
"content": []map[string]string{
|
||||
{
|
||||
"type": "text",
|
||||
"text": fmt.Sprintf("❌ 错误: %s", err.Error()),
|
||||
},
|
||||
},
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
response.Result = map[string]any{
|
||||
"content": []map[string]string{
|
||||
{
|
||||
"type": "text",
|
||||
"text": result,
|
||||
},
|
||||
},
|
||||
}
|
||||
default:
|
||||
response.Error = &jsonRPCError{Code: -32601, Message: "method not found"}
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
func (s *Server) initializeResult() map[string]any {
|
||||
return map[string]any{
|
||||
"protocolVersion": "2025-03-26",
|
||||
"capabilities": map[string]any{
|
||||
"tools": map[string]any{},
|
||||
},
|
||||
"serverInfo": map[string]any{
|
||||
"name": "funstat-mcp-go",
|
||||
"version": "1.0.0",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, status int, err error) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user