mirror of
https://github.com/Manoj-HV30/clawrity.git
synced 2026-05-16 19:35:21 +00:00
174 lines
5.4 KiB
Python
174 lines
5.4 KiB
Python
"""
|
|
Clawrity — Prophet Forecasting Engine
|
|
|
|
Trains Prophet models on branch-level monthly revenue time series.
|
|
Forecasts 6 months ahead. Caches results in PostgreSQL forecasts table.
|
|
|
|
Limitations (be explicit):
|
|
- Predicts revenue TRENDS only
|
|
- Does NOT claim ROI-per-dollar forecasting (spend→revenue is approximate)
|
|
- Requires minimum 2 years of data per branch
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional
|
|
|
|
import pandas as pd
|
|
|
|
from skills.postgres_connector import get_connector
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MIN_MONTHS = 24 # Minimum 2 years of data
|
|
FORECAST_MONTHS = 6
|
|
|
|
|
|
class ProphetEngine:
|
|
"""Time series forecasting using Facebook Prophet."""
|
|
|
|
def train_and_forecast(self, client_id: str) -> List[Dict]:
|
|
"""
|
|
Train Prophet models for each branch and cache forecasts.
|
|
|
|
Args:
|
|
client_id: Client to forecast for
|
|
|
|
Returns:
|
|
List of forecast result dicts (one per branch)
|
|
"""
|
|
from prophet import Prophet
|
|
|
|
db = get_connector()
|
|
|
|
# Get monthly revenue per branch
|
|
sql = """
|
|
SELECT branch, country,
|
|
DATE_TRUNC('month', date) AS month,
|
|
SUM(revenue) AS monthly_revenue
|
|
FROM spend_data
|
|
WHERE client_id = %s
|
|
GROUP BY branch, country, DATE_TRUNC('month', date)
|
|
ORDER BY branch, month
|
|
"""
|
|
df = db.execute_query(sql, (client_id,))
|
|
|
|
if df.empty:
|
|
logger.warning(f"No data for forecasting: {client_id}")
|
|
return []
|
|
|
|
results = []
|
|
branches = df.groupby(["branch", "country"])
|
|
|
|
for (branch, country), group in branches:
|
|
group = group.sort_values("month").reset_index(drop=True)
|
|
|
|
if len(group) < MIN_MONTHS:
|
|
logger.info(
|
|
f"Skipping {branch} ({country}): only {len(group)} months "
|
|
f"(need {MIN_MONTHS})"
|
|
)
|
|
continue
|
|
|
|
try:
|
|
# Prepare Prophet format: ds (date), y (value)
|
|
prophet_df = pd.DataFrame({
|
|
"ds": pd.to_datetime(group["month"]),
|
|
"y": group["monthly_revenue"].astype(float),
|
|
})
|
|
|
|
# Train
|
|
model = Prophet(
|
|
yearly_seasonality=True,
|
|
weekly_seasonality=False,
|
|
daily_seasonality=False,
|
|
)
|
|
model.fit(prophet_df)
|
|
|
|
# Forecast
|
|
future = model.make_future_dataframe(
|
|
periods=FORECAST_MONTHS, freq="MS"
|
|
)
|
|
forecast = model.predict(future)
|
|
|
|
# Extract forecast period only
|
|
forecast_only = forecast.tail(FORECAST_MONTHS)
|
|
|
|
forecast_data = {
|
|
"branch": branch,
|
|
"country": country,
|
|
"horizon_months": FORECAST_MONTHS,
|
|
"dates": forecast_only["ds"].dt.strftime("%Y-%m-%d").tolist(),
|
|
"forecast_revenue": forecast_only["yhat"].round(2).tolist(),
|
|
"lower_bound": forecast_only["yhat_lower"].round(2).tolist(),
|
|
"upper_bound": forecast_only["yhat_upper"].round(2).tolist(),
|
|
"computed_at": datetime.utcnow().isoformat(),
|
|
}
|
|
|
|
# Cache in PostgreSQL
|
|
self._cache_forecast(client_id, forecast_data)
|
|
results.append(forecast_data)
|
|
|
|
logger.info(
|
|
f"Forecast generated for {branch} ({country}): "
|
|
f"{FORECAST_MONTHS} months ahead"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Prophet failed for {branch} ({country}): {e}")
|
|
|
|
logger.info(f"Forecasting complete: {len(results)} branches forecast")
|
|
return results
|
|
|
|
def get_cached_forecast(
|
|
self,
|
|
client_id: str,
|
|
branch: str,
|
|
) -> Optional[Dict]:
|
|
"""Get the most recent cached forecast for a branch."""
|
|
db = get_connector()
|
|
|
|
sql = """
|
|
SELECT forecast_data, computed_at
|
|
FROM forecasts
|
|
WHERE client_id = %s AND branch = %s
|
|
ORDER BY computed_at DESC
|
|
LIMIT 1
|
|
"""
|
|
rows = db.execute_raw(sql, (client_id, branch))
|
|
|
|
if not rows:
|
|
return None
|
|
|
|
row = rows[0]
|
|
data = row["forecast_data"]
|
|
if isinstance(data, str):
|
|
data = json.loads(data)
|
|
|
|
data["computed_at"] = str(row["computed_at"])
|
|
return data
|
|
|
|
def _cache_forecast(self, client_id: str, forecast_data: Dict):
|
|
"""Store forecast in PostgreSQL."""
|
|
db = get_connector()
|
|
|
|
# Delete old forecast for this branch
|
|
db.execute_write(
|
|
"DELETE FROM forecasts WHERE client_id = %s AND branch = %s AND country = %s",
|
|
(client_id, forecast_data["branch"], forecast_data["country"]),
|
|
)
|
|
|
|
# Insert new
|
|
db.execute_write(
|
|
"""INSERT INTO forecasts (client_id, branch, country, horizon_months, forecast_data)
|
|
VALUES (%s, %s, %s, %s, %s)""",
|
|
(
|
|
client_id,
|
|
forecast_data["branch"],
|
|
forecast_data["country"],
|
|
forecast_data["horizon_months"],
|
|
json.dumps(forecast_data),
|
|
),
|
|
)
|