mirror of
https://github.com/Manoj-HV30/clawrity.git
synced 2026-05-16 19:35:21 +00:00
prototype
This commit is contained in:
+287
@@ -0,0 +1,287 @@
|
||||
"""
|
||||
Clawrity — RAG Chunker
|
||||
|
||||
Aggregation-based semantic chunking — NOT fixed-size, NOT sliding window.
|
||||
Source is structured tabular data. We aggregate rows into business-meaningful
|
||||
units and write natural language narratives.
|
||||
|
||||
Three chunk types:
|
||||
1. branch_weekly — GROUP BY branch, country, week
|
||||
2. channel_monthly — GROUP BY channel, country, month
|
||||
3. trend_qoq — GROUP BY branch, country, quarter (QoQ delta COMPUTED)
|
||||
|
||||
Plus Faker-generated narrative summaries reflecting real patterns.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from faker import Faker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
fake = Faker()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chunk:
|
||||
"""A single RAG chunk."""
|
||||
id: str
|
||||
client_id: str
|
||||
chunk_type: str
|
||||
text: str
|
||||
metadata: Dict
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"client_id": self.client_id,
|
||||
"chunk_type": self.chunk_type,
|
||||
"text": self.text,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
|
||||
def generate_chunks(df: pd.DataFrame, client_id: str) -> List[Chunk]:
|
||||
"""Generate all chunk types from preprocessed data."""
|
||||
chunks = []
|
||||
|
||||
df = df.copy()
|
||||
df["date"] = pd.to_datetime(df["date"])
|
||||
|
||||
chunks.extend(_branch_weekly(df, client_id))
|
||||
chunks.extend(_channel_monthly(df, client_id))
|
||||
chunks.extend(_trend_qoq(df, client_id))
|
||||
chunks.extend(_faker_narratives(df, client_id))
|
||||
|
||||
logger.info(f"Generated {len(chunks)} total chunks for {client_id}")
|
||||
return chunks
|
||||
|
||||
|
||||
def _chunk_id(client_id: str, chunk_type: str, *parts) -> str:
|
||||
"""Generate a deterministic chunk ID."""
|
||||
raw = f"{client_id}:{chunk_type}:" + ":".join(str(p) for p in parts)
|
||||
return hashlib.md5(raw.encode()).hexdigest()[:16]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chunk Type 1: Branch Weekly
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _branch_weekly(df: pd.DataFrame, client_id: str) -> List[Chunk]:
|
||||
"""GROUP BY branch, country, week. One chunk per branch per week."""
|
||||
chunks = []
|
||||
df = df.copy()
|
||||
df["week"] = df["date"].dt.isocalendar().week.astype(int)
|
||||
df["month"] = df["date"].dt.month_name()
|
||||
df["year"] = df["date"].dt.year
|
||||
|
||||
grouped = df.groupby(["branch", "country", "year", "week", "month"]).agg(
|
||||
spend=("spend", "sum"),
|
||||
revenue=("revenue", "sum"),
|
||||
leads=("leads", "sum"),
|
||||
conversions=("conversions", "sum"),
|
||||
).reset_index()
|
||||
|
||||
for _, row in grouped.iterrows():
|
||||
spend = row["spend"]
|
||||
revenue = row["revenue"]
|
||||
roi = round(revenue / spend, 2) if spend > 0 else 0
|
||||
conv_rate = round(row["conversions"] / row["leads"] * 100, 1) if row["leads"] > 0 else 0
|
||||
|
||||
text = (
|
||||
f"{row['branch']} ({row['country']}) in week {row['week']} of "
|
||||
f"{row['month']} {row['year']}: spent ${spend:,.0f}, earned "
|
||||
f"${revenue:,.0f}, ROI {roi}x, {row['leads']} leads, "
|
||||
f"{conv_rate}% conversion rate."
|
||||
)
|
||||
|
||||
chunks.append(Chunk(
|
||||
id=_chunk_id(client_id, "branch_weekly", row["branch"], row["year"], row["week"]),
|
||||
client_id=client_id,
|
||||
chunk_type="branch_weekly",
|
||||
text=text,
|
||||
metadata={
|
||||
"branch": row["branch"],
|
||||
"country": row["country"],
|
||||
"week": int(row["week"]),
|
||||
"month": row["month"],
|
||||
"year": int(row["year"]),
|
||||
"roi": roi,
|
||||
},
|
||||
))
|
||||
|
||||
logger.info(f"Generated {len(chunks)} branch_weekly chunks")
|
||||
return chunks
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chunk Type 2: Channel Monthly
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _channel_monthly(df: pd.DataFrame, client_id: str) -> List[Chunk]:
|
||||
"""GROUP BY channel, country, month, quarter."""
|
||||
chunks = []
|
||||
df = df.copy()
|
||||
df["month"] = df["date"].dt.month_name()
|
||||
df["quarter"] = "Q" + df["date"].dt.quarter.astype(str)
|
||||
df["year"] = df["date"].dt.year
|
||||
|
||||
grouped = df.groupby(["channel", "country", "year", "month", "quarter"]).agg(
|
||||
spend=("spend", "sum"),
|
||||
revenue=("revenue", "sum"),
|
||||
leads=("leads", "sum"),
|
||||
conversions=("conversions", "sum"),
|
||||
).reset_index()
|
||||
|
||||
for _, row in grouped.iterrows():
|
||||
spend = row["spend"]
|
||||
revenue = row["revenue"]
|
||||
roi = round(revenue / spend, 2) if spend > 0 else 0
|
||||
|
||||
text = (
|
||||
f"{row['channel']} in {row['country']} during {row['month']} "
|
||||
f"({row['quarter']}) {row['year']}: ${spend:,.0f} spent, "
|
||||
f"${revenue:,.0f} revenue, ROI {roi}x."
|
||||
)
|
||||
|
||||
chunks.append(Chunk(
|
||||
id=_chunk_id(client_id, "channel_monthly", row["channel"], row["country"], row["year"], row["month"]),
|
||||
client_id=client_id,
|
||||
chunk_type="channel_monthly",
|
||||
text=text,
|
||||
metadata={
|
||||
"channel": row["channel"],
|
||||
"country": row["country"],
|
||||
"month": row["month"],
|
||||
"quarter": row["quarter"],
|
||||
"year": int(row["year"]),
|
||||
"roi": roi,
|
||||
},
|
||||
))
|
||||
|
||||
logger.info(f"Generated {len(chunks)} channel_monthly chunks")
|
||||
return chunks
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chunk Type 3: QoQ Trend (Most Important)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _trend_qoq(df: pd.DataFrame, client_id: str) -> List[Chunk]:
|
||||
"""GROUP BY branch, country, quarter. Compute quarter-over-quarter delta."""
|
||||
chunks = []
|
||||
df = df.copy()
|
||||
df["quarter"] = df["date"].dt.to_period("Q").astype(str)
|
||||
|
||||
grouped = df.groupby(["branch", "country", "quarter"]).agg(
|
||||
spend=("spend", "sum"),
|
||||
revenue=("revenue", "sum"),
|
||||
).reset_index()
|
||||
|
||||
# Sort for QoQ calculation
|
||||
grouped = grouped.sort_values(["branch", "country", "quarter"])
|
||||
|
||||
for (branch, country), group in grouped.groupby(["branch", "country"]):
|
||||
group = group.sort_values("quarter").reset_index(drop=True)
|
||||
|
||||
for i in range(1, len(group)):
|
||||
prev = group.iloc[i - 1]
|
||||
curr = group.iloc[i]
|
||||
|
||||
prev_rev = prev["revenue"]
|
||||
curr_rev = curr["revenue"]
|
||||
|
||||
if prev_rev > 0:
|
||||
delta = round((curr_rev - prev_rev) / prev_rev * 100, 1)
|
||||
else:
|
||||
delta = 0
|
||||
|
||||
direction = "grew" if delta > 0 else "declined"
|
||||
|
||||
text = (
|
||||
f"{branch} ({country}) revenue {direction} {abs(delta)}% "
|
||||
f"in {curr['quarter']} vs {prev['quarter']}. "
|
||||
f"Total spend: ${curr['spend']:,.0f}, revenue: ${curr_rev:,.0f}."
|
||||
)
|
||||
|
||||
chunks.append(Chunk(
|
||||
id=_chunk_id(client_id, "trend_qoq", branch, country, curr["quarter"]),
|
||||
client_id=client_id,
|
||||
chunk_type="trend_qoq",
|
||||
text=text,
|
||||
metadata={
|
||||
"branch": branch,
|
||||
"country": country,
|
||||
"quarter": curr["quarter"],
|
||||
"prev_quarter": prev["quarter"],
|
||||
"delta_pct": delta,
|
||||
},
|
||||
))
|
||||
|
||||
logger.info(f"Generated {len(chunks)} trend_qoq chunks")
|
||||
return chunks
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Faker Narrative Chunks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _faker_narratives(df: pd.DataFrame, client_id: str) -> List[Chunk]:
|
||||
"""Generate plausible narrative chunks reflecting real data patterns."""
|
||||
chunks = []
|
||||
df = df.copy()
|
||||
df["quarter"] = df["date"].dt.to_period("Q").astype(str)
|
||||
|
||||
# Find top and bottom performers per quarter
|
||||
quarterly = df.groupby(["branch", "country", "quarter"]).agg(
|
||||
revenue=("revenue", "sum"),
|
||||
spend=("spend", "sum"),
|
||||
leads=("leads", "sum"),
|
||||
).reset_index()
|
||||
|
||||
templates = [
|
||||
"{branch} branch demonstrated strong {quarter} performance driven by {channel} efficiency, outperforming regional averages.",
|
||||
"In {quarter}, {branch} ({country}) showed {trend} momentum with revenue reaching ${revenue:,.0f}, primarily through {channel} campaigns.",
|
||||
"{branch} branch in {country} maintained steady growth in {quarter}, with lead generation up and conversion rates holding above {conv_rate:.1f}%.",
|
||||
"Cost efficiency at {branch} ({country}) improved in {quarter}, with spend-to-revenue ratio tightening to {ratio:.2f}x.",
|
||||
]
|
||||
|
||||
channels = df["channel"].dropna().unique().tolist() or ["Paid Search", "Social Media", "Email"]
|
||||
|
||||
for _, row in quarterly.iterrows():
|
||||
roi = row["revenue"] / row["spend"] if row["spend"] > 0 else 0
|
||||
conv_rate = np.random.uniform(5, 20)
|
||||
trend = "positive" if roi > 1.5 else "moderate" if roi > 1 else "challenging"
|
||||
channel = np.random.choice(channels)
|
||||
|
||||
template = np.random.choice(templates)
|
||||
text = template.format(
|
||||
branch=row["branch"],
|
||||
country=row["country"],
|
||||
quarter=row["quarter"],
|
||||
channel=channel,
|
||||
revenue=row["revenue"],
|
||||
trend=trend,
|
||||
conv_rate=conv_rate,
|
||||
ratio=1 / roi if roi > 0 else 0,
|
||||
)
|
||||
|
||||
chunks.append(Chunk(
|
||||
id=_chunk_id(client_id, "narrative", row["branch"], row["country"], row["quarter"]),
|
||||
client_id=client_id,
|
||||
chunk_type="narrative",
|
||||
text=text,
|
||||
metadata={
|
||||
"branch": row["branch"],
|
||||
"country": row["country"],
|
||||
"quarter": row["quarter"],
|
||||
"source": "generated_narrative",
|
||||
},
|
||||
))
|
||||
|
||||
logger.info(f"Generated {len(chunks)} narrative chunks")
|
||||
return chunks
|
||||
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Clawrity — RAG Evaluator
|
||||
|
||||
Lightweight Groq-based evaluation (no OpenAI, no full RAGAs).
|
||||
Four metrics: faithfulness, answer_relevancy, context_precision, context_recall.
|
||||
Single Groq call with structured JSON output.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from groq import Groq
|
||||
|
||||
from config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EVAL_PROMPT = """Evaluate this RAG-augmented response on four criteria.
|
||||
|
||||
## User Query
|
||||
{query}
|
||||
|
||||
## Retrieved Context Chunks
|
||||
{chunks}
|
||||
|
||||
## Generated Response
|
||||
{response}
|
||||
|
||||
## Evaluation Criteria (score each 0.0 to 1.0)
|
||||
|
||||
1. **Faithfulness**: Does the response ONLY contain information from the retrieved chunks? No hallucination?
|
||||
2. **Answer Relevancy**: Does the response directly address the user's question?
|
||||
3. **Context Precision**: Were the retrieved chunks actually relevant to the question?
|
||||
4. **Context Recall**: Did the retrieval capture enough context to answer the question fully?
|
||||
|
||||
Return ONLY a JSON object:
|
||||
{{
|
||||
"faithfulness": <float>,
|
||||
"answer_relevancy": <float>,
|
||||
"context_precision": <float>,
|
||||
"context_recall": <float>,
|
||||
"overall": <float (average of all four)>,
|
||||
"notes": "<brief explanation>"
|
||||
}}"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalResult:
|
||||
faithfulness: float = 0.0
|
||||
answer_relevancy: float = 0.0
|
||||
context_precision: float = 0.0
|
||||
context_recall: float = 0.0
|
||||
overall: float = 0.0
|
||||
notes: str = ""
|
||||
|
||||
|
||||
class RAGEvaluator:
|
||||
"""Evaluates RAG pipeline quality using Groq LLM."""
|
||||
|
||||
def __init__(self):
|
||||
settings = get_settings()
|
||||
self.client = Groq(api_key=settings.groq_api_key)
|
||||
self.model = settings.llm_model
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
query: str,
|
||||
chunks: List[Dict],
|
||||
response: str,
|
||||
) -> EvalResult:
|
||||
"""Evaluate a RAG response."""
|
||||
chunks_text = "\n".join(
|
||||
f"{i+1}. {c.get('text', '')} (similarity: {c.get('similarity', 0):.2f})"
|
||||
for i, c in enumerate(chunks)
|
||||
) if chunks else "No chunks retrieved."
|
||||
|
||||
prompt = EVAL_PROMPT.format(
|
||||
query=query,
|
||||
chunks=chunks_text,
|
||||
response=response,
|
||||
)
|
||||
|
||||
try:
|
||||
result = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a RAG evaluation expert. Return only valid JSON."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.1,
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
raw = result.choices[0].message.content.strip()
|
||||
return self._parse(raw)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RAG evaluation failed: {e}")
|
||||
return EvalResult(notes=f"Evaluation error: {str(e)}")
|
||||
|
||||
def _parse(self, raw: str) -> EvalResult:
|
||||
"""Parse JSON evaluation response."""
|
||||
try:
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith("```"):
|
||||
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned[3:]
|
||||
if cleaned.endswith("```"):
|
||||
cleaned = cleaned[:-3]
|
||||
|
||||
data = json.loads(cleaned.strip())
|
||||
return EvalResult(
|
||||
faithfulness=float(data.get("faithfulness", 0)),
|
||||
answer_relevancy=float(data.get("answer_relevancy", 0)),
|
||||
context_precision=float(data.get("context_precision", 0)),
|
||||
context_recall=float(data.get("context_recall", 0)),
|
||||
overall=float(data.get("overall", 0)),
|
||||
notes=data.get("notes", ""),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not parse evaluation: {e}")
|
||||
return EvalResult(notes="Parse error")
|
||||
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
Clawrity — RAG Monitoring
|
||||
|
||||
Logs every interaction to JSONL and provides aggregated stats.
|
||||
Exposes data for /admin/stats/{client_id} endpoint.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
|
||||
from config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _log_path(client_id: str) -> str:
|
||||
"""Get the JSONL log file path for a client."""
|
||||
logs_dir = get_settings().logs_dir
|
||||
os.makedirs(logs_dir, exist_ok=True)
|
||||
return os.path.join(logs_dir, f"{client_id}_interactions.jsonl")
|
||||
|
||||
|
||||
def log_interaction(
|
||||
client_id: str,
|
||||
query: str,
|
||||
num_chunks: int,
|
||||
chunk_types_used: list,
|
||||
qa_score: float,
|
||||
qa_passed: bool,
|
||||
retries: int,
|
||||
response_length: int,
|
||||
elapsed_seconds: float = 0.0,
|
||||
):
|
||||
"""Log an interaction to JSONL."""
|
||||
entry = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"client_id": client_id,
|
||||
"query": query,
|
||||
"num_chunks": num_chunks,
|
||||
"chunk_types_used": chunk_types_used,
|
||||
"qa_score": qa_score,
|
||||
"qa_passed": qa_passed,
|
||||
"retries": retries,
|
||||
"response_length": response_length,
|
||||
"elapsed_seconds": elapsed_seconds,
|
||||
}
|
||||
|
||||
try:
|
||||
path = _log_path(client_id)
|
||||
with open(path, "a") as f:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log interaction: {e}")
|
||||
|
||||
|
||||
def get_stats(client_id: str) -> Dict:
|
||||
"""
|
||||
Get aggregated monitoring stats for a client.
|
||||
|
||||
Returns:
|
||||
Dict with: total_queries, pass_rate, avg_qa_score, avg_retries,
|
||||
queries_needing_retry
|
||||
"""
|
||||
path = _log_path(client_id)
|
||||
if not os.path.exists(path):
|
||||
return {
|
||||
"client_id": client_id,
|
||||
"total_queries": 0,
|
||||
"pass_rate": 0.0,
|
||||
"avg_qa_score": 0.0,
|
||||
"avg_retries": 0.0,
|
||||
"queries_needing_retry": 0,
|
||||
}
|
||||
|
||||
entries = []
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
entries.append(json.loads(line))
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading log file: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
if not entries:
|
||||
return {"client_id": client_id, "total_queries": 0}
|
||||
|
||||
total = len(entries)
|
||||
passed = sum(1 for e in entries if e.get("qa_passed", False))
|
||||
scores = [e.get("qa_score", 0) for e in entries]
|
||||
retries = [e.get("retries", 0) for e in entries]
|
||||
retry_queries = sum(1 for r in retries if r > 0)
|
||||
|
||||
return {
|
||||
"client_id": client_id,
|
||||
"total_queries": total,
|
||||
"pass_rate": round(passed / total * 100, 1) if total > 0 else 0,
|
||||
"avg_qa_score": round(sum(scores) / total, 3) if total > 0 else 0,
|
||||
"avg_retries": round(sum(retries) / total, 2) if total > 0 else 0,
|
||||
"queries_needing_retry": retry_queries,
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
Clawrity — RAG Preprocessor
|
||||
|
||||
Fetches data from PostgreSQL, cleans it for RAG chunking:
|
||||
- Removes nulls, outliers > 3 std devs, duplicates
|
||||
- Normalises string columns
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from etl.normaliser import remove_outliers
|
||||
from skills.postgres_connector import get_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def preprocess_for_rag(
|
||||
client_id: str,
|
||||
days: int = 365,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Fetch and preprocess data for RAG chunking.
|
||||
|
||||
Args:
|
||||
client_id: Client to fetch data for
|
||||
days: Number of days of data to fetch (default 365)
|
||||
|
||||
Returns:
|
||||
Clean DataFrame ready for chunking
|
||||
"""
|
||||
db = get_connector()
|
||||
|
||||
sql = """
|
||||
SELECT date, country, branch, channel, spend, revenue, leads, conversions
|
||||
FROM spend_data
|
||||
WHERE client_id = %s AND date >= CURRENT_DATE - INTERVAL '%s days'
|
||||
ORDER BY date
|
||||
"""
|
||||
# Can't parameterise interval directly, use string formatting for days
|
||||
safe_sql = f"""
|
||||
SELECT date, country, branch, channel, spend, revenue, leads, conversions
|
||||
FROM spend_data
|
||||
WHERE client_id = %s AND date >= CURRENT_DATE - INTERVAL '{int(days)} days'
|
||||
ORDER BY date
|
||||
"""
|
||||
df = db.execute_query(safe_sql, (client_id,))
|
||||
logger.info(f"Fetched {len(df)} rows for RAG preprocessing")
|
||||
|
||||
if df.empty:
|
||||
logger.warning(f"No data found for client {client_id}")
|
||||
return df
|
||||
|
||||
# Remove rows with critical nulls
|
||||
critical_cols = ["date", "branch", "country", "revenue"]
|
||||
df = df.dropna(subset=[c for c in critical_cols if c in df.columns])
|
||||
|
||||
# Remove outliers on numeric columns
|
||||
df = remove_outliers(df, ["spend", "revenue", "leads", "conversions"])
|
||||
|
||||
# Clean strings
|
||||
for col in ["country", "branch", "channel"]:
|
||||
if col in df.columns:
|
||||
df[col] = df[col].astype(str).str.strip().str.title()
|
||||
|
||||
# Remove duplicates
|
||||
df = df.drop_duplicates()
|
||||
|
||||
logger.info(f"Preprocessed: {len(df)} rows ready for chunking")
|
||||
return df
|
||||
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
Clawrity — RAG Retriever
|
||||
|
||||
Detects query intent → selects chunk_type → searches pgvector.
|
||||
Intent detection based on keywords:
|
||||
- "should/recommend/allocate/shift" → trend_qoq
|
||||
- "channel/paid/email/social" → channel_monthly
|
||||
- everything else → branch_weekly
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from rag.vector_store import search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Intent → chunk_type mapping based on keywords
|
||||
INTENT_PATTERNS = {
|
||||
"trend_qoq": [
|
||||
"should", "recommend", "allocate", "shift", "increase", "decrease",
|
||||
"budget", "realloc", "invest", "optimize", "growth", "trend",
|
||||
"quarter", "qoq", "forecast", "predict",
|
||||
],
|
||||
"channel_monthly": [
|
||||
"channel", "paid", "email", "social", "search", "display",
|
||||
"organic", "referral", "campaign", "marketing", "roi",
|
||||
"spend", "advertising",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class Retriever:
|
||||
"""RAG retriever with intent-based chunk type filtering."""
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
client_id: str,
|
||||
top_k: int = 5,
|
||||
chunk_type_override: Optional[str] = None,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Retrieve relevant chunks based on query intent.
|
||||
|
||||
Args:
|
||||
query: User's natural language query
|
||||
client_id: Client to search within
|
||||
top_k: Number of chunks to retrieve
|
||||
chunk_type_override: Force a specific chunk type
|
||||
|
||||
Returns:
|
||||
List of dicts with text, metadata, similarity
|
||||
"""
|
||||
if chunk_type_override:
|
||||
chunk_type = chunk_type_override
|
||||
else:
|
||||
chunk_type = self._detect_intent(query)
|
||||
|
||||
logger.info(f"Detected intent → chunk_type: {chunk_type}")
|
||||
|
||||
results = search(
|
||||
query=query,
|
||||
client_id=client_id,
|
||||
chunk_type=chunk_type,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
# If no results with the detected type, fall back to all types
|
||||
if not results:
|
||||
logger.info(f"No results for {chunk_type}, falling back to all types")
|
||||
results = search(
|
||||
query=query,
|
||||
client_id=client_id,
|
||||
chunk_type=None,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def _detect_intent(self, query: str) -> str:
|
||||
"""Detect query intent from keywords."""
|
||||
query_lower = query.lower()
|
||||
|
||||
scores = {}
|
||||
for chunk_type, keywords in INTENT_PATTERNS.items():
|
||||
score = sum(1 for kw in keywords if kw in query_lower)
|
||||
scores[chunk_type] = score
|
||||
|
||||
# Return the chunk type with highest score, default to branch_weekly
|
||||
if max(scores.values()) > 0:
|
||||
return max(scores, key=scores.get)
|
||||
|
||||
return "branch_weekly"
|
||||
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Clawrity — RAG Vector Store
|
||||
|
||||
Embeds chunks using sentence-transformers all-MiniLM-L6-v2 (CPU, 384 dims).
|
||||
Stores and searches via pgvector in PostgreSQL.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from rag.chunker import Chunk
|
||||
from skills.postgres_connector import get_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_model = None
|
||||
|
||||
|
||||
def _get_embedding_model():
|
||||
"""Lazy-load the embedding model (CPU only, ~90MB)."""
|
||||
global _model
|
||||
if _model is None:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
_model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
logger.info("Loaded embedding model: all-MiniLM-L6-v2 (384 dims)")
|
||||
return _model
|
||||
|
||||
|
||||
def embed_texts(texts: List[str], batch_size: int = 100) -> np.ndarray:
|
||||
"""
|
||||
Embed a list of texts using MiniLM.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to embed
|
||||
batch_size: Batch size for encoding (default 100)
|
||||
|
||||
Returns:
|
||||
numpy array of shape (len(texts), 384)
|
||||
"""
|
||||
model = _get_embedding_model()
|
||||
embeddings = model.encode(
|
||||
texts,
|
||||
batch_size=batch_size,
|
||||
show_progress_bar=len(texts) > 100,
|
||||
normalize_embeddings=True,
|
||||
)
|
||||
logger.info(f"Embedded {len(texts)} texts → shape {embeddings.shape}")
|
||||
return embeddings
|
||||
|
||||
|
||||
def embed_query(query: str) -> np.ndarray:
|
||||
"""Embed a single query string."""
|
||||
model = _get_embedding_model()
|
||||
return model.encode(query, normalize_embeddings=True)
|
||||
|
||||
|
||||
def store_chunks(chunks: List[Chunk], embeddings: np.ndarray):
|
||||
"""
|
||||
Upsert chunks + embeddings into pgvector.
|
||||
Uses ON CONFLICT DO UPDATE for safe nightly re-indexing.
|
||||
"""
|
||||
seen = set()
|
||||
unique_chunks = []
|
||||
unique_embeddings = []
|
||||
for chunk, emb in zip(chunks, embeddings):
|
||||
if chunk.id not in seen:
|
||||
seen.add(chunk.id)
|
||||
unique_chunks.append(chunk)
|
||||
unique_embeddings.append(emb)
|
||||
chunks = unique_chunks
|
||||
embeddings = unique_embeddings
|
||||
|
||||
db = get_connector()
|
||||
|
||||
data = []
|
||||
for chunk, embedding in zip(chunks, embeddings):
|
||||
data.append({
|
||||
"id": chunk.id,
|
||||
"client_id": chunk.client_id,
|
||||
"chunk_type": chunk.chunk_type,
|
||||
"text": chunk.text,
|
||||
"metadata": chunk.metadata,
|
||||
"embedding": embedding.tolist(),
|
||||
})
|
||||
|
||||
# Batch upsert
|
||||
batch_size = 100
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i + batch_size]
|
||||
db.upsert_embeddings(batch)
|
||||
|
||||
logger.info(f"Stored {len(data)} chunks in pgvector")
|
||||
|
||||
# Try to create IVFFlat index (needs enough rows)
|
||||
try:
|
||||
db.create_vector_index()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def search(
|
||||
query: str,
|
||||
client_id: str,
|
||||
chunk_type: Optional[str] = None,
|
||||
top_k: int = 5,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Search pgvector for similar chunks.
|
||||
|
||||
Args:
|
||||
query: Natural language query
|
||||
client_id: Client to search within
|
||||
chunk_type: Optional filter (branch_weekly, channel_monthly, trend_qoq)
|
||||
top_k: Number of results
|
||||
|
||||
Returns:
|
||||
List of dicts with text, metadata, similarity
|
||||
"""
|
||||
query_embedding = embed_query(query)
|
||||
db = get_connector()
|
||||
|
||||
results = db.search_embeddings(
|
||||
query_embedding=query_embedding,
|
||||
client_id=client_id,
|
||||
chunk_type=chunk_type,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Vector search: query='{query[:50]}...', "
|
||||
f"chunk_type={chunk_type}, results={len(results)}"
|
||||
)
|
||||
return results
|
||||
Reference in New Issue
Block a user