mirror of
https://github.com/Manoj-HV30/clawrity.git
synced 2026-05-16 19:35:21 +00:00
295 lines
11 KiB
Python
295 lines
11 KiB
Python
"""
|
|
Clawrity — Orchestrator
|
|
|
|
Coordinates the full message pipeline:
|
|
NormalisedMessage → NL-to-SQL → PostgreSQL → (RAG Retriever) → Gen Agent → QA Agent → Response
|
|
|
|
Max 2 retries per query. Returns best attempt with confidence warning after max retries.
|
|
|
|
Context enrichment: when a query returns sparse data (≤3 rows) and the question
|
|
asks for recommendations, automatically pulls top-performing branches as comparison
|
|
context so the Gen Agent can give actionable suggestions.
|
|
"""
|
|
|
|
import re
|
|
import logging
|
|
import time
|
|
from typing import Dict, Optional, List
|
|
|
|
import pandas as pd
|
|
|
|
from agents.gen_agent import GenAgent
|
|
from agents.qa_agent import QAAgent
|
|
from channels.protocol_adapter import NormalisedMessage
|
|
from config.client_loader import ClientConfig
|
|
from skills.nl_to_sql import NLToSQL
|
|
from skills.postgres_connector import get_connector
|
|
from soul.soul_loader import load_soul
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MAX_RETRIES = 2
|
|
|
|
# Keywords that signal the user wants recommendations, not just raw data
|
|
_RECOMMENDATION_KEYWORDS = re.compile(
|
|
r"\b(improve|increase|boost|grow|fix|help|recommend|suggest|advice|strategy|"
|
|
r"what (should|can|do)|how (to|can|do|should))\b",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
|
|
class Orchestrator:
|
|
"""Pipeline orchestrator — the central brain of Clawrity."""
|
|
|
|
def __init__(self):
|
|
self.nl_to_sql = NLToSQL()
|
|
self.gen_agent = GenAgent()
|
|
self.qa_agent = QAAgent()
|
|
self.retriever = None # Set in Phase 2 via set_retriever()
|
|
|
|
def set_retriever(self, retriever):
|
|
"""Attach the RAG retriever (Phase 2)."""
|
|
self.retriever = retriever
|
|
|
|
async def process(
|
|
self,
|
|
message: NormalisedMessage,
|
|
client_config: ClientConfig,
|
|
) -> Dict:
|
|
"""
|
|
Process a user message through the full pipeline.
|
|
|
|
Returns:
|
|
Dict with: response, qa_score, qa_passed, retries, metadata
|
|
"""
|
|
start_time = time.time()
|
|
db = get_connector()
|
|
|
|
# Load SOUL
|
|
soul_content = load_soul(client_config)
|
|
|
|
# Step 1: NL-to-SQL
|
|
schema_meta = db.get_spend_data_schema(client_config.client_id)
|
|
sql = self.nl_to_sql.generate_sql(
|
|
question=message.text,
|
|
client_id=client_config.client_id,
|
|
schema_metadata=schema_meta,
|
|
)
|
|
|
|
# Step 2: Execute SQL
|
|
data_context = None
|
|
if sql:
|
|
try:
|
|
data_context = db.execute_query(sql)
|
|
logger.info(f"SQL returned {len(data_context)} rows")
|
|
except Exception as e:
|
|
logger.error(f"SQL execution failed: {e}")
|
|
data_context = pd.DataFrame()
|
|
else:
|
|
data_context = pd.DataFrame()
|
|
|
|
# Step 2b: Context enrichment for sparse results
|
|
# When data is sparse and the user wants recommendations, pull
|
|
# top performers and channel benchmarks as supplementary context
|
|
supplementary_context = None
|
|
if self._needs_enrichment(message.text, data_context):
|
|
supplementary_context = self._enrich_context(
|
|
db, client_config.client_id, message.text, data_context
|
|
)
|
|
if supplementary_context is not None:
|
|
logger.info(
|
|
f"Context enriched: {len(supplementary_context)} supplementary rows"
|
|
)
|
|
|
|
# Step 3: RAG Retrieval (Phase 2)
|
|
rag_chunks = None
|
|
if self.retriever:
|
|
try:
|
|
rag_chunks = self.retriever.retrieve(
|
|
query=message.text,
|
|
client_id=client_config.client_id,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"RAG retrieval failed: {e}")
|
|
|
|
# Step 4: Gen Agent → QA Agent loop (max 2 retries)
|
|
# When supplementary context is provided (enrichment mode), use a relaxed
|
|
# QA threshold since the response naturally references broader benchmark data
|
|
qa_threshold = client_config.hallucination_threshold
|
|
if supplementary_context is not None and len(supplementary_context) > 0:
|
|
qa_threshold = min(qa_threshold, 0.5)
|
|
logger.info(f"Using relaxed QA threshold ({qa_threshold}) for enriched context")
|
|
|
|
best_response = None
|
|
best_score = 0.0
|
|
qa_result = {"score": 0, "passed": False, "issues": []}
|
|
retries = 0
|
|
|
|
for attempt in range(MAX_RETRIES + 1):
|
|
retry_issues = qa_result["issues"] if attempt > 0 else None
|
|
|
|
# On retry, add explicit data-only instruction to prevent hallucination
|
|
strict_data_instruction = None
|
|
if attempt > 0:
|
|
if supplementary_context is not None and len(supplementary_context) > 0:
|
|
strict_data_instruction = (
|
|
"CRITICAL: Only use data from the Data Context and Benchmark Data "
|
|
"sections provided. Do NOT invent figures or branch names that are "
|
|
"not present in either of those sections. You MAY reference benchmark "
|
|
"branches for comparison and recommendations."
|
|
)
|
|
else:
|
|
strict_data_instruction = (
|
|
"CRITICAL: Do NOT mention any branches, figures, or historical data "
|
|
"that are not in the SQL query result provided. Stick strictly to the "
|
|
"data. If historical context from RAG is about different branches than "
|
|
"what the query returned, IGNORE that context entirely."
|
|
)
|
|
|
|
response = self.gen_agent.generate(
|
|
question=message.text,
|
|
soul_content=soul_content,
|
|
data_context=data_context,
|
|
rag_chunks=rag_chunks,
|
|
retry_issues=retry_issues,
|
|
retry_count=attempt,
|
|
strict_data_instruction=strict_data_instruction,
|
|
supplementary_context=supplementary_context,
|
|
)
|
|
|
|
qa_result = self.qa_agent.evaluate(
|
|
response=response,
|
|
data_context=data_context,
|
|
threshold=qa_threshold,
|
|
supplementary_context=supplementary_context,
|
|
user_question=message.text,
|
|
)
|
|
|
|
# Track best response (prefer longer, richer responses over "no data" stubs)
|
|
if qa_result["score"] > best_score or (
|
|
qa_result["score"] == best_score
|
|
and best_response is not None
|
|
and len(response) > len(best_response)
|
|
):
|
|
best_score = qa_result["score"]
|
|
best_response = response
|
|
|
|
if qa_result["passed"]:
|
|
logger.info(f"QA passed on attempt {attempt + 1}")
|
|
break
|
|
else:
|
|
retries += 1
|
|
logger.warning(
|
|
f"QA failed on attempt {attempt + 1}: "
|
|
f"score={qa_result['score']:.2f}, issues={qa_result['issues']}"
|
|
)
|
|
|
|
# If max retries exceeded, use best response with confidence warning
|
|
final_response = best_response or response
|
|
if not qa_result["passed"] and retries >= MAX_RETRIES:
|
|
final_response += (
|
|
"\n\n---\n"
|
|
f"⚠️ *Confidence: {best_score:.0%} — "
|
|
f"This response may contain approximations. "
|
|
f"Please verify critical numbers against your source data.*"
|
|
)
|
|
|
|
elapsed = time.time() - start_time
|
|
|
|
result = {
|
|
"response": final_response,
|
|
"qa_score": best_score,
|
|
"qa_passed": qa_result["passed"],
|
|
"retries": retries,
|
|
"sql": sql,
|
|
"data_rows": len(data_context) if data_context is not None else 0,
|
|
"rag_chunks_used": len(rag_chunks) if rag_chunks else 0,
|
|
"elapsed_seconds": round(elapsed, 2),
|
|
}
|
|
|
|
# Log interaction
|
|
self._log_interaction(message, client_config, result)
|
|
|
|
return result
|
|
|
|
def _needs_enrichment(
|
|
self,
|
|
question: str,
|
|
data_context: Optional[pd.DataFrame],
|
|
) -> bool:
|
|
"""Check if the query result is too sparse for a recommendation question."""
|
|
# Only enrich if data is sparse
|
|
if data_context is not None and len(data_context) > 3:
|
|
return False
|
|
|
|
# Only enrich if user is asking for recommendations/improvement
|
|
return bool(_RECOMMENDATION_KEYWORDS.search(question))
|
|
|
|
def _enrich_context(
|
|
self,
|
|
db,
|
|
client_id: str,
|
|
question: str,
|
|
data_context: Optional[pd.DataFrame],
|
|
) -> Optional[pd.DataFrame]:
|
|
"""
|
|
Pull supplementary context: top-performing branches and channel
|
|
benchmarks to help Gen Agent give actionable recommendations.
|
|
"""
|
|
try:
|
|
# Get top 5 branches by ROI for comparison
|
|
enrichment_sql = """
|
|
SELECT branch, country, channel,
|
|
SUM(spend) as total_spend,
|
|
SUM(revenue) as total_revenue,
|
|
SUM(leads) as total_leads,
|
|
SUM(conversions) as total_conversions,
|
|
ROUND((SUM(revenue)/NULLIF(SUM(spend),0))::numeric, 2) as roi
|
|
FROM spend_data
|
|
WHERE client_id = %s
|
|
AND date >= CURRENT_DATE - INTERVAL '90 days'
|
|
GROUP BY branch, country, channel
|
|
HAVING SUM(spend) > 0
|
|
ORDER BY roi DESC
|
|
LIMIT 10
|
|
"""
|
|
top_performers = db.execute_query(enrichment_sql, (client_id,))
|
|
|
|
if top_performers is not None and len(top_performers) > 0:
|
|
logger.info(f"Enrichment: fetched {len(top_performers)} top performer rows")
|
|
return top_performers
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Context enrichment failed: {e}")
|
|
|
|
return None
|
|
|
|
def _log_interaction(
|
|
self,
|
|
message: NormalisedMessage,
|
|
client_config: ClientConfig,
|
|
result: Dict,
|
|
):
|
|
"""Log interaction for monitoring."""
|
|
try:
|
|
from rag.monitoring import log_interaction
|
|
log_interaction(
|
|
client_id=client_config.client_id,
|
|
query=message.text,
|
|
num_chunks=result.get("rag_chunks_used", 0),
|
|
chunk_types_used=[], # Populated when retriever provides this info
|
|
qa_score=result.get("qa_score", 0),
|
|
qa_passed=result.get("qa_passed", False),
|
|
retries=result.get("retries", 0),
|
|
response_length=len(result.get("response", "")),
|
|
elapsed_seconds=result.get("elapsed_seconds", 0),
|
|
)
|
|
except Exception as e:
|
|
logger.debug(f"Monitoring log failed: {e}")
|
|
|
|
logger.info(
|
|
f"[{client_config.client_id}] Query processed: "
|
|
f"score={result['qa_score']:.2f}, passed={result['qa_passed']}, "
|
|
f"retries={result['retries']}, time={result['elapsed_seconds']}s"
|
|
)
|