Files
clawrity/skills/postgres_connector.py
2026-05-04 22:00:38 +05:30

385 lines
13 KiB
Python

"""
Clawrity — PostgreSQL + pgvector Connector
Connection pool management, schema initialization, and query execution.
Single database handles both structured queries (NL-to-SQL) and vector search (pgvector).
"""
import logging
import time
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
import psycopg2
import psycopg2.extras
from pgvector.psycopg2 import register_vector
from config.settings import get_settings
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Schema DDL
# ---------------------------------------------------------------------------
INIT_SCHEMA_SQL = """
-- Enable pgvector extension
CREATE EXTENSION IF NOT EXISTS vector;
-- Structured business data (replaces BigQuery)
CREATE TABLE IF NOT EXISTS spend_data (
id SERIAL PRIMARY KEY,
date DATE,
country VARCHAR(100),
branch VARCHAR(100),
channel VARCHAR(100),
spend FLOAT,
revenue FLOAT,
leads INT,
conversions INT,
client_id VARCHAR(100)
);
-- Vector embeddings (replaces ChromaDB)
CREATE TABLE IF NOT EXISTS embeddings (
id VARCHAR(200) PRIMARY KEY,
client_id VARCHAR(100),
chunk_type VARCHAR(50),
text TEXT,
metadata JSONB,
embedding vector(384)
);
-- Forecast cache
CREATE TABLE IF NOT EXISTS forecasts (
id SERIAL PRIMARY KEY,
client_id VARCHAR(100),
branch VARCHAR(100),
country VARCHAR(100),
horizon_months INT,
forecast_data JSONB,
computed_at TIMESTAMP DEFAULT NOW()
);
-- Indexes
CREATE INDEX IF NOT EXISTS idx_spend_data_client
ON spend_data (client_id);
CREATE INDEX IF NOT EXISTS idx_spend_data_date
ON spend_data (client_id, date);
CREATE INDEX IF NOT EXISTS idx_embeddings_client_type
ON embeddings (client_id, chunk_type);
CREATE INDEX IF NOT EXISTS idx_forecasts_client
ON forecasts (client_id, branch, country);
"""
# IVFFlat index requires rows to exist — created separately after data load
IVFFLAT_INDEX_SQL = """
CREATE INDEX IF NOT EXISTS idx_embeddings_cosine
ON embeddings USING ivfflat (embedding vector_cosine_ops)
WITH (lists = 100);
"""
class PostgresConnector:
"""PostgreSQL + pgvector connection manager."""
def __init__(self, database_url: Optional[str] = None):
self.database_url = database_url or get_settings().database_url
self._conn: Optional[psycopg2.extensions.connection] = None
def _get_connection(self) -> psycopg2.extensions.connection:
"""Get or create a database connection with retry logic."""
if self._conn is None or self._conn.closed:
max_retries = 3
for attempt in range(max_retries):
try:
self._conn = psycopg2.connect(self.database_url)
register_vector(self._conn)
logger.info("Connected to PostgreSQL with pgvector support")
return self._conn
except psycopg2.OperationalError as e:
wait = 2**attempt
logger.warning(
f"DB connection attempt {attempt + 1}/{max_retries} failed: {e}. "
f"Retrying in {wait}s..."
)
time.sleep(wait)
raise ConnectionError("Failed to connect to PostgreSQL after 3 attempts")
return self._conn
def close(self):
"""Close the database connection."""
if self._conn and not self._conn.closed:
self._conn.close()
logger.info("PostgreSQL connection closed")
def init_schema(self):
"""Create tables and extensions if they don't exist."""
conn = self._get_connection()
try:
with conn.cursor() as cur:
cur.execute(INIT_SCHEMA_SQL)
conn.commit()
logger.info("Database schema initialized successfully")
except Exception as e:
conn.rollback()
logger.error(f"Schema initialization failed: {e}")
raise
def create_vector_index(self):
"""Create IVFFlat index — call AFTER data has been loaded into embeddings."""
conn = self._get_connection()
try:
with conn.cursor() as cur:
cur.execute(IVFFLAT_INDEX_SQL)
conn.commit()
logger.info("IVFFlat vector index created")
except Exception as e:
conn.rollback()
logger.warning(f"Could not create IVFFlat index (may need more rows): {e}")
# ------------------------------------------------------------------
# Query execution
# ------------------------------------------------------------------
def execute_query(self, sql: str, params: Optional[tuple] = None) -> pd.DataFrame:
"""
Execute a SELECT query and return results as a DataFrame.
Args:
sql: SQL query string (must be SELECT only)
params: Query parameters for parameterised queries
Returns:
pandas DataFrame with query results
"""
conn = self._get_connection()
try:
df = pd.read_sql_query(sql, conn, params=params)
conn.rollback()
logger.debug(f"Query returned {len(df)} rows")
return df
except Exception as e:
logger.error(f"Query execution failed: {e}")
conn.rollback()
raise
def execute_raw(self, sql: str, params: Optional[tuple] = None) -> List[Dict]:
"""Execute a query and return raw dictionaries."""
conn = self._get_connection()
try:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql, params)
if cur.description:
results = [dict(row) for row in cur.fetchall()]
conn.rollback()
return results
conn.commit()
return []
except Exception as e:
conn.rollback()
logger.error(f"Raw query execution failed: {e}")
raise
def execute_write(self, sql: str, params: Optional[tuple] = None):
"""Execute an INSERT/UPDATE/DELETE statement."""
conn = self._get_connection()
try:
with conn.cursor() as cur:
cur.execute(sql, params)
conn.commit()
except Exception as e:
conn.rollback()
logger.error(f"Write execution failed: {e}")
raise
def execute_batch(self, sql: str, data: List[tuple], page_size: int = 1000):
"""Execute a batch INSERT using execute_values for performance."""
conn = self._get_connection()
try:
with conn.cursor() as cur:
psycopg2.extras.execute_values(cur, sql, data, page_size=page_size)
conn.commit()
logger.info(f"Batch insert: {len(data)} rows")
except Exception as e:
conn.rollback()
logger.error(f"Batch execution failed: {e}")
raise
# ------------------------------------------------------------------
# pgvector operations
# ------------------------------------------------------------------
def upsert_embeddings(self, embeddings_data: List[Dict[str, Any]]):
"""
Upsert embedding records into the embeddings table.
Args:
embeddings_data: List of dicts with keys:
id, client_id, chunk_type, text, metadata, embedding
"""
conn = self._get_connection()
sql = """
INSERT INTO embeddings (id, client_id, chunk_type, text, metadata, embedding)
VALUES %s
ON CONFLICT (id) DO UPDATE SET
text = EXCLUDED.text,
metadata = EXCLUDED.metadata,
embedding = EXCLUDED.embedding
"""
data = [
(
d["id"],
d["client_id"],
d["chunk_type"],
d["text"],
psycopg2.extras.Json(d["metadata"]),
np.array(d["embedding"]),
)
for d in embeddings_data
]
try:
with conn.cursor() as cur:
psycopg2.extras.execute_values(cur, sql, data, page_size=100)
conn.commit()
logger.info(f"Upserted {len(data)} embeddings")
except Exception as e:
conn.rollback()
logger.error(f"Embedding upsert failed: {e}")
raise
def search_embeddings(
self,
query_embedding: np.ndarray,
client_id: str,
chunk_type: Optional[str] = None,
top_k: int = 5,
) -> List[Dict]:
"""
Search for similar embeddings using pgvector cosine similarity.
Args:
query_embedding: Query vector (384 dims)
client_id: Filter by client
chunk_type: Optional filter by chunk type
top_k: Number of results to return
Returns:
List of dicts with text, metadata, and similarity score
"""
conn = self._get_connection()
query_vec = np.array(query_embedding)
if chunk_type:
sql = """
SELECT text, metadata, 1 - (embedding <=> %s) AS similarity
FROM embeddings
WHERE client_id = %s AND chunk_type = %s
ORDER BY embedding <=> %s
LIMIT %s
"""
params = (query_vec, client_id, chunk_type, query_vec, top_k)
else:
sql = """
SELECT text, metadata, 1 - (embedding <=> %s) AS similarity
FROM embeddings
WHERE client_id = %s
ORDER BY embedding <=> %s
LIMIT %s
"""
params = (query_vec, client_id, query_vec, top_k)
try:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql, params)
results = [dict(row) for row in cur.fetchall()]
logger.debug(f"Vector search returned {len(results)} results")
return results
except Exception as e:
logger.error(f"Vector search failed: {e}")
raise
# ------------------------------------------------------------------
# Utility
# ------------------------------------------------------------------
def get_table_count(self, table: str, client_id: Optional[str] = None) -> int:
"""Get row count for a table, optionally filtered by client_id."""
conn = self._get_connection()
try:
with conn.cursor() as cur:
if client_id:
cur.execute(
f"SELECT COUNT(*) FROM {table} WHERE client_id = %s",
(client_id,),
)
else:
cur.execute(f"SELECT COUNT(*) FROM {table}")
return cur.fetchone()[0]
except Exception as e:
logger.error(f"Count query failed: {e}")
return 0
def get_spend_data_schema(self, client_id: str) -> Dict:
"""Get metadata about available data for a client — used by NL-to-SQL."""
conn = self._get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"SELECT DISTINCT country FROM spend_data WHERE client_id = %s ORDER BY country",
(client_id,),
)
countries = [row[0] for row in cur.fetchall()]
cur.execute(
"SELECT DISTINCT branch FROM spend_data WHERE client_id = %s ORDER BY branch",
(client_id,),
)
branches = [row[0] for row in cur.fetchall()]
cur.execute(
"SELECT DISTINCT channel FROM spend_data WHERE client_id = %s ORDER BY channel",
(client_id,),
)
channels = [row[0] for row in cur.fetchall()]
cur.execute(
"SELECT MIN(date), MAX(date) FROM spend_data WHERE client_id = %s",
(client_id,),
)
date_range = cur.fetchone()
return {
"countries": countries,
"branches": branches,
"channels": channels,
"date_min": str(date_range[0]) if date_range[0] else None,
"date_max": str(date_range[1]) if date_range[1] else None,
}
except Exception as e:
logger.error(f"Schema metadata query failed: {e}")
return {
"countries": [],
"branches": [],
"channels": [],
"date_min": None,
"date_max": None,
}
# ---------------------------------------------------------------------------
# Module-level singleton
# ---------------------------------------------------------------------------
_connector: Optional[PostgresConnector] = None
def get_connector() -> PostgresConnector:
"""Get the shared PostgresConnector singleton."""
global _connector
if _connector is None:
_connector = PostgresConnector()
return _connector