Files

302 lines
11 KiB
Python

"""
Clawrity — Orchestrator
Coordinates the full message pipeline:
NormalisedMessage → NL-to-SQL → PostgreSQL → (RAG Retriever) → Gen Agent → QA Agent → Response
Max 2 retries per query. Returns best attempt with confidence warning after max retries.
Context enrichment: when a query returns sparse data (≤3 rows) and the question
asks for recommendations, automatically pulls top-performing branches as comparison
context so the Gen Agent can give actionable suggestions.
"""
import re
import logging
import time
from typing import Dict, Optional, List
import pandas as pd
from agents.gen_agent import GenAgent
from agents.qa_agent import QAAgent
from channels.protocol_adapter import NormalisedMessage
from config.client_loader import ClientConfig
from skills.nl_to_sql import NLToSQL
from skills.postgres_connector import get_connector
from soul.soul_loader import load_soul
logger = logging.getLogger(__name__)
MAX_RETRIES = 2
# Keywords that signal the user wants recommendations, not just raw data
_RECOMMENDATION_KEYWORDS = re.compile(
r"\b(improve|increase|boost|grow|fix|help|recommend|suggest|advice|strategy|"
r"what (should|can|do)|how (to|can|do|should))\b",
re.IGNORECASE,
)
class Orchestrator:
"""Pipeline orchestrator — the central brain of Clawrity."""
def __init__(self):
self.nl_to_sql = NLToSQL()
self.gen_agent = GenAgent()
self.qa_agent = QAAgent()
self.retriever = None # Set in Phase 2 via set_retriever()
def set_retriever(self, retriever):
"""Attach the RAG retriever (Phase 2)."""
self.retriever = retriever
async def process(
self,
message: NormalisedMessage,
client_config: ClientConfig,
) -> Dict:
"""
Process a user message through the full pipeline.
Returns:
Dict with: response, qa_score, qa_passed, retries, metadata
"""
start_time = time.time()
db = get_connector()
# Load SOUL
soul_content = load_soul(client_config)
# Step 1: NL-to-SQL
schema_meta = db.get_spend_data_schema(client_config.client_id)
sql = self.nl_to_sql.generate_sql(
question=message.text,
client_id=client_config.client_id,
schema_metadata=schema_meta,
)
# Step 2: Execute SQL
data_context = None
if sql:
try:
data_context = db.execute_query(sql)
logger.info(f"SQL returned {len(data_context)} rows")
except Exception as e:
logger.error(f"SQL execution failed: {e}")
data_context = pd.DataFrame()
else:
data_context = pd.DataFrame()
# Step 2b: Context enrichment for sparse results
# When data is sparse and the user wants recommendations, pull
# top performers and channel benchmarks as supplementary context
supplementary_context = None
if self._needs_enrichment(message.text, data_context):
supplementary_context = self._enrich_context(
db, client_config.client_id, message.text, data_context
)
if supplementary_context is not None:
logger.info(
f"Context enriched: {len(supplementary_context)} supplementary rows"
)
# Step 3: RAG Retrieval (Phase 2)
rag_chunks = None
if self.retriever:
try:
rag_chunks = self.retriever.retrieve(
query=message.text,
client_id=client_config.client_id,
)
except Exception as e:
logger.warning(f"RAG retrieval failed: {e}")
# Step 4: Gen Agent → QA Agent loop (max 2 retries)
# When supplementary context is provided (enrichment mode), use a relaxed
# QA threshold since the response naturally references broader benchmark data
qa_threshold = client_config.hallucination_threshold
if supplementary_context is not None and len(supplementary_context) > 0:
qa_threshold = min(qa_threshold, 0.5)
logger.info(
f"Using relaxed QA threshold ({qa_threshold}) for enriched context"
)
best_response = None
best_score = 0.0
qa_result = {"score": 0, "passed": False, "issues": []}
retries = 0
for attempt in range(MAX_RETRIES + 1):
retry_issues = qa_result["issues"] if attempt > 0 else None
# Always provide strict data grounding instruction to prevent
# the Gen Agent from hallucinating branch/figure data from RAG
# chunks that don't match the actual SQL query results.
if supplementary_context is not None and len(supplementary_context) > 0:
strict_data_instruction = (
"CRITICAL: Only use data from the Data Context and Benchmark Data "
"sections provided. Do NOT invent figures or branch names that are "
"not present in either of those sections. You MAY reference benchmark "
"branches for comparison and recommendations."
)
else:
strict_data_instruction = (
"CRITICAL: Do NOT mention any branches, figures, or historical data "
"that are not in the SQL query result provided. Stick strictly to the "
"data. If historical context from RAG is about different branches than "
"what the query returned, IGNORE that context entirely."
)
response = self.gen_agent.generate(
question=message.text,
soul_content=soul_content,
data_context=data_context,
rag_chunks=rag_chunks,
retry_issues=retry_issues,
retry_count=attempt,
strict_data_instruction=strict_data_instruction,
supplementary_context=supplementary_context,
sql=sql,
)
qa_result = self.qa_agent.evaluate(
response=response,
data_context=data_context,
threshold=qa_threshold,
supplementary_context=supplementary_context,
user_question=message.text,
sql=sql,
)
# Track best response (prefer longer, richer responses over "no data" stubs)
if qa_result["score"] > best_score or (
qa_result["score"] == best_score
and best_response is not None
and len(response) > len(best_response)
):
best_score = qa_result["score"]
best_response = response
if qa_result["passed"]:
logger.info(f"QA passed on attempt {attempt + 1}")
break
else:
retries += 1
logger.warning(
f"QA failed on attempt {attempt + 1}: "
f"score={qa_result['score']:.2f}, issues={qa_result['issues']}"
)
# If max retries exceeded, use best response with confidence warning
final_response = best_response or response
if not qa_result["passed"] and retries >= MAX_RETRIES:
final_response += (
"\n\n---\n"
f"⚠️ *Confidence: {best_score:.0%}"
f"This response may contain approximations. "
f"Please verify critical numbers against your source data.*"
)
elapsed = time.time() - start_time
result = {
"response": final_response,
"qa_score": best_score,
"qa_passed": qa_result["passed"],
"retries": retries,
"sql": sql,
"data_rows": len(data_context) if data_context is not None else 0,
"rag_chunks_used": len(rag_chunks) if rag_chunks else 0,
"elapsed_seconds": round(elapsed, 2),
}
# Log interaction
self._log_interaction(message, client_config, result)
return result
def _needs_enrichment(
self,
question: str,
data_context: Optional[pd.DataFrame],
) -> bool:
"""Check if the query result is too sparse for a recommendation question."""
# Only enrich if data is sparse
if data_context is not None and len(data_context) > 3:
return False
# Only enrich if user is asking for recommendations/improvement
return bool(_RECOMMENDATION_KEYWORDS.search(question))
def _enrich_context(
self,
db,
client_id: str,
question: str,
data_context: Optional[pd.DataFrame],
) -> Optional[pd.DataFrame]:
"""
Pull supplementary context: top-performing branches and channel
benchmarks to help Gen Agent give actionable recommendations.
"""
try:
# Get top 5 branches by ROI for comparison
enrichment_sql = """
SELECT branch, country, channel,
SUM(spend) as total_spend,
SUM(revenue) as total_revenue,
SUM(leads) as total_leads,
SUM(conversions) as total_conversions,
ROUND((SUM(revenue)/NULLIF(SUM(spend),0))::numeric, 2) as roi
FROM spend_data
WHERE client_id = %s
AND date >= CURRENT_DATE - INTERVAL '90 days'
GROUP BY branch, country, channel
HAVING SUM(spend) > 0
ORDER BY roi DESC
LIMIT 10
"""
top_performers = db.execute_query(enrichment_sql, (client_id,))
if top_performers is not None and len(top_performers) > 0:
logger.info(
f"Enrichment: fetched {len(top_performers)} top performer rows"
)
return top_performers
except Exception as e:
logger.warning(f"Context enrichment failed: {e}")
return None
def _log_interaction(
self,
message: NormalisedMessage,
client_config: ClientConfig,
result: Dict,
):
"""Log interaction for monitoring."""
try:
from rag.monitoring import log_interaction
log_interaction(
client_id=client_config.client_id,
query=message.text,
num_chunks=result.get("rag_chunks_used", 0),
chunk_types_used=[], # Populated when retriever provides this info
qa_score=result.get("qa_score", 0),
qa_passed=result.get("qa_passed", False),
retries=result.get("retries", 0),
response_length=len(result.get("response", "")),
elapsed_seconds=result.get("elapsed_seconds", 0),
)
except Exception as e:
logger.debug(f"Monitoring log failed: {e}")
logger.info(
f"[{client_config.client_id}] Query processed: "
f"score={result['qa_score']:.2f}, passed={result['qa_passed']}, "
f"retries={result['retries']}, time={result['elapsed_seconds']}s"
)