229 lines
8.0 KiB
Python
229 lines
8.0 KiB
Python
import asyncio
|
|
import json
|
|
import hashlib
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import asyncpg
|
|
|
|
from models import PageRecord
|
|
|
|
|
|
class StorageManager:
|
|
def __init__(self, database_url: str):
|
|
self.database_url = database_url
|
|
self.pool: Optional[asyncpg.pool.Pool] = None
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def initialize(self):
|
|
if self.pool:
|
|
return
|
|
self.pool = await asyncpg.create_pool(
|
|
self.database_url,
|
|
min_size=1,
|
|
max_size=5,
|
|
timeout=10
|
|
)
|
|
await self._ensure_schema()
|
|
|
|
async def close(self):
|
|
if self.pool:
|
|
await self.pool.close()
|
|
self.pool = None
|
|
|
|
async def _ensure_schema(self):
|
|
assert self.pool is not None
|
|
async with self.pool.acquire() as conn:
|
|
await conn.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS mcp_results (
|
|
id BIGSERIAL PRIMARY KEY,
|
|
command TEXT NOT NULL,
|
|
bot_command TEXT NOT NULL,
|
|
arguments JSONB NOT NULL,
|
|
arguments_hash TEXT NOT NULL,
|
|
page_number INTEGER NOT NULL,
|
|
raw_text TEXT NOT NULL,
|
|
raw_text_hash TEXT NOT NULL,
|
|
entities JSONB,
|
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
source_account TEXT,
|
|
synced BOOLEAN NOT NULL DEFAULT FALSE,
|
|
sync_batch_id TEXT,
|
|
last_sync_at TIMESTAMPTZ,
|
|
sync_attempts INTEGER NOT NULL DEFAULT 0,
|
|
UNIQUE (command, bot_command, page_number, arguments_hash, raw_text_hash)
|
|
);
|
|
"""
|
|
)
|
|
|
|
await conn.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS mcp_entities (
|
|
id BIGSERIAL PRIMARY KEY,
|
|
result_id BIGINT NOT NULL REFERENCES mcp_results (id) ON DELETE CASCADE,
|
|
entity_type TEXT NOT NULL,
|
|
entity_value TEXT NOT NULL,
|
|
metadata JSONB,
|
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
UNIQUE (result_id, entity_type, entity_value)
|
|
);
|
|
"""
|
|
)
|
|
|
|
async def save_response(
|
|
self,
|
|
command: str,
|
|
bot_command: str,
|
|
arguments: Dict[str, Any],
|
|
pages: List[PageRecord],
|
|
source_account: Optional[str] = None
|
|
) -> List[int]:
|
|
if not pages or not self.pool:
|
|
return []
|
|
|
|
arguments_json = self._normalize_arguments(arguments)
|
|
arguments_hash = self._hash_value(arguments_json)
|
|
inserted_ids: List[int] = []
|
|
|
|
async with self.pool.acquire() as conn:
|
|
async with conn.transaction():
|
|
for page in pages:
|
|
raw_text = page.text or ""
|
|
raw_hash = self._hash_value(raw_text)
|
|
|
|
record = await conn.fetchrow(
|
|
"""
|
|
INSERT INTO mcp_results (
|
|
command, bot_command, arguments, arguments_hash,
|
|
page_number, raw_text, raw_text_hash, entities,
|
|
source_account
|
|
) VALUES ($1,$2,$3::jsonb,$4,$5,$6,$7,$8::jsonb,$9)
|
|
ON CONFLICT (command, bot_command, page_number, arguments_hash, raw_text_hash)
|
|
DO NOTHING
|
|
RETURNING id;
|
|
""",
|
|
command,
|
|
bot_command,
|
|
arguments_json,
|
|
arguments_hash,
|
|
page.page_number,
|
|
raw_text,
|
|
raw_hash,
|
|
json.dumps(page.entities or [], ensure_ascii=False),
|
|
source_account,
|
|
)
|
|
|
|
if record:
|
|
result_id = record["id"]
|
|
inserted_ids.append(result_id)
|
|
if page.entities:
|
|
await self._upsert_entities(conn, result_id, page.entities)
|
|
|
|
return inserted_ids
|
|
|
|
async def fetch_unsynced_results(self, limit: int = 200) -> List[Dict[str, Any]]:
|
|
if not self.pool:
|
|
return []
|
|
|
|
async with self.pool.acquire() as conn:
|
|
rows = await conn.fetch(
|
|
"""
|
|
SELECT
|
|
r.id,
|
|
r.command,
|
|
r.bot_command,
|
|
r.arguments,
|
|
r.page_number,
|
|
r.raw_text,
|
|
r.entities,
|
|
r.created_at,
|
|
r.source_account,
|
|
COALESCE(
|
|
json_agg(
|
|
json_build_object(
|
|
'entity_type', e.entity_type,
|
|
'entity_value', e.entity_value,
|
|
'metadata', e.metadata
|
|
)
|
|
) FILTER (WHERE e.id IS NOT NULL),
|
|
'[]'::json
|
|
) AS entity_list
|
|
FROM mcp_results r
|
|
LEFT JOIN mcp_entities e ON e.result_id = r.id
|
|
WHERE r.synced = FALSE
|
|
GROUP BY r.id
|
|
ORDER BY r.created_at ASC
|
|
LIMIT $1;
|
|
""",
|
|
limit,
|
|
)
|
|
|
|
payload = []
|
|
for row in rows:
|
|
payload.append({
|
|
"id": row["id"],
|
|
"command": row["command"],
|
|
"bot_command": row["bot_command"],
|
|
"arguments": row["arguments"],
|
|
"page_number": row["page_number"],
|
|
"raw_text": row["raw_text"],
|
|
"entities": row["entity_list"],
|
|
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
|
"source_account": row["source_account"],
|
|
})
|
|
return payload
|
|
|
|
async def mark_synced(self, ids: List[int], batch_id: str):
|
|
if not ids or not self.pool:
|
|
return
|
|
|
|
async with self.pool.acquire() as conn:
|
|
await conn.execute(
|
|
"""
|
|
UPDATE mcp_results
|
|
SET synced = TRUE,
|
|
sync_batch_id = $1,
|
|
last_sync_at = NOW(),
|
|
sync_attempts = 0
|
|
WHERE id = ANY($2::bigint[]);
|
|
""",
|
|
batch_id,
|
|
ids,
|
|
)
|
|
|
|
async def mark_failed_sync(self, ids: List[int]):
|
|
if not ids or not self.pool:
|
|
return
|
|
|
|
async with self.pool.acquire() as conn:
|
|
await conn.execute(
|
|
"""
|
|
UPDATE mcp_results
|
|
SET sync_attempts = sync_attempts + 1
|
|
WHERE id = ANY($1::bigint[]);
|
|
""",
|
|
ids,
|
|
)
|
|
|
|
async def _upsert_entities(self, conn: asyncpg.Connection, result_id: int, entities: List[Dict[str, Any]]):
|
|
for entity in entities:
|
|
await conn.execute(
|
|
"""
|
|
INSERT INTO mcp_entities (result_id, entity_type, entity_value, metadata)
|
|
VALUES ($1, $2, $3, $4::jsonb)
|
|
ON CONFLICT (result_id, entity_type, entity_value) DO NOTHING;
|
|
""",
|
|
result_id,
|
|
entity.get("type", "unknown"),
|
|
str(entity.get("value", "")),
|
|
json.dumps(entity.get("metadata", {}), ensure_ascii=False),
|
|
)
|
|
|
|
@staticmethod
|
|
def _normalize_arguments(arguments: Dict[str, Any]) -> str:
|
|
return json.dumps(arguments or {}, ensure_ascii=False, sort_keys=True, default=str)
|
|
|
|
@staticmethod
|
|
def _hash_value(value: str) -> str:
|
|
return hashlib.sha256(value.encode("utf-8")).hexdigest()
|