Files
funstat-mcp-go/internal/telegram/client.go
你的用户名 8d1ce4598d Initial commit: FunStat MCP Server Go implementation
- Telegram integration for customer statistics
- MCP server implementation with rate limiting
- Cache system for performance optimization
- Multi-language support
- RESTful API endpoints

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-04 15:28:06 +08:00

498 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package telegram
import (
"context"
"errors"
"fmt"
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"
"golang.org/x/net/proxy"
"github.com/gotd/td/session"
"github.com/gotd/td/telegram"
"github.com/gotd/td/telegram/dcs"
"github.com/gotd/td/telegram/message"
"github.com/gotd/td/tg"
"funstatmcp/internal/cache"
"funstatmcp/internal/ratelimit"
)
type Client struct {
cfg Config
limiter *ratelimit.Limiter
cache *cache.Cache
sessionPath string
sessionOnce sync.Once
}
func New(cfg Config) (*Client, error) {
if err := cfg.Validate(); err != nil {
return nil, err
}
sessionPath := cfg.SessionStorage
if sessionPath == "" {
home, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("get home dir: %w", err)
}
sessionPath = filepath.Join(home, ".funstatmcp", "session.json")
}
if err := os.MkdirAll(filepath.Dir(sessionPath), 0o700); err != nil {
return nil, fmt.Errorf("create session directory: %w", err)
}
client := &Client{
cfg: cfg,
limiter: ratelimit.New(cfg.RateLimitPerSecond(), cfg.RateLimitDuration()),
cache: cache.New(cfg.CacheDuration()),
sessionPath: sessionPath,
}
if strings.TrimSpace(cfg.SessionString) != "" {
if err := client.writeStringSession(strings.TrimSpace(cfg.SessionString)); err != nil {
return nil, err
}
}
return client, nil
}
func (c *Client) writeStringSession(sessionStr string) error {
var result error
c.sessionOnce.Do(func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
data, err := session.TelethonSession(sessionStr)
if err != nil {
result = fmt.Errorf("decode telethon session: %w", err)
return
}
loader := session.Loader{Storage: &session.FileStorage{Path: c.sessionPath}}
if err := loader.Save(ctx, data); err != nil {
result = fmt.Errorf("save session: %w", err)
return
}
})
return result
}
func (c *Client) createOptions() (telegram.Options, error) {
opts := telegram.Options{
SessionStorage: &session.FileStorage{Path: c.sessionPath},
NoUpdates: true,
}
proxyCfg := c.cfg.Proxy
if proxyCfg != nil && proxyCfg.Host != "" && proxyCfg.Port > 0 {
address := fmt.Sprintf("%s:%d", proxyCfg.Host, proxyCfg.Port)
switch strings.ToLower(proxyCfg.Type) {
case "", "socks5", "socks":
var auth *proxy.Auth
if proxyCfg.Username != "" {
auth = &proxy.Auth{User: proxyCfg.Username, Password: proxyCfg.Password}
}
dialer, err := proxy.SOCKS5("tcp", address, auth, proxy.Direct)
if err != nil {
return opts, fmt.Errorf("create SOCKS5 proxy: %w", err)
}
contextDialer, ok := dialer.(proxy.ContextDialer)
if !ok {
contextDialer = &contextDialerAdapter{Dialer: dialer}
}
opts.Resolver = dcs.Plain(dcs.PlainOptions{Dial: contextDialer.DialContext})
default:
return opts, fmt.Errorf("unsupported proxy type %q", proxyCfg.Type)
}
}
return opts, nil
}
type contextDialerAdapter struct {
Dialer proxy.Dialer
}
func (a *contextDialerAdapter) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
type dialResult struct {
conn net.Conn
err error
}
result := make(chan dialResult, 1)
go func() {
conn, err := a.Dialer.Dial(network, addr)
result <- dialResult{conn: conn, err: err}
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case res := <-result:
return res.conn, res.err
}
}
func (c *Client) withClient(ctx context.Context, fn func(ctx context.Context, api *tg.Client, sender *message.Sender) error) error {
opts, err := c.createOptions()
if err != nil {
return err
}
client := telegram.NewClient(c.cfg.APIID, c.cfg.APIHash, opts)
return client.Run(ctx, func(runCtx context.Context) error {
if err := c.ensureAuthorized(runCtx, client); err != nil {
return err
}
raw := tg.NewClient(client)
sender := message.NewSender(raw)
return fn(runCtx, raw, sender)
})
}
func (c *Client) withPeer(ctx context.Context, fn func(ctx context.Context, api *tg.Client, sender *message.Sender, peer tg.InputPeerClass, botID int64) error) error {
return c.withClient(ctx, func(runCtx context.Context, api *tg.Client, sender *message.Sender) error {
peer, botID, err := c.resolvePeer(runCtx, sender)
if err != nil {
return err
}
return fn(runCtx, api, sender, peer, botID)
})
}
func (c *Client) ensureAuthorized(ctx context.Context, client *telegram.Client) error {
status, err := client.Auth().Status(ctx)
if err != nil {
return fmt.Errorf("check auth status: %w", err)
}
if !status.Authorized {
return errors.New("telegram session is not authorized; provide TELEGRAM_SESSION_STRING")
}
return nil
}
func (c *Client) resolvePeer(ctx context.Context, sender *message.Sender) (tg.InputPeerClass, int64, error) {
builder := sender.Resolve(c.cfg.BotUsername)
peer, err := builder.AsInputPeer(ctx)
if err != nil {
return nil, 0, fmt.Errorf("resolve bot peer: %w", err)
}
inputUser, err := builder.AsInputUser(ctx)
if err != nil {
return nil, 0, fmt.Errorf("resolve bot user: %w", err)
}
return peer, inputUser.UserID, nil
}
func (c *Client) latestIncomingMessageID(ctx context.Context, api *tg.Client, peer tg.InputPeerClass, botID int64) (int, error) {
resp, err := api.MessagesGetHistory(ctx, &tg.MessagesGetHistoryRequest{
Peer: peer,
Limit: 5,
})
if err != nil {
return 0, err
}
messages, err := extractMessages(resp)
if err != nil {
return 0, err
}
last := 0
for _, msg := range messages {
if isFromBot(msg, botID) && msg.ID > last {
last = msg.ID
}
}
return last, nil
}
func (c *Client) waitForMessage(ctx context.Context, api *tg.Client, peer tg.InputPeerClass, botID int64, lastID int, timeout time.Duration) (*tg.Message, error) {
deadline := time.Now().Add(timeout)
for {
if err := ctx.Err(); err != nil {
return nil, err
}
resp, err := api.MessagesGetHistory(ctx, &tg.MessagesGetHistoryRequest{
Peer: peer,
Limit: 5,
})
if err != nil {
return nil, err
}
messages, err := extractMessages(resp)
if err != nil {
return nil, err
}
for _, msg := range messages {
if msg.ID > lastID && isFromBot(msg, botID) {
return msg, nil
}
}
if time.Now().After(deadline) {
return nil, fmt.Errorf("timeout waiting for bot response")
}
if err := sleepWithContext(ctx, 500*time.Millisecond); err != nil {
return nil, err
}
}
}
func (c *Client) SendCommand(ctx context.Context, command string, useCache bool) (string, error) {
cacheKey := fmt.Sprintf("cmd:%s", command)
if useCache {
if cached, ok := c.cache.Get(cacheKey); ok {
return cached, nil
}
}
if err := c.limiter.Wait(ctx); err != nil {
return "", err
}
var response string
err := c.withPeer(ctx, func(runCtx context.Context, api *tg.Client, sender *message.Sender, peer tg.InputPeerClass, botID int64) error {
lastID, err := c.latestIncomingMessageID(runCtx, api, peer, botID)
if err != nil {
return err
}
if _, err := sender.Resolve(c.cfg.BotUsername).Text(runCtx, command); err != nil {
return fmt.Errorf("send command: %w", err)
}
msg, err := c.waitForMessage(runCtx, api, peer, botID, lastID, 15*time.Second)
if err != nil {
return err
}
response = strings.TrimSpace(msg.Message)
return nil
})
if err != nil {
return "", err
}
if useCache {
c.cache.Set(cacheKey, response)
}
return response, nil
}
func (c *Client) SendCommandMessage(ctx context.Context, command string, timeout time.Duration) (*tg.Message, error) {
if err := c.limiter.Wait(ctx); err != nil {
return nil, err
}
var result *tg.Message
err := c.withPeer(ctx, func(runCtx context.Context, api *tg.Client, sender *message.Sender, peer tg.InputPeerClass, botID int64) error {
lastID, err := c.latestIncomingMessageID(runCtx, api, peer, botID)
if err != nil {
return err
}
if _, err := sender.Resolve(c.cfg.BotUsername).Text(runCtx, command); err != nil {
return fmt.Errorf("send command: %w", err)
}
msg, err := c.waitForMessage(runCtx, api, peer, botID, lastID, timeout)
if err != nil {
return err
}
result = msg
return nil
})
if err != nil {
return nil, err
}
return result, nil
}
func (c *Client) PressButton(ctx context.Context, msg *tg.Message, keyword string) (*tg.Message, error) {
if msg == nil {
return nil, errors.New("message cannot be nil")
}
if err := c.limiter.Wait(ctx); err != nil {
return nil, err
}
normalizedKeyword := strings.ToLower(keyword)
var updated *tg.Message
err := c.withPeer(ctx, func(runCtx context.Context, api *tg.Client, sender *message.Sender, peer tg.InputPeerClass, botID int64) error {
button, err := findCallbackButton(msg, normalizedKeyword)
if err != nil {
return err
}
req := &tg.MessagesGetBotCallbackAnswerRequest{
Peer: peer,
MsgID: msg.ID,
}
if len(button.Data) > 0 {
req.SetData(button.Data)
}
invoke := func() error {
_, err := api.MessagesGetBotCallbackAnswer(runCtx, req)
return err
}
if err := invoke(); err != nil {
if wait, ok := telegram.AsFloodWait(err); ok {
if err := sleepWithContext(runCtx, wait+time.Second); err != nil {
return err
}
if err := invoke(); err != nil {
return fmt.Errorf("callback retry failed: %w", err)
}
} else {
return fmt.Errorf("press callback button: %w", err)
}
}
if err := sleepWithContext(runCtx, 1200*time.Millisecond); err != nil {
return err
}
resp, err := api.MessagesGetMessages(runCtx, []tg.InputMessageClass{
&tg.InputMessageID{ID: msg.ID},
})
if err != nil {
return fmt.Errorf("fetch updated message: %w", err)
}
refreshed, err := extractMessageByID(resp, msg.ID)
if err != nil {
return err
}
updated = refreshed
return nil
})
if err != nil {
return nil, err
}
return updated, nil
}
func (c *Client) FetchUserMessages(ctx context.Context, identifier string, maxPages *int) (string, error) {
id := strings.TrimSpace(identifier)
if id == "" {
return "", errors.New("identifier cannot be empty")
}
if !strings.HasPrefix(id, "/") {
if !strings.HasPrefix(id, "@") && !isNumericIdentifier(id) {
id = "@" + id
}
}
command := id
if !strings.HasPrefix(command, "/") {
command = fmt.Sprintf("/user_info %s", id)
}
base, err := c.SendCommandMessage(ctx, command, 20*time.Second)
if err != nil {
return "", err
}
stage, err := c.PressButton(ctx, base, "messages")
if err != nil {
return "", err
}
current, err := c.PressButton(ctx, stage, "all")
if err != nil {
return "", err
}
pages := make([]string, 0)
seen := map[string]struct{}{}
currentPage := 1
totalPages := extractTotalPages(current)
var limit int
if maxPages != nil {
if *maxPages <= 0 {
return "", errors.New("maxPages must be greater than zero")
}
limit = *maxPages
}
for {
text := strings.TrimSpace(current.Message)
if text != "" {
if _, ok := seen[text]; !ok {
header := fmt.Sprintf("第 %d 页", currentPage)
if totalPages > 0 {
header = fmt.Sprintf("%s/%d", header, totalPages)
}
entry := strings.Join([]string{header, "", text}, "\n")
pages = append(pages, entry)
seen[text] = struct{}{}
}
}
if limit > 0 && currentPage >= limit {
break
}
updated, err := c.PressButton(ctx, current, "➡")
if err != nil {
break
}
newText := strings.TrimSpace(updated.Message)
if _, ok := seen[newText]; ok {
break
}
current = updated
currentPage++
}
if len(pages) == 0 {
return fmt.Sprintf("未找到 %s 的消息记录。", identifier), nil
}
summary := fmt.Sprintf("共收集 %d 页消息", len(pages))
if totalPages > 0 {
summary = fmt.Sprintf("%s存在 %d 页)", summary, totalPages)
}
result := append([]string{summary, ""}, pages...)
return strings.Join(result, "\n\n"), nil
}
func isNumericIdentifier(value string) bool {
for _, r := range value {
if r != '+' && (r < '0' || r > '9') {
return false
}
}
return value != ""
}