import asyncio import json import hashlib from typing import Any, Dict, List, Optional import asyncpg from models import PageRecord class StorageManager: def __init__(self, database_url: str): self.database_url = database_url self.pool: Optional[asyncpg.pool.Pool] = None self._lock = asyncio.Lock() async def initialize(self): if self.pool: return self.pool = await asyncpg.create_pool( self.database_url, min_size=1, max_size=5, timeout=10 ) await self._ensure_schema() async def close(self): if self.pool: await self.pool.close() self.pool = None async def _ensure_schema(self): assert self.pool is not None async with self.pool.acquire() as conn: await conn.execute( """ CREATE TABLE IF NOT EXISTS mcp_results ( id BIGSERIAL PRIMARY KEY, command TEXT NOT NULL, bot_command TEXT NOT NULL, arguments JSONB NOT NULL, arguments_hash TEXT NOT NULL, page_number INTEGER NOT NULL, raw_text TEXT NOT NULL, raw_text_hash TEXT NOT NULL, entities JSONB, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), source_account TEXT, synced BOOLEAN NOT NULL DEFAULT FALSE, sync_batch_id TEXT, last_sync_at TIMESTAMPTZ, sync_attempts INTEGER NOT NULL DEFAULT 0, UNIQUE (command, bot_command, page_number, arguments_hash, raw_text_hash) ); """ ) await conn.execute( """ CREATE TABLE IF NOT EXISTS mcp_entities ( id BIGSERIAL PRIMARY KEY, result_id BIGINT NOT NULL REFERENCES mcp_results (id) ON DELETE CASCADE, entity_type TEXT NOT NULL, entity_value TEXT NOT NULL, metadata JSONB, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), UNIQUE (result_id, entity_type, entity_value) ); """ ) async def save_response( self, command: str, bot_command: str, arguments: Dict[str, Any], pages: List[PageRecord], source_account: Optional[str] = None ) -> List[int]: if not pages or not self.pool: return [] arguments_json = self._normalize_arguments(arguments) arguments_hash = self._hash_value(arguments_json) inserted_ids: List[int] = [] async with self.pool.acquire() as conn: async with conn.transaction(): for page in pages: raw_text = page.text or "" raw_hash = self._hash_value(raw_text) record = await conn.fetchrow( """ INSERT INTO mcp_results ( command, bot_command, arguments, arguments_hash, page_number, raw_text, raw_text_hash, entities, source_account ) VALUES ($1,$2,$3::jsonb,$4,$5,$6,$7,$8::jsonb,$9) ON CONFLICT (command, bot_command, page_number, arguments_hash, raw_text_hash) DO NOTHING RETURNING id; """, command, bot_command, arguments_json, arguments_hash, page.page_number, raw_text, raw_hash, json.dumps(page.entities or [], ensure_ascii=False), source_account, ) if record: result_id = record["id"] inserted_ids.append(result_id) if page.entities: await self._upsert_entities(conn, result_id, page.entities) return inserted_ids async def fetch_unsynced_results(self, limit: int = 200) -> List[Dict[str, Any]]: if not self.pool: return [] async with self.pool.acquire() as conn: rows = await conn.fetch( """ SELECT r.id, r.command, r.bot_command, r.arguments, r.page_number, r.raw_text, r.entities, r.created_at, r.source_account, COALESCE( json_agg( json_build_object( 'entity_type', e.entity_type, 'entity_value', e.entity_value, 'metadata', e.metadata ) ) FILTER (WHERE e.id IS NOT NULL), '[]'::json ) AS entity_list FROM mcp_results r LEFT JOIN mcp_entities e ON e.result_id = r.id WHERE r.synced = FALSE GROUP BY r.id ORDER BY r.created_at ASC LIMIT $1; """, limit, ) payload = [] for row in rows: payload.append({ "id": row["id"], "command": row["command"], "bot_command": row["bot_command"], "arguments": row["arguments"], "page_number": row["page_number"], "raw_text": row["raw_text"], "entities": row["entity_list"], "created_at": row["created_at"].isoformat() if row["created_at"] else None, "source_account": row["source_account"], }) return payload async def mark_synced(self, ids: List[int], batch_id: str): if not ids or not self.pool: return async with self.pool.acquire() as conn: await conn.execute( """ UPDATE mcp_results SET synced = TRUE, sync_batch_id = $1, last_sync_at = NOW(), sync_attempts = 0 WHERE id = ANY($2::bigint[]); """, batch_id, ids, ) async def mark_failed_sync(self, ids: List[int]): if not ids or not self.pool: return async with self.pool.acquire() as conn: await conn.execute( """ UPDATE mcp_results SET sync_attempts = sync_attempts + 1 WHERE id = ANY($1::bigint[]); """, ids, ) async def _upsert_entities(self, conn: asyncpg.Connection, result_id: int, entities: List[Dict[str, Any]]): for entity in entities: await conn.execute( """ INSERT INTO mcp_entities (result_id, entity_type, entity_value, metadata) VALUES ($1, $2, $3, $4::jsonb) ON CONFLICT (result_id, entity_type, entity_value) DO NOTHING; """, result_id, entity.get("type", "unknown"), str(entity.get("value", "")), json.dumps(entity.get("metadata", {}), ensure_ascii=False), ) @staticmethod def _normalize_arguments(arguments: Dict[str, Any]) -> str: return json.dumps(arguments or {}, ensure_ascii=False, sort_keys=True, default=str) @staticmethod def _hash_value(value: str) -> str: return hashlib.sha256(value.encode("utf-8")).hexdigest()