mirror of
https://github.com/Manoj-HV30/clawrity.git
synced 2026-05-16 19:35:21 +00:00
prototype
This commit is contained in:
@@ -0,0 +1,26 @@
|
||||
# =============================================================================
|
||||
# Clawrity — Environment Variables
|
||||
# Copy this file to .env and fill in your values.
|
||||
# NEVER commit .env to git.
|
||||
# =============================================================================
|
||||
|
||||
# --- Groq API (free at https://console.groq.com) ---
|
||||
GROQ_API_KEY=
|
||||
|
||||
# --- PostgreSQL + pgvector (docker-compose handles this if using defaults) ---
|
||||
DATABASE_URL=postgresql://user:pass@localhost:5432/clawrity
|
||||
|
||||
# --- Slack Bot (Socket Mode) ---
|
||||
# 1. Create app at https://api.slack.com/apps
|
||||
# 2. Enable Socket Mode → generate App-Level Token (xapp-...)
|
||||
# 3. OAuth & Permissions → install to workspace → copy Bot Token (xoxb-...)
|
||||
# 4. Basic Information → Signing Secret
|
||||
SLACK_BOT_TOKEN=
|
||||
SLACK_APP_TOKEN=
|
||||
SLACK_SIGNING_SECRET=
|
||||
|
||||
# --- Tavily Web Search (free at https://app.tavily.com) ---
|
||||
TAVILY_API_KEY=
|
||||
|
||||
# --- Slack Webhook for digest delivery ---
|
||||
ACME_SLACK_WEBHOOK=
|
||||
+43
@@ -0,0 +1,43 @@
|
||||
# === Environment & Secrets ===
|
||||
.env
|
||||
*.env
|
||||
|
||||
# === Dataset files — never commit raw or processed data ===
|
||||
data/raw/
|
||||
data/processed/
|
||||
|
||||
# === Python ===
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
*.egg
|
||||
|
||||
# === Virtual Environment ===
|
||||
venv/
|
||||
.venv/
|
||||
env/
|
||||
|
||||
# === IDE ===
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# === OS ===
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# === Logs ===
|
||||
logs/
|
||||
*.log
|
||||
*.jsonl
|
||||
|
||||
# === Docker ===
|
||||
pg_data/
|
||||
|
||||
# === Model Cache ===
|
||||
.cache/
|
||||
+23
@@ -0,0 +1,23 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies for psycopg2 and Prophet
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gcc \
|
||||
libpq-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python dependencies
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy project
|
||||
COPY . .
|
||||
|
||||
# Create necessary directories
|
||||
RUN mkdir -p data/raw data/processed logs
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
@@ -0,0 +1,213 @@
|
||||
# Clawrity
|
||||
|
||||
**Multi-channel AI business intelligence agent.** Enterprise clients interact via Slack (or Teams) and get data-grounded answers, daily digests, budget recommendations, ROI forecasts, and competitor/sector intelligence — all specific to their business data.
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
Built on the **OpenClaw pattern**:
|
||||
- **ProtocolAdapter** — normalises messages from any channel (Slack, Teams, etc.)
|
||||
- **SOUL.md** — per-client personality, rules, and business context
|
||||
- **HEARTBEAT.md** — autonomous daily digest scheduling
|
||||
|
||||
All intelligence lives in the Clawrity backend. OpenClaw layer has zero business logic.
|
||||
|
||||
## Tech Stack
|
||||
|
||||
| Component | Tool |
|
||||
|---|---|
|
||||
| Language | Python 3.11 |
|
||||
| API Framework | FastAPI + uvicorn |
|
||||
| LLM | Groq API — llama-3.3-70b-versatile |
|
||||
| Embeddings | sentence-transformers all-MiniLM-L6-v2 (CPU, 384d) |
|
||||
| Database | PostgreSQL + pgvector |
|
||||
| Channel (dev) | Slack Bolt SDK (Socket Mode) |
|
||||
| Channel (demo) | Microsoft Teams Bot Framework SDK |
|
||||
| Scheduler | APScheduler AsyncIOScheduler |
|
||||
| Web Search | Tavily API + DuckDuckGo fallback |
|
||||
| Forecasting | Prophet |
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Prerequisites
|
||||
|
||||
- Python 3.11+
|
||||
- Docker & Docker Compose
|
||||
- Groq API key (free: https://console.groq.com)
|
||||
- Tavily API key (free: https://app.tavily.com)
|
||||
|
||||
### 2. Environment Setup
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# Fill in your API keys in .env
|
||||
```
|
||||
|
||||
### 3. Start PostgreSQL + pgvector
|
||||
|
||||
```bash
|
||||
docker compose up -d postgres
|
||||
```
|
||||
|
||||
### 4. Install Dependencies
|
||||
|
||||
```bash
|
||||
python -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 5. Download Kaggle Datasets
|
||||
|
||||
Download these two datasets and place them in `data/raw/`:
|
||||
|
||||
1. **Global Superstore**: https://kaggle.com/datasets/apoorvaappz/global-super-store-dataset
|
||||
2. **Marketing Campaign Performance**: https://kaggle.com/datasets/manishabhatt22/marketing-campaign-performance-dataset
|
||||
|
||||
```bash
|
||||
mkdir -p data/raw data/processed
|
||||
# Place downloaded files in data/raw/
|
||||
```
|
||||
|
||||
### 6. Seed Demo Data
|
||||
|
||||
```bash
|
||||
python scripts/seed_demo_data.py --client_id acme_corp \
|
||||
--superstore data/raw/Global_Superstore2.csv \
|
||||
--marketing data/raw/marketing_campaign_dataset.csv
|
||||
```
|
||||
|
||||
### 7. Run RAG Pipeline
|
||||
|
||||
```bash
|
||||
python scripts/run_rag_pipeline.py --client_id acme_corp
|
||||
```
|
||||
|
||||
### 8. Start the API
|
||||
|
||||
```bash
|
||||
uvicorn main:app --reload --port 8000
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Slack Bot Setup (Socket Mode)
|
||||
|
||||
### Step 1: Create Slack App
|
||||
|
||||
1. Go to https://api.slack.com/apps
|
||||
2. Click **Create New App** → **From scratch**
|
||||
3. Name it `Clawrity` and select your workspace
|
||||
|
||||
### Step 2: Enable Socket Mode
|
||||
|
||||
1. In the left sidebar, click **Socket Mode**
|
||||
2. Toggle **Enable Socket Mode** to ON
|
||||
3. Click **Generate Token** — name it `clawrity-socket`
|
||||
4. Copy the `xapp-...` token → paste into `.env` as `SLACK_APP_TOKEN`
|
||||
|
||||
### Step 3: Configure Bot Token
|
||||
|
||||
1. Go to **OAuth & Permissions**
|
||||
2. Under **Bot Token Scopes**, add:
|
||||
- `app_mentions:read`
|
||||
- `chat:write`
|
||||
- `channels:history`
|
||||
- `channels:read`
|
||||
3. Click **Install to Workspace**
|
||||
4. Copy the `xoxb-...` token → paste into `.env` as `SLACK_BOT_TOKEN`
|
||||
|
||||
### Step 4: Enable Events
|
||||
|
||||
1. Go to **Event Subscriptions**
|
||||
2. Toggle **Enable Events** to ON (no Request URL needed in Socket Mode)
|
||||
3. Under **Subscribe to bot events**, add:
|
||||
- `app_mention`
|
||||
- `message.channels`
|
||||
4. Click **Save Changes**
|
||||
|
||||
### Step 5: Get Signing Secret
|
||||
|
||||
1. Go to **Basic Information**
|
||||
2. Under **App Credentials**, copy **Signing Secret**
|
||||
3. Paste into `.env` as `SLACK_SIGNING_SECRET`
|
||||
|
||||
### Step 6: Invite Bot to Channel
|
||||
|
||||
In Slack, go to your desired channel and type:
|
||||
```
|
||||
/invite @Clawrity
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## API Endpoints
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| POST | `/chat` | Send message → get AI response |
|
||||
| POST | `/slack/events` | Slack webhook fallback |
|
||||
| POST | `/compare` | Side-by-side RAG vs no-RAG |
|
||||
| POST | `/forecast/run/{client_id}` | Trigger Prophet forecasting |
|
||||
| GET | `/forecast/{client_id}/{branch}` | Get cached forecast |
|
||||
| GET | `/admin/stats/{client_id}` | RAG monitoring stats |
|
||||
| GET | `/health` | System status |
|
||||
|
||||
## Adding a New Client
|
||||
|
||||
1. Create `config/clients/client_newclient.yaml` (copy from `client_acme.yaml`)
|
||||
2. Create `soul/newclient_soul.md`
|
||||
3. Create `heartbeat/newclient_heartbeat.md`
|
||||
4. Place data in `data/raw/` and run seed + RAG scripts
|
||||
5. Restart — zero code changes required
|
||||
|
||||
---
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
clawrity/
|
||||
├── main.py # FastAPI application
|
||||
├── config/ # Configuration
|
||||
│ ├── settings.py # pydantic-settings from .env
|
||||
│ ├── client_loader.py # YAML client config loader
|
||||
│ └── clients/client_acme.yaml # Per-client config
|
||||
├── soul/ # Per-client personality
|
||||
│ ├── soul_loader.py
|
||||
│ └── acme_soul.md
|
||||
├── heartbeat/ # Autonomous digest scheduling
|
||||
│ ├── heartbeat_loader.py
|
||||
│ ├── scheduler.py
|
||||
│ └── acme_heartbeat.md
|
||||
├── agents/ # AI agents
|
||||
│ ├── gen_agent.py # Response generation
|
||||
│ ├── qa_agent.py # Quality assurance
|
||||
│ ├── orchestrator.py # Pipeline coordinator
|
||||
│ └── scout_agent.py # Competitor intelligence
|
||||
├── skills/ # Capabilities
|
||||
│ ├── postgres_connector.py # DB connection pool
|
||||
│ ├── nl_to_sql.py # Natural language → SQL
|
||||
│ └── web_search.py # Tavily + DuckDuckGo
|
||||
├── channels/ # Message channels
|
||||
│ ├── protocol_adapter.py # OpenClaw normalisation
|
||||
│ ├── slack_handler.py # Slack Socket Mode
|
||||
│ └── teams_handler.py # Teams stub
|
||||
├── rag/ # Retrieval-augmented generation
|
||||
│ ├── preprocessor.py
|
||||
│ ├── chunker.py
|
||||
│ ├── vector_store.py
|
||||
│ ├── retriever.py
|
||||
│ ├── evaluator.py
|
||||
│ └── monitoring.py
|
||||
├── forecasting/
|
||||
│ └── prophet_engine.py
|
||||
├── connectors/
|
||||
│ ├── base_connector.py
|
||||
│ └── csv_connector.py
|
||||
├── etl/
|
||||
│ └── normaliser.py
|
||||
└── scripts/
|
||||
├── seed_demo_data.py
|
||||
└── run_rag_pipeline.py
|
||||
```
|
||||
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
Clawrity — Gen Agent
|
||||
|
||||
Generates newsletter-style, data-grounded responses using LLM.
|
||||
Supports NVIDIA NIM and Groq via OpenAI-compatible API.
|
||||
Temperature 0.7 (reduced by 0.2 on each retry).
|
||||
Augmented with SOUL.md + live query results + RAG chunks (Phase 2).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from config.llm_client import get_llm_client, get_model_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenAgent:
|
||||
"""Response generation agent using LLM (NVIDIA NIM or Groq)."""
|
||||
|
||||
def __init__(self):
|
||||
self.client = get_llm_client()
|
||||
self.model = get_model_name()
|
||||
self.base_temperature = 0.7
|
||||
|
||||
def generate(
|
||||
self,
|
||||
question: str,
|
||||
soul_content: str,
|
||||
data_context: Optional[pd.DataFrame] = None,
|
||||
rag_chunks: Optional[List[Dict]] = None,
|
||||
retry_issues: Optional[List[str]] = None,
|
||||
retry_count: int = 0,
|
||||
strict_data_instruction: Optional[str] = None,
|
||||
supplementary_context: Optional[pd.DataFrame] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a data-grounded response.
|
||||
|
||||
Args:
|
||||
question: User's original question
|
||||
soul_content: SOUL.md content for personality/rules
|
||||
data_context: DataFrame from PostgreSQL query results
|
||||
rag_chunks: Retrieved chunks with similarity scores (Phase 2)
|
||||
retry_issues: QA Agent issues from previous attempt
|
||||
retry_count: Current retry number (0-2)
|
||||
|
||||
Returns:
|
||||
Markdown-formatted response string
|
||||
"""
|
||||
temperature = max(0.1, self.base_temperature - (retry_count * 0.2))
|
||||
|
||||
prompt = self._build_prompt(
|
||||
question, soul_content, data_context, rag_chunks, retry_issues,
|
||||
strict_data_instruction, supplementary_context,
|
||||
)
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": soul_content},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=temperature,
|
||||
max_tokens=2048,
|
||||
)
|
||||
result = response.choices[0].message.content.strip()
|
||||
logger.info(
|
||||
f"Gen Agent produced {len(result)} chars "
|
||||
f"(temp={temperature}, retry={retry_count})"
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Gen Agent failed: {e}")
|
||||
return f"I encountered an error generating your response. Please try again."
|
||||
|
||||
def generate_digest(
|
||||
self,
|
||||
soul_content: str,
|
||||
data_context: pd.DataFrame,
|
||||
rag_chunks: Optional[List[Dict]] = None,
|
||||
) -> str:
|
||||
"""Generate a daily digest newsletter."""
|
||||
prompt = f"""Generate a professional daily business intelligence digest.
|
||||
|
||||
## Performance Data (Last 7 Days)
|
||||
{data_context.to_markdown(index=False) if data_context is not None and len(data_context) > 0 else "No data available."}
|
||||
|
||||
"""
|
||||
if rag_chunks:
|
||||
prompt += "## Historical Context\n"
|
||||
for i, chunk in enumerate(rag_chunks, 1):
|
||||
sim = chunk.get("similarity", 0)
|
||||
prompt += f"{i}. {chunk['text']} (relevance: {sim:.2f})\n"
|
||||
prompt += "\n"
|
||||
|
||||
prompt += """Format as a newsletter with:
|
||||
1. **Executive Summary** — key highlights in 2-3 sentences
|
||||
2. **Top Performers** — best performing branches
|
||||
3. **Attention Required** — bottom 3 branches by revenue (ALWAYS include this)
|
||||
4. **Channel Insights** — spending efficiency across channels
|
||||
5. **Recommendations** — specific, data-backed suggestions
|
||||
|
||||
Use bullet points, bold key numbers, and keep it concise."""
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": soul_content},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=3000,
|
||||
)
|
||||
return response.choices[0].message.content.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Digest generation failed: {e}")
|
||||
return "Daily digest generation encountered an error."
|
||||
|
||||
def _build_prompt(
|
||||
self,
|
||||
question: str,
|
||||
soul_content: str,
|
||||
data_context: Optional[pd.DataFrame],
|
||||
rag_chunks: Optional[List[Dict]],
|
||||
retry_issues: Optional[List[str]],
|
||||
strict_data_instruction: Optional[str] = None,
|
||||
supplementary_context: Optional[pd.DataFrame] = None,
|
||||
) -> str:
|
||||
"""Build the augmented prompt for response generation."""
|
||||
parts = []
|
||||
|
||||
# Strict data instruction (on retry — prevents hallucination)
|
||||
if strict_data_instruction:
|
||||
parts.append(f"## ⚠️ STRICT REQUIREMENT\n{strict_data_instruction}\n")
|
||||
|
||||
# Data context
|
||||
if data_context is not None and len(data_context) > 0:
|
||||
parts.append("## Data Context (query results for the user's question)")
|
||||
parts.append(data_context.to_markdown(index=False))
|
||||
else:
|
||||
parts.append("## Data Context\nNo query results available.")
|
||||
|
||||
# Supplementary context (top performers for comparison)
|
||||
if supplementary_context is not None and len(supplementary_context) > 0:
|
||||
parts.append("\n## Benchmark Data (top-performing branches for comparison)")
|
||||
parts.append(supplementary_context.to_markdown(index=False))
|
||||
parts.append(
|
||||
"\nUse this benchmark data to compare the queried branch's performance "
|
||||
"against top performers. Identify which channels and strategies work "
|
||||
"best, and recommend specific, actionable improvements based on what "
|
||||
"top-performing branches are doing differently."
|
||||
)
|
||||
|
||||
# RAG chunks (Phase 2)
|
||||
if rag_chunks:
|
||||
parts.append("\n## Historical Business Context (retrieved from intelligence layer)")
|
||||
if strict_data_instruction:
|
||||
parts.append("⚠️ ONLY use historical context that is about branches/entities in the Data Context above. IGNORE any historical context about other branches.")
|
||||
for i, chunk in enumerate(rag_chunks, 1):
|
||||
sim = chunk.get("similarity", 0)
|
||||
parts.append(f"{i}. {chunk['text']} (relevance: {sim:.2f})")
|
||||
parts.append("\nBase suggestions on historical context. Cite specific data points.")
|
||||
|
||||
# Retry instructions
|
||||
if retry_issues:
|
||||
parts.append("\n## IMPORTANT — Previous Response Issues")
|
||||
parts.append("Your previous response had these problems. Fix them:")
|
||||
for issue in retry_issues:
|
||||
parts.append(f"- {issue}")
|
||||
parts.append("Be more precise. Only state facts supported by the data above.")
|
||||
parts.append("Do NOT introduce any new branches, cities, or figures that are not in the Data Context.")
|
||||
|
||||
# User question
|
||||
parts.append(f"\n## User Question\n{question}")
|
||||
|
||||
parts.append("\nProvide a professional, data-grounded response. Cite specific numbers from the data.")
|
||||
|
||||
return "\n".join(parts)
|
||||
@@ -0,0 +1,294 @@
|
||||
"""
|
||||
Clawrity — Orchestrator
|
||||
|
||||
Coordinates the full message pipeline:
|
||||
NormalisedMessage → NL-to-SQL → PostgreSQL → (RAG Retriever) → Gen Agent → QA Agent → Response
|
||||
|
||||
Max 2 retries per query. Returns best attempt with confidence warning after max retries.
|
||||
|
||||
Context enrichment: when a query returns sparse data (≤3 rows) and the question
|
||||
asks for recommendations, automatically pulls top-performing branches as comparison
|
||||
context so the Gen Agent can give actionable suggestions.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Optional, List
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from agents.gen_agent import GenAgent
|
||||
from agents.qa_agent import QAAgent
|
||||
from channels.protocol_adapter import NormalisedMessage
|
||||
from config.client_loader import ClientConfig
|
||||
from skills.nl_to_sql import NLToSQL
|
||||
from skills.postgres_connector import get_connector
|
||||
from soul.soul_loader import load_soul
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_RETRIES = 2
|
||||
|
||||
# Keywords that signal the user wants recommendations, not just raw data
|
||||
_RECOMMENDATION_KEYWORDS = re.compile(
|
||||
r"\b(improve|increase|boost|grow|fix|help|recommend|suggest|advice|strategy|"
|
||||
r"what (should|can|do)|how (to|can|do|should))\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
class Orchestrator:
|
||||
"""Pipeline orchestrator — the central brain of Clawrity."""
|
||||
|
||||
def __init__(self):
|
||||
self.nl_to_sql = NLToSQL()
|
||||
self.gen_agent = GenAgent()
|
||||
self.qa_agent = QAAgent()
|
||||
self.retriever = None # Set in Phase 2 via set_retriever()
|
||||
|
||||
def set_retriever(self, retriever):
|
||||
"""Attach the RAG retriever (Phase 2)."""
|
||||
self.retriever = retriever
|
||||
|
||||
async def process(
|
||||
self,
|
||||
message: NormalisedMessage,
|
||||
client_config: ClientConfig,
|
||||
) -> Dict:
|
||||
"""
|
||||
Process a user message through the full pipeline.
|
||||
|
||||
Returns:
|
||||
Dict with: response, qa_score, qa_passed, retries, metadata
|
||||
"""
|
||||
start_time = time.time()
|
||||
db = get_connector()
|
||||
|
||||
# Load SOUL
|
||||
soul_content = load_soul(client_config)
|
||||
|
||||
# Step 1: NL-to-SQL
|
||||
schema_meta = db.get_spend_data_schema(client_config.client_id)
|
||||
sql = self.nl_to_sql.generate_sql(
|
||||
question=message.text,
|
||||
client_id=client_config.client_id,
|
||||
schema_metadata=schema_meta,
|
||||
)
|
||||
|
||||
# Step 2: Execute SQL
|
||||
data_context = None
|
||||
if sql:
|
||||
try:
|
||||
data_context = db.execute_query(sql)
|
||||
logger.info(f"SQL returned {len(data_context)} rows")
|
||||
except Exception as e:
|
||||
logger.error(f"SQL execution failed: {e}")
|
||||
data_context = pd.DataFrame()
|
||||
else:
|
||||
data_context = pd.DataFrame()
|
||||
|
||||
# Step 2b: Context enrichment for sparse results
|
||||
# When data is sparse and the user wants recommendations, pull
|
||||
# top performers and channel benchmarks as supplementary context
|
||||
supplementary_context = None
|
||||
if self._needs_enrichment(message.text, data_context):
|
||||
supplementary_context = self._enrich_context(
|
||||
db, client_config.client_id, message.text, data_context
|
||||
)
|
||||
if supplementary_context is not None:
|
||||
logger.info(
|
||||
f"Context enriched: {len(supplementary_context)} supplementary rows"
|
||||
)
|
||||
|
||||
# Step 3: RAG Retrieval (Phase 2)
|
||||
rag_chunks = None
|
||||
if self.retriever:
|
||||
try:
|
||||
rag_chunks = self.retriever.retrieve(
|
||||
query=message.text,
|
||||
client_id=client_config.client_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"RAG retrieval failed: {e}")
|
||||
|
||||
# Step 4: Gen Agent → QA Agent loop (max 2 retries)
|
||||
# When supplementary context is provided (enrichment mode), use a relaxed
|
||||
# QA threshold since the response naturally references broader benchmark data
|
||||
qa_threshold = client_config.hallucination_threshold
|
||||
if supplementary_context is not None and len(supplementary_context) > 0:
|
||||
qa_threshold = min(qa_threshold, 0.5)
|
||||
logger.info(f"Using relaxed QA threshold ({qa_threshold}) for enriched context")
|
||||
|
||||
best_response = None
|
||||
best_score = 0.0
|
||||
qa_result = {"score": 0, "passed": False, "issues": []}
|
||||
retries = 0
|
||||
|
||||
for attempt in range(MAX_RETRIES + 1):
|
||||
retry_issues = qa_result["issues"] if attempt > 0 else None
|
||||
|
||||
# On retry, add explicit data-only instruction to prevent hallucination
|
||||
strict_data_instruction = None
|
||||
if attempt > 0:
|
||||
if supplementary_context is not None and len(supplementary_context) > 0:
|
||||
strict_data_instruction = (
|
||||
"CRITICAL: Only use data from the Data Context and Benchmark Data "
|
||||
"sections provided. Do NOT invent figures or branch names that are "
|
||||
"not present in either of those sections. You MAY reference benchmark "
|
||||
"branches for comparison and recommendations."
|
||||
)
|
||||
else:
|
||||
strict_data_instruction = (
|
||||
"CRITICAL: Do NOT mention any branches, figures, or historical data "
|
||||
"that are not in the SQL query result provided. Stick strictly to the "
|
||||
"data. If historical context from RAG is about different branches than "
|
||||
"what the query returned, IGNORE that context entirely."
|
||||
)
|
||||
|
||||
response = self.gen_agent.generate(
|
||||
question=message.text,
|
||||
soul_content=soul_content,
|
||||
data_context=data_context,
|
||||
rag_chunks=rag_chunks,
|
||||
retry_issues=retry_issues,
|
||||
retry_count=attempt,
|
||||
strict_data_instruction=strict_data_instruction,
|
||||
supplementary_context=supplementary_context,
|
||||
)
|
||||
|
||||
qa_result = self.qa_agent.evaluate(
|
||||
response=response,
|
||||
data_context=data_context,
|
||||
threshold=qa_threshold,
|
||||
supplementary_context=supplementary_context,
|
||||
user_question=message.text,
|
||||
)
|
||||
|
||||
# Track best response (prefer longer, richer responses over "no data" stubs)
|
||||
if qa_result["score"] > best_score or (
|
||||
qa_result["score"] == best_score
|
||||
and best_response is not None
|
||||
and len(response) > len(best_response)
|
||||
):
|
||||
best_score = qa_result["score"]
|
||||
best_response = response
|
||||
|
||||
if qa_result["passed"]:
|
||||
logger.info(f"QA passed on attempt {attempt + 1}")
|
||||
break
|
||||
else:
|
||||
retries += 1
|
||||
logger.warning(
|
||||
f"QA failed on attempt {attempt + 1}: "
|
||||
f"score={qa_result['score']:.2f}, issues={qa_result['issues']}"
|
||||
)
|
||||
|
||||
# If max retries exceeded, use best response with confidence warning
|
||||
final_response = best_response or response
|
||||
if not qa_result["passed"] and retries >= MAX_RETRIES:
|
||||
final_response += (
|
||||
"\n\n---\n"
|
||||
f"⚠️ *Confidence: {best_score:.0%} — "
|
||||
f"This response may contain approximations. "
|
||||
f"Please verify critical numbers against your source data.*"
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
result = {
|
||||
"response": final_response,
|
||||
"qa_score": best_score,
|
||||
"qa_passed": qa_result["passed"],
|
||||
"retries": retries,
|
||||
"sql": sql,
|
||||
"data_rows": len(data_context) if data_context is not None else 0,
|
||||
"rag_chunks_used": len(rag_chunks) if rag_chunks else 0,
|
||||
"elapsed_seconds": round(elapsed, 2),
|
||||
}
|
||||
|
||||
# Log interaction
|
||||
self._log_interaction(message, client_config, result)
|
||||
|
||||
return result
|
||||
|
||||
def _needs_enrichment(
|
||||
self,
|
||||
question: str,
|
||||
data_context: Optional[pd.DataFrame],
|
||||
) -> bool:
|
||||
"""Check if the query result is too sparse for a recommendation question."""
|
||||
# Only enrich if data is sparse
|
||||
if data_context is not None and len(data_context) > 3:
|
||||
return False
|
||||
|
||||
# Only enrich if user is asking for recommendations/improvement
|
||||
return bool(_RECOMMENDATION_KEYWORDS.search(question))
|
||||
|
||||
def _enrich_context(
|
||||
self,
|
||||
db,
|
||||
client_id: str,
|
||||
question: str,
|
||||
data_context: Optional[pd.DataFrame],
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Pull supplementary context: top-performing branches and channel
|
||||
benchmarks to help Gen Agent give actionable recommendations.
|
||||
"""
|
||||
try:
|
||||
# Get top 5 branches by ROI for comparison
|
||||
enrichment_sql = """
|
||||
SELECT branch, country, channel,
|
||||
SUM(spend) as total_spend,
|
||||
SUM(revenue) as total_revenue,
|
||||
SUM(leads) as total_leads,
|
||||
SUM(conversions) as total_conversions,
|
||||
ROUND((SUM(revenue)/NULLIF(SUM(spend),0))::numeric, 2) as roi
|
||||
FROM spend_data
|
||||
WHERE client_id = %s
|
||||
AND date >= CURRENT_DATE - INTERVAL '90 days'
|
||||
GROUP BY branch, country, channel
|
||||
HAVING SUM(spend) > 0
|
||||
ORDER BY roi DESC
|
||||
LIMIT 10
|
||||
"""
|
||||
top_performers = db.execute_query(enrichment_sql, (client_id,))
|
||||
|
||||
if top_performers is not None and len(top_performers) > 0:
|
||||
logger.info(f"Enrichment: fetched {len(top_performers)} top performer rows")
|
||||
return top_performers
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Context enrichment failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _log_interaction(
|
||||
self,
|
||||
message: NormalisedMessage,
|
||||
client_config: ClientConfig,
|
||||
result: Dict,
|
||||
):
|
||||
"""Log interaction for monitoring."""
|
||||
try:
|
||||
from rag.monitoring import log_interaction
|
||||
log_interaction(
|
||||
client_id=client_config.client_id,
|
||||
query=message.text,
|
||||
num_chunks=result.get("rag_chunks_used", 0),
|
||||
chunk_types_used=[], # Populated when retriever provides this info
|
||||
qa_score=result.get("qa_score", 0),
|
||||
qa_passed=result.get("qa_passed", False),
|
||||
retries=result.get("retries", 0),
|
||||
response_length=len(result.get("response", "")),
|
||||
elapsed_seconds=result.get("elapsed_seconds", 0),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Monitoring log failed: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[{client_config.client_id}] Query processed: "
|
||||
f"score={result['qa_score']:.2f}, passed={result['qa_passed']}, "
|
||||
f"retries={result['retries']}, time={result['elapsed_seconds']}s"
|
||||
)
|
||||
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Clawrity — QA Agent
|
||||
|
||||
Evaluates Gen Agent responses for faithfulness against data context.
|
||||
Uses Groq LLM at temperature 0.1 for strict, deterministic evaluation.
|
||||
Returns JSON: { score, passed, issues }
|
||||
Threshold from client YAML hallucination_threshold (default 0.75).
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, List, Dict
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from config.llm_client import get_llm_client, get_model_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EVAL_PROMPT = """You are a strict quality assurance evaluator for business intelligence responses.
|
||||
|
||||
Your job: verify that the response ONLY contains claims supported by the provided data.
|
||||
|
||||
## Data Context (ground truth)
|
||||
{data_context}
|
||||
|
||||
## Response to Evaluate
|
||||
{response}
|
||||
|
||||
## Evaluation Criteria
|
||||
|
||||
### 1. Branch Name Validation (CRITICAL)
|
||||
- Extract ALL branch/city names mentioned in the response
|
||||
- Compare against the branch names in the Data Context above
|
||||
- If ANY branch name appears in the response but NOT in the Data Context, this is a HALLUCINATION
|
||||
- Deduct 0.3 from score for EACH unrelated branch mentioned
|
||||
|
||||
### 2. Numerical Accuracy (CRITICAL)
|
||||
- ALL revenue, spend, lead, conversion, and ROI figures in the response must match the Data Context EXACTLY
|
||||
- If a number is mentioned that does not appear in the Data Context, deduct 0.2 from score
|
||||
- Rounded numbers are acceptable only if clearly approximate (e.g., "~$1.2M")
|
||||
|
||||
### 3. Historical Context Relevance
|
||||
- If the response includes historical context or trends, it is acceptable ONLY if it directly supports the answer about branches/entities present in the Data Context
|
||||
- Historical context about branches NOT in the current Data Context must be penalized: deduct 0.3 from score
|
||||
- Example: If Data Context shows Toronto, Vancouver, Dubai but response mentions "Lawton showed 16436% growth" — this is IRRELEVANT historical context and must be penalized
|
||||
|
||||
### 4. Completeness
|
||||
- Does the response address the user's question?
|
||||
- Are key data points from the Data Context included?
|
||||
|
||||
### 5. Appropriate Hedging
|
||||
- Does the response use uncertain language for inferences?
|
||||
- Recommendations should be clearly marked as suggestions, not facts
|
||||
|
||||
## Scoring
|
||||
Start at 1.0 and deduct points per the rules above. Minimum score is 0.0.
|
||||
|
||||
Return a JSON object with exactly this structure:
|
||||
{{
|
||||
"score": <float between 0.0 and 1.0>,
|
||||
"passed": <true if score >= {threshold}>,
|
||||
"issues": [<list of specific issues found, empty if none>]
|
||||
}}
|
||||
|
||||
IMPORTANT: If score < {threshold}, include in issues list exactly which branches, figures, or historical data were mentioned that do NOT appear in the Data Context. Format as:
|
||||
"Mentioned branches/figures not in current query result: [list them]"
|
||||
|
||||
Return ONLY the JSON. No other text."""
|
||||
|
||||
|
||||
class QAAgent:
|
||||
"""Quality assurance agent for validating Gen Agent responses."""
|
||||
|
||||
def __init__(self):
|
||||
self.client = get_llm_client()
|
||||
self.model = get_model_name()
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
response: str,
|
||||
data_context: Optional[pd.DataFrame] = None,
|
||||
threshold: float = 0.75,
|
||||
supplementary_context: Optional[pd.DataFrame] = None,
|
||||
user_question: str = "",
|
||||
) -> Dict:
|
||||
"""
|
||||
Evaluate a response for faithfulness.
|
||||
|
||||
Args:
|
||||
response: Gen Agent's response text
|
||||
data_context: The data the response should be grounded in
|
||||
threshold: Minimum score to pass (from client YAML)
|
||||
supplementary_context: Benchmark data (top performers) that is also valid ground truth
|
||||
user_question: The user's original question (entities mentioned here are valid context)
|
||||
|
||||
Returns:
|
||||
Dict with score (float), passed (bool), issues (list[str])
|
||||
"""
|
||||
data_str = ""
|
||||
if data_context is not None and len(data_context) > 0:
|
||||
data_str = data_context.to_markdown(index=False)
|
||||
else:
|
||||
data_str = "No structured data available."
|
||||
|
||||
# Include supplementary (benchmark) context as valid ground truth
|
||||
if supplementary_context is not None and len(supplementary_context) > 0:
|
||||
data_str += "\n\n### Benchmark Data (also valid ground truth)\n"
|
||||
data_str += supplementary_context.to_markdown(index=False)
|
||||
|
||||
# Include user question so QA knows which entities are valid context
|
||||
if user_question:
|
||||
data_str += f"\n\n### User Question Context\nThe user asked: \"{user_question}\"\nBranch/entity names mentioned in the user's question are valid to reference in the response."
|
||||
|
||||
prompt = EVAL_PROMPT.format(
|
||||
data_context=data_str,
|
||||
response=response,
|
||||
threshold=threshold,
|
||||
)
|
||||
|
||||
try:
|
||||
result = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a strict QA evaluator. Return only valid JSON. Pay special attention to branch names and figures that appear in the response but NOT in the data context — these are hallucinations."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.1,
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
raw = result.choices[0].message.content.strip()
|
||||
evaluation = self._parse_response(raw, threshold)
|
||||
logger.info(
|
||||
f"QA evaluation: score={evaluation['score']:.2f}, "
|
||||
f"passed={evaluation['passed']}, issues={len(evaluation['issues'])}"
|
||||
)
|
||||
return evaluation
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"QA evaluation failed: {e}")
|
||||
# On failure, pass with warning
|
||||
return {"score": 0.5, "passed": True, "issues": [f"QA evaluation error: {str(e)}"]}
|
||||
|
||||
def _parse_response(self, raw: str, threshold: float) -> Dict:
|
||||
"""Parse JSON response from QA LLM call."""
|
||||
try:
|
||||
# Strip markdown code fences if present
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith("```"):
|
||||
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned[3:]
|
||||
if cleaned.endswith("```"):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
|
||||
data = json.loads(cleaned)
|
||||
score = float(data.get("score", 0.5))
|
||||
return {
|
||||
"score": score,
|
||||
"passed": score >= threshold,
|
||||
"issues": data.get("issues", []),
|
||||
}
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.warning(f"Could not parse QA response: {e}. Raw: {raw[:200]}")
|
||||
return {"score": 0.5, "passed": True, "issues": ["QA response parsing failed"]}
|
||||
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
Clawrity — Scout Agent
|
||||
|
||||
Fetches real-time competitor updates and sector-specific news.
|
||||
Runs inside HEARTBEAT digest job ONLY — never on ad-hoc /chat queries.
|
||||
Appends "Market Intelligence" section to morning digest.
|
||||
|
||||
If nothing relevant is found, the section is omitted entirely — no filler.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from config.llm_client import get_llm_client, get_model_name
|
||||
from config.client_loader import ClientConfig
|
||||
from config.settings import get_settings
|
||||
from skills.web_search import web_search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SCOUT_PROMPT = """You are a business intelligence scout for {client_name}.
|
||||
Their sector: {sector}
|
||||
Their competitors: {competitors}
|
||||
|
||||
Below are web search results from the last {lookback} day(s).
|
||||
Extract ONLY what is directly relevant to this client's business.
|
||||
Ignore anything generic or unrelated to their sector.
|
||||
If nothing is relevant, respond with exactly: NO_RELEVANT_NEWS
|
||||
|
||||
Format relevant findings as a clean "Market Intelligence" section with bullet points.
|
||||
Each bullet should summarize one key finding with its source.
|
||||
|
||||
Results:
|
||||
{search_results}"""
|
||||
|
||||
QUERY_PROMPT = """You are a business intelligence scout for {client_name}.
|
||||
Sector: {sector}
|
||||
Competitors: {competitors}
|
||||
|
||||
The user asked: "{query}"
|
||||
|
||||
Below are web search results. Extract ONLY what is directly relevant to the
|
||||
user's question and this client's business context. Ignore generic or unrelated content.
|
||||
If nothing is relevant, respond with exactly: NO_RELEVANT_NEWS
|
||||
|
||||
Format findings as concise bullet points with sources.
|
||||
|
||||
Results:
|
||||
{search_results}"""
|
||||
|
||||
|
||||
class ScoutAgent:
|
||||
"""Competitor and sector intelligence agent."""
|
||||
|
||||
def __init__(self):
|
||||
self.client = get_llm_client()
|
||||
self.model = get_model_name()
|
||||
|
||||
async def gather_intelligence(
|
||||
self,
|
||||
client_config: ClientConfig,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Fetch and summarize competitor/sector news for digest.
|
||||
|
||||
Args:
|
||||
client_config: Client config with scout section
|
||||
|
||||
Returns:
|
||||
Formatted "Market Intelligence" markdown section, or None if nothing relevant
|
||||
"""
|
||||
scout_config = client_config.scout
|
||||
if not scout_config.sector and not scout_config.competitors:
|
||||
logger.info(f"[{client_config.client_id}] No scout config — skipping")
|
||||
return None
|
||||
|
||||
lookback = scout_config.news_lookback_days
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
# Gather search results
|
||||
all_results = []
|
||||
|
||||
# Search for each competitor
|
||||
for competitor in scout_config.competitors:
|
||||
query = f"{competitor} latest news"
|
||||
results = web_search(query, max_results=3, lookback_days=lookback)
|
||||
all_results.extend(results)
|
||||
|
||||
# Search for sector keywords
|
||||
for keyword in scout_config.keywords[:3]: # Limit to 3 keywords
|
||||
query = f"{keyword} news {today}"
|
||||
results = web_search(query, max_results=3, lookback_days=lookback)
|
||||
all_results.extend(results)
|
||||
|
||||
if not all_results:
|
||||
logger.info(f"[{client_config.client_id}] No search results found")
|
||||
return None
|
||||
|
||||
# Format results for LLM
|
||||
results_text = "\n\n".join(
|
||||
f"**{r['title']}** ({r['url']})\n{r['content']}"
|
||||
for r in all_results
|
||||
)
|
||||
|
||||
# Summarize with Groq
|
||||
prompt = SCOUT_PROMPT.format(
|
||||
client_name=client_config.client_name,
|
||||
sector=scout_config.sector,
|
||||
competitors=", ".join(scout_config.competitors),
|
||||
lookback=lookback,
|
||||
search_results=results_text,
|
||||
)
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a business intelligence scout."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=1024,
|
||||
)
|
||||
|
||||
result = response.choices[0].message.content.strip()
|
||||
|
||||
if result == "NO_RELEVANT_NEWS":
|
||||
logger.info(f"[{client_config.client_id}] Scout: no relevant news found")
|
||||
return None
|
||||
|
||||
section = f"## 🔭 Market Intelligence\n\n{result}"
|
||||
logger.info(f"[{client_config.client_id}] Scout: generated intelligence section")
|
||||
return section
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Scout Agent failed: {e}")
|
||||
return None
|
||||
|
||||
async def search_query(
|
||||
self,
|
||||
client_config: ClientConfig,
|
||||
query: str,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Run a targeted scout search for a specific user query.
|
||||
|
||||
Used by the /scout endpoint for ad-hoc competitor/news queries.
|
||||
|
||||
Args:
|
||||
client_config: Client config with scout section
|
||||
query: User's specific question about competitors/market
|
||||
|
||||
Returns:
|
||||
Formatted intelligence summary, or None if nothing relevant
|
||||
"""
|
||||
scout_config = client_config.scout
|
||||
|
||||
# Search with the user's query directly
|
||||
results = web_search(query, max_results=5, lookback_days=scout_config.news_lookback_days)
|
||||
|
||||
# Also search with competitor names if they appear in the query
|
||||
for competitor in scout_config.competitors:
|
||||
if competitor.lower() in query.lower():
|
||||
extra = web_search(f"{competitor} latest news", max_results=3, lookback_days=scout_config.news_lookback_days)
|
||||
results.extend(extra)
|
||||
|
||||
if not results:
|
||||
logger.info(f"[{client_config.client_id}] Scout query returned no results")
|
||||
return None
|
||||
|
||||
# Deduplicate by URL
|
||||
seen_urls = set()
|
||||
unique_results = []
|
||||
for r in results:
|
||||
if r["url"] not in seen_urls:
|
||||
seen_urls.add(r["url"])
|
||||
unique_results.append(r)
|
||||
|
||||
# Format results for LLM
|
||||
results_text = "\n\n".join(
|
||||
f"**{r['title']}** ({r['url']})\n{r['content']}"
|
||||
for r in unique_results
|
||||
)
|
||||
|
||||
prompt = QUERY_PROMPT.format(
|
||||
client_name=client_config.client_name,
|
||||
sector=scout_config.sector,
|
||||
competitors=", ".join(scout_config.competitors),
|
||||
query=query,
|
||||
search_results=results_text,
|
||||
)
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a business intelligence scout."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=1024,
|
||||
)
|
||||
|
||||
result = response.choices[0].message.content.strip()
|
||||
|
||||
if result == "NO_RELEVANT_NEWS":
|
||||
return None
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Scout query failed: {e}")
|
||||
return None
|
||||
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
Clawrity — Protocol Adapter (OpenClaw Pattern)
|
||||
|
||||
Normalises messages from any channel into a unified NormalisedMessage.
|
||||
Maps workspace/team IDs → client_id. Strips bot mentions.
|
||||
Interface: any channel handler produces NormalisedMessage — adding Teams,
|
||||
WhatsApp, etc. requires zero pipeline changes.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
|
||||
from config.client_loader import ClientConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NormalisedMessage:
|
||||
"""Unified message format — channel-agnostic."""
|
||||
text: str
|
||||
channel: str # Channel/conversation ID
|
||||
user_id: str
|
||||
client_id: str
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
source: str = "unknown" # "slack", "teams", "api"
|
||||
raw_event: Optional[Dict] = None
|
||||
|
||||
|
||||
# Pattern to match Slack bot mentions like <@U1234567890>
|
||||
SLACK_MENTION_PATTERN = re.compile(r"<@[A-Z0-9]+>\s*")
|
||||
|
||||
|
||||
class ProtocolAdapter:
|
||||
"""Normalises raw channel events into NormalisedMessages."""
|
||||
|
||||
def __init__(self, client_configs: Dict[str, ClientConfig]):
|
||||
"""
|
||||
Args:
|
||||
client_configs: Dict of client_id → ClientConfig
|
||||
"""
|
||||
self.client_configs = client_configs
|
||||
# Build workspace → client_id lookup
|
||||
self._workspace_map: Dict[str, str] = {}
|
||||
for cid, config in client_configs.items():
|
||||
for ws_id in config.slack_workspace_ids:
|
||||
self._workspace_map[ws_id] = cid
|
||||
# If only one client, use it as default
|
||||
self._default_client_id = (
|
||||
list(client_configs.keys())[0] if len(client_configs) == 1 else None
|
||||
)
|
||||
|
||||
def normalise_slack(self, event: dict, team_id: Optional[str] = None) -> NormalisedMessage:
|
||||
"""
|
||||
Normalise a Slack event into a NormalisedMessage.
|
||||
|
||||
Args:
|
||||
event: Raw Slack event dict (from Bolt SDK)
|
||||
team_id: Slack workspace/team ID
|
||||
|
||||
Returns:
|
||||
NormalisedMessage
|
||||
"""
|
||||
text = event.get("text", "")
|
||||
# Strip bot mention tags
|
||||
text = SLACK_MENTION_PATTERN.sub("", text).strip()
|
||||
|
||||
channel = event.get("channel", "")
|
||||
user_id = event.get("user", "")
|
||||
|
||||
# Map workspace to client
|
||||
client_id = self._resolve_client_id(team_id)
|
||||
|
||||
return NormalisedMessage(
|
||||
text=text,
|
||||
channel=channel,
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
source="slack",
|
||||
raw_event=event,
|
||||
)
|
||||
|
||||
def normalise_api(self, client_id: str, message: str) -> NormalisedMessage:
|
||||
"""Normalise a direct API call (POST /chat)."""
|
||||
return NormalisedMessage(
|
||||
text=message,
|
||||
channel="api",
|
||||
user_id="api_user",
|
||||
client_id=client_id,
|
||||
source="api",
|
||||
)
|
||||
|
||||
def normalise_teams(self, activity: dict) -> NormalisedMessage:
|
||||
"""
|
||||
Normalise a Microsoft Teams Bot Framework activity.
|
||||
# TODO: Implement full Teams normalisation when Teams handler is wired up.
|
||||
"""
|
||||
text = activity.get("text", "")
|
||||
# Strip Teams bot mention (usually <at>BotName</at>)
|
||||
text = re.sub(r"<at>.*?</at>\s*", "", text).strip()
|
||||
|
||||
return NormalisedMessage(
|
||||
text=text,
|
||||
channel=activity.get("channelId", "teams"),
|
||||
user_id=activity.get("from", {}).get("id", ""),
|
||||
client_id=self._default_client_id or "unknown",
|
||||
source="teams",
|
||||
raw_event=activity,
|
||||
)
|
||||
|
||||
def _resolve_client_id(self, workspace_id: Optional[str]) -> str:
|
||||
"""Resolve workspace/team ID to client_id."""
|
||||
if workspace_id and workspace_id in self._workspace_map:
|
||||
return self._workspace_map[workspace_id]
|
||||
if self._default_client_id:
|
||||
return self._default_client_id
|
||||
logger.warning(f"Could not resolve client for workspace: {workspace_id}")
|
||||
return "unknown"
|
||||
@@ -0,0 +1,263 @@
|
||||
"""
|
||||
Clawrity — Slack Handler (Socket Mode)
|
||||
|
||||
Listens for app_mention and message events via Slack Bolt SDK.
|
||||
Runs in a background thread to not block FastAPI.
|
||||
|
||||
=== SETUP REQUIRED ===
|
||||
Before running, configure these in your .env file:
|
||||
|
||||
SLACK_BOT_TOKEN=xoxb-... ← OAuth & Permissions → Install to Workspace
|
||||
SLACK_APP_TOKEN=xapp-... ← Socket Mode → Generate App-Level Token
|
||||
SLACK_SIGNING_SECRET=... ← Basic Information → App Credentials
|
||||
|
||||
See README.md for detailed Slack app setup instructions.
|
||||
=======================
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, Optional, Set
|
||||
|
||||
from config.settings import get_settings
|
||||
from config.client_loader import ClientConfig
|
||||
from channels.protocol_adapter import ProtocolAdapter, NormalisedMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Thread pool for processing LLM pipeline without blocking event handlers
|
||||
_executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="clawrity-slack")
|
||||
|
||||
# Module-level guard: only one SlackHandler should be active at a time
|
||||
_active_handler: Optional["SlackHandler"] = None
|
||||
|
||||
|
||||
class SlackHandler:
|
||||
"""Slack Bot using Socket Mode via Bolt SDK."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
protocol_adapter: ProtocolAdapter,
|
||||
client_configs: Dict[str, ClientConfig],
|
||||
orchestrator, # agents.orchestrator.Orchestrator
|
||||
):
|
||||
self.adapter = protocol_adapter
|
||||
self.client_configs = client_configs
|
||||
self.orchestrator = orchestrator
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Bot Token (xoxb-...) — from .env SLACK_BOT_TOKEN
|
||||
# This is the OAuth token installed to your workspace.
|
||||
# ---------------------------------------------------------------
|
||||
self.bot_token = settings.slack_bot_token
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# App-Level Token (xapp-...) — from .env SLACK_APP_TOKEN
|
||||
# Required for Socket Mode. Generated in Slack app settings.
|
||||
# ---------------------------------------------------------------
|
||||
self.app_token = settings.slack_app_token
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Signing Secret — from .env SLACK_SIGNING_SECRET
|
||||
# Used to verify incoming requests from Slack.
|
||||
# ---------------------------------------------------------------
|
||||
self.signing_secret = settings.slack_signing_secret
|
||||
|
||||
self.app = None
|
||||
self.handler = None
|
||||
|
||||
# Deduplication: track recently processed event timestamps
|
||||
# Slack retries events if handler is slow — this prevents duplicates
|
||||
self._processed_events: Set[str] = set()
|
||||
self._processed_lock = threading.Lock()
|
||||
|
||||
def _validate_tokens(self) -> bool:
|
||||
"""Check that all required Slack tokens are configured."""
|
||||
if not self.bot_token:
|
||||
logger.warning(
|
||||
"SLACK_BOT_TOKEN not set. Slack bot will not start. "
|
||||
"See README.md → Slack Bot Setup for instructions."
|
||||
)
|
||||
return False
|
||||
if not self.app_token:
|
||||
logger.warning(
|
||||
"SLACK_APP_TOKEN not set. Socket Mode requires an app-level token. "
|
||||
"Go to your Slack app → Socket Mode → Generate Token."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _is_duplicate_event(self, event: dict) -> bool:
|
||||
"""Check if we've already processed this event (Slack retry dedup)."""
|
||||
# Use multiple fields to build a robust dedup key.
|
||||
# client_msg_id is unique per user message (present on message events,
|
||||
# but NOT on app_mention events). event_ts is present on both.
|
||||
# We store keys for all strategies so cross-event-type dedup works.
|
||||
msg_id = event.get("client_msg_id")
|
||||
event_ts = event.get("event_ts") or event.get("ts", "")
|
||||
user = event.get("user", "")
|
||||
|
||||
# Build candidate keys
|
||||
keys = set()
|
||||
if msg_id:
|
||||
keys.add(f"msg:{msg_id}")
|
||||
if event_ts:
|
||||
keys.add(f"ts:{event_ts}")
|
||||
# Fallback: combine event type + ts + user for events without client_msg_id
|
||||
event_type = event.get("type", "")
|
||||
if event_ts and user:
|
||||
keys.add(f"evt:{event_type}:{event_ts}:{user}")
|
||||
|
||||
if not keys:
|
||||
return False
|
||||
|
||||
with self._processed_lock:
|
||||
# Check ALL keys — if any match, it's a duplicate
|
||||
for key in keys:
|
||||
if key in self._processed_events:
|
||||
logger.debug(f"Skipping duplicate event (matched key: {key})")
|
||||
return True
|
||||
|
||||
# Register ALL keys so cross-event-type dedup works
|
||||
# (app_mention and message for the same user message share event_ts)
|
||||
self._processed_events.update(keys)
|
||||
|
||||
# Prune old entries (keep set from growing indefinitely)
|
||||
if len(self._processed_events) > 500:
|
||||
self._processed_events = set(list(self._processed_events)[-200:])
|
||||
|
||||
return False
|
||||
|
||||
def _setup_app(self):
|
||||
"""Initialize Slack Bolt App and register event handlers."""
|
||||
from slack_bolt import App
|
||||
from slack_bolt.adapter.socket_mode import SocketModeHandler
|
||||
|
||||
self.app = App(
|
||||
token=self.bot_token,
|
||||
signing_secret=self.signing_secret if self.signing_secret else None,
|
||||
)
|
||||
|
||||
# Track bot's own user ID to prevent self-response loops
|
||||
self._bot_user_id = None
|
||||
try:
|
||||
auth = self.app.client.auth_test()
|
||||
self._bot_user_id = auth.get("user_id", "")
|
||||
logger.info(f"Bot user ID: {self._bot_user_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not fetch bot user ID: {e}")
|
||||
|
||||
# --- Event: Bot mentioned in a channel ---
|
||||
@self.app.event("app_mention")
|
||||
def handle_mention(event, say, context):
|
||||
# Return IMMEDIATELY so Slack gets ack — process in background
|
||||
if self._is_duplicate_event(event):
|
||||
return
|
||||
_executor.submit(self._handle_event, event, say, context)
|
||||
|
||||
# --- Event: Direct message to bot ---
|
||||
@self.app.event("message")
|
||||
def handle_message(event, say, context):
|
||||
# Ignore bot's own messages and message_changed events
|
||||
if event.get("subtype") in (
|
||||
"bot_message",
|
||||
"message_changed",
|
||||
"message_deleted",
|
||||
):
|
||||
return
|
||||
if event.get("bot_id"):
|
||||
return
|
||||
# Ignore if this is from the bot itself
|
||||
if self._bot_user_id and event.get("user") == self._bot_user_id:
|
||||
return
|
||||
# Skip channel messages that contain a bot mention —
|
||||
# those are handled by the app_mention handler above.
|
||||
# Only process DMs here (channel_type == "im").
|
||||
channel_type = event.get("channel_type", "")
|
||||
if channel_type != "im":
|
||||
return
|
||||
if self._is_duplicate_event(event):
|
||||
return
|
||||
# Return IMMEDIATELY — process in background
|
||||
_executor.submit(self._handle_event, event, say, context)
|
||||
|
||||
self.handler = SocketModeHandler(self.app, self.app_token)
|
||||
|
||||
def _handle_event(self, event: dict, say, context):
|
||||
"""Process an incoming Slack event (runs in background thread)."""
|
||||
try:
|
||||
team_id = context.get("team_id", None) if context else None
|
||||
message = self.adapter.normalise_slack(event, team_id=team_id)
|
||||
|
||||
if not message.text:
|
||||
return
|
||||
|
||||
if message.client_id == "unknown":
|
||||
say("⚠️ Could not identify your workspace. Please contact support.")
|
||||
return
|
||||
|
||||
client_config = self.client_configs.get(message.client_id)
|
||||
if not client_config:
|
||||
say(f"⚠️ No configuration found for client: {message.client_id}")
|
||||
return
|
||||
|
||||
# Run the orchestrator pipeline (async in sync context)
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
result = loop.run_until_complete(
|
||||
self.orchestrator.process(message, client_config)
|
||||
)
|
||||
say(result["response"])
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Slack event handler error: {e}", exc_info=True)
|
||||
say(
|
||||
"❌ I encountered an error processing your request. "
|
||||
"Please try again or contact support."
|
||||
)
|
||||
|
||||
def start(self):
|
||||
"""Start the Slack bot in a background thread."""
|
||||
global _active_handler
|
||||
|
||||
if not self._validate_tokens():
|
||||
logger.info("Slack bot not started — missing tokens")
|
||||
return
|
||||
|
||||
# Stop any existing handler to prevent duplicate Socket Mode connections
|
||||
if _active_handler is not None:
|
||||
logger.info("Stopping previous Slack handler before starting new one")
|
||||
_active_handler.stop()
|
||||
_active_handler = None
|
||||
|
||||
try:
|
||||
self._setup_app()
|
||||
|
||||
def _run():
|
||||
logger.info("Starting Slack bot (Socket Mode)...")
|
||||
self.handler.start()
|
||||
|
||||
self._thread = threading.Thread(target=_run, daemon=True)
|
||||
self._thread.start()
|
||||
_active_handler = self
|
||||
logger.info("Slack bot started in background thread")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start Slack bot: {e}")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the Slack bot."""
|
||||
if self.handler:
|
||||
try:
|
||||
self.handler.close()
|
||||
logger.info("Slack bot stopped")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping Slack bot: {e}")
|
||||
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Clawrity — Microsoft Teams Handler (STUB)
|
||||
|
||||
Skeleton implementation of the Bot Framework adapter for Microsoft Teams.
|
||||
Proves the multi-channel architecture is real — any channel handler produces
|
||||
NormalisedMessage via ProtocolAdapter, so the entire pipeline works unchanged.
|
||||
|
||||
# TODO: Wire up Azure Bot credentials when ready for Teams demo.
|
||||
# Required: MICROSOFT_APP_ID, MICROSOFT_APP_PASSWORD
|
||||
# Package: botbuilder-core, botbuilder-schema
|
||||
|
||||
Status: NOT IMPLEMENTED — Slack is the priority for development.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
|
||||
from channels.protocol_adapter import ProtocolAdapter, NormalisedMessage
|
||||
from config.client_loader import ClientConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TeamsHandler:
|
||||
"""
|
||||
Microsoft Teams bot handler stub.
|
||||
|
||||
Architecture:
|
||||
Teams Activity → ProtocolAdapter.normalise_teams() → Orchestrator → Response
|
||||
|
||||
The same pipeline used by Slack — zero business logic in this layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
protocol_adapter: ProtocolAdapter,
|
||||
client_configs: Dict[str, ClientConfig],
|
||||
orchestrator, # agents.orchestrator.Orchestrator
|
||||
):
|
||||
self.adapter = protocol_adapter
|
||||
self.client_configs = client_configs
|
||||
self.orchestrator = orchestrator
|
||||
|
||||
# TODO: Wire up Azure Bot credentials from .env
|
||||
# self.app_id = settings.microsoft_app_id
|
||||
# self.app_password = settings.microsoft_app_password
|
||||
|
||||
async def handle_activity(self, activity: dict) -> str:
|
||||
"""
|
||||
Process an incoming Teams Bot Framework activity.
|
||||
|
||||
# TODO: Implement when ready for Teams demo.
|
||||
|
||||
Expected flow:
|
||||
1. Receive activity from Bot Framework webhook
|
||||
2. Normalise via ProtocolAdapter.normalise_teams(activity)
|
||||
3. Route to Orchestrator.process(message, client_config)
|
||||
4. Return response via Bot Framework turn context
|
||||
|
||||
Args:
|
||||
activity: Raw Bot Framework activity dict
|
||||
|
||||
Returns:
|
||||
Response text to send back to Teams
|
||||
"""
|
||||
# --- Stub implementation ---
|
||||
message = self.adapter.normalise_teams(activity)
|
||||
|
||||
client_config = self.client_configs.get(message.client_id)
|
||||
if not client_config:
|
||||
return f"No configuration found for client: {message.client_id}"
|
||||
|
||||
result = await self.orchestrator.process(message, client_config)
|
||||
return result["response"]
|
||||
|
||||
def setup_routes(self, app):
|
||||
"""
|
||||
Register Teams webhook endpoint with FastAPI.
|
||||
|
||||
# TODO: Implement Bot Framework adapter integration.
|
||||
|
||||
Expected endpoint:
|
||||
POST /api/teams/messages → Bot Framework webhook
|
||||
|
||||
Requires:
|
||||
- botbuilder-core package
|
||||
- BotFrameworkAdapter with app_id + app_password
|
||||
- CloudAdapter or BotFrameworkHttpClient
|
||||
"""
|
||||
logger.info(
|
||||
"Teams handler stub loaded. "
|
||||
"To enable Teams: install botbuilder-core, set Azure Bot credentials."
|
||||
)
|
||||
|
||||
# TODO: Uncomment and implement when ready
|
||||
#
|
||||
# from botbuilder.core import (
|
||||
# BotFrameworkAdapter,
|
||||
# BotFrameworkAdapterSettings,
|
||||
# TurnContext,
|
||||
# )
|
||||
#
|
||||
# settings = BotFrameworkAdapterSettings(
|
||||
# app_id=self.app_id,
|
||||
# app_password=self.app_password,
|
||||
# )
|
||||
# adapter = BotFrameworkAdapter(settings)
|
||||
#
|
||||
# @app.post("/api/teams/messages")
|
||||
# async def teams_webhook(request: Request):
|
||||
# body = await request.json()
|
||||
# activity = Activity().deserialize(body)
|
||||
# auth_header = request.headers.get("Authorization", "")
|
||||
# response = await adapter.process_activity(
|
||||
# activity, auth_header, self._on_turn
|
||||
# )
|
||||
# return response
|
||||
#
|
||||
# async def _on_turn(turn_context: TurnContext):
|
||||
# activity = turn_context.activity
|
||||
# response = await self.handle_activity(activity.__dict__)
|
||||
# await turn_context.send_activity(response)
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
Clawrity — Client Configuration Loader
|
||||
|
||||
Scans config/clients/ for YAML files and parses each into a ClientConfig model.
|
||||
Supports ${ENV_VAR} interpolation in YAML values.
|
||||
New client = new YAML file. Zero code changes.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import glob
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic models for client YAML structure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class DataSourceConfig(BaseModel):
|
||||
type: str = "csv"
|
||||
path: str = ""
|
||||
|
||||
|
||||
class DatabaseConfig(BaseModel):
|
||||
url: str = ""
|
||||
schema_name: str = "" # 'schema' is a Pydantic reserved attr
|
||||
|
||||
|
||||
class ScoutConfig(BaseModel):
|
||||
sector: str = ""
|
||||
competitors: List[str] = []
|
||||
keywords: List[str] = []
|
||||
news_lookback_days: int = 1
|
||||
|
||||
|
||||
class ClientConfig(BaseModel):
|
||||
client_id: str
|
||||
client_name: str = ""
|
||||
|
||||
data_source: DataSourceConfig = DataSourceConfig()
|
||||
database: DatabaseConfig = DatabaseConfig()
|
||||
|
||||
countries: List[str] = []
|
||||
risk_threshold: float = 0.15
|
||||
hallucination_threshold: float = 0.75
|
||||
|
||||
digest_schedule: str = "08:00"
|
||||
timezone: str = "UTC"
|
||||
|
||||
channels: Dict[str, str] = {}
|
||||
|
||||
soul_file: str = ""
|
||||
heartbeat_file: str = ""
|
||||
|
||||
column_mapping: Dict[str, str] = {}
|
||||
|
||||
scout: ScoutConfig = ScoutConfig()
|
||||
|
||||
# Runtime: workspace/team ID → client_id mapping for ProtocolAdapter
|
||||
slack_workspace_ids: List[str] = []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Environment variable interpolation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_ENV_PATTERN = re.compile(r"\$\{(\w+)\}")
|
||||
|
||||
|
||||
def _interpolate_env(value: str) -> str:
|
||||
"""Replace ${ENV_VAR} placeholders with actual environment variable values."""
|
||||
def _replace(match):
|
||||
var_name = match.group(1)
|
||||
return os.environ.get(var_name, match.group(0))
|
||||
|
||||
if isinstance(value, str):
|
||||
return _ENV_PATTERN.sub(_replace, value)
|
||||
return value
|
||||
|
||||
|
||||
def _interpolate_dict(d: dict) -> dict:
|
||||
"""Recursively interpolate environment variables in a dictionary."""
|
||||
result = {}
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict):
|
||||
result[key] = _interpolate_dict(value)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [
|
||||
_interpolate_env(v) if isinstance(v, str) else v
|
||||
for v in value
|
||||
]
|
||||
elif isinstance(value, str):
|
||||
result[key] = _interpolate_env(value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Loader
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_client_configs(config_dir: Optional[str] = None) -> Dict[str, ClientConfig]:
|
||||
"""
|
||||
Load all client YAML files from the config directory.
|
||||
|
||||
Returns:
|
||||
Dict mapping client_id → ClientConfig
|
||||
"""
|
||||
if config_dir is None:
|
||||
config_dir = get_settings().clients_config_dir
|
||||
|
||||
configs: Dict[str, ClientConfig] = {}
|
||||
yaml_pattern = os.path.join(config_dir, "*.yaml")
|
||||
|
||||
for yaml_path in glob.glob(yaml_pattern):
|
||||
try:
|
||||
with open(yaml_path, "r") as f:
|
||||
raw = yaml.safe_load(f)
|
||||
|
||||
if not raw or "client_id" not in raw:
|
||||
logger.warning(f"Skipping {yaml_path}: missing client_id")
|
||||
continue
|
||||
|
||||
# Interpolate environment variables
|
||||
interpolated = _interpolate_dict(raw)
|
||||
|
||||
# Handle 'schema' → 'schema_name' mapping for Pydantic
|
||||
if "database" in interpolated and "schema" in interpolated["database"]:
|
||||
interpolated["database"]["schema_name"] = interpolated["database"].pop("schema")
|
||||
|
||||
config = ClientConfig(**interpolated)
|
||||
configs[config.client_id] = config
|
||||
logger.info(f"Loaded client config: {config.client_id} from {yaml_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {yaml_path}: {e}")
|
||||
|
||||
if not configs:
|
||||
logger.warning(f"No client configs found in {config_dir}")
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def get_client_config(client_id: str, configs: Optional[Dict[str, ClientConfig]] = None) -> Optional[ClientConfig]:
|
||||
"""Get a specific client config by ID."""
|
||||
if configs is None:
|
||||
configs = load_client_configs()
|
||||
return configs.get(client_id)
|
||||
@@ -0,0 +1,36 @@
|
||||
client_id: acme_corp
|
||||
client_name: ACME Corporation
|
||||
|
||||
data_source:
|
||||
type: "csv"
|
||||
path: "data/processed/acme_merged.csv"
|
||||
|
||||
database:
|
||||
url: "${DATABASE_URL}"
|
||||
schema: "acme"
|
||||
|
||||
countries: ["US", "Canada", "MENA"]
|
||||
risk_threshold: 0.15
|
||||
hallucination_threshold: 0.75
|
||||
|
||||
digest_schedule: "08:00"
|
||||
timezone: "Asia/Kolkata"
|
||||
|
||||
channels:
|
||||
slack_webhook: "${ACME_SLACK_WEBHOOK}"
|
||||
|
||||
soul_file: "soul/acme_soul.md"
|
||||
heartbeat_file: "heartbeat/acme_heartbeat.md"
|
||||
|
||||
column_mapping:
|
||||
Order Date: date
|
||||
Country: country
|
||||
City: branch
|
||||
Sales: revenue
|
||||
Profit: profit
|
||||
|
||||
scout:
|
||||
sector: "global retail"
|
||||
competitors: ["IKEA", "Amazon", "Walmart", "Staples"]
|
||||
keywords: ["retail supply chain", "furniture market trends", "office supplies demand", "global retail ecommerce"]
|
||||
news_lookback_days: 1
|
||||
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
Clawrity — LLM Client Factory
|
||||
|
||||
Provides a unified LLM client that works with both NVIDIA NIM and Groq.
|
||||
Both are OpenAI-compatible APIs, so we use the OpenAI client with different
|
||||
base URLs and API keys.
|
||||
|
||||
Auto-detects provider from settings:
|
||||
- NVIDIA NIM: base_url="https://integrate.api.nvidia.com/v1"
|
||||
- Groq: base_url="https://api.groq.com/openai/v1"
|
||||
"""
|
||||
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Provider configs
|
||||
_PROVIDERS = {
|
||||
"nvidia": {
|
||||
"base_url": "https://integrate.api.nvidia.com/v1",
|
||||
"default_model": "meta/llama-3.3-70b-instruct",
|
||||
},
|
||||
"groq": {
|
||||
"base_url": "https://api.groq.com/openai/v1",
|
||||
"default_model": "llama-3.3-70b-versatile",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_llm_client() -> OpenAI:
|
||||
"""Get the configured LLM client (NVIDIA NIM or Groq)."""
|
||||
settings = get_settings()
|
||||
provider = settings.active_llm_provider
|
||||
|
||||
if provider == "nvidia":
|
||||
api_key = settings.nvidia_api_key
|
||||
elif provider == "groq":
|
||||
api_key = settings.groq_api_key
|
||||
else:
|
||||
raise ValueError(f"Unknown LLM provider: {provider}")
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
f"No API key configured for LLM provider '{provider}'. "
|
||||
f"Set {'NVIDIA_API_KEY' if provider == 'nvidia' else 'GROQ_API_KEY'} in .env"
|
||||
)
|
||||
|
||||
config = _PROVIDERS[provider]
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=config["base_url"],
|
||||
)
|
||||
|
||||
logger.info(f"LLM client: {provider} ({config['base_url']})")
|
||||
return client
|
||||
|
||||
|
||||
def get_model_name() -> str:
|
||||
"""Get the model name for the active provider."""
|
||||
settings = get_settings()
|
||||
provider = settings.active_llm_provider
|
||||
|
||||
# If user specified a model in settings, use it
|
||||
# Otherwise use the provider default
|
||||
model = settings.llm_model
|
||||
if model == "meta/llama-3.3-70b-instruct" and provider == "groq":
|
||||
model = _PROVIDERS["groq"]["default_model"]
|
||||
elif model == "llama-3.3-70b-versatile" and provider == "nvidia":
|
||||
model = _PROVIDERS["nvidia"]["default_model"]
|
||||
|
||||
return model
|
||||
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
Clawrity — Application Settings
|
||||
|
||||
Loads environment variables via pydantic-settings.
|
||||
All secrets read from .env file — nothing is hardcoded.
|
||||
"""
|
||||
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
# --- Database ---
|
||||
database_url: str = "postgresql://user:pass@localhost:5432/clawrity"
|
||||
|
||||
# --- LLM Providers ---
|
||||
groq_api_key: str = ""
|
||||
nvidia_api_key: str = ""
|
||||
|
||||
# --- Slack (Socket Mode) ---
|
||||
# Bot Token (xoxb-...) — OAuth & Permissions → Install to Workspace
|
||||
slack_bot_token: str = ""
|
||||
# App-Level Token (xapp-...) — Socket Mode → Generate Token
|
||||
slack_app_token: str = ""
|
||||
# Signing Secret — Basic Information → App Credentials
|
||||
slack_signing_secret: str = ""
|
||||
|
||||
# --- Tavily Web Search ---
|
||||
tavily_api_key: str = ""
|
||||
|
||||
# --- Slack Webhook for digest delivery ---
|
||||
acme_slack_webhook: str = ""
|
||||
|
||||
# --- Paths ---
|
||||
data_raw_dir: str = "data/raw"
|
||||
data_processed_dir: str = "data/processed"
|
||||
logs_dir: str = "logs"
|
||||
clients_config_dir: str = "config/clients"
|
||||
|
||||
# --- Model Defaults ---
|
||||
llm_model: str = "meta/llama-3.3-70b-instruct"
|
||||
llm_provider: str = "" # auto-detected: "nvidia" or "groq"
|
||||
embedding_model: str = "all-MiniLM-L6-v2"
|
||||
embedding_dim: int = 384
|
||||
|
||||
@property
|
||||
def active_llm_provider(self) -> str:
|
||||
"""Auto-detect which LLM provider to use based on available keys."""
|
||||
if self.llm_provider:
|
||||
return self.llm_provider
|
||||
if self.nvidia_api_key:
|
||||
return "nvidia"
|
||||
if self.groq_api_key:
|
||||
return "groq"
|
||||
return "nvidia" # default
|
||||
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
}
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""Singleton settings instance. Cached after first call."""
|
||||
return Settings()
|
||||
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
Clawrity — Base Data Connector
|
||||
|
||||
Abstract interface for data connectors.
|
||||
All connectors implement load() → pd.DataFrame.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class BaseConnector(ABC):
|
||||
"""Abstract base class for data source connectors."""
|
||||
|
||||
@abstractmethod
|
||||
def load(self, path: str, **kwargs) -> pd.DataFrame:
|
||||
"""
|
||||
Load data from the source.
|
||||
|
||||
Args:
|
||||
path: Path to the data source
|
||||
**kwargs: Additional arguments specific to the connector
|
||||
|
||||
Returns:
|
||||
pandas DataFrame with loaded data
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate(self, df: pd.DataFrame, required_columns: list) -> bool:
|
||||
"""
|
||||
Validate that the DataFrame has expected columns.
|
||||
|
||||
Args:
|
||||
df: DataFrame to validate
|
||||
required_columns: List of column names that must be present
|
||||
|
||||
Returns:
|
||||
True if valid
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Clawrity — CSV/Excel Data Connector
|
||||
|
||||
Auto-detects file format based on extension:
|
||||
.csv → pandas read_csv
|
||||
.xlsx / .xls → pandas read_excel (via openpyxl)
|
||||
|
||||
Supports both formats since Kaggle datasets vary by download version.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from connectors.base_connector import BaseConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CSVConnector(BaseConnector):
|
||||
"""Connector for CSV and Excel files with auto-detection."""
|
||||
|
||||
def load(self, path: str, **kwargs) -> pd.DataFrame:
|
||||
"""
|
||||
Load data from a CSV or Excel file.
|
||||
Auto-detects format based on file extension.
|
||||
|
||||
Args:
|
||||
path: Path to the file (.csv, .xlsx, .xls)
|
||||
**kwargs: Passed through to pandas read function.
|
||||
Useful kwargs: sheet_name, encoding, sep
|
||||
|
||||
Returns:
|
||||
pandas DataFrame
|
||||
"""
|
||||
file_path = Path(path)
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"Data file not found: {path}")
|
||||
|
||||
ext = file_path.suffix.lower()
|
||||
|
||||
if ext == ".csv":
|
||||
logger.info(f"Loading CSV: {path}")
|
||||
df = pd.read_csv(path, encoding='latin-1', **kwargs)
|
||||
elif ext in (".xlsx", ".xls"):
|
||||
logger.info(f"Loading Excel ({ext}): {path}")
|
||||
# Default to first sheet unless specified
|
||||
sheet_name = kwargs.pop("sheet_name", 0)
|
||||
df = pd.read_excel(path, sheet_name=sheet_name, engine="openpyxl", **kwargs)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported file format: {ext}. "
|
||||
f"Supported: .csv, .xlsx, .xls"
|
||||
)
|
||||
|
||||
logger.info(f"Loaded {len(df)} rows, {len(df.columns)} columns from {file_path.name}")
|
||||
return df
|
||||
|
||||
def validate(self, df: pd.DataFrame, required_columns: list) -> bool:
|
||||
"""
|
||||
Validate that the DataFrame has all required columns.
|
||||
Uses case-insensitive matching.
|
||||
|
||||
Args:
|
||||
df: DataFrame to validate
|
||||
required_columns: List of column names that must be present
|
||||
|
||||
Returns:
|
||||
True if all required columns found
|
||||
"""
|
||||
df_cols_lower = {col.lower().strip() for col in df.columns}
|
||||
missing = []
|
||||
|
||||
for col in required_columns:
|
||||
if col.lower().strip() not in df_cols_lower:
|
||||
missing.append(col)
|
||||
|
||||
if missing:
|
||||
logger.error(
|
||||
f"Missing required columns: {missing}. "
|
||||
f"Available: {list(df.columns)}"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -0,0 +1,38 @@
|
||||
services:
|
||||
clawrity-api:
|
||||
build: .
|
||||
ports:
|
||||
- "8000:8000"
|
||||
environment:
|
||||
- DATABASE_URL=postgresql://user:pass@postgres:5432/clawrity
|
||||
- GROQ_API_KEY=${GROQ_API_KEY}
|
||||
- SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN}
|
||||
- SLACK_APP_TOKEN=${SLACK_APP_TOKEN}
|
||||
- SLACK_SIGNING_SECRET=${SLACK_SIGNING_SECRET}
|
||||
- TAVILY_API_KEY=${TAVILY_API_KEY}
|
||||
- ACME_SLACK_WEBHOOK=${ACME_SLACK_WEBHOOK}
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
- ./data:/app/data
|
||||
- ./logs:/app/logs
|
||||
|
||||
postgres:
|
||||
image: ankane/pgvector
|
||||
environment:
|
||||
POSTGRES_DB: clawrity
|
||||
POSTGRES_USER: user
|
||||
POSTGRES_PASSWORD: pass
|
||||
volumes:
|
||||
- pg_data:/var/lib/postgresql/data
|
||||
ports:
|
||||
- "5432:5432"
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U user -d clawrity"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
volumes:
|
||||
pg_data:
|
||||
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
Clawrity — ETL Normaliser
|
||||
|
||||
Applies column mappings from client YAML, normalises data types,
|
||||
cleans strings, and handles nulls.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def normalise_dataframe(
|
||||
df: pd.DataFrame,
|
||||
column_mapping: Dict[str, str],
|
||||
date_column: str = "date",
|
||||
) -> pd.DataFrame:
|
||||
"""Normalise a DataFrame using the client's column mapping."""
|
||||
df = df.copy()
|
||||
original_len = len(df)
|
||||
|
||||
# Step 1: Apply column mapping (case-insensitive)
|
||||
df_cols_map = {col.strip(): col for col in df.columns}
|
||||
rename_map = {}
|
||||
for source, target in column_mapping.items():
|
||||
if source in df_cols_map:
|
||||
rename_map[df_cols_map[source]] = target
|
||||
else:
|
||||
for orig_col, actual_col in df_cols_map.items():
|
||||
if orig_col.lower() == source.lower():
|
||||
rename_map[actual_col] = target
|
||||
break
|
||||
if rename_map:
|
||||
df = df.rename(columns=rename_map)
|
||||
logger.info(f"Renamed columns: {rename_map}")
|
||||
|
||||
# Step 2: Parse dates
|
||||
if date_column in df.columns:
|
||||
df[date_column] = pd.to_datetime(df[date_column], errors="coerce")
|
||||
df = df.dropna(subset=[date_column])
|
||||
df[date_column] = df[date_column].dt.date
|
||||
|
||||
# Step 3: Clean string columns
|
||||
for col in ["country", "branch", "channel"]:
|
||||
if col in df.columns:
|
||||
df[col] = (
|
||||
df[col].astype(str).str.strip().str.title()
|
||||
.replace({"Nan": None, "None": None, "": None})
|
||||
)
|
||||
|
||||
# Step 4: Handle numeric nulls
|
||||
for col in ["spend", "revenue", "profit", "leads", "conversions"]:
|
||||
if col in df.columns:
|
||||
df[col] = pd.to_numeric(df[col], errors="coerce")
|
||||
|
||||
# Step 5: Remove duplicates
|
||||
df = df.drop_duplicates()
|
||||
dropped = original_len - len(df)
|
||||
if dropped > 0:
|
||||
logger.info(f"Removed {dropped} duplicate rows")
|
||||
|
||||
logger.info(f"Normalisation complete: {len(df)} rows")
|
||||
return df
|
||||
|
||||
|
||||
def remove_outliers(df: pd.DataFrame, columns: list, n_std: float = 3.0) -> pd.DataFrame:
|
||||
"""Remove rows with values > n_std standard deviations from mean."""
|
||||
df = df.copy()
|
||||
original_len = len(df)
|
||||
for col in columns:
|
||||
if col in df.columns and pd.api.types.is_numeric_dtype(df[col]):
|
||||
mean, std = df[col].mean(), df[col].std()
|
||||
if std > 0:
|
||||
df = df[(df[col] - mean).abs() <= n_std * std]
|
||||
removed = original_len - len(df)
|
||||
if removed > 0:
|
||||
logger.info(f"Removed {removed} outlier rows (>{n_std} std devs)")
|
||||
return df
|
||||
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
Clawrity — Prophet Forecasting Engine
|
||||
|
||||
Trains Prophet models on branch-level monthly revenue time series.
|
||||
Forecasts 6 months ahead. Caches results in PostgreSQL forecasts table.
|
||||
|
||||
Limitations (be explicit):
|
||||
- Predicts revenue TRENDS only
|
||||
- Does NOT claim ROI-per-dollar forecasting (spend→revenue is approximate)
|
||||
- Requires minimum 2 years of data per branch
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from skills.postgres_connector import get_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MIN_MONTHS = 24 # Minimum 2 years of data
|
||||
FORECAST_MONTHS = 6
|
||||
|
||||
|
||||
class ProphetEngine:
|
||||
"""Time series forecasting using Facebook Prophet."""
|
||||
|
||||
def train_and_forecast(self, client_id: str) -> List[Dict]:
|
||||
"""
|
||||
Train Prophet models for each branch and cache forecasts.
|
||||
|
||||
Args:
|
||||
client_id: Client to forecast for
|
||||
|
||||
Returns:
|
||||
List of forecast result dicts (one per branch)
|
||||
"""
|
||||
from prophet import Prophet
|
||||
|
||||
db = get_connector()
|
||||
|
||||
# Get monthly revenue per branch
|
||||
sql = """
|
||||
SELECT branch, country,
|
||||
DATE_TRUNC('month', date) AS month,
|
||||
SUM(revenue) AS monthly_revenue
|
||||
FROM spend_data
|
||||
WHERE client_id = %s
|
||||
GROUP BY branch, country, DATE_TRUNC('month', date)
|
||||
ORDER BY branch, month
|
||||
"""
|
||||
df = db.execute_query(sql, (client_id,))
|
||||
|
||||
if df.empty:
|
||||
logger.warning(f"No data for forecasting: {client_id}")
|
||||
return []
|
||||
|
||||
results = []
|
||||
branches = df.groupby(["branch", "country"])
|
||||
|
||||
for (branch, country), group in branches:
|
||||
group = group.sort_values("month").reset_index(drop=True)
|
||||
|
||||
if len(group) < MIN_MONTHS:
|
||||
logger.info(
|
||||
f"Skipping {branch} ({country}): only {len(group)} months "
|
||||
f"(need {MIN_MONTHS})"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
# Prepare Prophet format: ds (date), y (value)
|
||||
prophet_df = pd.DataFrame({
|
||||
"ds": pd.to_datetime(group["month"]),
|
||||
"y": group["monthly_revenue"].astype(float),
|
||||
})
|
||||
|
||||
# Train
|
||||
model = Prophet(
|
||||
yearly_seasonality=True,
|
||||
weekly_seasonality=False,
|
||||
daily_seasonality=False,
|
||||
)
|
||||
model.fit(prophet_df)
|
||||
|
||||
# Forecast
|
||||
future = model.make_future_dataframe(
|
||||
periods=FORECAST_MONTHS, freq="MS"
|
||||
)
|
||||
forecast = model.predict(future)
|
||||
|
||||
# Extract forecast period only
|
||||
forecast_only = forecast.tail(FORECAST_MONTHS)
|
||||
|
||||
forecast_data = {
|
||||
"branch": branch,
|
||||
"country": country,
|
||||
"horizon_months": FORECAST_MONTHS,
|
||||
"dates": forecast_only["ds"].dt.strftime("%Y-%m-%d").tolist(),
|
||||
"forecast_revenue": forecast_only["yhat"].round(2).tolist(),
|
||||
"lower_bound": forecast_only["yhat_lower"].round(2).tolist(),
|
||||
"upper_bound": forecast_only["yhat_upper"].round(2).tolist(),
|
||||
"computed_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
# Cache in PostgreSQL
|
||||
self._cache_forecast(client_id, forecast_data)
|
||||
results.append(forecast_data)
|
||||
|
||||
logger.info(
|
||||
f"Forecast generated for {branch} ({country}): "
|
||||
f"{FORECAST_MONTHS} months ahead"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Prophet failed for {branch} ({country}): {e}")
|
||||
|
||||
logger.info(f"Forecasting complete: {len(results)} branches forecast")
|
||||
return results
|
||||
|
||||
def get_cached_forecast(
|
||||
self,
|
||||
client_id: str,
|
||||
branch: str,
|
||||
) -> Optional[Dict]:
|
||||
"""Get the most recent cached forecast for a branch."""
|
||||
db = get_connector()
|
||||
|
||||
sql = """
|
||||
SELECT forecast_data, computed_at
|
||||
FROM forecasts
|
||||
WHERE client_id = %s AND branch = %s
|
||||
ORDER BY computed_at DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
rows = db.execute_raw(sql, (client_id, branch))
|
||||
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
row = rows[0]
|
||||
data = row["forecast_data"]
|
||||
if isinstance(data, str):
|
||||
data = json.loads(data)
|
||||
|
||||
data["computed_at"] = str(row["computed_at"])
|
||||
return data
|
||||
|
||||
def _cache_forecast(self, client_id: str, forecast_data: Dict):
|
||||
"""Store forecast in PostgreSQL."""
|
||||
db = get_connector()
|
||||
|
||||
# Delete old forecast for this branch
|
||||
db.execute_write(
|
||||
"DELETE FROM forecasts WHERE client_id = %s AND branch = %s AND country = %s",
|
||||
(client_id, forecast_data["branch"], forecast_data["country"]),
|
||||
)
|
||||
|
||||
# Insert new
|
||||
db.execute_write(
|
||||
"""INSERT INTO forecasts (client_id, branch, country, horizon_months, forecast_data)
|
||||
VALUES (%s, %s, %s, %s, %s)""",
|
||||
(
|
||||
client_id,
|
||||
forecast_data["branch"],
|
||||
forecast_data["country"],
|
||||
forecast_data["horizon_months"],
|
||||
json.dumps(forecast_data),
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,18 @@
|
||||
# HEARTBEAT — ACME Corporation
|
||||
|
||||
## Schedule
|
||||
- trigger: daily
|
||||
- time: "08:00"
|
||||
- timezone: "Asia/Kolkata"
|
||||
|
||||
## Digest Tasks
|
||||
1. Pull last 7 days spend + revenue per branch
|
||||
2. Identify bottom 3 performing branches by revenue
|
||||
3. Generate newsletter-style summary via Gen Agent → QA Agent
|
||||
4. Run Scout Agent for competitor + sector news
|
||||
5. Append Market Intelligence section to digest
|
||||
6. Push complete digest to Slack channel
|
||||
|
||||
## Retry
|
||||
- on_failure: retry after 15 minutes
|
||||
- max_retries: 3
|
||||
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Clawrity — HEARTBEAT Loader
|
||||
|
||||
Parses HEARTBEAT.md files to extract schedule, digest tasks, and retry config.
|
||||
HEARTBEAT.md drives autonomous daily digest generation per client.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from config.client_loader import ClientConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HeartbeatConfig:
|
||||
"""Parsed heartbeat configuration."""
|
||||
|
||||
def __init__(self):
|
||||
self.trigger: str = "daily"
|
||||
self.time: str = "08:00"
|
||||
self.timezone: str = "UTC"
|
||||
self.retry_delay_minutes: int = 15
|
||||
self.max_retries: int = 3
|
||||
self.tasks: list = []
|
||||
self.raw_content: str = ""
|
||||
|
||||
@property
|
||||
def hour(self) -> int:
|
||||
"""Extract hour from time string."""
|
||||
return int(self.time.split(":")[0])
|
||||
|
||||
@property
|
||||
def minute(self) -> int:
|
||||
"""Extract minute from time string."""
|
||||
return int(self.time.split(":")[1])
|
||||
|
||||
|
||||
def load_heartbeat(client_config: ClientConfig) -> HeartbeatConfig:
|
||||
"""
|
||||
Load and parse the HEARTBEAT.md file for a client.
|
||||
|
||||
Args:
|
||||
client_config: The client's configuration containing heartbeat_file path.
|
||||
|
||||
Returns:
|
||||
Parsed HeartbeatConfig with schedule, tasks, and retry settings.
|
||||
"""
|
||||
config = HeartbeatConfig()
|
||||
heartbeat_path = Path(client_config.heartbeat_file)
|
||||
|
||||
# Use client YAML timezone as fallback
|
||||
config.timezone = client_config.timezone
|
||||
|
||||
if not heartbeat_path.exists():
|
||||
logger.warning(
|
||||
f"HEARTBEAT file not found at {heartbeat_path} for client "
|
||||
f"{client_config.client_id}. Using defaults from client YAML."
|
||||
)
|
||||
config.time = client_config.digest_schedule
|
||||
return config
|
||||
|
||||
try:
|
||||
content = heartbeat_path.read_text(encoding="utf-8")
|
||||
config.raw_content = content
|
||||
_parse_heartbeat(content, config)
|
||||
logger.info(
|
||||
f"Loaded HEARTBEAT for {client_config.client_id}: "
|
||||
f"{config.trigger} at {config.time} {config.timezone}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing HEARTBEAT file {heartbeat_path}: {e}")
|
||||
config.time = client_config.digest_schedule
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _parse_heartbeat(content: str, config: HeartbeatConfig) -> None:
|
||||
"""Parse markdown content and extract structured config."""
|
||||
lines = content.split("\n")
|
||||
|
||||
current_section = None
|
||||
task_lines = []
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
|
||||
# Detect section headers
|
||||
if stripped.startswith("## "):
|
||||
current_section = stripped[3:].strip().lower()
|
||||
continue
|
||||
|
||||
if current_section == "schedule":
|
||||
# Parse key-value pairs like "- trigger: daily"
|
||||
match = re.match(r"-\s*(\w+):\s*\"?([^\"]+)\"?", stripped)
|
||||
if match:
|
||||
key, value = match.group(1).strip(), match.group(2).strip()
|
||||
if key == "trigger":
|
||||
config.trigger = value
|
||||
elif key == "time":
|
||||
config.time = value
|
||||
elif key == "timezone":
|
||||
config.timezone = value
|
||||
|
||||
elif current_section == "digest tasks":
|
||||
# Parse numbered list items
|
||||
match = re.match(r"\d+\.\s+(.*)", stripped)
|
||||
if match:
|
||||
config.tasks.append(match.group(1).strip())
|
||||
|
||||
elif current_section == "retry":
|
||||
# Parse retry config
|
||||
match = re.match(r"-\s*(\w+):\s*(.+)", stripped)
|
||||
if match:
|
||||
key, value = match.group(1).strip(), match.group(2).strip()
|
||||
if "retry" in key and "after" in value:
|
||||
# Extract minutes from "retry after 15 minutes"
|
||||
mins = re.search(r"(\d+)", value)
|
||||
if mins:
|
||||
config.retry_delay_minutes = int(mins.group(1))
|
||||
elif key == "max_retries":
|
||||
config.max_retries = int(value)
|
||||
@@ -0,0 +1,295 @@
|
||||
"""
|
||||
Clawrity — HEARTBEAT Scheduler
|
||||
|
||||
APScheduler AsyncIOScheduler fires digest jobs per client at configured times.
|
||||
Schedule: ETL at 02:00 → RAG re-index at 03:00 → Digest + Scout at configured time.
|
||||
Retry: on failure, retry after N minutes, max retries from HEARTBEAT.md.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
|
||||
import httpx
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
|
||||
from agents.orchestrator import Orchestrator
|
||||
from channels.protocol_adapter import NormalisedMessage
|
||||
from config.client_loader import ClientConfig
|
||||
from config.settings import get_settings
|
||||
from heartbeat.heartbeat_loader import load_heartbeat
|
||||
from skills.postgres_connector import get_connector
|
||||
from soul.soul_loader import load_soul
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_digest(
|
||||
client_config: ClientConfig,
|
||||
orchestrator: Orchestrator,
|
||||
retry_count: int = 0,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Run the daily digest for a client.
|
||||
|
||||
Steps:
|
||||
1. Query bottom 3 branches by revenue (last 7 days)
|
||||
2. Gen Agent → QA Agent pipeline for digest
|
||||
3. Scout Agent for competitor/sector news
|
||||
4. Push to Slack webhook
|
||||
5. Log success/failure to JSONL
|
||||
|
||||
Returns:
|
||||
Full digest text if successful, None on failure
|
||||
"""
|
||||
from agents.gen_agent import GenAgent
|
||||
from agents.qa_agent import QAAgent
|
||||
|
||||
client_id = client_config.client_id
|
||||
logger.info(f"[{client_id}] Running daily digest (attempt {retry_count + 1})")
|
||||
|
||||
db = get_connector()
|
||||
|
||||
try:
|
||||
# Step 1: Get bottom 3 branches by revenue with ROI
|
||||
bottom_sql = """
|
||||
SELECT branch, country,
|
||||
SUM(revenue) as total_revenue,
|
||||
SUM(spend) as total_spend,
|
||||
SUM(leads) as total_leads,
|
||||
ROUND((SUM(revenue)/NULLIF(SUM(spend),0))::numeric, 2) as roi
|
||||
FROM spend_data
|
||||
WHERE client_id = %s
|
||||
AND date >= CURRENT_DATE - INTERVAL '7 days'
|
||||
GROUP BY branch, country
|
||||
ORDER BY total_revenue ASC
|
||||
LIMIT 3
|
||||
"""
|
||||
data = db.execute_query(bottom_sql, (client_id,))
|
||||
|
||||
# Step 2: Generate digest via Gen Agent with specific prompt
|
||||
soul_content = load_soul(client_config)
|
||||
gen_agent = GenAgent()
|
||||
qa_agent = QAAgent()
|
||||
|
||||
# Retrieve RAG chunks for digest context
|
||||
rag_chunks = None
|
||||
if orchestrator.retriever:
|
||||
try:
|
||||
rag_chunks = orchestrator.retriever.retrieve(
|
||||
query="weekly performance bottom performers budget recommendations",
|
||||
client_id=client_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"RAG retrieval for digest failed: {e}")
|
||||
|
||||
# Generate digest with explicit prompt
|
||||
digest = gen_agent.generate(
|
||||
question="Generate morning business digest. Highlight bottom 3 branches. Suggest where to focus budget. Newsletter style.",
|
||||
soul_content=soul_content,
|
||||
data_context=data,
|
||||
rag_chunks=rag_chunks,
|
||||
)
|
||||
|
||||
# Step 2b: QA pass on digest (more lenient threshold for digest)
|
||||
qa_result = qa_agent.evaluate(
|
||||
response=digest,
|
||||
data_context=data,
|
||||
threshold=0.6, # More lenient for digest
|
||||
)
|
||||
|
||||
if not qa_result["passed"]:
|
||||
logger.warning(
|
||||
f"[{client_id}] Digest QA failed (score={qa_result['score']:.2f}), "
|
||||
f"retrying with strict instruction"
|
||||
)
|
||||
# Retry digest generation with strict instruction
|
||||
digest = gen_agent.generate(
|
||||
question="Generate morning business digest. Highlight bottom 3 branches. Suggest where to focus budget. Newsletter style.",
|
||||
soul_content=soul_content,
|
||||
data_context=data,
|
||||
rag_chunks=rag_chunks,
|
||||
retry_issues=qa_result["issues"],
|
||||
retry_count=1,
|
||||
strict_data_instruction=(
|
||||
"CRITICAL: Only mention branches and figures that appear in the "
|
||||
"Data Context. Do not reference any other branches or historical data."
|
||||
),
|
||||
)
|
||||
|
||||
# Step 3: Scout Agent for competitor/sector news
|
||||
scout_section = None
|
||||
try:
|
||||
from agents.scout_agent import ScoutAgent
|
||||
scout = ScoutAgent()
|
||||
scout_section = await scout.gather_intelligence(client_config)
|
||||
except Exception as e:
|
||||
logger.warning(f"Scout Agent failed: {e}")
|
||||
|
||||
# Step 4: Assemble full digest
|
||||
full_digest = f"📊 **Clawrity Daily Digest — {client_config.client_name}**\n"
|
||||
full_digest += f"*{datetime.now().strftime('%B %d, %Y')}*\n\n"
|
||||
full_digest += digest
|
||||
|
||||
if scout_section:
|
||||
full_digest += f"\n\n---\n\n{scout_section}"
|
||||
|
||||
# Step 5: Push to Slack webhook
|
||||
webhook_url = client_config.channels.get("slack_webhook", "")
|
||||
if webhook_url:
|
||||
await _push_to_slack(webhook_url, full_digest)
|
||||
else:
|
||||
logger.warning(f"[{client_id}] No Slack webhook configured")
|
||||
|
||||
# Step 6: Log success to JSONL
|
||||
_log_digest_event(client_id, "success", {
|
||||
"qa_score": qa_result["score"],
|
||||
"qa_passed": qa_result["passed"],
|
||||
"scout_included": scout_section is not None,
|
||||
"digest_length": len(full_digest),
|
||||
})
|
||||
|
||||
logger.info(f"[{client_id}] Digest completed successfully")
|
||||
return full_digest
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{client_id}] Digest failed: {e}", exc_info=True)
|
||||
_log_digest_event(client_id, "failure", {"error": str(e), "attempt": retry_count + 1})
|
||||
|
||||
heartbeat = load_heartbeat(client_config)
|
||||
|
||||
if retry_count < heartbeat.max_retries:
|
||||
delay_minutes = heartbeat.retry_delay_minutes
|
||||
logger.info(
|
||||
f"[{client_id}] Scheduling digest retry in {delay_minutes} minutes "
|
||||
f"(attempt {retry_count + 2}/{heartbeat.max_retries + 1})"
|
||||
)
|
||||
await asyncio.sleep(delay_minutes * 60)
|
||||
return await run_digest(client_config, orchestrator, retry_count + 1)
|
||||
else:
|
||||
logger.error(f"[{client_id}] Digest failed after {heartbeat.max_retries + 1} attempts")
|
||||
# Post failure notification to Slack
|
||||
webhook_url = client_config.channels.get("slack_webhook", "")
|
||||
if webhook_url:
|
||||
await _push_to_slack(
|
||||
webhook_url,
|
||||
"Clawrity digest unavailable. Backend may be offline."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def _push_to_slack(webhook_url: str, message: str):
|
||||
"""Push a message to a Slack incoming webhook."""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
webhook_url,
|
||||
json={"text": message},
|
||||
timeout=30,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
logger.info("Digest pushed to Slack successfully")
|
||||
else:
|
||||
logger.error(f"Slack webhook returned {response.status_code}: {response.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to push digest to Slack: {e}")
|
||||
|
||||
|
||||
def _log_digest_event(client_id: str, status: str, details: dict):
|
||||
"""Log digest event to JSONL monitoring file."""
|
||||
settings = get_settings()
|
||||
logs_dir = settings.logs_dir
|
||||
os.makedirs(logs_dir, exist_ok=True)
|
||||
log_path = os.path.join(logs_dir, f"{client_id}_digest.jsonl")
|
||||
|
||||
entry = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"client_id": client_id,
|
||||
"event": "digest",
|
||||
"status": status,
|
||||
**details,
|
||||
}
|
||||
|
||||
try:
|
||||
with open(log_path, "a") as f:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log digest event: {e}")
|
||||
|
||||
|
||||
def start_scheduler(
|
||||
client_configs: Dict[str, ClientConfig],
|
||||
orchestrator: Orchestrator,
|
||||
) -> AsyncIOScheduler:
|
||||
"""
|
||||
Start the APScheduler with digest jobs for all clients.
|
||||
|
||||
Schedule per client:
|
||||
- Digest at configured time (from HEARTBEAT.md)
|
||||
- ETL sync at 02:00 (placeholder)
|
||||
- RAG re-index at 03:00 (placeholder)
|
||||
"""
|
||||
scheduler = AsyncIOScheduler()
|
||||
|
||||
for client_id, config in client_configs.items():
|
||||
heartbeat = load_heartbeat(config)
|
||||
|
||||
# Daily digest at configured time
|
||||
scheduler.add_job(
|
||||
run_digest,
|
||||
CronTrigger(
|
||||
hour=heartbeat.hour,
|
||||
minute=heartbeat.minute,
|
||||
timezone=heartbeat.timezone,
|
||||
),
|
||||
args=[config, orchestrator],
|
||||
id=f"digest_{client_id}",
|
||||
name=f"Daily Digest — {config.client_name}",
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Scheduled digest for {client_id}: "
|
||||
f"{heartbeat.time} {heartbeat.timezone}"
|
||||
)
|
||||
|
||||
# ETL sync at 02:00 (placeholder)
|
||||
scheduler.add_job(
|
||||
_etl_sync_placeholder,
|
||||
CronTrigger(hour=2, minute=0, timezone=heartbeat.timezone),
|
||||
args=[client_id],
|
||||
id=f"etl_{client_id}",
|
||||
name=f"ETL Sync — {config.client_name}",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
# RAG re-index at 03:00 (placeholder)
|
||||
scheduler.add_job(
|
||||
_rag_reindex_placeholder,
|
||||
CronTrigger(hour=3, minute=0, timezone=heartbeat.timezone),
|
||||
args=[client_id],
|
||||
id=f"rag_reindex_{client_id}",
|
||||
name=f"RAG Re-index — {config.client_name}",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
scheduler.start()
|
||||
return scheduler
|
||||
|
||||
|
||||
async def _etl_sync_placeholder(client_id: str):
|
||||
"""Placeholder for nightly ETL data sync."""
|
||||
logger.info(f"[{client_id}] ETL sync triggered (placeholder)")
|
||||
|
||||
|
||||
async def _rag_reindex_placeholder(client_id: str):
|
||||
"""Placeholder for nightly RAG re-indexing."""
|
||||
logger.info(f"[{client_id}] RAG re-index triggered (placeholder)")
|
||||
try:
|
||||
from scripts.run_rag_pipeline import run_pipeline
|
||||
run_pipeline(client_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"RAG re-index failed: {e}")
|
||||
@@ -0,0 +1,345 @@
|
||||
"""
|
||||
Clawrity — FastAPI Application
|
||||
|
||||
Main entry point. Initializes database, loads client configs,
|
||||
starts Slack bot, and exposes REST endpoints.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agents.orchestrator import Orchestrator
|
||||
from channels.protocol_adapter import ProtocolAdapter, NormalisedMessage
|
||||
from channels.slack_handler import SlackHandler
|
||||
from config.client_loader import ClientConfig, load_client_configs
|
||||
from config.settings import get_settings
|
||||
from skills.postgres_connector import get_connector
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s — %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Global state
|
||||
# ---------------------------------------------------------------------------
|
||||
client_configs: Dict[str, ClientConfig] = {}
|
||||
orchestrator: Optional[Orchestrator] = None
|
||||
protocol_adapter: Optional[ProtocolAdapter] = None
|
||||
slack_handler: Optional[SlackHandler] = None
|
||||
scheduler = None # Set by heartbeat.scheduler
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lifespan
|
||||
# ---------------------------------------------------------------------------
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Startup and shutdown logic."""
|
||||
global client_configs, orchestrator, protocol_adapter, slack_handler, scheduler
|
||||
|
||||
logger.info("=== Clawrity starting up ===")
|
||||
|
||||
# 1. Init database schema
|
||||
db = get_connector()
|
||||
db.init_schema()
|
||||
logger.info("Database schema ready")
|
||||
|
||||
# 2. Load client configs
|
||||
client_configs = load_client_configs()
|
||||
logger.info(f"Loaded {len(client_configs)} client(s): {list(client_configs.keys())}")
|
||||
|
||||
# 3. Init orchestrator
|
||||
orchestrator = Orchestrator()
|
||||
|
||||
# 4. Try to attach RAG retriever
|
||||
try:
|
||||
from rag.retriever import Retriever
|
||||
retriever = Retriever()
|
||||
orchestrator.set_retriever(retriever)
|
||||
logger.info("RAG retriever attached to orchestrator")
|
||||
except Exception as e:
|
||||
logger.info(f"RAG retriever not available (Phase 2): {e}")
|
||||
|
||||
# 5. Init protocol adapter
|
||||
protocol_adapter = ProtocolAdapter(client_configs)
|
||||
|
||||
# 6. Start Slack bot
|
||||
slack_handler = SlackHandler(protocol_adapter, client_configs, orchestrator)
|
||||
slack_handler.start()
|
||||
|
||||
# 7. Start scheduler
|
||||
try:
|
||||
from heartbeat.scheduler import start_scheduler
|
||||
scheduler = start_scheduler(client_configs, orchestrator)
|
||||
logger.info("HEARTBEAT scheduler started")
|
||||
except Exception as e:
|
||||
logger.warning(f"Scheduler not started: {e}")
|
||||
|
||||
logger.info("=== Clawrity ready ===")
|
||||
|
||||
yield # App runs here
|
||||
|
||||
# Shutdown
|
||||
logger.info("=== Clawrity shutting down ===")
|
||||
if slack_handler:
|
||||
slack_handler.stop()
|
||||
if scheduler:
|
||||
scheduler.shutdown(wait=False)
|
||||
db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FastAPI App
|
||||
# ---------------------------------------------------------------------------
|
||||
app = FastAPI(
|
||||
title="Clawrity",
|
||||
description="Multi-channel AI business intelligence agent",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request/Response Models
|
||||
# ---------------------------------------------------------------------------
|
||||
class ChatRequest(BaseModel):
|
||||
client_id: str
|
||||
message: str
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
response: str
|
||||
qa_score: float
|
||||
qa_passed: bool
|
||||
retries: int
|
||||
sql: Optional[str] = None
|
||||
data_rows: int = 0
|
||||
rag_chunks_used: int = 0
|
||||
elapsed_seconds: float = 0.0
|
||||
|
||||
|
||||
class CompareRequest(BaseModel):
|
||||
client_id: str
|
||||
message: str
|
||||
|
||||
|
||||
class CompareResponse(BaseModel):
|
||||
without_rag: ChatResponse
|
||||
with_rag: ChatResponse
|
||||
|
||||
|
||||
class ScoutRequest(BaseModel):
|
||||
client_id: str
|
||||
query: str
|
||||
|
||||
|
||||
class ClientRequest(BaseModel):
|
||||
client_id: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.post("/chat", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest):
|
||||
"""Send a message and get an AI response."""
|
||||
if request.client_id not in client_configs:
|
||||
raise HTTPException(status_code=404, detail=f"Client not found: {request.client_id}")
|
||||
|
||||
config = client_configs[request.client_id]
|
||||
message = protocol_adapter.normalise_api(request.client_id, request.message)
|
||||
|
||||
result = await orchestrator.process(message, config)
|
||||
return ChatResponse(**result)
|
||||
|
||||
|
||||
@app.post("/compare", response_model=CompareResponse)
|
||||
async def compare(request: CompareRequest):
|
||||
"""Side-by-side comparison: with RAG vs without RAG."""
|
||||
if request.client_id not in client_configs:
|
||||
raise HTTPException(status_code=404, detail=f"Client not found: {request.client_id}")
|
||||
|
||||
config = client_configs[request.client_id]
|
||||
message = protocol_adapter.normalise_api(request.client_id, request.message)
|
||||
|
||||
# Without RAG
|
||||
saved_retriever = orchestrator.retriever
|
||||
orchestrator.retriever = None
|
||||
result_no_rag = await orchestrator.process(message, config)
|
||||
orchestrator.retriever = saved_retriever
|
||||
|
||||
# With RAG
|
||||
result_with_rag = await orchestrator.process(message, config)
|
||||
|
||||
return CompareResponse(
|
||||
without_rag=ChatResponse(**result_no_rag),
|
||||
with_rag=ChatResponse(**result_with_rag),
|
||||
)
|
||||
|
||||
|
||||
@app.post("/scout")
|
||||
async def scout(request: ScoutRequest):
|
||||
"""Run a targeted scout search for competitor/market intelligence."""
|
||||
if request.client_id not in client_configs:
|
||||
raise HTTPException(status_code=404, detail=f"Client not found: {request.client_id}")
|
||||
|
||||
config = client_configs[request.client_id]
|
||||
|
||||
try:
|
||||
from agents.scout_agent import ScoutAgent
|
||||
scout_agent = ScoutAgent()
|
||||
result = await scout_agent.search_query(config, request.query)
|
||||
|
||||
if result is None:
|
||||
return {"response": "No relevant competitor or market news found for this query.", "has_results": False}
|
||||
|
||||
return {"response": result, "has_results": True}
|
||||
except Exception as e:
|
||||
logger.error(f"Scout endpoint failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/scout/digest")
|
||||
async def scout_digest(request: ClientRequest):
|
||||
"""Run full scout agent digest for a client."""
|
||||
if request.client_id not in client_configs:
|
||||
raise HTTPException(status_code=404, detail=f"Client not found: {request.client_id}")
|
||||
|
||||
config = client_configs[request.client_id]
|
||||
|
||||
try:
|
||||
from agents.scout_agent import ScoutAgent
|
||||
scout_agent = ScoutAgent()
|
||||
result = await scout_agent.gather_intelligence(config)
|
||||
|
||||
if result is None:
|
||||
return {"response": "No relevant market intelligence found.", "has_results": False}
|
||||
|
||||
return {"response": result, "has_results": True}
|
||||
except Exception as e:
|
||||
logger.error(f"Scout digest failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/digest")
|
||||
async def trigger_digest(request: ClientRequest):
|
||||
"""Manually trigger the daily digest pipeline (same as scheduled job)."""
|
||||
if request.client_id not in client_configs:
|
||||
raise HTTPException(status_code=404, detail=f"Client not found: {request.client_id}")
|
||||
|
||||
config = client_configs[request.client_id]
|
||||
|
||||
try:
|
||||
from heartbeat.scheduler import run_digest
|
||||
digest_text = await run_digest(config, orchestrator)
|
||||
|
||||
if digest_text is None:
|
||||
raise HTTPException(status_code=500, detail="Digest generation failed after all retries")
|
||||
|
||||
return {"response": digest_text, "status": "success"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Manual digest trigger failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/admin/stats/{client_id}")
|
||||
async def admin_stats(client_id: str):
|
||||
"""RAG monitoring stats for a client."""
|
||||
if client_id not in client_configs:
|
||||
raise HTTPException(status_code=404, detail=f"Client not found: {client_id}")
|
||||
|
||||
try:
|
||||
from rag.monitoring import get_stats
|
||||
return get_stats(client_id)
|
||||
except Exception as e:
|
||||
return {"error": str(e), "message": "Monitoring not yet configured"}
|
||||
|
||||
|
||||
@app.post("/forecast/run/{client_id}")
|
||||
async def run_forecast(client_id: str):
|
||||
"""Trigger Prophet forecasting for a client."""
|
||||
if client_id not in client_configs:
|
||||
raise HTTPException(status_code=404, detail=f"Client not found: {client_id}")
|
||||
|
||||
try:
|
||||
from forecasting.prophet_engine import ProphetEngine
|
||||
engine = ProphetEngine()
|
||||
results = engine.train_and_forecast(client_id)
|
||||
return {"status": "success", "branches_forecast": len(results)}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/forecast/{client_id}/{branch}")
|
||||
async def get_forecast(client_id: str, branch: str):
|
||||
"""Get cached forecast for a branch."""
|
||||
if client_id not in client_configs:
|
||||
raise HTTPException(status_code=404, detail=f"Client not found: {client_id}")
|
||||
|
||||
try:
|
||||
from forecasting.prophet_engine import ProphetEngine
|
||||
engine = ProphetEngine()
|
||||
forecast = engine.get_cached_forecast(client_id, branch)
|
||||
if not forecast:
|
||||
raise HTTPException(status_code=404, detail=f"No forecast found for {branch}")
|
||||
return forecast
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""System health check."""
|
||||
db = get_connector()
|
||||
db_connected = False
|
||||
try:
|
||||
db.execute_raw("SELECT 1")
|
||||
db_connected = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
scheduled_jobs = []
|
||||
if scheduler and hasattr(scheduler, 'get_jobs'):
|
||||
try:
|
||||
scheduled_jobs = [
|
||||
{"id": job.id, "name": job.name, "next_run": str(job.next_run_time)}
|
||||
for job in scheduler.get_jobs()
|
||||
]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"status": "healthy" if db_connected else "degraded",
|
||||
"database": "connected" if db_connected else "disconnected",
|
||||
"clients": list(client_configs.keys()),
|
||||
"scheduler_running": scheduler is not None and scheduler.running if scheduler else False,
|
||||
"scheduled_jobs": scheduled_jobs,
|
||||
"slack_active": slack_handler is not None and slack_handler._thread is not None,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/slack/events")
|
||||
async def slack_events():
|
||||
"""Slack webhook endpoint (HTTP mode fallback). Socket Mode is primary."""
|
||||
return {"message": "Slack events are handled via Socket Mode. This endpoint is a fallback."}
|
||||
+287
@@ -0,0 +1,287 @@
|
||||
"""
|
||||
Clawrity — RAG Chunker
|
||||
|
||||
Aggregation-based semantic chunking — NOT fixed-size, NOT sliding window.
|
||||
Source is structured tabular data. We aggregate rows into business-meaningful
|
||||
units and write natural language narratives.
|
||||
|
||||
Three chunk types:
|
||||
1. branch_weekly — GROUP BY branch, country, week
|
||||
2. channel_monthly — GROUP BY channel, country, month
|
||||
3. trend_qoq — GROUP BY branch, country, quarter (QoQ delta COMPUTED)
|
||||
|
||||
Plus Faker-generated narrative summaries reflecting real patterns.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from faker import Faker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
fake = Faker()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chunk:
|
||||
"""A single RAG chunk."""
|
||||
id: str
|
||||
client_id: str
|
||||
chunk_type: str
|
||||
text: str
|
||||
metadata: Dict
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"client_id": self.client_id,
|
||||
"chunk_type": self.chunk_type,
|
||||
"text": self.text,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
|
||||
def generate_chunks(df: pd.DataFrame, client_id: str) -> List[Chunk]:
|
||||
"""Generate all chunk types from preprocessed data."""
|
||||
chunks = []
|
||||
|
||||
df = df.copy()
|
||||
df["date"] = pd.to_datetime(df["date"])
|
||||
|
||||
chunks.extend(_branch_weekly(df, client_id))
|
||||
chunks.extend(_channel_monthly(df, client_id))
|
||||
chunks.extend(_trend_qoq(df, client_id))
|
||||
chunks.extend(_faker_narratives(df, client_id))
|
||||
|
||||
logger.info(f"Generated {len(chunks)} total chunks for {client_id}")
|
||||
return chunks
|
||||
|
||||
|
||||
def _chunk_id(client_id: str, chunk_type: str, *parts) -> str:
|
||||
"""Generate a deterministic chunk ID."""
|
||||
raw = f"{client_id}:{chunk_type}:" + ":".join(str(p) for p in parts)
|
||||
return hashlib.md5(raw.encode()).hexdigest()[:16]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chunk Type 1: Branch Weekly
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _branch_weekly(df: pd.DataFrame, client_id: str) -> List[Chunk]:
|
||||
"""GROUP BY branch, country, week. One chunk per branch per week."""
|
||||
chunks = []
|
||||
df = df.copy()
|
||||
df["week"] = df["date"].dt.isocalendar().week.astype(int)
|
||||
df["month"] = df["date"].dt.month_name()
|
||||
df["year"] = df["date"].dt.year
|
||||
|
||||
grouped = df.groupby(["branch", "country", "year", "week", "month"]).agg(
|
||||
spend=("spend", "sum"),
|
||||
revenue=("revenue", "sum"),
|
||||
leads=("leads", "sum"),
|
||||
conversions=("conversions", "sum"),
|
||||
).reset_index()
|
||||
|
||||
for _, row in grouped.iterrows():
|
||||
spend = row["spend"]
|
||||
revenue = row["revenue"]
|
||||
roi = round(revenue / spend, 2) if spend > 0 else 0
|
||||
conv_rate = round(row["conversions"] / row["leads"] * 100, 1) if row["leads"] > 0 else 0
|
||||
|
||||
text = (
|
||||
f"{row['branch']} ({row['country']}) in week {row['week']} of "
|
||||
f"{row['month']} {row['year']}: spent ${spend:,.0f}, earned "
|
||||
f"${revenue:,.0f}, ROI {roi}x, {row['leads']} leads, "
|
||||
f"{conv_rate}% conversion rate."
|
||||
)
|
||||
|
||||
chunks.append(Chunk(
|
||||
id=_chunk_id(client_id, "branch_weekly", row["branch"], row["year"], row["week"]),
|
||||
client_id=client_id,
|
||||
chunk_type="branch_weekly",
|
||||
text=text,
|
||||
metadata={
|
||||
"branch": row["branch"],
|
||||
"country": row["country"],
|
||||
"week": int(row["week"]),
|
||||
"month": row["month"],
|
||||
"year": int(row["year"]),
|
||||
"roi": roi,
|
||||
},
|
||||
))
|
||||
|
||||
logger.info(f"Generated {len(chunks)} branch_weekly chunks")
|
||||
return chunks
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chunk Type 2: Channel Monthly
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _channel_monthly(df: pd.DataFrame, client_id: str) -> List[Chunk]:
|
||||
"""GROUP BY channel, country, month, quarter."""
|
||||
chunks = []
|
||||
df = df.copy()
|
||||
df["month"] = df["date"].dt.month_name()
|
||||
df["quarter"] = "Q" + df["date"].dt.quarter.astype(str)
|
||||
df["year"] = df["date"].dt.year
|
||||
|
||||
grouped = df.groupby(["channel", "country", "year", "month", "quarter"]).agg(
|
||||
spend=("spend", "sum"),
|
||||
revenue=("revenue", "sum"),
|
||||
leads=("leads", "sum"),
|
||||
conversions=("conversions", "sum"),
|
||||
).reset_index()
|
||||
|
||||
for _, row in grouped.iterrows():
|
||||
spend = row["spend"]
|
||||
revenue = row["revenue"]
|
||||
roi = round(revenue / spend, 2) if spend > 0 else 0
|
||||
|
||||
text = (
|
||||
f"{row['channel']} in {row['country']} during {row['month']} "
|
||||
f"({row['quarter']}) {row['year']}: ${spend:,.0f} spent, "
|
||||
f"${revenue:,.0f} revenue, ROI {roi}x."
|
||||
)
|
||||
|
||||
chunks.append(Chunk(
|
||||
id=_chunk_id(client_id, "channel_monthly", row["channel"], row["country"], row["year"], row["month"]),
|
||||
client_id=client_id,
|
||||
chunk_type="channel_monthly",
|
||||
text=text,
|
||||
metadata={
|
||||
"channel": row["channel"],
|
||||
"country": row["country"],
|
||||
"month": row["month"],
|
||||
"quarter": row["quarter"],
|
||||
"year": int(row["year"]),
|
||||
"roi": roi,
|
||||
},
|
||||
))
|
||||
|
||||
logger.info(f"Generated {len(chunks)} channel_monthly chunks")
|
||||
return chunks
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chunk Type 3: QoQ Trend (Most Important)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _trend_qoq(df: pd.DataFrame, client_id: str) -> List[Chunk]:
|
||||
"""GROUP BY branch, country, quarter. Compute quarter-over-quarter delta."""
|
||||
chunks = []
|
||||
df = df.copy()
|
||||
df["quarter"] = df["date"].dt.to_period("Q").astype(str)
|
||||
|
||||
grouped = df.groupby(["branch", "country", "quarter"]).agg(
|
||||
spend=("spend", "sum"),
|
||||
revenue=("revenue", "sum"),
|
||||
).reset_index()
|
||||
|
||||
# Sort for QoQ calculation
|
||||
grouped = grouped.sort_values(["branch", "country", "quarter"])
|
||||
|
||||
for (branch, country), group in grouped.groupby(["branch", "country"]):
|
||||
group = group.sort_values("quarter").reset_index(drop=True)
|
||||
|
||||
for i in range(1, len(group)):
|
||||
prev = group.iloc[i - 1]
|
||||
curr = group.iloc[i]
|
||||
|
||||
prev_rev = prev["revenue"]
|
||||
curr_rev = curr["revenue"]
|
||||
|
||||
if prev_rev > 0:
|
||||
delta = round((curr_rev - prev_rev) / prev_rev * 100, 1)
|
||||
else:
|
||||
delta = 0
|
||||
|
||||
direction = "grew" if delta > 0 else "declined"
|
||||
|
||||
text = (
|
||||
f"{branch} ({country}) revenue {direction} {abs(delta)}% "
|
||||
f"in {curr['quarter']} vs {prev['quarter']}. "
|
||||
f"Total spend: ${curr['spend']:,.0f}, revenue: ${curr_rev:,.0f}."
|
||||
)
|
||||
|
||||
chunks.append(Chunk(
|
||||
id=_chunk_id(client_id, "trend_qoq", branch, country, curr["quarter"]),
|
||||
client_id=client_id,
|
||||
chunk_type="trend_qoq",
|
||||
text=text,
|
||||
metadata={
|
||||
"branch": branch,
|
||||
"country": country,
|
||||
"quarter": curr["quarter"],
|
||||
"prev_quarter": prev["quarter"],
|
||||
"delta_pct": delta,
|
||||
},
|
||||
))
|
||||
|
||||
logger.info(f"Generated {len(chunks)} trend_qoq chunks")
|
||||
return chunks
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Faker Narrative Chunks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _faker_narratives(df: pd.DataFrame, client_id: str) -> List[Chunk]:
|
||||
"""Generate plausible narrative chunks reflecting real data patterns."""
|
||||
chunks = []
|
||||
df = df.copy()
|
||||
df["quarter"] = df["date"].dt.to_period("Q").astype(str)
|
||||
|
||||
# Find top and bottom performers per quarter
|
||||
quarterly = df.groupby(["branch", "country", "quarter"]).agg(
|
||||
revenue=("revenue", "sum"),
|
||||
spend=("spend", "sum"),
|
||||
leads=("leads", "sum"),
|
||||
).reset_index()
|
||||
|
||||
templates = [
|
||||
"{branch} branch demonstrated strong {quarter} performance driven by {channel} efficiency, outperforming regional averages.",
|
||||
"In {quarter}, {branch} ({country}) showed {trend} momentum with revenue reaching ${revenue:,.0f}, primarily through {channel} campaigns.",
|
||||
"{branch} branch in {country} maintained steady growth in {quarter}, with lead generation up and conversion rates holding above {conv_rate:.1f}%.",
|
||||
"Cost efficiency at {branch} ({country}) improved in {quarter}, with spend-to-revenue ratio tightening to {ratio:.2f}x.",
|
||||
]
|
||||
|
||||
channels = df["channel"].dropna().unique().tolist() or ["Paid Search", "Social Media", "Email"]
|
||||
|
||||
for _, row in quarterly.iterrows():
|
||||
roi = row["revenue"] / row["spend"] if row["spend"] > 0 else 0
|
||||
conv_rate = np.random.uniform(5, 20)
|
||||
trend = "positive" if roi > 1.5 else "moderate" if roi > 1 else "challenging"
|
||||
channel = np.random.choice(channels)
|
||||
|
||||
template = np.random.choice(templates)
|
||||
text = template.format(
|
||||
branch=row["branch"],
|
||||
country=row["country"],
|
||||
quarter=row["quarter"],
|
||||
channel=channel,
|
||||
revenue=row["revenue"],
|
||||
trend=trend,
|
||||
conv_rate=conv_rate,
|
||||
ratio=1 / roi if roi > 0 else 0,
|
||||
)
|
||||
|
||||
chunks.append(Chunk(
|
||||
id=_chunk_id(client_id, "narrative", row["branch"], row["country"], row["quarter"]),
|
||||
client_id=client_id,
|
||||
chunk_type="narrative",
|
||||
text=text,
|
||||
metadata={
|
||||
"branch": row["branch"],
|
||||
"country": row["country"],
|
||||
"quarter": row["quarter"],
|
||||
"source": "generated_narrative",
|
||||
},
|
||||
))
|
||||
|
||||
logger.info(f"Generated {len(chunks)} narrative chunks")
|
||||
return chunks
|
||||
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Clawrity — RAG Evaluator
|
||||
|
||||
Lightweight Groq-based evaluation (no OpenAI, no full RAGAs).
|
||||
Four metrics: faithfulness, answer_relevancy, context_precision, context_recall.
|
||||
Single Groq call with structured JSON output.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from groq import Groq
|
||||
|
||||
from config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EVAL_PROMPT = """Evaluate this RAG-augmented response on four criteria.
|
||||
|
||||
## User Query
|
||||
{query}
|
||||
|
||||
## Retrieved Context Chunks
|
||||
{chunks}
|
||||
|
||||
## Generated Response
|
||||
{response}
|
||||
|
||||
## Evaluation Criteria (score each 0.0 to 1.0)
|
||||
|
||||
1. **Faithfulness**: Does the response ONLY contain information from the retrieved chunks? No hallucination?
|
||||
2. **Answer Relevancy**: Does the response directly address the user's question?
|
||||
3. **Context Precision**: Were the retrieved chunks actually relevant to the question?
|
||||
4. **Context Recall**: Did the retrieval capture enough context to answer the question fully?
|
||||
|
||||
Return ONLY a JSON object:
|
||||
{{
|
||||
"faithfulness": <float>,
|
||||
"answer_relevancy": <float>,
|
||||
"context_precision": <float>,
|
||||
"context_recall": <float>,
|
||||
"overall": <float (average of all four)>,
|
||||
"notes": "<brief explanation>"
|
||||
}}"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalResult:
|
||||
faithfulness: float = 0.0
|
||||
answer_relevancy: float = 0.0
|
||||
context_precision: float = 0.0
|
||||
context_recall: float = 0.0
|
||||
overall: float = 0.0
|
||||
notes: str = ""
|
||||
|
||||
|
||||
class RAGEvaluator:
|
||||
"""Evaluates RAG pipeline quality using Groq LLM."""
|
||||
|
||||
def __init__(self):
|
||||
settings = get_settings()
|
||||
self.client = Groq(api_key=settings.groq_api_key)
|
||||
self.model = settings.llm_model
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
query: str,
|
||||
chunks: List[Dict],
|
||||
response: str,
|
||||
) -> EvalResult:
|
||||
"""Evaluate a RAG response."""
|
||||
chunks_text = "\n".join(
|
||||
f"{i+1}. {c.get('text', '')} (similarity: {c.get('similarity', 0):.2f})"
|
||||
for i, c in enumerate(chunks)
|
||||
) if chunks else "No chunks retrieved."
|
||||
|
||||
prompt = EVAL_PROMPT.format(
|
||||
query=query,
|
||||
chunks=chunks_text,
|
||||
response=response,
|
||||
)
|
||||
|
||||
try:
|
||||
result = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a RAG evaluation expert. Return only valid JSON."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.1,
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
raw = result.choices[0].message.content.strip()
|
||||
return self._parse(raw)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RAG evaluation failed: {e}")
|
||||
return EvalResult(notes=f"Evaluation error: {str(e)}")
|
||||
|
||||
def _parse(self, raw: str) -> EvalResult:
|
||||
"""Parse JSON evaluation response."""
|
||||
try:
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith("```"):
|
||||
cleaned = cleaned.split("\n", 1)[1] if "\n" in cleaned else cleaned[3:]
|
||||
if cleaned.endswith("```"):
|
||||
cleaned = cleaned[:-3]
|
||||
|
||||
data = json.loads(cleaned.strip())
|
||||
return EvalResult(
|
||||
faithfulness=float(data.get("faithfulness", 0)),
|
||||
answer_relevancy=float(data.get("answer_relevancy", 0)),
|
||||
context_precision=float(data.get("context_precision", 0)),
|
||||
context_recall=float(data.get("context_recall", 0)),
|
||||
overall=float(data.get("overall", 0)),
|
||||
notes=data.get("notes", ""),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not parse evaluation: {e}")
|
||||
return EvalResult(notes="Parse error")
|
||||
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
Clawrity — RAG Monitoring
|
||||
|
||||
Logs every interaction to JSONL and provides aggregated stats.
|
||||
Exposes data for /admin/stats/{client_id} endpoint.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
|
||||
from config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _log_path(client_id: str) -> str:
|
||||
"""Get the JSONL log file path for a client."""
|
||||
logs_dir = get_settings().logs_dir
|
||||
os.makedirs(logs_dir, exist_ok=True)
|
||||
return os.path.join(logs_dir, f"{client_id}_interactions.jsonl")
|
||||
|
||||
|
||||
def log_interaction(
|
||||
client_id: str,
|
||||
query: str,
|
||||
num_chunks: int,
|
||||
chunk_types_used: list,
|
||||
qa_score: float,
|
||||
qa_passed: bool,
|
||||
retries: int,
|
||||
response_length: int,
|
||||
elapsed_seconds: float = 0.0,
|
||||
):
|
||||
"""Log an interaction to JSONL."""
|
||||
entry = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"client_id": client_id,
|
||||
"query": query,
|
||||
"num_chunks": num_chunks,
|
||||
"chunk_types_used": chunk_types_used,
|
||||
"qa_score": qa_score,
|
||||
"qa_passed": qa_passed,
|
||||
"retries": retries,
|
||||
"response_length": response_length,
|
||||
"elapsed_seconds": elapsed_seconds,
|
||||
}
|
||||
|
||||
try:
|
||||
path = _log_path(client_id)
|
||||
with open(path, "a") as f:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log interaction: {e}")
|
||||
|
||||
|
||||
def get_stats(client_id: str) -> Dict:
|
||||
"""
|
||||
Get aggregated monitoring stats for a client.
|
||||
|
||||
Returns:
|
||||
Dict with: total_queries, pass_rate, avg_qa_score, avg_retries,
|
||||
queries_needing_retry
|
||||
"""
|
||||
path = _log_path(client_id)
|
||||
if not os.path.exists(path):
|
||||
return {
|
||||
"client_id": client_id,
|
||||
"total_queries": 0,
|
||||
"pass_rate": 0.0,
|
||||
"avg_qa_score": 0.0,
|
||||
"avg_retries": 0.0,
|
||||
"queries_needing_retry": 0,
|
||||
}
|
||||
|
||||
entries = []
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
entries.append(json.loads(line))
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading log file: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
if not entries:
|
||||
return {"client_id": client_id, "total_queries": 0}
|
||||
|
||||
total = len(entries)
|
||||
passed = sum(1 for e in entries if e.get("qa_passed", False))
|
||||
scores = [e.get("qa_score", 0) for e in entries]
|
||||
retries = [e.get("retries", 0) for e in entries]
|
||||
retry_queries = sum(1 for r in retries if r > 0)
|
||||
|
||||
return {
|
||||
"client_id": client_id,
|
||||
"total_queries": total,
|
||||
"pass_rate": round(passed / total * 100, 1) if total > 0 else 0,
|
||||
"avg_qa_score": round(sum(scores) / total, 3) if total > 0 else 0,
|
||||
"avg_retries": round(sum(retries) / total, 2) if total > 0 else 0,
|
||||
"queries_needing_retry": retry_queries,
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
Clawrity — RAG Preprocessor
|
||||
|
||||
Fetches data from PostgreSQL, cleans it for RAG chunking:
|
||||
- Removes nulls, outliers > 3 std devs, duplicates
|
||||
- Normalises string columns
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from etl.normaliser import remove_outliers
|
||||
from skills.postgres_connector import get_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def preprocess_for_rag(
|
||||
client_id: str,
|
||||
days: int = 365,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Fetch and preprocess data for RAG chunking.
|
||||
|
||||
Args:
|
||||
client_id: Client to fetch data for
|
||||
days: Number of days of data to fetch (default 365)
|
||||
|
||||
Returns:
|
||||
Clean DataFrame ready for chunking
|
||||
"""
|
||||
db = get_connector()
|
||||
|
||||
sql = """
|
||||
SELECT date, country, branch, channel, spend, revenue, leads, conversions
|
||||
FROM spend_data
|
||||
WHERE client_id = %s AND date >= CURRENT_DATE - INTERVAL '%s days'
|
||||
ORDER BY date
|
||||
"""
|
||||
# Can't parameterise interval directly, use string formatting for days
|
||||
safe_sql = f"""
|
||||
SELECT date, country, branch, channel, spend, revenue, leads, conversions
|
||||
FROM spend_data
|
||||
WHERE client_id = %s AND date >= CURRENT_DATE - INTERVAL '{int(days)} days'
|
||||
ORDER BY date
|
||||
"""
|
||||
df = db.execute_query(safe_sql, (client_id,))
|
||||
logger.info(f"Fetched {len(df)} rows for RAG preprocessing")
|
||||
|
||||
if df.empty:
|
||||
logger.warning(f"No data found for client {client_id}")
|
||||
return df
|
||||
|
||||
# Remove rows with critical nulls
|
||||
critical_cols = ["date", "branch", "country", "revenue"]
|
||||
df = df.dropna(subset=[c for c in critical_cols if c in df.columns])
|
||||
|
||||
# Remove outliers on numeric columns
|
||||
df = remove_outliers(df, ["spend", "revenue", "leads", "conversions"])
|
||||
|
||||
# Clean strings
|
||||
for col in ["country", "branch", "channel"]:
|
||||
if col in df.columns:
|
||||
df[col] = df[col].astype(str).str.strip().str.title()
|
||||
|
||||
# Remove duplicates
|
||||
df = df.drop_duplicates()
|
||||
|
||||
logger.info(f"Preprocessed: {len(df)} rows ready for chunking")
|
||||
return df
|
||||
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
Clawrity — RAG Retriever
|
||||
|
||||
Detects query intent → selects chunk_type → searches pgvector.
|
||||
Intent detection based on keywords:
|
||||
- "should/recommend/allocate/shift" → trend_qoq
|
||||
- "channel/paid/email/social" → channel_monthly
|
||||
- everything else → branch_weekly
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from rag.vector_store import search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Intent → chunk_type mapping based on keywords
|
||||
INTENT_PATTERNS = {
|
||||
"trend_qoq": [
|
||||
"should", "recommend", "allocate", "shift", "increase", "decrease",
|
||||
"budget", "realloc", "invest", "optimize", "growth", "trend",
|
||||
"quarter", "qoq", "forecast", "predict",
|
||||
],
|
||||
"channel_monthly": [
|
||||
"channel", "paid", "email", "social", "search", "display",
|
||||
"organic", "referral", "campaign", "marketing", "roi",
|
||||
"spend", "advertising",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class Retriever:
|
||||
"""RAG retriever with intent-based chunk type filtering."""
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
client_id: str,
|
||||
top_k: int = 5,
|
||||
chunk_type_override: Optional[str] = None,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Retrieve relevant chunks based on query intent.
|
||||
|
||||
Args:
|
||||
query: User's natural language query
|
||||
client_id: Client to search within
|
||||
top_k: Number of chunks to retrieve
|
||||
chunk_type_override: Force a specific chunk type
|
||||
|
||||
Returns:
|
||||
List of dicts with text, metadata, similarity
|
||||
"""
|
||||
if chunk_type_override:
|
||||
chunk_type = chunk_type_override
|
||||
else:
|
||||
chunk_type = self._detect_intent(query)
|
||||
|
||||
logger.info(f"Detected intent → chunk_type: {chunk_type}")
|
||||
|
||||
results = search(
|
||||
query=query,
|
||||
client_id=client_id,
|
||||
chunk_type=chunk_type,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
# If no results with the detected type, fall back to all types
|
||||
if not results:
|
||||
logger.info(f"No results for {chunk_type}, falling back to all types")
|
||||
results = search(
|
||||
query=query,
|
||||
client_id=client_id,
|
||||
chunk_type=None,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def _detect_intent(self, query: str) -> str:
|
||||
"""Detect query intent from keywords."""
|
||||
query_lower = query.lower()
|
||||
|
||||
scores = {}
|
||||
for chunk_type, keywords in INTENT_PATTERNS.items():
|
||||
score = sum(1 for kw in keywords if kw in query_lower)
|
||||
scores[chunk_type] = score
|
||||
|
||||
# Return the chunk type with highest score, default to branch_weekly
|
||||
if max(scores.values()) > 0:
|
||||
return max(scores, key=scores.get)
|
||||
|
||||
return "branch_weekly"
|
||||
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Clawrity — RAG Vector Store
|
||||
|
||||
Embeds chunks using sentence-transformers all-MiniLM-L6-v2 (CPU, 384 dims).
|
||||
Stores and searches via pgvector in PostgreSQL.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from rag.chunker import Chunk
|
||||
from skills.postgres_connector import get_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_model = None
|
||||
|
||||
|
||||
def _get_embedding_model():
|
||||
"""Lazy-load the embedding model (CPU only, ~90MB)."""
|
||||
global _model
|
||||
if _model is None:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
_model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
logger.info("Loaded embedding model: all-MiniLM-L6-v2 (384 dims)")
|
||||
return _model
|
||||
|
||||
|
||||
def embed_texts(texts: List[str], batch_size: int = 100) -> np.ndarray:
|
||||
"""
|
||||
Embed a list of texts using MiniLM.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to embed
|
||||
batch_size: Batch size for encoding (default 100)
|
||||
|
||||
Returns:
|
||||
numpy array of shape (len(texts), 384)
|
||||
"""
|
||||
model = _get_embedding_model()
|
||||
embeddings = model.encode(
|
||||
texts,
|
||||
batch_size=batch_size,
|
||||
show_progress_bar=len(texts) > 100,
|
||||
normalize_embeddings=True,
|
||||
)
|
||||
logger.info(f"Embedded {len(texts)} texts → shape {embeddings.shape}")
|
||||
return embeddings
|
||||
|
||||
|
||||
def embed_query(query: str) -> np.ndarray:
|
||||
"""Embed a single query string."""
|
||||
model = _get_embedding_model()
|
||||
return model.encode(query, normalize_embeddings=True)
|
||||
|
||||
|
||||
def store_chunks(chunks: List[Chunk], embeddings: np.ndarray):
|
||||
"""
|
||||
Upsert chunks + embeddings into pgvector.
|
||||
Uses ON CONFLICT DO UPDATE for safe nightly re-indexing.
|
||||
"""
|
||||
seen = set()
|
||||
unique_chunks = []
|
||||
unique_embeddings = []
|
||||
for chunk, emb in zip(chunks, embeddings):
|
||||
if chunk.id not in seen:
|
||||
seen.add(chunk.id)
|
||||
unique_chunks.append(chunk)
|
||||
unique_embeddings.append(emb)
|
||||
chunks = unique_chunks
|
||||
embeddings = unique_embeddings
|
||||
|
||||
db = get_connector()
|
||||
|
||||
data = []
|
||||
for chunk, embedding in zip(chunks, embeddings):
|
||||
data.append({
|
||||
"id": chunk.id,
|
||||
"client_id": chunk.client_id,
|
||||
"chunk_type": chunk.chunk_type,
|
||||
"text": chunk.text,
|
||||
"metadata": chunk.metadata,
|
||||
"embedding": embedding.tolist(),
|
||||
})
|
||||
|
||||
# Batch upsert
|
||||
batch_size = 100
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i + batch_size]
|
||||
db.upsert_embeddings(batch)
|
||||
|
||||
logger.info(f"Stored {len(data)} chunks in pgvector")
|
||||
|
||||
# Try to create IVFFlat index (needs enough rows)
|
||||
try:
|
||||
db.create_vector_index()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def search(
|
||||
query: str,
|
||||
client_id: str,
|
||||
chunk_type: Optional[str] = None,
|
||||
top_k: int = 5,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Search pgvector for similar chunks.
|
||||
|
||||
Args:
|
||||
query: Natural language query
|
||||
client_id: Client to search within
|
||||
chunk_type: Optional filter (branch_weekly, channel_monthly, trend_qoq)
|
||||
top_k: Number of results
|
||||
|
||||
Returns:
|
||||
List of dicts with text, metadata, similarity
|
||||
"""
|
||||
query_embedding = embed_query(query)
|
||||
db = get_connector()
|
||||
|
||||
results = db.search_embeddings(
|
||||
query_embedding=query_embedding,
|
||||
client_id=client_id,
|
||||
chunk_type=chunk_type,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Vector search: query='{query[:50]}...', "
|
||||
f"chunk_type={chunk_type}, results={len(results)}"
|
||||
)
|
||||
return results
|
||||
@@ -0,0 +1,42 @@
|
||||
# === Core Framework ===
|
||||
fastapi>=0.115.0
|
||||
uvicorn[standard]>=0.30.0
|
||||
python-dotenv>=1.0.0
|
||||
|
||||
# === LLM ===
|
||||
groq>=0.11.0
|
||||
|
||||
# === Embeddings (CPU only — all-MiniLM-L6-v2, 384 dims, ~90MB) ===
|
||||
sentence-transformers>=3.0.0
|
||||
|
||||
# === Database — PostgreSQL + pgvector ===
|
||||
psycopg2-binary>=2.9.9
|
||||
pgvector>=0.3.0
|
||||
asyncpg>=0.29.0
|
||||
|
||||
# === Channel — Slack (Socket Mode) ===
|
||||
slack-bolt>=1.20.0
|
||||
|
||||
# === Scheduler ===
|
||||
apscheduler>=3.10.0
|
||||
|
||||
# === Web Search (Scout Agent) ===
|
||||
tavily-python>=0.5.0
|
||||
duckduckgo-search>=6.0.0
|
||||
|
||||
# === Forecasting ===
|
||||
prophet>=1.1.5
|
||||
|
||||
# === Data Processing ===
|
||||
pandas>=2.2.0
|
||||
numpy>=1.26.0
|
||||
openpyxl>=3.1.0
|
||||
faker>=28.0.0
|
||||
|
||||
# === Config ===
|
||||
pydantic>=2.9.0
|
||||
pydantic-settings>=2.5.0
|
||||
pyyaml>=6.0.2
|
||||
|
||||
# === HTTP Client ===
|
||||
httpx>=0.27.0
|
||||
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
Clawrity — RAG Pipeline Script
|
||||
|
||||
CLI to run the full RAG pipeline: preprocess → chunk → embed → store in pgvector.
|
||||
|
||||
Usage:
|
||||
python scripts/run_rag_pipeline.py --client_id acme_corp
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from rag.preprocessor import preprocess_for_rag
|
||||
from rag.chunker import generate_chunks
|
||||
from rag.vector_store import embed_texts, store_chunks
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_pipeline(client_id: str, days: int = 365):
|
||||
"""Run the full RAG pipeline for a client."""
|
||||
logger.info(f"=== RAG Pipeline: {client_id} ===")
|
||||
|
||||
# Step 1: Preprocess
|
||||
logger.info("Step 1/4: Preprocessing data...")
|
||||
df = preprocess_for_rag(client_id, days=days)
|
||||
if df.empty:
|
||||
logger.error("No data to process. Run seed_demo_data.py first.")
|
||||
return
|
||||
|
||||
# Step 2: Generate chunks
|
||||
logger.info("Step 2/4: Generating chunks...")
|
||||
chunks = generate_chunks(df, client_id)
|
||||
logger.info(f"Generated {len(chunks)} chunks")
|
||||
|
||||
if not chunks:
|
||||
logger.error("No chunks generated.")
|
||||
return
|
||||
|
||||
# Step 3: Embed
|
||||
logger.info("Step 3/4: Embedding chunks (CPU, batch_size=100)...")
|
||||
texts = [c.text for c in chunks]
|
||||
embeddings = embed_texts(texts, batch_size=100)
|
||||
|
||||
# Step 4: Store in pgvector
|
||||
logger.info("Step 4/4: Upserting into pgvector...")
|
||||
store_chunks(chunks, embeddings)
|
||||
|
||||
logger.info(f"=== RAG Pipeline complete: {len(chunks)} chunks indexed ===")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run RAG pipeline")
|
||||
parser.add_argument("--client_id", required=True, help="Client ID")
|
||||
parser.add_argument("--days", type=int, default=365, help="Days of data to process")
|
||||
args = parser.parse_args()
|
||||
|
||||
run_pipeline(args.client_id, args.days)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
Clawrity — Demo Data Seeder
|
||||
|
||||
Merges Global Superstore + Marketing Campaign datasets with Faker gap-filling.
|
||||
Inserts into PostgreSQL spend_data table.
|
||||
|
||||
Usage:
|
||||
python scripts/seed_demo_data.py --client_id acme_corp \
|
||||
--superstore data/raw/Global_Superstore2.csv \
|
||||
--marketing data/raw/marketing_campaign_dataset.csv
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
import sys
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from faker import Faker
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from connectors.csv_connector import CSVConnector
|
||||
from etl.normaliser import normalise_dataframe
|
||||
from skills.postgres_connector import PostgresConnector
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
fake = Faker()
|
||||
Faker.seed(42)
|
||||
random.seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
# Marketing channels to assign
|
||||
CHANNELS = ["Paid Search", "Social Media", "Email", "Display", "Organic", "Referral"]
|
||||
|
||||
# Column mapping for Global Superstore
|
||||
SUPERSTORE_MAPPING = {
|
||||
"Order Date": "date",
|
||||
"Country": "country",
|
||||
"City": "branch",
|
||||
"Sales": "revenue",
|
||||
"Profit": "profit",
|
||||
}
|
||||
|
||||
|
||||
def load_superstore(path: str) -> pd.DataFrame:
|
||||
"""Load and normalize the Global Superstore dataset."""
|
||||
connector = CSVConnector()
|
||||
df = connector.load(path)
|
||||
logger.info(f"Superstore columns: {list(df.columns)}")
|
||||
|
||||
# Apply column mapping
|
||||
df = normalise_dataframe(df, SUPERSTORE_MAPPING)
|
||||
|
||||
# Keep only needed columns
|
||||
keep = ["date", "country", "branch", "revenue", "profit"]
|
||||
available = [c for c in keep if c in df.columns]
|
||||
df = df[available].copy()
|
||||
|
||||
logger.info(f"Superstore: {len(df)} rows after normalisation")
|
||||
return df
|
||||
|
||||
|
||||
def load_marketing(path: str) -> pd.DataFrame:
|
||||
"""Load the Marketing Campaign Performance dataset."""
|
||||
connector = CSVConnector()
|
||||
df = connector.load(path)
|
||||
logger.info(f"Marketing columns: {list(df.columns)}")
|
||||
|
||||
# Standardize column names
|
||||
col_map = {}
|
||||
for col in df.columns:
|
||||
cl = col.lower().strip()
|
||||
if "channel" in cl:
|
||||
col_map[col] = "channel"
|
||||
elif "spend" in cl or "budget" in cl:
|
||||
col_map[col] = "spend"
|
||||
elif "click" in cl:
|
||||
col_map[col] = "leads"
|
||||
elif "conversion" in cl:
|
||||
col_map[col] = "conversions"
|
||||
elif "roi" in cl:
|
||||
col_map[col] = "roi_raw"
|
||||
elif "impression" in cl:
|
||||
col_map[col] = "impressions"
|
||||
|
||||
df = df.rename(columns=col_map)
|
||||
logger.info(f"Marketing: {len(df)} rows, mapped columns: {list(df.columns)}")
|
||||
return df
|
||||
|
||||
|
||||
def merge_datasets(superstore: pd.DataFrame, marketing: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Merge superstore (base) with marketing channel metrics.
|
||||
Each superstore row gets a channel + spend/leads/conversions.
|
||||
"""
|
||||
df = superstore.copy()
|
||||
|
||||
# Assign channels proportionally from marketing data
|
||||
if "channel" in marketing.columns:
|
||||
channel_list = marketing["channel"].dropna().unique().tolist()
|
||||
if not channel_list:
|
||||
channel_list = CHANNELS
|
||||
else:
|
||||
channel_list = CHANNELS
|
||||
|
||||
# Assign channel to each row (deterministic based on index)
|
||||
df["channel"] = [channel_list[i % len(channel_list)] for i in range(len(df))]
|
||||
|
||||
# Build channel-level spend/leads/conversions stats from marketing data
|
||||
channel_stats = {}
|
||||
if "spend" in marketing.columns and "channel" in marketing.columns:
|
||||
for ch in channel_list:
|
||||
ch_data = marketing[marketing["channel"] == ch] if "channel" in marketing.columns else marketing
|
||||
channel_stats[ch] = {
|
||||
"avg_spend": ch_data["spend"].mean() if "spend" in ch_data.columns and len(ch_data) > 0 else 500,
|
||||
"avg_leads": ch_data["leads"].mean() if "leads" in ch_data.columns and len(ch_data) > 0 else 50,
|
||||
"avg_conv": ch_data["conversions"].mean() if "conversions" in ch_data.columns and len(ch_data) > 0 else 5,
|
||||
}
|
||||
|
||||
# Fill spend, leads, conversions using marketing stats + Faker variation
|
||||
spends, leads_list, conv_list = [], [], []
|
||||
for _, row in df.iterrows():
|
||||
ch = row["channel"]
|
||||
stats = channel_stats.get(ch, {"avg_spend": 500, "avg_leads": 50, "avg_conv": 5})
|
||||
|
||||
rev = row.get("revenue", 1000)
|
||||
# Spend: proportion of revenue with channel-based variation
|
||||
spend = max(10, rev * random.uniform(0.3, 0.6) + random.gauss(0, stats["avg_spend"] * 0.1))
|
||||
leads = max(1, int(spend / random.uniform(15, 40)))
|
||||
conversions = max(0, int(leads * random.uniform(0.05, 0.20)))
|
||||
|
||||
spends.append(round(spend, 2))
|
||||
leads_list.append(leads)
|
||||
conv_list.append(conversions)
|
||||
|
||||
df["spend"] = spends
|
||||
df["leads"] = leads_list
|
||||
df["conversions"] = conv_list
|
||||
|
||||
# Drop profit column (not in spend_data schema)
|
||||
if "profit" in df.columns:
|
||||
df = df.drop(columns=["profit"])
|
||||
|
||||
logger.info(f"Merged dataset: {len(df)} rows, columns: {list(df.columns)}")
|
||||
return df
|
||||
|
||||
|
||||
def seed_to_postgres(df: pd.DataFrame, client_id: str):
|
||||
"""Insert merged data into PostgreSQL spend_data table."""
|
||||
connector = PostgresConnector()
|
||||
connector.init_schema()
|
||||
|
||||
# Clear existing data for this client
|
||||
connector.execute_write(
|
||||
"DELETE FROM spend_data WHERE client_id = %s", (client_id,)
|
||||
)
|
||||
logger.info(f"Cleared existing data for client: {client_id}")
|
||||
|
||||
# Add client_id column
|
||||
df["client_id"] = client_id
|
||||
|
||||
# Prepare batch insert
|
||||
sql = """
|
||||
INSERT INTO spend_data (date, country, branch, channel, spend, revenue, leads, conversions, client_id)
|
||||
VALUES %s
|
||||
"""
|
||||
data = [
|
||||
(
|
||||
row["date"], row["country"], row["branch"], row["channel"],
|
||||
row["spend"], row["revenue"], row["leads"], row["conversions"],
|
||||
row["client_id"]
|
||||
)
|
||||
for _, row in df.iterrows()
|
||||
]
|
||||
|
||||
connector.execute_batch(sql, data, page_size=2000)
|
||||
|
||||
count = connector.get_table_count("spend_data", client_id)
|
||||
logger.info(f"Seeded {count} rows into spend_data for client: {client_id}")
|
||||
|
||||
# Save processed CSV
|
||||
os.makedirs("data/processed", exist_ok=True)
|
||||
output_path = f"data/processed/{client_id}_merged.csv"
|
||||
df.to_csv(output_path, index=False)
|
||||
logger.info(f"Saved processed data to {output_path}")
|
||||
|
||||
connector.close()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Seed demo data into PostgreSQL")
|
||||
parser.add_argument("--client_id", default="acme_corp", help="Client ID")
|
||||
parser.add_argument("--superstore", required=True, help="Path to Global Superstore CSV/XLSX")
|
||||
parser.add_argument("--marketing", required=True, help="Path to Marketing Campaign CSV")
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(f"=== Seeding data for client: {args.client_id} ===")
|
||||
|
||||
superstore = load_superstore(args.superstore)
|
||||
marketing = load_marketing(args.marketing)
|
||||
merged = merge_datasets(superstore, marketing)
|
||||
seed_to_postgres(merged, args.client_id)
|
||||
|
||||
logger.info("=== Seeding complete ===")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Clawrity — NL-to-SQL Engine
|
||||
|
||||
Converts natural language questions into valid PostgreSQL SELECT queries.
|
||||
Uses LLM at temperature 0.1 for deterministic SQL generation.
|
||||
Safety: Only SELECT queries allowed. INSERT/UPDATE/DELETE/DROP rejected.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from config.llm_client import get_llm_client, get_model_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Dangerous SQL patterns — reject anything that isn't a SELECT
|
||||
UNSAFE_PATTERNS = re.compile(
|
||||
r"\b(INSERT|UPDATE|DELETE|DROP|ALTER|TRUNCATE|CREATE|GRANT|REVOKE|EXEC)\b",
|
||||
re.IGNORECASE
|
||||
)
|
||||
|
||||
SYSTEM_PROMPT = """You are a PostgreSQL SQL generator. Generate ONLY a valid SELECT query.
|
||||
Return ONLY the raw SQL — no markdown, no explanation, no code fences.
|
||||
|
||||
Table: spend_data
|
||||
Columns:
|
||||
- id: SERIAL PRIMARY KEY
|
||||
- date: DATE
|
||||
- country: VARCHAR(100)
|
||||
- branch: VARCHAR(100)
|
||||
- channel: VARCHAR(100)
|
||||
- spend: FLOAT
|
||||
- revenue: FLOAT
|
||||
- leads: INT
|
||||
- conversions: INT
|
||||
- client_id: VARCHAR(100)
|
||||
|
||||
Available countries: {countries}
|
||||
Available branches (sample): {branches}
|
||||
Available channels: {channels}
|
||||
Date range: {date_min} to {date_max}
|
||||
|
||||
RULES:
|
||||
1. ALWAYS include WHERE client_id = '{client_id}' in your queries
|
||||
2. Use standard PostgreSQL syntax
|
||||
3. For date ranges, use DATE type comparisons
|
||||
4. For "last N days", use: date >= CURRENT_DATE - INTERVAL '{n} days'
|
||||
5. For "last month", use: date >= DATE_TRUNC('month', CURRENT_DATE - INTERVAL '1 month')
|
||||
6. Return meaningful aggregations with GROUP BY when appropriate
|
||||
7. Use aliases for computed columns (e.g., SUM(revenue) AS total_revenue)
|
||||
8. LIMIT results to 50 rows maximum unless the user asks for all
|
||||
9. For "bottom N" use ASC ordering, for "top N" use DESC ordering
|
||||
"""
|
||||
|
||||
|
||||
class NLToSQL:
|
||||
"""Natural language to SQL converter using LLM."""
|
||||
|
||||
def __init__(self):
|
||||
self.client = get_llm_client()
|
||||
self.model = get_model_name()
|
||||
|
||||
def generate_sql(
|
||||
self,
|
||||
question: str,
|
||||
client_id: str,
|
||||
schema_metadata: dict,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Convert a natural language question to a PostgreSQL SELECT query.
|
||||
|
||||
Args:
|
||||
question: User's natural language question
|
||||
client_id: Client ID for filtering
|
||||
schema_metadata: Dict with countries, branches, channels, date_min, date_max
|
||||
|
||||
Returns:
|
||||
Valid SQL SELECT string, or None on failure
|
||||
"""
|
||||
# Build the system prompt with schema context
|
||||
system = SYSTEM_PROMPT.format(
|
||||
countries=", ".join(schema_metadata.get("countries", [])[:20]),
|
||||
branches=", ".join(schema_metadata.get("branches", [])[:20]),
|
||||
channels=", ".join(schema_metadata.get("channels", [])),
|
||||
date_min=schema_metadata.get("date_min", "unknown"),
|
||||
date_max=schema_metadata.get("date_max", "unknown"),
|
||||
client_id=client_id,
|
||||
n="7", # Default for interval template
|
||||
)
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": question},
|
||||
],
|
||||
temperature=0.1,
|
||||
max_tokens=1024,
|
||||
)
|
||||
|
||||
raw_sql = response.choices[0].message.content.strip()
|
||||
sql = self._clean_sql(raw_sql)
|
||||
|
||||
if not self._validate_sql(sql):
|
||||
logger.warning(f"Generated SQL failed validation: {sql}")
|
||||
return None
|
||||
|
||||
logger.info(f"Generated SQL: {sql}")
|
||||
return sql
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"NL-to-SQL generation failed: {e}")
|
||||
return None
|
||||
|
||||
def _clean_sql(self, raw: str) -> str:
|
||||
"""Extract SQL from LLM response, stripping markdown code fences."""
|
||||
# Remove markdown code blocks
|
||||
cleaned = re.sub(r"```(?:sql)?\s*", "", raw)
|
||||
cleaned = re.sub(r"```\s*$", "", cleaned)
|
||||
cleaned = cleaned.strip().rstrip(";") + ";"
|
||||
return cleaned
|
||||
|
||||
def _validate_sql(self, sql: str) -> bool:
|
||||
"""Validate that the SQL is a safe SELECT query."""
|
||||
if not sql or len(sql) < 10:
|
||||
return False
|
||||
|
||||
# Must start with SELECT
|
||||
if not sql.strip().upper().startswith("SELECT"):
|
||||
logger.warning("SQL does not start with SELECT")
|
||||
return False
|
||||
|
||||
# Must not contain dangerous operations
|
||||
if UNSAFE_PATTERNS.search(sql):
|
||||
logger.warning("SQL contains unsafe operations")
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
Clawrity — PostgreSQL + pgvector Connector
|
||||
|
||||
Connection pool management, schema initialization, and query execution.
|
||||
Single database handles both structured queries (NL-to-SQL) and vector search (pgvector).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
from pgvector.psycopg2 import register_vector
|
||||
|
||||
from config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema DDL
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
INIT_SCHEMA_SQL = """
|
||||
-- Enable pgvector extension
|
||||
CREATE EXTENSION IF NOT EXISTS vector;
|
||||
|
||||
-- Structured business data (replaces BigQuery)
|
||||
CREATE TABLE IF NOT EXISTS spend_data (
|
||||
id SERIAL PRIMARY KEY,
|
||||
date DATE,
|
||||
country VARCHAR(100),
|
||||
branch VARCHAR(100),
|
||||
channel VARCHAR(100),
|
||||
spend FLOAT,
|
||||
revenue FLOAT,
|
||||
leads INT,
|
||||
conversions INT,
|
||||
client_id VARCHAR(100)
|
||||
);
|
||||
|
||||
-- Vector embeddings (replaces ChromaDB)
|
||||
CREATE TABLE IF NOT EXISTS embeddings (
|
||||
id VARCHAR(200) PRIMARY KEY,
|
||||
client_id VARCHAR(100),
|
||||
chunk_type VARCHAR(50),
|
||||
text TEXT,
|
||||
metadata JSONB,
|
||||
embedding vector(384)
|
||||
);
|
||||
|
||||
-- Forecast cache
|
||||
CREATE TABLE IF NOT EXISTS forecasts (
|
||||
id SERIAL PRIMARY KEY,
|
||||
client_id VARCHAR(100),
|
||||
branch VARCHAR(100),
|
||||
country VARCHAR(100),
|
||||
horizon_months INT,
|
||||
forecast_data JSONB,
|
||||
computed_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- Indexes
|
||||
CREATE INDEX IF NOT EXISTS idx_spend_data_client
|
||||
ON spend_data (client_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_spend_data_date
|
||||
ON spend_data (client_id, date);
|
||||
CREATE INDEX IF NOT EXISTS idx_embeddings_client_type
|
||||
ON embeddings (client_id, chunk_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_forecasts_client
|
||||
ON forecasts (client_id, branch, country);
|
||||
"""
|
||||
|
||||
# IVFFlat index requires rows to exist — created separately after data load
|
||||
IVFFLAT_INDEX_SQL = """
|
||||
CREATE INDEX IF NOT EXISTS idx_embeddings_cosine
|
||||
ON embeddings USING ivfflat (embedding vector_cosine_ops)
|
||||
WITH (lists = 100);
|
||||
"""
|
||||
|
||||
|
||||
class PostgresConnector:
|
||||
"""PostgreSQL + pgvector connection manager."""
|
||||
|
||||
def __init__(self, database_url: Optional[str] = None):
|
||||
self.database_url = database_url or get_settings().database_url
|
||||
self._conn: Optional[psycopg2.extensions.connection] = None
|
||||
|
||||
def _get_connection(self) -> psycopg2.extensions.connection:
|
||||
"""Get or create a database connection with retry logic."""
|
||||
if self._conn is None or self._conn.closed:
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
self._conn = psycopg2.connect(self.database_url)
|
||||
register_vector(self._conn)
|
||||
logger.info("Connected to PostgreSQL with pgvector support")
|
||||
return self._conn
|
||||
except psycopg2.OperationalError as e:
|
||||
wait = 2**attempt
|
||||
logger.warning(
|
||||
f"DB connection attempt {attempt + 1}/{max_retries} failed: {e}. "
|
||||
f"Retrying in {wait}s..."
|
||||
)
|
||||
time.sleep(wait)
|
||||
raise ConnectionError("Failed to connect to PostgreSQL after 3 attempts")
|
||||
return self._conn
|
||||
|
||||
def close(self):
|
||||
"""Close the database connection."""
|
||||
if self._conn and not self._conn.closed:
|
||||
self._conn.close()
|
||||
logger.info("PostgreSQL connection closed")
|
||||
|
||||
def init_schema(self):
|
||||
"""Create tables and extensions if they don't exist."""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(INIT_SCHEMA_SQL)
|
||||
conn.commit()
|
||||
logger.info("Database schema initialized successfully")
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logger.error(f"Schema initialization failed: {e}")
|
||||
raise
|
||||
|
||||
def create_vector_index(self):
|
||||
"""Create IVFFlat index — call AFTER data has been loaded into embeddings."""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(IVFFLAT_INDEX_SQL)
|
||||
conn.commit()
|
||||
logger.info("IVFFlat vector index created")
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logger.warning(f"Could not create IVFFlat index (may need more rows): {e}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Query execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def execute_query(self, sql: str, params: Optional[tuple] = None) -> pd.DataFrame:
|
||||
"""
|
||||
Execute a SELECT query and return results as a DataFrame.
|
||||
|
||||
Args:
|
||||
sql: SQL query string (must be SELECT only)
|
||||
params: Query parameters for parameterised queries
|
||||
|
||||
Returns:
|
||||
pandas DataFrame with query results
|
||||
"""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
df = pd.read_sql_query(sql, conn, params=params)
|
||||
conn.rollback()
|
||||
logger.debug(f"Query returned {len(df)} rows")
|
||||
return df
|
||||
except Exception as e:
|
||||
logger.error(f"Query execution failed: {e}")
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
def execute_raw(self, sql: str, params: Optional[tuple] = None) -> List[Dict]:
|
||||
"""Execute a query and return raw dictionaries."""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(sql, params)
|
||||
if cur.description:
|
||||
results = [dict(row) for row in cur.fetchall()]
|
||||
conn.rollback()
|
||||
return results
|
||||
conn.commit()
|
||||
return []
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logger.error(f"Raw query execution failed: {e}")
|
||||
raise
|
||||
|
||||
def execute_write(self, sql: str, params: Optional[tuple] = None):
|
||||
"""Execute an INSERT/UPDATE/DELETE statement."""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, params)
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logger.error(f"Write execution failed: {e}")
|
||||
raise
|
||||
|
||||
def execute_batch(self, sql: str, data: List[tuple], page_size: int = 1000):
|
||||
"""Execute a batch INSERT using execute_values for performance."""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
psycopg2.extras.execute_values(cur, sql, data, page_size=page_size)
|
||||
conn.commit()
|
||||
logger.info(f"Batch insert: {len(data)} rows")
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logger.error(f"Batch execution failed: {e}")
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# pgvector operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def upsert_embeddings(self, embeddings_data: List[Dict[str, Any]]):
|
||||
"""
|
||||
Upsert embedding records into the embeddings table.
|
||||
|
||||
Args:
|
||||
embeddings_data: List of dicts with keys:
|
||||
id, client_id, chunk_type, text, metadata, embedding
|
||||
"""
|
||||
conn = self._get_connection()
|
||||
sql = """
|
||||
INSERT INTO embeddings (id, client_id, chunk_type, text, metadata, embedding)
|
||||
VALUES %s
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
text = EXCLUDED.text,
|
||||
metadata = EXCLUDED.metadata,
|
||||
embedding = EXCLUDED.embedding
|
||||
"""
|
||||
data = [
|
||||
(
|
||||
d["id"],
|
||||
d["client_id"],
|
||||
d["chunk_type"],
|
||||
d["text"],
|
||||
psycopg2.extras.Json(d["metadata"]),
|
||||
np.array(d["embedding"]),
|
||||
)
|
||||
for d in embeddings_data
|
||||
]
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
psycopg2.extras.execute_values(cur, sql, data, page_size=100)
|
||||
conn.commit()
|
||||
logger.info(f"Upserted {len(data)} embeddings")
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logger.error(f"Embedding upsert failed: {e}")
|
||||
raise
|
||||
|
||||
def search_embeddings(
|
||||
self,
|
||||
query_embedding: np.ndarray,
|
||||
client_id: str,
|
||||
chunk_type: Optional[str] = None,
|
||||
top_k: int = 5,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Search for similar embeddings using pgvector cosine similarity.
|
||||
|
||||
Args:
|
||||
query_embedding: Query vector (384 dims)
|
||||
client_id: Filter by client
|
||||
chunk_type: Optional filter by chunk type
|
||||
top_k: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of dicts with text, metadata, and similarity score
|
||||
"""
|
||||
conn = self._get_connection()
|
||||
query_vec = np.array(query_embedding)
|
||||
|
||||
if chunk_type:
|
||||
sql = """
|
||||
SELECT text, metadata, 1 - (embedding <=> %s) AS similarity
|
||||
FROM embeddings
|
||||
WHERE client_id = %s AND chunk_type = %s
|
||||
ORDER BY embedding <=> %s
|
||||
LIMIT %s
|
||||
"""
|
||||
params = (query_vec, client_id, chunk_type, query_vec, top_k)
|
||||
else:
|
||||
sql = """
|
||||
SELECT text, metadata, 1 - (embedding <=> %s) AS similarity
|
||||
FROM embeddings
|
||||
WHERE client_id = %s
|
||||
ORDER BY embedding <=> %s
|
||||
LIMIT %s
|
||||
"""
|
||||
params = (query_vec, client_id, query_vec, top_k)
|
||||
|
||||
try:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(sql, params)
|
||||
results = [dict(row) for row in cur.fetchall()]
|
||||
logger.debug(f"Vector search returned {len(results)} results")
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Vector search failed: {e}")
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Utility
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_table_count(self, table: str, client_id: Optional[str] = None) -> int:
|
||||
"""Get row count for a table, optionally filtered by client_id."""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
if client_id:
|
||||
cur.execute(
|
||||
f"SELECT COUNT(*) FROM {table} WHERE client_id = %s",
|
||||
(client_id,),
|
||||
)
|
||||
else:
|
||||
cur.execute(f"SELECT COUNT(*) FROM {table}")
|
||||
return cur.fetchone()[0]
|
||||
except Exception as e:
|
||||
logger.error(f"Count query failed: {e}")
|
||||
return 0
|
||||
|
||||
def get_spend_data_schema(self, client_id: str) -> Dict:
|
||||
"""Get metadata about available data for a client — used by NL-to-SQL."""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT DISTINCT country FROM spend_data WHERE client_id = %s ORDER BY country",
|
||||
(client_id,),
|
||||
)
|
||||
countries = [row[0] for row in cur.fetchall()]
|
||||
|
||||
cur.execute(
|
||||
"SELECT DISTINCT branch FROM spend_data WHERE client_id = %s ORDER BY branch",
|
||||
(client_id,),
|
||||
)
|
||||
branches = [row[0] for row in cur.fetchall()]
|
||||
|
||||
cur.execute(
|
||||
"SELECT DISTINCT channel FROM spend_data WHERE client_id = %s ORDER BY channel",
|
||||
(client_id,),
|
||||
)
|
||||
channels = [row[0] for row in cur.fetchall()]
|
||||
|
||||
cur.execute(
|
||||
"SELECT MIN(date), MAX(date) FROM spend_data WHERE client_id = %s",
|
||||
(client_id,),
|
||||
)
|
||||
date_range = cur.fetchone()
|
||||
|
||||
return {
|
||||
"countries": countries,
|
||||
"branches": branches,
|
||||
"channels": channels,
|
||||
"date_min": str(date_range[0]) if date_range[0] else None,
|
||||
"date_max": str(date_range[1]) if date_range[1] else None,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Schema metadata query failed: {e}")
|
||||
return {
|
||||
"countries": [],
|
||||
"branches": [],
|
||||
"channels": [],
|
||||
"date_min": None,
|
||||
"date_max": None,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_connector: Optional[PostgresConnector] = None
|
||||
|
||||
|
||||
def get_connector() -> PostgresConnector:
|
||||
"""Get the shared PostgresConnector singleton."""
|
||||
global _connector
|
||||
if _connector is None:
|
||||
_connector = PostgresConnector()
|
||||
return _connector
|
||||
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
Clawrity — Web Search Skill
|
||||
|
||||
Primary: Tavily API (clean, summarised results built for LLM agents)
|
||||
Fallback: duckduckgo-search (no API key, no rate limits, free)
|
||||
|
||||
Auto-fallback: if Tavily errors or quota exceeded, silently switch to DuckDuckGo.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def web_search(
|
||||
query: str,
|
||||
max_results: int = 5,
|
||||
lookback_days: int = 1,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Search the web using Tavily (primary) or DuckDuckGo (fallback).
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
max_results: Maximum number of results
|
||||
lookback_days: Only keep results from the last N days
|
||||
|
||||
Returns:
|
||||
List of dicts with: title, url, content, date
|
||||
"""
|
||||
results = _tavily_search(query, max_results)
|
||||
|
||||
if not results:
|
||||
logger.info("Tavily returned no results, falling back to DuckDuckGo")
|
||||
results = _ddg_search(query, max_results)
|
||||
|
||||
# Filter by recency
|
||||
if lookback_days > 0:
|
||||
results = _filter_recent(results, lookback_days)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _tavily_search(query: str, max_results: int = 5) -> List[Dict]:
|
||||
"""Search using Tavily API."""
|
||||
settings = get_settings()
|
||||
|
||||
if not settings.tavily_api_key:
|
||||
logger.info("Tavily API key not configured, skipping")
|
||||
return []
|
||||
|
||||
try:
|
||||
from tavily import TavilyClient
|
||||
|
||||
client = TavilyClient(api_key=settings.tavily_api_key)
|
||||
response = client.search(
|
||||
query=query,
|
||||
search_depth="advanced",
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
results = []
|
||||
for item in response.get("results", []):
|
||||
results.append({
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"content": item.get("content", ""),
|
||||
"date": item.get("published_date", ""),
|
||||
"source": "tavily",
|
||||
})
|
||||
|
||||
logger.info(f"Tavily returned {len(results)} results for: {query[:50]}")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Tavily search failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _ddg_search(query: str, max_results: int = 5) -> List[Dict]:
|
||||
"""Search using DuckDuckGo (fallback — no API key needed)."""
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
results = []
|
||||
with DDGS() as ddgs:
|
||||
for r in ddgs.text(query, max_results=max_results):
|
||||
results.append({
|
||||
"title": r.get("title", ""),
|
||||
"url": r.get("href", ""),
|
||||
"content": r.get("body", ""),
|
||||
"date": "",
|
||||
"source": "duckduckgo",
|
||||
})
|
||||
|
||||
logger.info(f"DuckDuckGo returned {len(results)} results for: {query[:50]}")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"DuckDuckGo search failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _filter_recent(results: List[Dict], lookback_days: int) -> List[Dict]:
|
||||
"""Filter results to only include items from the last N days."""
|
||||
if not results:
|
||||
return results
|
||||
|
||||
cutoff = datetime.utcnow() - timedelta(days=lookback_days)
|
||||
filtered = []
|
||||
|
||||
for r in results:
|
||||
date_str = r.get("date", "")
|
||||
if not date_str:
|
||||
# No date info — include it (benefit of the doubt)
|
||||
filtered.append(r)
|
||||
continue
|
||||
|
||||
try:
|
||||
# Try common date formats
|
||||
for fmt in ("%Y-%m-%dT%H:%M:%S", "%Y-%m-%d", "%B %d, %Y"):
|
||||
try:
|
||||
dt = datetime.strptime(date_str[:19], fmt)
|
||||
if dt >= cutoff:
|
||||
filtered.append(r)
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
else:
|
||||
# Can't parse date, include it
|
||||
filtered.append(r)
|
||||
except Exception:
|
||||
filtered.append(r)
|
||||
|
||||
return filtered
|
||||
@@ -0,0 +1,17 @@
|
||||
# SOUL — ACME Corporation
|
||||
|
||||
## Identity
|
||||
You are Clawrity, ACME's business intelligence assistant.
|
||||
Speak professionally but conversationally.
|
||||
Always ground answers in data. Never speculate.
|
||||
|
||||
## Business Context
|
||||
- Operates in: US, Canada, MENA
|
||||
- Primary metric: Revenue per lead
|
||||
- Risk tolerance: Conservative (max 15% budget reallocation per suggestion)
|
||||
|
||||
## Rules
|
||||
- If data unavailable, say "I don't have that data right now"
|
||||
- Always surface bottom 3 branches in daily digests
|
||||
- Budget suggestions must cite specific historical data points
|
||||
- Never compare to competitors by name unless from Scout Agent
|
||||
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
Clawrity — SOUL Loader
|
||||
|
||||
Reads the SOUL.md file for a client and returns raw text for prompt injection.
|
||||
SOUL.md defines the AI's personality, business context, and rules per client.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from config.client_loader import ClientConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_soul(client_config: ClientConfig) -> str:
|
||||
"""
|
||||
Load the SOUL.md content for a client.
|
||||
|
||||
Args:
|
||||
client_config: The client's configuration containing soul_file path.
|
||||
|
||||
Returns:
|
||||
Raw markdown text of the SOUL file, or a default prompt if file not found.
|
||||
"""
|
||||
soul_path = Path(client_config.soul_file)
|
||||
|
||||
if not soul_path.exists():
|
||||
logger.warning(
|
||||
f"SOUL file not found at {soul_path} for client {client_config.client_id}. "
|
||||
f"Using default personality."
|
||||
)
|
||||
return _default_soul(client_config)
|
||||
|
||||
try:
|
||||
content = soul_path.read_text(encoding="utf-8")
|
||||
logger.info(f"Loaded SOUL for {client_config.client_id} from {soul_path}")
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading SOUL file {soul_path}: {e}")
|
||||
return _default_soul(client_config)
|
||||
|
||||
|
||||
def _default_soul(client_config: ClientConfig) -> str:
|
||||
"""Generate a minimal default SOUL if the file is missing."""
|
||||
return f"""# SOUL — {client_config.client_name}
|
||||
|
||||
## Identity
|
||||
You are Clawrity, {client_config.client_name}'s business intelligence assistant.
|
||||
Speak professionally. Always ground answers in data. Never speculate.
|
||||
|
||||
## Rules
|
||||
- If data unavailable, say "I don't have that data right now"
|
||||
- Always cite specific data points in your responses
|
||||
"""
|
||||
Reference in New Issue
Block a user