feat: add postgres storage and remote sync
This commit is contained in:
228
core/storage.py
Normal file
228
core/storage.py
Normal file
@@ -0,0 +1,228 @@
|
||||
import asyncio
|
||||
import json
|
||||
import hashlib
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import asyncpg
|
||||
|
||||
from models import PageRecord
|
||||
|
||||
|
||||
class StorageManager:
|
||||
def __init__(self, database_url: str):
|
||||
self.database_url = database_url
|
||||
self.pool: Optional[asyncpg.pool.Pool] = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def initialize(self):
|
||||
if self.pool:
|
||||
return
|
||||
self.pool = await asyncpg.create_pool(
|
||||
self.database_url,
|
||||
min_size=1,
|
||||
max_size=5,
|
||||
timeout=10
|
||||
)
|
||||
await self._ensure_schema()
|
||||
|
||||
async def close(self):
|
||||
if self.pool:
|
||||
await self.pool.close()
|
||||
self.pool = None
|
||||
|
||||
async def _ensure_schema(self):
|
||||
assert self.pool is not None
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS mcp_results (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
command TEXT NOT NULL,
|
||||
bot_command TEXT NOT NULL,
|
||||
arguments JSONB NOT NULL,
|
||||
arguments_hash TEXT NOT NULL,
|
||||
page_number INTEGER NOT NULL,
|
||||
raw_text TEXT NOT NULL,
|
||||
raw_text_hash TEXT NOT NULL,
|
||||
entities JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
source_account TEXT,
|
||||
synced BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
sync_batch_id TEXT,
|
||||
last_sync_at TIMESTAMPTZ,
|
||||
sync_attempts INTEGER NOT NULL DEFAULT 0,
|
||||
UNIQUE (command, bot_command, page_number, arguments_hash, raw_text_hash)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
await conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS mcp_entities (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
result_id BIGINT NOT NULL REFERENCES mcp_results (id) ON DELETE CASCADE,
|
||||
entity_type TEXT NOT NULL,
|
||||
entity_value TEXT NOT NULL,
|
||||
metadata JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE (result_id, entity_type, entity_value)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
async def save_response(
|
||||
self,
|
||||
command: str,
|
||||
bot_command: str,
|
||||
arguments: Dict[str, Any],
|
||||
pages: List[PageRecord],
|
||||
source_account: Optional[str] = None
|
||||
) -> List[int]:
|
||||
if not pages or not self.pool:
|
||||
return []
|
||||
|
||||
arguments_json = self._normalize_arguments(arguments)
|
||||
arguments_hash = self._hash_value(arguments_json)
|
||||
inserted_ids: List[int] = []
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
for page in pages:
|
||||
raw_text = page.text or ""
|
||||
raw_hash = self._hash_value(raw_text)
|
||||
|
||||
record = await conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO mcp_results (
|
||||
command, bot_command, arguments, arguments_hash,
|
||||
page_number, raw_text, raw_text_hash, entities,
|
||||
source_account
|
||||
) VALUES ($1,$2,$3::jsonb,$4,$5,$6,$7,$8::jsonb,$9)
|
||||
ON CONFLICT (command, bot_command, page_number, arguments_hash, raw_text_hash)
|
||||
DO NOTHING
|
||||
RETURNING id;
|
||||
""",
|
||||
command,
|
||||
bot_command,
|
||||
arguments_json,
|
||||
arguments_hash,
|
||||
page.page_number,
|
||||
raw_text,
|
||||
raw_hash,
|
||||
json.dumps(page.entities or [], ensure_ascii=False),
|
||||
source_account,
|
||||
)
|
||||
|
||||
if record:
|
||||
result_id = record["id"]
|
||||
inserted_ids.append(result_id)
|
||||
if page.entities:
|
||||
await self._upsert_entities(conn, result_id, page.entities)
|
||||
|
||||
return inserted_ids
|
||||
|
||||
async def fetch_unsynced_results(self, limit: int = 200) -> List[Dict[str, Any]]:
|
||||
if not self.pool:
|
||||
return []
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT
|
||||
r.id,
|
||||
r.command,
|
||||
r.bot_command,
|
||||
r.arguments,
|
||||
r.page_number,
|
||||
r.raw_text,
|
||||
r.entities,
|
||||
r.created_at,
|
||||
r.source_account,
|
||||
COALESCE(
|
||||
json_agg(
|
||||
json_build_object(
|
||||
'entity_type', e.entity_type,
|
||||
'entity_value', e.entity_value,
|
||||
'metadata', e.metadata
|
||||
)
|
||||
) FILTER (WHERE e.id IS NOT NULL),
|
||||
'[]'::json
|
||||
) AS entity_list
|
||||
FROM mcp_results r
|
||||
LEFT JOIN mcp_entities e ON e.result_id = r.id
|
||||
WHERE r.synced = FALSE
|
||||
GROUP BY r.id
|
||||
ORDER BY r.created_at ASC
|
||||
LIMIT $1;
|
||||
""",
|
||||
limit,
|
||||
)
|
||||
|
||||
payload = []
|
||||
for row in rows:
|
||||
payload.append({
|
||||
"id": row["id"],
|
||||
"command": row["command"],
|
||||
"bot_command": row["bot_command"],
|
||||
"arguments": row["arguments"],
|
||||
"page_number": row["page_number"],
|
||||
"raw_text": row["raw_text"],
|
||||
"entities": row["entity_list"],
|
||||
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
||||
"source_account": row["source_account"],
|
||||
})
|
||||
return payload
|
||||
|
||||
async def mark_synced(self, ids: List[int], batch_id: str):
|
||||
if not ids or not self.pool:
|
||||
return
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE mcp_results
|
||||
SET synced = TRUE,
|
||||
sync_batch_id = $1,
|
||||
last_sync_at = NOW(),
|
||||
sync_attempts = 0
|
||||
WHERE id = ANY($2::bigint[]);
|
||||
""",
|
||||
batch_id,
|
||||
ids,
|
||||
)
|
||||
|
||||
async def mark_failed_sync(self, ids: List[int]):
|
||||
if not ids or not self.pool:
|
||||
return
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE mcp_results
|
||||
SET sync_attempts = sync_attempts + 1
|
||||
WHERE id = ANY($1::bigint[]);
|
||||
""",
|
||||
ids,
|
||||
)
|
||||
|
||||
async def _upsert_entities(self, conn: asyncpg.Connection, result_id: int, entities: List[Dict[str, Any]]):
|
||||
for entity in entities:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO mcp_entities (result_id, entity_type, entity_value, metadata)
|
||||
VALUES ($1, $2, $3, $4::jsonb)
|
||||
ON CONFLICT (result_id, entity_type, entity_value) DO NOTHING;
|
||||
""",
|
||||
result_id,
|
||||
entity.get("type", "unknown"),
|
||||
str(entity.get("value", "")),
|
||||
json.dumps(entity.get("metadata", {}), ensure_ascii=False),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_arguments(arguments: Dict[str, Any]) -> str:
|
||||
return json.dumps(arguments or {}, ensure_ascii=False, sort_keys=True, default=str)
|
||||
|
||||
@staticmethod
|
||||
def _hash_value(value: str) -> str:
|
||||
return hashlib.sha256(value.encode("utf-8")).hexdigest()
|
||||
Reference in New Issue
Block a user