From bf358f724893b36e30296998614da04e37ebfb4c Mon Sep 17 00:00:00 2001 From: TimHoogervorst Date: Mon, 25 May 2026 10:19:50 +0200 Subject: [PATCH 1/5] added quick connect auth from jellyfin, still needs to have some more cleaning before push to prod --- .env.example | 32 +++- .gitignore | 3 +- agents/__init__.py | 1 + agents/media_agent.py | 9 +- api/v1/auth.py | 189 +++++++++++++++++++ auth/__init__.py | 93 ++++++++++ auth/jellyfin.py | 401 ++++++++++++++++++++++++++++++++++++++++ bot/discord_bot.py | 146 ++++++++++++++- core/auth_store.py | 316 +++++++++++++++++++++++++++++++ core/graph.py | 42 ++--- core/state.py | 1 + main.py | 10 +- requirements.txt | 3 +- skills/__init__.py | 27 ++- skills/media_info.py | 10 + skills/watch_history.py | 80 ++++++++ 16 files changed, 1318 insertions(+), 45 deletions(-) create mode 100644 api/v1/auth.py create mode 100644 auth/__init__.py create mode 100644 auth/jellyfin.py create mode 100644 core/auth_store.py create mode 100644 skills/watch_history.py diff --git a/.env.example b/.env.example index fc20b32..8aa184b 100644 --- a/.env.example +++ b/.env.example @@ -1,16 +1,19 @@ -# --------------------------------------------------------------------------- -# Agent Backend — Environment Variables -# Copy this to .env and fill in your values. -# --------------------------------------------------------------------------- +# ============================================================================= +# Agent Bot — Environment Configuration +# ============================================================================= +# Copy this file to .env and fill in your values. +# .env is git-ignored — never commit real secrets. +# --------------------------------------------------------------------------- # LLM — DeepSeek (OpenAI-compatible) +# --------------------------------------------------------------------------- DEEPSEEK_API_KEY=sk-your-deepseek-api-key # --------------------------------------------------------------------------- # Discord Bot # --------------------------------------------------------------------------- DISCORD_BOT_TOKEN=your-discord-bot-token-here -# DISCORD_MAX_HISTORY=7 # optional, defaults to 7 (max past messages per user) +# DISCORD_MAX_HISTORY=7 # optional, defaults to 7 (max past messages per user) # DISCORD_DEFAULT_AGENT=media-agent # optional, which agent the DM bot uses # --------------------------------------------------------------------------- @@ -18,4 +21,21 @@ DISCORD_BOT_TOKEN=your-discord-bot-token-here # --------------------------------------------------------------------------- SEERR_URL=https://seerr.example.com SEERR_API_KEY=your-seerr-api-key -# SEERR_TIMEOUT=30 # optional, defaults to 30 seconds +# SEERR_USERNAME=your-username # alternative: username+password auth +# SEERR_PASSWORD=your-password +# SEERR_TIMEOUT=30 # optional, defaults to 30 seconds + +# --------------------------------------------------------------------------- +# Auth System (Discord ↔ external services) +# --------------------------------------------------------------------------- +# The public-facing URL where users reach this bot's web API. +# Used to build the "Click here to link" URLs sent via Discord DM. +# For local dev: http://localhost:8000 +# For production behind a reverse proxy: https://bot.yourdomain.com +BASE_URL=http://localhost:8000 + +# Where the auth SQLite database lives (relative to project root) +# AUTH_DB_PATH=data/auth.db + +# Link token expiry in minutes (default 10) +# AUTH_TOKEN_EXPIRY=10 diff --git a/.gitignore b/.gitignore index 089e111..c84e18a 100644 --- a/.gitignore +++ b/.gitignore @@ -174,4 +174,5 @@ cython_debug/ # PyPI configuration file .pypirc -.docs/ \ No newline at end of file +.docs/ +data/ \ No newline at end of file diff --git a/agents/__init__.py b/agents/__init__.py index 48ae91a..76df02e 100644 --- a/agents/__init__.py +++ b/agents/__init__.py @@ -65,3 +65,4 @@ def load_all_agents() -> None: import skills.seerr # noqa: F401 import skills.triage # noqa: F401 import skills.easter_eggs # noqa: F401 + import skills.watch_history # noqa: F401 diff --git a/agents/media_agent.py b/agents/media_agent.py index a599b63..c26ab73 100644 --- a/agents/media_agent.py +++ b/agents/media_agent.py @@ -14,11 +14,16 @@ media_agent = Agent( agent_id="media-agent", description="Media assistant — handles movie/TV/subtitle/ticket requests " "via Seerr, Jellyfin, Sonarr, etc.", - skills=["media_info", "seerr", "triage", "easter_eggs"], + skills=["media_info", "seerr", "triage", "easter_eggs", "watch_history"], base_prompt=( "You are a media assistant connected to Seerr and other media services. " "Help users discover, request, and troubleshoot their media library. " - "Use the tools provided to perform real actions." + "Use the tools provided to perform real actions.\n\n" + "## Authentication\n" + "If a tool returns a message saying the user needs to log in first, " + "tell the user to type `/login ` in their DM (e.g. `/login jellyfin`). " + "This opens Quick Connect on their Jellyfin app so they can link their account. " + "Do NOT tell the user you 'can't connect' or 'don't have access' — just relay the login instructions." ), ) diff --git a/api/v1/auth.py b/api/v1/auth.py new file mode 100644 index 0000000..7c2d1a4 --- /dev/null +++ b/api/v1/auth.py @@ -0,0 +1,189 @@ +""" +Auth API — generic endpoints for linking Discord users to external services. + +GET /api/v1/auth/login?service=X&token=Y&discord_id=Z + Validates the link token and serves a service-specific login form. + +POST /api/v1/auth/login + Accepts the form submission, validates credentials against the service, + stores the session, and returns a result page. + +GET /api/v1/auth/status?discord_id=Z + Returns which services are linked for this Discord user. +""" + +from __future__ import annotations + +import logging + +from fastapi import APIRouter, Form, HTTPException, Request +from fastapi.responses import HTMLResponse + +from auth import get_auth_service, list_auth_services +from core import auth_store + +logger = logging.getLogger("api.auth") + +router = APIRouter(prefix="/api/v1/auth", tags=["auth"]) + + +# --------------------------------------------------------------------------- +# GET /auth/login — serve the login form +# --------------------------------------------------------------------------- + +@router.get("/login") +async def login_form( + service: str, + token: str, + discord_id: int, +): + """Validate the one-time link token and return a service-specific login form.""" + + # Validate the token WITHOUT consuming it (the POST will consume it) + result = auth_store.validate_token(token) + if result is None: + raise HTTPException(status_code=400, detail="Invalid or expired link token.") + + uid, svc = result + if uid != discord_id or svc != service: + raise HTTPException(status_code=400, detail="Token does not match the request.") + + # Look up the AuthService + svc_obj = get_auth_service(service) + if svc_obj is None: + raise HTTPException(status_code=404, detail=f"Unknown service: {service}") + + logger.info("Serving login form: user=%s service=%s", discord_id, service) + return HTMLResponse(svc_obj.render_login_form(token, discord_id)) + + +# --------------------------------------------------------------------------- +# POST /auth/login — handle form submission +# --------------------------------------------------------------------------- + +@router.post("/login") +async def login_submit(request: Request): + """Handle the login form POST: validate credentials, store auth, show result.""" + + # Parse form data + form = await request.form() + token = form.get("token", "") + discord_id_str = form.get("discord_id", "") + service = form.get("service", "") + + if not token or not discord_id_str or not service: + raise HTTPException(status_code=400, detail="Missing required fields.") + + try: + discord_id = int(discord_id_str) + except (ValueError, TypeError): + raise HTTPException(status_code=400, detail="Invalid discord_id.") + + # Consume the token on POST (the GET only validated, didn't consume) + result = auth_store.consume_token(token) + if result is None: + raise HTTPException(status_code=400, detail="Invalid or expired link token.") + + # Look up the AuthService + svc_obj = get_auth_service(service) + if svc_obj is None: + raise HTTPException(status_code=404, detail=f"Unknown service: {service}") + + # Collect service-specific form fields (everything except token, discord_id, service) + form_data: dict[str, str] = {} + for key, value in form.items(): + if key not in ("token", "discord_id", "service"): + form_data[key] = str(value) + + # Authenticate against the service + auth_result = await svc_obj.authenticate(form_data) + + if not auth_result.success: + return HTMLResponse( + status_code=401, + content=f""" +Login Failed + +

❌ Login Failed

+

{auth_result.error_message or "Authentication failed. Please try again."}

+

← Go back and try again

+""", + ) + + # Store the successful auth + auth_store.store_auth( + discord_user_id=discord_id, + service=service, + external_user_id=auth_result.external_user_id or "", + external_name=auth_result.external_name or "", + credentials=auth_result.credentials, + ) + + logger.info( + "Auth linked: discord=%s → %s (%s)", + discord_id, + service, + auth_result.external_name, + ) + + return HTMLResponse(f""" + + + + +Account Linked + + + +

✅ Account Linked!

+

Logged in as {auth_result.external_name} on {svc_obj.display_name}.

+

You can close this page and return to Discord.

+ +""") + + +# --------------------------------------------------------------------------- +# GET /auth/status — check which services are linked +# --------------------------------------------------------------------------- + +@router.get("/status") +async def auth_status(discord_id: int): + """Return which services this Discord user has linked.""" + services: dict[str, bool] = {} + for svc_name in list_auth_services(): + services[svc_name] = auth_store.is_authenticated(discord_id, svc_name) + return {"discord_id": discord_id, "services": services} + + +# --------------------------------------------------------------------------- +# POST /auth/reset — wipe auth store (DEV ONLY) +# --------------------------------------------------------------------------- + +from core.config import get_config # noqa: E402 + +@router.post("/reset") +async def reset_auth(): + """ + Reset the entire auth store — clears all link tokens and user auth records. + + Only enabled when ALLOW_AUTH_RESET=true in the environment. + Returns 403 in production. + """ + if get_config("ALLOW_AUTH_RESET", "false").lower() != "true": + raise HTTPException( + status_code=403, + detail="Auth reset is disabled. Set ALLOW_AUTH_RESET=true to enable (dev only).", + ) + + auth_store.reset_all() + logger.warning("Auth store reset via API endpoint.") + return {"status": "ok", "message": "Auth store cleared — all tokens and auth records removed."} diff --git a/auth/__init__.py b/auth/__init__.py new file mode 100644 index 0000000..2344e47 --- /dev/null +++ b/auth/__init__.py @@ -0,0 +1,93 @@ +""" +Auth Service registry — generic, pluggable authentication for any service. + +Add a new service (Plex, Seerr, etc.) by: +1. Subclassing AuthService +2. Dropping the module in this package +3. Calling register_auth_service() at import time +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Optional + + +# --------------------------------------------------------------------------- +# AuthResult — returned by AuthService.authenticate() +# --------------------------------------------------------------------------- + +@dataclass +class AuthResult: + """Outcome of a credential validation attempt.""" + success: bool + external_user_id: Optional[str] = None + external_name: Optional[str] = None + credentials: Optional[dict] = None + error_message: Optional[str] = None + + +# --------------------------------------------------------------------------- +# AuthService — abstract base class +# --------------------------------------------------------------------------- + +class AuthService(ABC): + """A service that users can authenticate against (Jellyfin, Seerr, Plex, etc.) + + Subclasses must implement: + - name : unique identifier used in URLs and DB keys + - display_name : human-readable label shown in Discord + - render_login_form(token, discord_id) → HTML string + - authenticate(form_data) → AuthResult + """ + + @property + @abstractmethod + def name(self) -> str: + """Unique service name: "jellyfin", "seerr", etc.""" + ... + + @property + @abstractmethod + def display_name(self) -> str: + """Human-readable: "Jellyfin", "Seerr", "Plex" """ + ... + + @abstractmethod + def render_login_form(self, token: str, discord_id: int) -> str: + """Return HTML string with a login form for this service. + + The form MUST include these hidden fields: + + + + """ + ... + + @abstractmethod + async def authenticate(self, form_data: dict) -> AuthResult: + """Validate credentials against the service. Return AuthResult.""" + ... + + +# --------------------------------------------------------------------------- +# Global registry +# --------------------------------------------------------------------------- + +_registry: dict[str, AuthService] = {} + + +def register_auth_service(svc: AuthService) -> None: + """Register an AuthService so it can be looked up by name.""" + _registry[svc.name] = svc + + +def get_auth_service(name: str) -> AuthService | None: + """Look up a registered AuthService by name.""" + return _registry.get(name) + + +def list_auth_services() -> list[str]: + """Return names of all registered auth services.""" + return list(_registry.keys()) diff --git a/auth/jellyfin.py b/auth/jellyfin.py new file mode 100644 index 0000000..b2744df --- /dev/null +++ b/auth/jellyfin.py @@ -0,0 +1,401 @@ +""" +Jellyfin AuthService — validates Jellyfin credentials and stores the session token. + +Two authentication flows: + 1. Quick Connect (primary): user enters a short code on their Jellyfin app. + - initiate_quick_connect() → {code, secret} + - poll_quick_connect(secret) → "Active" | "Authorized" | "Expired" + - authenticate_quick_connect(secret) → AuthResult with token + + 2. Username/password (legacy): renders an HTML form, called via the REST API. + - render_login_form(token, discord_id) → HTML string + - authenticate(form_data) → AuthResult +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Optional + +import httpx + +from auth import AuthService, AuthResult, register_auth_service +from core.config import get_config + +logger = logging.getLogger("auth.jellyfin") + +# Emby-style authorization header required by Jellyfin's AuthenticateByName +_EMBY_HEADER = ( + 'MediaBrowser Client="AgentBot",' + 'Device="DiscordBot",' + 'DeviceId="agent-bot",' + 'Version="1.0"' +) + + +@dataclass +class QuickConnectResult: + """Result of a Quick Connect initiation.""" + secret: str + code: str + device_id: str + device_name: str + + +class JellyfinAuth(AuthService): + name = "jellyfin" + display_name = "Jellyfin" + + # ------------------------------------------------------------------ + # Quick Connect helpers + # ------------------------------------------------------------------ + + def _qc_headers(self) -> dict[str, str]: + """Return headers used by all Quick Connect API calls.""" + return { + "X-Emby-Authorization": ( + 'MediaBrowser Client="AgentBot",' + 'Device="DiscordBot",' + 'DeviceId="agent-bot-qc",' + 'Version="1.0"' + ) + } + + async def _resolve_url(self) -> str | None: + """ + Resolve the Jellyfin server URL. + 1. Check JELLYFIN_URL env var (used in deployment). + 2. Check if user already has a stored auth with a URL (from legacy login). + Returns None if no URL is configured. + """ + # First: explicit env var + env_url = get_config("JELLYFIN_URL") + if env_url: + return env_url.strip().rstrip("/") + return None + + # ------------------------------------------------------------------ + # Phase 1a: initiate Quick Connect + # ------------------------------------------------------------------ + + async def initiate_quick_connect(self, url: str | None = None) -> QuickConnectResult | None: + """ + Call Jellyfin's POST /QuickConnect/Initiate. + Returns a QuickConnectResult with {secret, code} or None on failure. + + The *code* is what the user enters on their Jellyfin page. + The *secret* is used internally to poll/authenticate. + """ + base_url = url or await self._resolve_url() + if not base_url: + logger.error("QuickConnect failed — no JELLYFIN_URL configured.") + return None + + logger.info("Initiating Quick Connect on %s", base_url) + + async with httpx.AsyncClient(timeout=10) as client: + try: + resp = await client.post( + f"{base_url}/QuickConnect/Initiate", + headers=self._qc_headers(), + json={}, + ) + if resp.status_code != 200: + logger.warning( + "QuickConnect init failed: HTTP %s — %s", + resp.status_code, resp.text[:200], + ) + return None + + data = resp.json() + secret = data.get("Secret", "") + code = data.get("Code", "") + device_id = data.get("DeviceId", "") + device_name = data.get("DeviceName", "") + + if not secret or not code: + logger.warning("QuickConnect init returned unexpected payload: %s", data) + return None + + logger.info( + "Quick Connect initiated: code=%s device=%s", + code, device_name, + ) + return QuickConnectResult( + secret=secret, + code=code, + device_id=device_id, + device_name=device_name, + ) + + except httpx.TimeoutException: + logger.error("QuickConnect init timed out reaching %s", base_url) + return None + except httpx.ConnectError: + logger.error("QuickConnect init — cannot connect to %s", base_url) + return None + except Exception: + logger.exception("Unexpected error during QuickConnect init") + return None + + # ------------------------------------------------------------------ + # Phase 1b: poll Quick Connect status + # ------------------------------------------------------------------ + + async def poll_quick_connect(self, secret: str, url: str | None = None) -> str: + """ + Call Jellyfin's GET /QuickConnect/Connect?secret=. + Returns one of: + - "Active" → user hasn't entered the code yet + - "Authorized" → user entered code AND approved + - "Expired" → code expired / unknown secret + - "Error" → network or unexpected failure + """ + base_url = url or await self._resolve_url() + if not base_url: + logger.error("QuickConnect poll failed — no JELLYFIN_URL configured.") + return "Error" + + async with httpx.AsyncClient(timeout=10) as client: + try: + resp = await client.get( + f"{base_url}/QuickConnect/Connect", + params={"secret": secret}, + headers=self._qc_headers(), + ) + if resp.status_code == 404: + return "Expired" + + if resp.status_code == 200: + data = resp.json() + # Jellyfin returns "Authenticated" (not "Authorized") + if data.get("Authenticated") is True: + return "Authorized" + # "Authenticated" is false, missing, or null → still active + return "Active" + + logger.warning( + "QuickConnect poll unexpected: HTTP %s — %s", + resp.status_code, resp.text[:200], + ) + return "Error" + + except (httpx.TimeoutException, httpx.ConnectError): + logger.warning("QuickConnect poll network error") + return "Error" + except Exception: + logger.exception("Unexpected error during QuickConnect poll") + return "Error" + + # ------------------------------------------------------------------ + # Phase 1c: exchange secret for token + # ------------------------------------------------------------------ + + async def authenticate_quick_connect( + self, secret: str, url: str | None = None + ) -> AuthResult: + """ + After poll_quick_connect returns "Authorized", call + POST /Users/AuthenticateWithQuickConnect to exchange the secret + for a real access token. + + Returns AuthResult with token, user_id, username on success. + """ + base_url = url or await self._resolve_url() + if not base_url: + return AuthResult( + success=False, + error_message="No Jellyfin server URL configured.", + ) + + logger.info("Exchanging QuickConnect secret for token on %s", base_url) + + async with httpx.AsyncClient(timeout=10) as client: + try: + resp = await client.post( + f"{base_url}/Users/AuthenticateWithQuickConnect", + json={"Secret": secret}, + headers=self._qc_headers(), + ) + if resp.status_code != 200: + logger.warning( + "QuickConnect auth exchange failed: HTTP %s", + resp.status_code, + ) + return AuthResult( + success=False, + error_message="Quick Connect authentication failed. The code may have expired.", + ) + + data = resp.json() + user = data.get("User", {}) + token = data.get("AccessToken", "") + + if not token: + return AuthResult( + success=False, + error_message="Jellyfin returned an unexpected response.", + ) + + logger.info( + "QuickConnect linked: user=%s (%s)", + user.get("Name", "?"), + user.get("Id", "?"), + ) + + return AuthResult( + success=True, + external_user_id=user.get("Id", ""), + external_name=user.get("Name", "?"), + credentials={ + "token": token, + "url": base_url, + "user_id": user.get("Id", ""), + }, + ) + + except httpx.TimeoutException: + return AuthResult( + success=False, + error_message=f"Could not reach {base_url} — connection timed out.", + ) + except httpx.ConnectError: + return AuthResult( + success=False, + error_message=f"Could not connect to {base_url}. Is the server running?", + ) + except Exception: + logger.exception("Unexpected error during QuickConnect auth exchange") + return AuthResult( + success=False, + error_message="An unexpected error occurred during authentication.", + ) + + # ------------------------------------------------------------------ + # Login form (legacy — used by the REST API) + # ------------------------------------------------------------------ + + def render_login_form(self, token: str, discord_id: int) -> str: + return f""" + + + + +Link Jellyfin + + + +

🔗 Link Jellyfin to Discord

+

Enter your Jellyfin server URL and credentials to link your account.

+ +
+ + + + + + + + + + + + + + +
+ +""" + + # ------------------------------------------------------------------ + # Authentication + # ------------------------------------------------------------------ + + async def authenticate(self, form_data: dict) -> AuthResult: + url = form_data.get("jellyfin_url", "").strip().rstrip("/") + username = form_data.get("username", "").strip() + password = form_data.get("password", "").strip() + + if not url or not username or not password: + return AuthResult( + success=False, + error_message="All fields are required (URL, username, password).", + ) + + logger.info("Attempting Jellyfin login for '%s' on %s", username, url) + + async with httpx.AsyncClient(timeout=10) as client: + try: + resp = await client.post( + f"{url}/Users/AuthenticateByName", + json={"Username": username, "Pw": password}, + headers={"X-Emby-Authorization": _EMBY_HEADER}, + ) + if resp.status_code != 200: + logger.warning( + "Jellyfin login failed for '%s': HTTP %s", username, resp.status_code + ) + return AuthResult( + success=False, + error_message=f"Login failed — check your server URL and credentials.", + ) + + data = resp.json() + user = data.get("User", {}) + token = data.get("AccessToken", "") + + if not token: + return AuthResult( + success=False, + error_message="Jellyfin returned an unexpected response.", + ) + + logger.info( + "Jellyfin login OK: user=%s (%s)", + user.get("Name", "?"), + user.get("Id", "?"), + ) + + return AuthResult( + success=True, + external_user_id=user.get("Id", ""), + external_name=user.get("Name", username), + credentials={ + "token": token, + "url": url, + "user_id": user.get("Id", ""), + }, + ) + + except httpx.TimeoutException: + return AuthResult( + success=False, + error_message=f"Could not reach {url} — connection timed out. Check the URL.", + ) + except httpx.ConnectError: + return AuthResult( + success=False, + error_message=f"Could not connect to {url}. Is the server running?", + ) + except Exception as exc: + logger.exception("Unexpected error during Jellyfin login") + return AuthResult( + success=False, + error_message=f"An unexpected error occurred. Please try again.", + ) + + +# Self-register at import time +register_auth_service(JellyfinAuth()) diff --git a/bot/discord_bot.py b/bot/discord_bot.py index 57b1d4b..3681d2a 100644 --- a/bot/discord_bot.py +++ b/bot/discord_bot.py @@ -27,6 +27,8 @@ from bot.conversation import ConversationStore from core.config import DEEPSEEK_API_KEY, get_config from core.graph import create_agent_graph from core.llm import create_client +from core import auth_store +from auth import list_auth_services, get_auth_service logger = logging.getLogger("bot.discord") @@ -36,6 +38,7 @@ logger = logging.getLogger("bot.discord") DISCORD_BOT_TOKEN = get_config("DISCORD_BOT_TOKEN") or "" DISCORD_MAX_HISTORY = int(get_config("DISCORD_MAX_HISTORY", "7")) DISCORD_DEFAULT_AGENT = get_config("DISCORD_DEFAULT_AGENT", "media-agent") +BASE_URL = get_config("BASE_URL", "http://localhost:8000").rstrip("/") # --------------------------------------------------------------------------- # LLM client shared by all agents (same as the REST API uses) @@ -138,6 +141,12 @@ class AgentBot(discord.Client): # |--------------------------------------------------------------| user_id = message.author.id + content = message.content.strip() + + # |-- Bot commands — handled directly, never sent to the LLM --| + if await self._handle_command(message, user_id, content): + return + # |--------------------------------------------------------------| # Show typing indicator while the graph runs async with message.channel.typing(): @@ -154,6 +163,140 @@ class AgentBot(discord.Client): "Please try again in a moment." ) + # ------------------------------------------------------------------ + # Bot commands + # ------------------------------------------------------------------ + + async def _handle_command( + self, message: discord.Message, user_id: int, content: str + ) -> bool: + """Handle bot commands (/login, /logout). Returns True if handled.""" + lower = content.lower() + + # --- /login [service] --- + if lower.startswith("/login"): + parts = content.split() + service = parts[1].lower() if len(parts) > 1 else None + + available = list_auth_services() + if not available: + await message.channel.send("No auth services are configured yet.") + return True + + if service is None: + svc_list = ", ".join(available) + await message.channel.send( + f"Available services to link: **{svc_list}**\n" + f"Use `/login ` — e.g. `/login jellyfin`" + ) + return True + + if service not in available: + await message.channel.send( + f"Unknown service '{service}'. Available: {', '.join(available)}" + ) + return True + + if auth_store.is_authenticated(user_id, service): + svc_display = (get_auth_service(service) and get_auth_service(service).display_name) or service + await message.channel.send( + f"You're already linked to **{svc_display}**! " + f"Use `/logout {service}` to unlink." + ) + return True + + # --- Quick Connect flow --- + svc = get_auth_service(service) + if svc is None: + await message.channel.send(f"Unknown service: {service}") + return True + + await message.channel.send(f"🔑 Starting **{svc.display_name}** Quick Connect…") + + qc_result = await svc.initiate_quick_connect() + if qc_result is None: + await message.channel.send( + f"❌ Could not start Quick Connect for **{svc.display_name}**.\n" + "Check that `JELLYFIN_URL` is configured and the server is reachable." + ) + return True + + await message.channel.send( + f"Open **{svc.display_name}** → **Quick Connect** and enter this code:\n\n" + f"**`{qc_result.code}`**\n\n" + f"⏳ Waiting for you to approve…" + ) + + # Poll for authorization + async with message.channel.typing(): + for attempt in range(24): # 24 × 5s = 2 minutes + await asyncio.sleep(5) + status = await svc.poll_quick_connect(qc_result.secret) + + if status == "Authorized": + auth_result = await svc.authenticate_quick_connect(qc_result.secret) + if auth_result.success: + auth_store.store_auth( + discord_user_id=user_id, + service=service, + external_user_id=auth_result.external_user_id or "", + external_name=auth_result.external_name or "", + credentials=auth_result.credentials, + ) + await message.channel.send( + f"✅ Linked to **{svc.display_name}** as " + f"**{auth_result.external_name}**!" + ) + else: + await message.channel.send( + f"❌ Authentication failed: " + f"{auth_result.error_message or 'Unknown error'}" + ) + return True + + elif status == "Expired": + await message.channel.send( + "⌛ The Quick Connect code expired. " + f"Use `/login {service}` to try again." + ) + return True + + # else: still "Active" — keep polling + + await message.channel.send( + "⌛ Timed out waiting for Quick Connect approval. " + f"Use `/login {service}` to try again." + ) + return True + + # --- /logout [service] --- + if lower.startswith("/logout"): + parts = content.split() + service = parts[1].lower() if len(parts) > 1 else None + + if service is None: + linked = auth_store.list_services(user_id) + if not linked: + await message.channel.send("You don't have any linked services.") + else: + svc_list = ", ".join(linked) + await message.channel.send( + f"Linked services: **{svc_list}**\n" + f"Use `/logout ` to unlink." + ) + return True + + if not auth_store.is_authenticated(user_id, service): + await message.channel.send(f"You're not linked to **{service}**.") + return True + + auth_store.revoke(user_id, service) + svc_display = (get_auth_service(service) and get_auth_service(service).display_name) or service + await message.channel.send(f"Unlinked from **{svc_display}**. Use `/login {service}` to re-link.") + return True + + return False + # ------------------------------------------------------------------ # Agent invocation # ------------------------------------------------------------------ @@ -163,7 +306,6 @@ class AgentBot(discord.Client): reply, and return the assistant's final text.""" # 1. Pick agent — defaults to DISCORD_DEFAULT_AGENT env var. - # Change DISCORD_DEFAULT_AGENT in .env to switch agents. agent_id = DISCORD_DEFAULT_AGENT # 2. Build message list from stored history + new user message @@ -172,7 +314,7 @@ class AgentBot(discord.Client): # 3. Run the LangGraph (tools execute inline if needed) graph = _get_graph(agent_id) - state = {"messages": messages} + state = {"messages": messages, "discord_user_id": user_id} result = await graph.ainvoke(state) last_msg = result["messages"][-1] diff --git a/core/auth_store.py b/core/auth_store.py new file mode 100644 index 0000000..cb57cb4 --- /dev/null +++ b/core/auth_store.py @@ -0,0 +1,316 @@ +""" +Auth Store — SQLite-backed persistence for Discord-to-service authentication. + +Two tables: + - link_tokens : one-time tokens sent via Discord DM to initiate login + - user_auth : per-user, per-service credentials (Jellyfin token, etc.) + +Thread-safe via WAL mode and a shared lock. No passwords are ever stored +— only opaque service tokens (e.g. Jellyfin AccessToken). +""" + +from __future__ import annotations + +import logging +import secrets +import sqlite3 +import threading +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Optional + +from core.config import get_config + +logger = logging.getLogger("auth_store") + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- +AUTH_DB_PATH = get_config("AUTH_DB_PATH", "data/auth.db") +TOKEN_EXPIRY_MINUTES = int(get_config("AUTH_TOKEN_EXPIRY", "10")) + +# --------------------------------------------------------------------------- +# Singleton handle +# --------------------------------------------------------------------------- + +_db_path: Path | None = None +_db_lock = threading.Lock() + + +def _resolve_path() -> Path: + """Turn AUTH_DB_PATH into an absolute path, creating parent dirs.""" + global _db_path + if _db_path is not None: + return _db_path + p = Path(AUTH_DB_PATH) + if not p.is_absolute(): + # Relative to the project root (two levels above this file) + project_root = Path(__file__).resolve().parent.parent + p = project_root / p + p.parent.mkdir(parents=True, exist_ok=True) + _db_path = p + return p + + +def _get_conn() -> sqlite3.Connection: + """Return a thread-local connection to the auth database.""" + import sqlite3 + + conn = sqlite3.connect(str(_resolve_path()), check_same_thread=False) + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA foreign_keys=ON") + conn.row_factory = sqlite3.Row + return conn + + +# --------------------------------------------------------------------------- +# Schema +# --------------------------------------------------------------------------- + +_SCHEMA = """ +CREATE TABLE IF NOT EXISTS link_tokens ( + token TEXT PRIMARY KEY, + discord_user_id INTEGER NOT NULL, + service TEXT NOT NULL, + expires_at TEXT NOT NULL, + used INTEGER DEFAULT 0, + created_at TEXT DEFAULT (datetime('now')) +); + +CREATE TABLE IF NOT EXISTS user_auth ( + discord_user_id INTEGER NOT NULL, + service TEXT NOT NULL, + external_user_id TEXT, + external_name TEXT, + credentials TEXT, + linked_at TEXT DEFAULT (datetime('now')), + is_active INTEGER DEFAULT 1, + PRIMARY KEY (discord_user_id, service) +); +""" + +_initialized = False + + +def _ensure_schema() -> None: + global _initialized + if _initialized: + return + with _db_lock: + if _initialized: + return + conn = _get_conn() + conn.executescript(_SCHEMA) + conn.commit() + conn.close() + _initialized = True + logger.info("Auth store initialized at %s", _resolve_path()) + + +# --------------------------------------------------------------------------- +# Public API — Link Tokens +# --------------------------------------------------------------------------- + +def create_token(discord_user_id: int, service: str) -> str: + """Generate a one-time link token. Expires after TOKEN_EXPIRY_MINUTES.""" + _ensure_schema() + token = secrets.token_urlsafe(32) + expires = (datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRY_MINUTES)).isoformat() + + with _db_lock: + conn = _get_conn() + conn.execute( + "INSERT INTO link_tokens (token, discord_user_id, service, expires_at) VALUES (?, ?, ?, ?)", + (token, discord_user_id, service, expires), + ) + conn.commit() + conn.close() + + logger.info("Created link token for user %s / service %s", discord_user_id, service) + return token + + +def validate_token(token: str) -> tuple[int, str] | None: + """Read-only validation — does NOT consume the token. + + Returns (discord_user_id, service) if the token exists, is unused, + and has not expired. Returns None otherwise. + """ + _ensure_schema() + + with _db_lock: + conn = _get_conn() + row = conn.execute( + "SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?", + (token,), + ).fetchone() + conn.close() + + if row is None: + return None + if row["used"]: + return None + + expires = datetime.fromisoformat(row["expires_at"]) + if datetime.now(timezone.utc) > expires: + return None + + return (row["discord_user_id"], row["service"]) + + +def consume_token(token: str) -> tuple[int, str] | None: + """Validate and consume a link token. Returns (discord_user_id, service) or None. + + A token is valid if: + - It exists + - It has not been used + - It has not expired + """ + _ensure_schema() + + with _db_lock: + conn = _get_conn() + row = conn.execute( + "SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?", + (token,), + ).fetchone() + + if row is None: + conn.close() + return None + + if row["used"]: + conn.close() + logger.warning("Token already used: %s…", token[:8]) + return None + + expires = datetime.fromisoformat(row["expires_at"]) + if datetime.now(timezone.utc) > expires: + conn.close() + logger.warning("Token expired: %s…", token[:8]) + return None + + conn.execute("UPDATE link_tokens SET used = 1 WHERE token = ?", (token,)) + conn.commit() + result = (row["discord_user_id"], row["service"]) + conn.close() + logger.info("Token consumed: %s… → user=%s service=%s", token[:8], result[0], result[1]) + return result + + +# --------------------------------------------------------------------------- +# Public API — User Auth +# --------------------------------------------------------------------------- + +def store_auth( + discord_user_id: int, + service: str, + *, + external_user_id: str = "", + external_name: str = "", + credentials: dict | None = None, +) -> None: + """Store or update authentication for a user on a service.""" + _ensure_schema() + import json + + creds_json = json.dumps(credentials) if credentials else "{}" + + with _db_lock: + conn = _get_conn() + conn.execute( + """INSERT INTO user_auth (discord_user_id, service, external_user_id, external_name, credentials, linked_at) + VALUES (?, ?, ?, ?, ?, datetime('now')) + ON CONFLICT(discord_user_id, service) DO UPDATE SET + external_user_id = excluded.external_user_id, + external_name = excluded.external_name, + credentials = excluded.credentials, + linked_at = datetime('now'), + is_active = 1""", + (discord_user_id, service, external_user_id, external_name, creds_json), + ) + conn.commit() + conn.close() + + logger.info("Stored auth for user %s on %s as %s", discord_user_id, service, external_name) + + +def get_auth(discord_user_id: int, service: str) -> dict | None: + """Retrieve stored auth for a user on a service. Returns None if not linked.""" + _ensure_schema() + import json + + with _db_lock: + conn = _get_conn() + row = conn.execute( + "SELECT * FROM user_auth WHERE discord_user_id = ? AND service = ? AND is_active = 1", + (discord_user_id, service), + ).fetchone() + conn.close() + + if row is None: + return None + + credentials = json.loads(row["credentials"]) if row["credentials"] else {} + return { + "discord_user_id": row["discord_user_id"], + "service": row["service"], + "external_user_id": row["external_user_id"], + "external_name": row["external_name"], + "credentials": credentials, + "linked_at": row["linked_at"], + } + + +def is_authenticated(discord_user_id: int, service: str) -> bool: + """Quick check: is this user linked to this service?""" + return get_auth(discord_user_id, service) is not None + + +def list_services(discord_user_id: int) -> list[str]: + """Return list of service names this user has linked.""" + _ensure_schema() + + with _db_lock: + conn = _get_conn() + rows = conn.execute( + "SELECT service FROM user_auth WHERE discord_user_id = ? AND is_active = 1", + (discord_user_id,), + ).fetchall() + conn.close() + + return [r["service"] for r in rows] + + +def revoke(discord_user_id: int, service: str) -> None: + """Unlink a user from a service.""" + _ensure_schema() + + with _db_lock: + conn = _get_conn() + conn.execute( + "UPDATE user_auth SET is_active = 0 WHERE discord_user_id = ? AND service = ?", + (discord_user_id, service), + ) + conn.commit() + conn.close() + + logger.info("Revoked auth for user %s on %s", discord_user_id, service) + + +# --------------------------------------------------------------------------- +# Dev / testing — reset the entire store +# --------------------------------------------------------------------------- + +def reset_all() -> None: + """Truncate all auth tables — for development and testing only.""" + _ensure_schema() + + with _db_lock: + conn = _get_conn() + conn.execute("DELETE FROM link_tokens") + conn.execute("DELETE FROM user_auth") + conn.commit() + conn.close() + + logger.warning("Auth store RESET — all tokens and auth records cleared.") diff --git a/core/graph.py b/core/graph.py index 70f755a..4bb2675 100644 --- a/core/graph.py +++ b/core/graph.py @@ -1,12 +1,13 @@ """ LangGraph agent graph factory. -Builds a StateGraph that replaces the manual tool-calling loop in api/v1/chat.py. -The graph has two nodes: +Builds a StateGraph with two nodes: - agent_node : calls the LLM (with system prompt + tool definitions) - tool_node : executes tool calls via the existing skill system A conditional edge routes tool_calls back to the agent, or ends the run. +When a tool fails due to missing authentication, the failure message is +relayed to the LLM, which tells the user to use /login. """ from __future__ import annotations @@ -97,18 +98,14 @@ def _make_agent_node( full: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}] for m in messages: if isinstance(m, dict): - # Already a plain dict — pass through. - # But fix tool_calls if they're in LangChain format. d = dict(m) tc = d.get("tool_calls") if tc and isinstance(tc, list) and tc and isinstance(tc[0], dict) and "function" not in tc[0]: d["tool_calls"] = _langchain_tc_to_openai(tc) full.append(d) else: - # LangChain message object → OpenAI-compatible dict role = _lc_role_to_openai(getattr(m, "type", "user")) d: dict[str, Any] = {"role": role, "content": getattr(m, "content", "")} - # Serialize tool_calls back to OpenAI format (if this is an AI msg) tc = getattr(m, "tool_calls", None) if tc: d["tool_calls"] = _langchain_tc_to_openai(tc) @@ -125,7 +122,6 @@ def _make_agent_node( ) choice = resp.choices[0] - # Convert OpenAI tool_calls to the dict format LangChain expects. raw_tool_calls = list(choice.message.tool_calls) if choice.message.tool_calls else [] tool_calls: list[dict[str, Any]] = [] for tc in raw_tool_calls: @@ -153,9 +149,9 @@ def _make_tool_node(skill_names: list[str]): """ Return a callable that executes tool_calls from the last AI message. - This replaces LangGraph's built-in ToolNode — we call our own - `execute_tool()` pipeline so that skill-level auth, httpx sessions, - and ToolResult handling are fully preserved. + If a tool fails because the user isn't authenticated, the failure + message (which tells the user to /login) is returned to the LLM. + The LLM naturally relays the instructions to the user. """ async def tool_node(state: AgentState) -> dict[str, list]: @@ -164,18 +160,16 @@ def _make_tool_node(skill_names: list[str]): if not tool_calls: return {"messages": []} + discord_user_id = state.get("discord_user_id") + results: list[ToolMessage] = [] for tc in tool_calls: - # Handle both LangChain format (top-level name/args) and - # OpenAI format (nested "function" key). if isinstance(tc, dict): if "function" in tc: - # OpenAI format: {"id":..., "function": {"name":..., "arguments":"..."}} fn = tc["function"] fn_name = fn.get("name", "") fn_args_raw = fn.get("arguments", "{}") else: - # LangChain format: {"name":..., "args":{...}, "id":...} fn_name = tc.get("name", "") fn_args_raw = tc.get("args", {}) tc_id = tc.get("id", "") @@ -184,13 +178,15 @@ def _make_tool_node(skill_names: list[str]): fn_args_raw = getattr(tc, "args", {}) tc_id = getattr(tc, "id", "") - # Parse args if they arrive as a JSON string if isinstance(fn_args_raw, str): fn_args = json.loads(fn_args_raw) else: fn_args = fn_args_raw - tr = await execute_tool(skill_names, fn_name, fn_args) + tr = await execute_tool( + skill_names, fn_name, fn_args, + discord_user_id=discord_user_id, + ) content = tr.content if tr else f"Tool '{fn_name}' is not available." results.append(ToolMessage(content=content, tool_call_id=tc_id)) @@ -224,27 +220,16 @@ def create_agent_graph( ) -> StateGraph: """ Build and compile a LangGraph StateGraph for a single agent. - - Parameters - ---------- - client : The OpenAI-compatible client (already authenticated). - agent_skills : Skill names assigned to the agent (e.g. ["seerr", "triage"]). - system_prompt : The fully-built system prompt (base + skill fragments). - model_name : Model identifier sent to the LLM provider. - - Returns - ------- - A compiled LangGraph graph ready for `.ainvoke()` or `.astream()`. """ tool_defs = get_all_tools(agent_skills) graph = StateGraph(AgentState) - # Nodes graph.add_node( "agent_node", _make_agent_node(client, system_prompt, tool_defs, model_name), ) + if tool_defs: graph.add_node("tool_node", _make_tool_node(agent_skills)) graph.add_conditional_edges("agent_node", _should_continue, { @@ -253,7 +238,6 @@ def create_agent_graph( }) graph.add_edge("tool_node", "agent_node") else: - # No tools — agent responds once and finishes graph.add_edge("agent_node", END) graph.set_entry_point("agent_node") diff --git a/core/state.py b/core/state.py index 3e434b2..28ec77e 100644 --- a/core/state.py +++ b/core/state.py @@ -18,3 +18,4 @@ class AgentState(TypedDict): """ messages: Annotated[list, add_messages] + discord_user_id: int | None # set by the Discord bot, None for REST API calls diff --git a/main.py b/main.py index b03c488..2800bd6 100644 --- a/main.py +++ b/main.py @@ -4,8 +4,9 @@ from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from api.v1.auth import router as auth_router from api.v1.chat import router as v1_router -from core.config import DEEPSEEK_API_KEY +from core.config import DEEPSEEK_API_KEY, get_config from core.llm import create_client # --------------------------------------------------------------------------- @@ -18,12 +19,14 @@ logging.basicConfig( ) # --------------------------------------------------------------------------- -# Load all agents & skills so they self-register at startup +# Load all agents, skills, AND auth services so they self-register at startup # --------------------------------------------------------------------------- from agents import load_all_agents # noqa: E402 load_all_agents() +import auth.jellyfin # noqa: E402 — self-registers JellyfinAuth + # --------------------------------------------------------------------------- # Lifespan # --------------------------------------------------------------------------- @@ -60,4 +63,5 @@ app.state.agent_graphs: dict = {} # --------------------------------------------------------------------------- # Routers # --------------------------------------------------------------------------- -app.include_router(v1_router, prefix="/v1") \ No newline at end of file +app.include_router(v1_router, prefix="/v1") +app.include_router(auth_router) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index d1b0c2f..b62cbf6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ python-dotenv httpx langgraph langgraph-checkpoint -discord.py \ No newline at end of file +discord.py +python-multipart \ No newline at end of file diff --git a/skills/__init__.py b/skills/__init__.py index 1a3b32f..97ded41 100644 --- a/skills/__init__.py +++ b/skills/__init__.py @@ -47,6 +47,7 @@ class Skill: prompt_fragment: str = "" tools: List[Dict[str, Any]] = field(default_factory=list) execute: Optional[ToolExecutor] = None + requires_auth: List[str] = field(default_factory=list) # --------------------------------------------------------------------------- @@ -96,9 +97,15 @@ def get_all_tools(skill_names: list[str]) -> List[Dict[str, Any]]: async def execute_tool( - skill_names: list[str], tool_name: str, args: dict + skill_names: list[str], tool_name: str, args: dict, + discord_user_id: int | None = None, ) -> ToolResult | None: """Find the skill that owns *tool_name* and run its executor. + + If *discord_user_id* is provided, also checks whether the owning skill + requires authentication for any services. If auth is missing, returns + a friendly ToolResult.fail(...) telling the user how to log in. + Only logs failures to the console — successful calls are silent. """ import logging @@ -109,6 +116,24 @@ async def execute_tool( if s and s.execute: for t in s.tools: if t.get("function", {}).get("name") == tool_name: + # --- Auth gate --- + if s.requires_auth and discord_user_id is not None: + from core import auth_store + from auth import get_auth_service + missing: list[str] = [] + for svc in s.requires_auth: + if not auth_store.is_authenticated(discord_user_id, svc): + missing.append(svc) + if missing: + svc_displays = ", ".join( + (get_auth_service(m) and get_auth_service(m).display_name) or m + for m in missing + ) + return ToolResult.fail( + f"You need to log in to {svc_displays} first. " + + " ".join(f"Send `/login {m}` in a DM to get started." for m in missing) + ) + # --- End auth gate --- try: result = await s.execute(tool_name, args) if not result.success: diff --git a/skills/media_info.py b/skills/media_info.py index 8578cdf..3365900 100644 --- a/skills/media_info.py +++ b/skills/media_info.py @@ -23,6 +23,16 @@ When responding: suggest submitting a ticket if there's a problem. - Always confirm successful actions and warn about failures. +## Jellyfin & Authentication + +You are connected to the user's Jellyfin server. If a user asks you to +"connect to Jellyfin", "link my Jellyfin", or asks about their watch history, +simply call the `watch_history` tool. The system will automatically handle +authentication — if the user isn't linked yet, they'll be guided through +Quick Connect seamlessly. NEVER tell a user you "don't have access to +Jellyfin" or "can't connect" — always try the tool first and let the system +sort it out. + This is the base media assistant persona. Real API capabilities come from the attached skills (seerr, triage, etc.).""", ) diff --git a/skills/watch_history.py b/skills/watch_history.py new file mode 100644 index 0000000..29bc0b2 --- /dev/null +++ b/skills/watch_history.py @@ -0,0 +1,80 @@ +""" +Watch History skill — fetch the user's Jellyfin watch history. + +Currently a placeholder — returns a "coming soon" message. +The auth gate (`requires_auth=["jellyfin"]`) is already active: +users who haven't linked Jellyfin will be prompted to /login first. +""" + +from __future__ import annotations + +from skills import Skill, register, ToolResult + +# --------------------------------------------------------------------------- +# Tool definitions +# --------------------------------------------------------------------------- + +TOOLS = [ + { + "type": "function", + "function": { + "name": "watch_history", + "description": ( + "Get the user's recent Jellyfin watch history — movies and TV " + "episodes they have watched, sorted by most recent. " + "Call this when a user asks about their watching activity." + ), + "parameters": { + "type": "object", + "properties": { + "limit": { + "type": "integer", + "description": "How many items to return (default 10, max 20)", + } + }, + }, + }, + } +] + + +# --------------------------------------------------------------------------- +# Executor (placeholder) +# --------------------------------------------------------------------------- + +async def _execute(tool_name: str, args: dict) -> ToolResult: + if tool_name == "watch_history": + return ToolResult.ok( + "👷 **Watch History — Coming Soon!**\n\n" + "This feature is currently being built. Soon you'll be able to " + "see your recently watched movies and TV episodes right here.\n\n" + "In the meantime, you can check your watch history directly in Jellyfin." + ) + return ToolResult.fail(f"Unknown tool: {tool_name}") + + +# --------------------------------------------------------------------------- +# Skill registration +# --------------------------------------------------------------------------- + +watch_history_skill = Skill( + name="watch_history", + description="User's Jellyfin watch history (coming soon)", + requires_auth=["jellyfin"], + prompt_fragment="""## Watch History + +You can fetch the user's Jellyfin watch history with the `watch_history` tool. +Call it when users ask things like: +- "what have I watched?" +- "show my watch history" +- "what did I watch recently?" +- "what was the last movie I saw?" +- "what TV shows have I been watching?" + +The tool is currently a **placeholder** — it returns a "coming soon" message. +Tell the user this feature is being worked on and will be available soon.""", + tools=TOOLS, + execute=_execute, +) + +register(watch_history_skill) -- 2.52.0 From 51e099acdd98267a77c9a5c7c88df384d9aa8a16 Mon Sep 17 00:00:00 2001 From: TimHoogervorst Date: Mon, 25 May 2026 11:12:49 +0200 Subject: [PATCH 2/5] enhance auth status endpoint to return detailed linked services for Discord users --- api/v1/auth.py | 45 ++++++++++++++++++++++++++++++++++++++------- core/auth_store.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 7 deletions(-) diff --git a/api/v1/auth.py b/api/v1/auth.py index 7c2d1a4..094d8b8 100644 --- a/api/v1/auth.py +++ b/api/v1/auth.py @@ -152,16 +152,47 @@ async def login_submit(request: Request): # --------------------------------------------------------------------------- -# GET /auth/status — check which services are linked +# GET /auth/status — get all linked services for a Discord user # --------------------------------------------------------------------------- -@router.get("/status") +@router.get("/Discord/status") async def auth_status(discord_id: int): - """Return which services this Discord user has linked.""" - services: dict[str, bool] = {} - for svc_name in list_auth_services(): - services[svc_name] = auth_store.is_authenticated(discord_id, svc_name) - return {"discord_id": discord_id, "services": services} + """ + Return all services linked to this Discord user with full details. + + Response: + { + "discord_id": 123456789, + "linked_services": { + "jellyfin": { + "external_user_id": "abc123", + "external_name": "Tim", + "linked_at": "2026-05-25T10:00:00", + "credentials": { + "token": "jwt...", + "url": "http://jellyfin:8096", + "user_id": "abc123" + } + } + } + } + """ + auths = auth_store.get_all_auths(discord_id) + + linked_services: dict[str, dict] = {} + for auth in auths: + svc_name = auth["service"] + linked_services[svc_name] = { + "external_user_id": auth["external_user_id"], + "external_name": auth["external_name"], + "linked_at": auth["linked_at"], + "credentials": auth["credentials"], + } + + return { + "discord_id": discord_id, + "linked_services": linked_services, + } # --------------------------------------------------------------------------- diff --git a/core/auth_store.py b/core/auth_store.py index cb57cb4..e6f8d13 100644 --- a/core/auth_store.py +++ b/core/auth_store.py @@ -298,6 +298,48 @@ def revoke(discord_user_id: int, service: str) -> None: logger.info("Revoked auth for user %s on %s", discord_user_id, service) +def get_all_auths(discord_user_id: int) -> list[dict]: + """ + Return all active auth records for a Discord user. + Each record includes service name, external user id, external name, + linked_at timestamp, and the raw credentials (e.g. Jellyfin token + URL). + + Used by the /api/v1/auth/status endpoint so other services can discover + linked accounts for a given Discord ID. + """ + _ensure_schema() + import json + + with _db_lock: + conn = _get_conn() + rows = conn.execute( + """SELECT service, external_user_id, external_name, credentials, linked_at + FROM user_auth + WHERE discord_user_id = ? AND is_active = 1 + ORDER BY linked_at DESC""", + (discord_user_id,), + ).fetchall() + conn.close() + + results: list[dict] = [] + for row in rows: + creds = {} + if row["credentials"]: + try: + creds = json.loads(row["credentials"]) + except (json.JSONDecodeError, TypeError): + creds = {} + results.append({ + "service": row["service"], + "external_user_id": row["external_user_id"] or "", + "external_name": row["external_name"] or "", + "linked_at": row["linked_at"] or "", + "credentials": creds, + }) + + return results + + # --------------------------------------------------------------------------- # Dev / testing — reset the entire store # --------------------------------------------------------------------------- -- 2.52.0 From b0f10b6bb10c8796a136af958a3161d9c6d4c591 Mon Sep 17 00:00:00 2001 From: TimHoogervorst Date: Mon, 25 May 2026 12:16:24 +0200 Subject: [PATCH 3/5] small refactor of the structure --- agents/__init__.py | 12 ++++++------ {skills => agents/skills}/__init__.py | 2 +- {skills => agents/skills}/easter_eggs.py | 2 +- {skills => agents/skills}/media_info.py | 2 +- {skills => agents/skills}/seerr.py | 2 +- {skills => agents/skills}/triage.py | 2 +- {skills => agents/skills}/watch_history.py | 2 +- {api => gateway}/__init__.py | 0 {api => gateway}/api.md | 0 {auth => gateway/auth}/__init__.py | 0 {auth => gateway/auth}/jellyfin.py | 4 ++-- {api => gateway}/dependencies.py | 2 +- {bot => gateway/discord}/__init__.py | 0 bot/discord_bot.py => gateway/discord/bot.py | 12 ++++++------ {bot => gateway/discord}/conversation.py | 0 {api => gateway}/v1/__init__.py | 0 {api => gateway}/v1/auth.py | 8 ++++---- {api => gateway}/v1/chat.py | 4 ++-- main.py | 12 ++++++------ {core => src}/__init__.py | 0 {core => src}/auth_store.py | 2 +- {core => src}/config.py | 0 {core => src}/graph.py | 4 ++-- {core => src}/llm.py | 0 {core => src}/state.py | 0 {core => src}/tools_adapter.py | 2 +- 26 files changed, 37 insertions(+), 37 deletions(-) rename {skills => agents/skills}/__init__.py (98%) rename {skills => agents/skills}/easter_eggs.py (99%) rename {skills => agents/skills}/media_info.py (97%) rename {skills => agents/skills}/seerr.py (99%) rename {skills => agents/skills}/triage.py (97%) rename {skills => agents/skills}/watch_history.py (98%) rename {api => gateway}/__init__.py (100%) rename {api => gateway}/api.md (100%) rename {auth => gateway/auth}/__init__.py (100%) rename {auth => gateway/auth}/jellyfin.py (99%) rename {api => gateway}/dependencies.py (96%) rename {bot => gateway/discord}/__init__.py (100%) rename bot/discord_bot.py => gateway/discord/bot.py (98%) rename {bot => gateway/discord}/conversation.py (100%) rename {api => gateway}/v1/__init__.py (100%) rename {api => gateway}/v1/auth.py (97%) rename {api => gateway}/v1/chat.py (98%) rename {core => src}/__init__.py (100%) rename {core => src}/auth_store.py (99%) rename {core => src}/config.py (100%) rename {core => src}/graph.py (98%) rename {core => src}/llm.py (100%) rename {core => src}/state.py (100%) rename {core => src}/tools_adapter.py (97%) diff --git a/agents/__init__.py b/agents/__init__.py index 76df02e..3947b03 100644 --- a/agents/__init__.py +++ b/agents/__init__.py @@ -12,7 +12,7 @@ An Agent is a lightweight wrapper: from dataclasses import dataclass, field from typing import Dict, List -from skills import Skill, get_combined_prompt, list_all as list_all_skills +from agents.skills import Skill, get_combined_prompt, list_all as list_all_skills @dataclass @@ -61,8 +61,8 @@ def load_all_agents() -> None: import agents.media_agent # noqa: F401 # Also import skill modules so they self-register - import skills.media_info # noqa: F401 - import skills.seerr # noqa: F401 - import skills.triage # noqa: F401 - import skills.easter_eggs # noqa: F401 - import skills.watch_history # noqa: F401 + import agents.skills.media_info # noqa: F401 + import agents.skills.seerr # noqa: F401 + import agents.skills.triage # noqa: F401 + import agents.skills.easter_eggs # noqa: F401 + import agents.skills.watch_history # noqa: F401 diff --git a/skills/__init__.py b/agents/skills/__init__.py similarity index 98% rename from skills/__init__.py rename to agents/skills/__init__.py index 97ded41..1e37b4d 100644 --- a/skills/__init__.py +++ b/agents/skills/__init__.py @@ -12,7 +12,7 @@ A Skill is a lightweight object with: from dataclasses import dataclass, field from typing import Any, Awaitable, Callable, Dict, List, Optional -from core.config import get_config # re-export so every skill can use it +from src.config import get_config # re-export so every skill can use it # --------------------------------------------------------------------------- diff --git a/skills/easter_eggs.py b/agents/skills/easter_eggs.py similarity index 99% rename from skills/easter_eggs.py rename to agents/skills/easter_eggs.py index 3fa3783..5900c18 100644 --- a/skills/easter_eggs.py +++ b/agents/skills/easter_eggs.py @@ -8,7 +8,7 @@ requested actions normally. Functionality is never sacrificed for a reference. Add a new theme by adding one entry to THEMES — no code changes needed. """ -from skills import Skill, register +from agents.skills import Skill, register THEMES = { "naruto": { diff --git a/skills/media_info.py b/agents/skills/media_info.py similarity index 97% rename from skills/media_info.py rename to agents/skills/media_info.py index 3365900..116ecef 100644 --- a/skills/media_info.py +++ b/agents/skills/media_info.py @@ -5,7 +5,7 @@ A lightweight base skill that teaches the agent it is a media assistant. Real API capabilities come from other skills (seerr, triage, etc.). """ -from skills import Skill, register +from agents.skills import Skill, register media_info_skill = Skill( name="media_info", diff --git a/skills/seerr.py b/agents/skills/seerr.py similarity index 99% rename from skills/seerr.py rename to agents/skills/seerr.py index 7cdbc33..46269c2 100644 --- a/skills/seerr.py +++ b/agents/skills/seerr.py @@ -24,7 +24,7 @@ from urllib.parse import quote import httpx -from skills import Skill, register, ToolResult, get_config +from agents.skills import Skill, register, ToolResult, get_config # --------------------------------------------------------------------------- # Config diff --git a/skills/triage.py b/agents/skills/triage.py similarity index 97% rename from skills/triage.py rename to agents/skills/triage.py index 55bd21b..2a92231 100644 --- a/skills/triage.py +++ b/agents/skills/triage.py @@ -10,7 +10,7 @@ cancelling requests, banning users), this skill teaches the LLM to: 3. Use the seerr_submit_issue tool (if available) to create the ticket. """ -from skills import Skill, register +from agents.skills import Skill, register # This skill has no tools of its own — it guides the LLM's behavior. # The actual ticket submission is handled by seerr_submit_issue. diff --git a/skills/watch_history.py b/agents/skills/watch_history.py similarity index 98% rename from skills/watch_history.py rename to agents/skills/watch_history.py index 29bc0b2..ff414fb 100644 --- a/skills/watch_history.py +++ b/agents/skills/watch_history.py @@ -8,7 +8,7 @@ users who haven't linked Jellyfin will be prompted to /login first. from __future__ import annotations -from skills import Skill, register, ToolResult +from agents.skills import Skill, register, ToolResult # --------------------------------------------------------------------------- # Tool definitions diff --git a/api/__init__.py b/gateway/__init__.py similarity index 100% rename from api/__init__.py rename to gateway/__init__.py diff --git a/api/api.md b/gateway/api.md similarity index 100% rename from api/api.md rename to gateway/api.md diff --git a/auth/__init__.py b/gateway/auth/__init__.py similarity index 100% rename from auth/__init__.py rename to gateway/auth/__init__.py diff --git a/auth/jellyfin.py b/gateway/auth/jellyfin.py similarity index 99% rename from auth/jellyfin.py rename to gateway/auth/jellyfin.py index b2744df..6a49efd 100644 --- a/auth/jellyfin.py +++ b/gateway/auth/jellyfin.py @@ -20,8 +20,8 @@ from typing import Optional import httpx -from auth import AuthService, AuthResult, register_auth_service -from core.config import get_config +from gateway.auth import AuthService, AuthResult, register_auth_service +from src.config import get_config logger = logging.getLogger("auth.jellyfin") diff --git a/api/dependencies.py b/gateway/dependencies.py similarity index 96% rename from api/dependencies.py rename to gateway/dependencies.py index d57fee1..6dfaba8 100644 --- a/api/dependencies.py +++ b/gateway/dependencies.py @@ -1,7 +1,7 @@ from fastapi import Request from openai import OpenAI -from core.graph import create_agent_graph +from src.graph import create_agent_graph def get_llm_client(request: Request) -> OpenAI: diff --git a/bot/__init__.py b/gateway/discord/__init__.py similarity index 100% rename from bot/__init__.py rename to gateway/discord/__init__.py diff --git a/bot/discord_bot.py b/gateway/discord/bot.py similarity index 98% rename from bot/discord_bot.py rename to gateway/discord/bot.py index 3681d2a..05da455 100644 --- a/bot/discord_bot.py +++ b/gateway/discord/bot.py @@ -23,12 +23,12 @@ import os import discord from agents import list_all as list_all_agents -from bot.conversation import ConversationStore -from core.config import DEEPSEEK_API_KEY, get_config -from core.graph import create_agent_graph -from core.llm import create_client -from core import auth_store -from auth import list_auth_services, get_auth_service +from gateway.discord.conversation import ConversationStore +from src.config import DEEPSEEK_API_KEY, get_config +from src.graph import create_agent_graph +from src.llm import create_client +from src import auth_store +from gateway.auth import list_auth_services, get_auth_service logger = logging.getLogger("bot.discord") diff --git a/bot/conversation.py b/gateway/discord/conversation.py similarity index 100% rename from bot/conversation.py rename to gateway/discord/conversation.py diff --git a/api/v1/__init__.py b/gateway/v1/__init__.py similarity index 100% rename from api/v1/__init__.py rename to gateway/v1/__init__.py diff --git a/api/v1/auth.py b/gateway/v1/auth.py similarity index 97% rename from api/v1/auth.py rename to gateway/v1/auth.py index 094d8b8..0ca93b2 100644 --- a/api/v1/auth.py +++ b/gateway/v1/auth.py @@ -19,10 +19,10 @@ import logging from fastapi import APIRouter, Form, HTTPException, Request from fastapi.responses import HTMLResponse -from auth import get_auth_service, list_auth_services -from core import auth_store +from gateway.auth import get_auth_service, list_auth_services +from src import auth_store -logger = logging.getLogger("api.auth") +logger = logging.getLogger("gateway.auth") router = APIRouter(prefix="/api/v1/auth", tags=["auth"]) @@ -199,7 +199,7 @@ async def auth_status(discord_id: int): # POST /auth/reset — wipe auth store (DEV ONLY) # --------------------------------------------------------------------------- -from core.config import get_config # noqa: E402 +from src.config import get_config # noqa: E402 @router.post("/reset") async def reset_auth(): diff --git a/api/v1/chat.py b/gateway/v1/chat.py similarity index 98% rename from api/v1/chat.py rename to gateway/v1/chat.py index 5f71fb8..68cfd95 100644 --- a/api/v1/chat.py +++ b/gateway/v1/chat.py @@ -4,9 +4,9 @@ from openai import OpenAI from pydantic import BaseModel import json -from api.dependencies import get_llm_client, get_agent_graph +from gateway.dependencies import get_llm_client, get_agent_graph from agents import get as get_agent, list_all as list_all_agents -from core.state import AgentState +from src.state import AgentState router = APIRouter() diff --git a/main.py b/main.py index 2800bd6..586ca45 100644 --- a/main.py +++ b/main.py @@ -4,10 +4,10 @@ from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from api.v1.auth import router as auth_router -from api.v1.chat import router as v1_router -from core.config import DEEPSEEK_API_KEY, get_config -from core.llm import create_client +from gateway.v1.auth import router as auth_router +from gateway.v1.chat import router as v1_router +from src.config import DEEPSEEK_API_KEY, get_config +from src.llm import create_client # --------------------------------------------------------------------------- # Logging — tool calls will appear in the uvicorn console @@ -25,14 +25,14 @@ from agents import load_all_agents # noqa: E402 load_all_agents() -import auth.jellyfin # noqa: E402 — self-registers JellyfinAuth +import gateway.auth.jellyfin # noqa: E402 — self-registers JellyfinAuth # --------------------------------------------------------------------------- # Lifespan # --------------------------------------------------------------------------- @asynccontextmanager async def lifespan(app: FastAPI): - from bot.discord_bot import start_in_background # noqa: E402 + from gateway.discord.bot import start_in_background # noqa: E402 start_in_background() diff --git a/core/__init__.py b/src/__init__.py similarity index 100% rename from core/__init__.py rename to src/__init__.py diff --git a/core/auth_store.py b/src/auth_store.py similarity index 99% rename from core/auth_store.py rename to src/auth_store.py index e6f8d13..f14864f 100644 --- a/core/auth_store.py +++ b/src/auth_store.py @@ -19,7 +19,7 @@ from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Optional -from core.config import get_config +from src.config import get_config logger = logging.getLogger("auth_store") diff --git a/core/config.py b/src/config.py similarity index 100% rename from core/config.py rename to src/config.py diff --git a/core/graph.py b/src/graph.py similarity index 98% rename from core/graph.py rename to src/graph.py index 4bb2675..bdbcd69 100644 --- a/core/graph.py +++ b/src/graph.py @@ -20,8 +20,8 @@ from langchain_core.messages import AIMessage, ToolMessage from langgraph.graph import END, StateGraph from openai import OpenAI -from core.state import AgentState -from skills import get_all_tools, execute_tool +from src.state import AgentState +from agents.skills import get_all_tools, execute_tool logger = logging.getLogger("graph") diff --git a/core/llm.py b/src/llm.py similarity index 100% rename from core/llm.py rename to src/llm.py diff --git a/core/state.py b/src/state.py similarity index 100% rename from core/state.py rename to src/state.py diff --git a/core/tools_adapter.py b/src/tools_adapter.py similarity index 97% rename from core/tools_adapter.py rename to src/tools_adapter.py index e7fccba..ee393b4 100644 --- a/core/tools_adapter.py +++ b/src/tools_adapter.py @@ -13,7 +13,7 @@ from typing import Any from langchain_core.tools import tool -from skills import get_all_tools, execute_tool +from agents.skills import get_all_tools, execute_tool def build_langgraph_tools(skill_names: list[str]) -> list: -- 2.52.0 From 4b87b817a8b7c0b03fcd5c4820fcdab088b52791 Mon Sep 17 00:00:00 2001 From: TimHoogervorst Date: Mon, 25 May 2026 12:24:53 +0200 Subject: [PATCH 4/5] added some quick .md files --- gateway/api.md | 268 ++++++++----------------------------- gateway/auth/auth.md | 152 +++++++++++++++++++++ gateway/discord/discord.md | 73 ++++++++++ gateway/v1/v1.md | 106 +++++++++++++++ 4 files changed, 385 insertions(+), 214 deletions(-) create mode 100644 gateway/auth/auth.md create mode 100644 gateway/discord/discord.md create mode 100644 gateway/v1/v1.md diff --git a/gateway/api.md b/gateway/api.md index 1153989..5f3479f 100644 --- a/gateway/api.md +++ b/gateway/api.md @@ -1,235 +1,75 @@ -# API Architecture — Agent + Skill + Graph Pipeline +# Gateway Architecture — Agent + Skill + Graph Pipeline -This document explains how the API routes user messages through the -agent / skill / LangGraph pipeline to produce responses. +This is the **interface layer** of the Agents project. Everything that connects +the outside world to the agent system lives here — REST APIs, Discord bot, +and authentication. --- -## Overview +## Directory Map -``` -┌─────────────────────────────────────────────────────────────────┐ -│ OpenWebUI / Client │ -│ POST /v1/chat/completions { model, messages, stream } │ -└──────────────────────────────┬──────────────────────────────────┘ - │ - ▼ -┌──────────────────────────────────────────────────────────────────┐ -│ api/v1/chat.py — chat_completions() │ -│ │ -│ 1. _resolve_agent(req.model) → Agent │ -│ 2. get_agent_graph(agent_id) → compiled StateGraph │ -│ 3. graph.ainvoke(state) or _stream_graph(graph, messages) │ -└──────────────────────────────┬───────────────────────────────────┘ - │ - ▼ -┌──────────────────────────────────────────────────────────────────┐ -│ LangGraph StateGraph (core/graph.py) │ -│ │ -│ ┌──────────────┐ tool_calls? ┌──────────────┐ │ -│ │ agent_node │ ───────────────▶ │ tool_node │ │ -│ │ (LLM call) │ ◀─────────────── │ (skill exec) │ │ -│ └──────┬───────┘ └──────────────┘ │ -│ │ no tool_calls │ -│ ▼ │ -│ [END] │ -└──────────────────────────────────────────────────────────────────┘ +| Path | Description | Docs | +|---|---|---| +| `gateway/v1/` | REST API endpoints — chat, agent listing, OpenAI-compatible completions | [v1.md](v1/v1.md) | +| `gateway/discord/` | Discord bot connector — in-process DM handler with LangGraph integration | [discord.md](discord/discord.md) | +| `gateway/auth/` | Auth service registry + Jellyfin Quick Connect implementation | [auth.md](auth/auth.md) | -## Key Concepts +--- -### 1. Agent +## Supporting Modules -An **Agent** is a persona + skill bundle. Defined in `agents/`. - -```python -# agents/media_agent.py -Agent( - agent_id="media-agent", - description="Media assistant with Seerr integration", - skills=["media_info", "seerr", "triage"], - base_prompt="You are a media assistant...", -) -``` - -- `agent_id` — unique name, exposed as a model in OpenWebUI -- `skills` — list of skill names to load -- `base_prompt` — starting system prompt, combined with skill fragments -- `build_system_prompt()` — merges base_prompt + all skill prompt fragments - -Agents self-register at import time via `agents/__init__.py`'s `register()`. -`main.py` calls `load_all_agents()` at startup to import every agent and skill -module. - -### 2. Skill - -A **Skill** is a capability bundle. Defined in `skills/`. - -```python -# skills/seerr.py -Skill( - name="seerr", - description="Seerr integration — trending, discover, request media, submit issues", - prompt_fragment="## Seerr Media Tools\n...", - tools=[...], # OpenAI function-calling schema - execute=_execute, # async handler: tool_name + args → ToolResult -) -``` - -- `prompt_fragment` — injected into the agent's system prompt. -- `tools` — list of OpenAI function definitions (name, description, parameters). -- `execute` — async callable that routes tool calls to API handlers. - -### 3. Graph - -Each agent gets a **compiled LangGraph StateGraph** built by -`core/graph.py:create_agent_graph()`. The graph is compiled lazily on the -first request and cached on `app.state.agent_graphs` for the lifetime of the -process. - -| Graph node / edge | What it does | +| Path | Purpose | |---|---| -| `agent_node` | Converts state messages to OpenAI dicts, calls the LLM with the agent's system prompt + tool definitions, returns an `AIMessage` | -| `tool_node` | Reads `tool_calls` from the last AI message, calls `execute_tool()` from the skill system, returns `ToolMessage` results | -| `_should_continue` | Conditional edge — returns `"tool_node"` if the AI message has `tool_calls`, else `END` | - -### 4. State - -Defined in `core/state.py`: - -```python -class AgentState(TypedDict): - messages: Annotated[list, add_messages] -``` - -LangGraph's `add_messages` reducer appends new messages and replaces messages -with matching IDs (so tool-call results overwrite their placeholders). - -### 5. Message Conversion - -Because we use the raw `openai` client (not `langchain-openai`), messages must -be converted between LangChain and OpenAI formats at every LLM call: - -- **LangChain → OpenAI** (`_lc_role_to_openai`, `_langchain_tc_to_openai`): - Maps `type` → `role` and converts top-level `name`/`args` tool-calls into - the nested `function` sub-object that the OpenAI API expects. - -- **OpenAI → LangChain** (inside `agent_node`): - Converts the `ChatCompletionMessage` response into an `AIMessage` with - LangChain-format `tool_calls` (top-level `name`/`args`/`id`). +| `gateway/dependencies.py` | FastAPI `Depends` providers — `get_llm_client()`, `get_agent_graph()` | +| `src/config.py` | `.env` loader and config accessor | +| `src/llm.py` | OpenAI-compatible client factory (DeepSeek) | +| `src/state.py` | LangGraph `AgentState` TypedDict | +| `src/graph.py` | LangGraph StateGraph factory — agent_node, tool_node, routing | +| `src/tools_adapter.py` | Wraps skill tools as LangChain `@tool` functions | +| `src/auth_store.py` | SQLite persistence for Discord → service auth linking | +| `agents/` | Agent definitions (dataclass + registry) | +| `agents/skills/` | Skill definitions — prompt fragments, tool schemas, executors | --- -## Full Request Flow - -### Step-by-step: "What are trending movies?" +## High-Level Request Flow ``` -1. OpenWebUI sends: - POST /v1/chat/completions - { - "model": "media-agent", - "messages": [ - {"role": "user", "content": "What are trending movies?"} - ], - "stream": false - } - -2. chat_completions(): - → _resolve_agent(model="media-agent") - → get_agent("media-agent") → Agent(skills=["media_info", "seerr", "triage"]) - → get_agent_graph("media-agent", request) - → looks up app.state.agent_graphs["media-agent"] - → first call → create_agent_graph() compiles the graph with 7 Seerr tools - → run_agent_with_tools(request, messages, agent_id) - → _invoke_graph(graph, messages) - -3. Graph — Pass 1 (agent_node): - → LLM receives: [system prompt] + [user: "What are trending movies?"] - → LLM responds with tool_calls: seerr_trending(kind="movie") - → agent_node returns AIMessage with tool_calls in LangChain format - -4. Graph — _should_continue: - → AIMessage has tool_calls → route to "tool_node" - -5. Graph — tool_node: - → Reads tool_call: name="seerr_trending", args={"kind": "movie"} - → execute_tool(["media_info", "seerr", "triage"], "seerr_trending", ...) - → Seerr API → GET /api/v1/discover/trending?mediaType=movie - → Returns ToolMessage with formatted results including [tmdb:IDs] - -6. Graph — Pass 2 (agent_node): - → LLM receives previous exchange + tool result - → LLM responds with text only (no tool_calls) - → agent_node returns AIMessage(content="Here are the top trending movies!...") - -7. Graph — _should_continue: - → No tool_calls → route to END - -8. chat_completions() returns: - { "choices": [{"message": {"role": "assistant", "content": "Here are the top..."}}] } +┌──────────────────────────────┐ +│ Client (OpenWebUI / HTTP) │ +└──────────────┬───────────────┘ + │ POST /v1/chat/completions + ▼ +┌──────────────────────────────┐ +│ gateway/v1/chat.py │ ← resolves agent, invokes graph +└──────────────┬───────────────┘ + │ + ▼ +┌──────────────────────────────┐ +│ LangGraph StateGraph │ ← src/graph.py +│ ┌──────────┐ ┌──────────┐│ +│ │agent_node│──▶│tool_node ││ +│ │(LLM call)│◀──│(skills) ││ +│ └──────────┘ └──────────┘│ +└──────────────┬───────────────┘ + │ + ▼ +┌──────────────────────────────┐ +│ agents/skills/ │ ← Seerr API, Jellyfin API, etc. +└──────────────────────────────┘ ``` -### Step-by-step: "Request the 2026 one" (multi-turn context) - -``` -1. OpenWebUI sends the FULL history: - { - "model": "media-agent", - "messages": [ - {"role": "user", "content": "What are trending movies?"}, - {"role": "assistant", "content": "Here are the top 10 trending movies! - 1. **Mortal Kombat II** (2026) [tmdb:931285] — ..."}, - {"role": "user", "content": "could request the mortal kombat one?"}, - {"role": "assistant", "content": "There are several Mortal Kombat entries! ..."}, - {"role": "user", "content": "the 2026 one"} - ] - } - -2. chat_completions(): - → req.messages contains the ENTIRE conversation history - → graph.ainvoke({"messages": all_messages}) - → agent_node prepends system prompt and sends everything to the LLM - -3. LLM reasons from full context: - - Previously listed Mortal Kombat II (2026) with [tmdb:931285] - - The user said "request the mortal kombat one" → I searched and showed 4 options - - Now they say "the 2026 one" → that matches Mortal Kombat II (2026) [tmdb:931285] - - I should call seerr_request_media(kind="movie", title="Mortal Kombat II", tmdb_id=931285) - -4. tool_node executes the request → ✅ Success -``` +For a detailed step-by-step walkthrough of the graph execution (including +multi-turn context and tool-calling loops), see [v1.md](v1/v1.md). --- -## Streaming +## Startup -Streaming works slightly differently from the sync path: +`main.py` is the entry point. At startup it: -``` -chat_completions(stream=True) - → _stream_graph(graph, messages) - → graph.ainvoke(state) # runs graph to completion (tools execute silently) - → yields content character-by-character via SSE -``` - -For true token-level streaming (tokens appear as the LLM generates them), -the agent_node would need to use `langchain-openai`'s `ChatOpenAI` instead of -the raw `openai` client. The current approach is a pragmatic middle ground -that avoids adding another dependency while still giving the SSE client -incremental output. - ---- - -## File Map - -| File | Responsibility | -|---|---| -| `main.py` | FastAPI app, singleton creation, router mounting | -| `api/v1/chat.py` | Endpoints — resolves agent, invokes graph, formats responses | -| `api/dependencies.py` | `get_llm_client()`, `get_agent_graph()` — FastAPI `Depends` | -| `core/graph.py` | `create_agent_graph()` — builds the StateGraph | -| `core/state.py` | `AgentState` TypedDict | -| `core/llm.py` | `create_client()` — OpenAI client factory | -| `core/config.py` | Environment variable loader | -| `agents/` | Agent definitions (dataclass + self-registration) | -| `skills/` | Skill definitions (prompt fragments + tools + executors) | +1. Loads `.env` → creates the LLM client (DeepSeek) → stores on `app.state.llm_client` +2. Calls `load_all_agents()` → imports every agent and skill module (they self-register) +3. Imports `gateway.auth.jellyfin` → self-registers the Jellyfin auth service +4. Mounts routers: `/v1/*` (chat endpoints) and `/api/v1/auth/*` (auth endpoints) +5. Starts the Discord bot as a background asyncio task (lifespan) diff --git a/gateway/auth/auth.md b/gateway/auth/auth.md new file mode 100644 index 0000000..df39c50 --- /dev/null +++ b/gateway/auth/auth.md @@ -0,0 +1,152 @@ +# Auth — Service Registry & Persistence + +The authentication system lets Discord users link their accounts to external +services (currently **Jellyfin**) so the agent can perform actions on their +behalf (e.g. checking watch history). + +--- + +## Architecture + +``` +gateway/auth/ gateway/v1/auth.py +┌──────────────────────┐ ┌──────────────────────────────┐ +│ AuthService (ABC) │ │ GET /api/v1/auth/login │ +│ ├─ JellyfinAuth │◀─────────│ POST /api/v1/auth/login │ +│ └─ (Plex, Seerr…) │ │ GET /api/v1/auth/status │ +│ │ │ GET /api/v1/auth/reset │ +└─────────┬────────────┘ └──────────────────────────────┘ + │ + ▼ +src/auth_store.py +┌──────────────────────┐ +│ SQLite │ +│ ├─ link_tokens │ one-time tokens sent via Discord DM +│ └─ user_auth │ per-user, per-service credentials +└──────────────────────┘ +``` + +--- + +## Files + +| File | Purpose | +|---|---| +| `gateway/auth/__init__.py` | Abstract `AuthService` base class + global registry | +| `gateway/auth/jellyfin.py` | Jellyfin implementation — Quick Connect + username/password | +| `gateway/v1/auth.py` | REST endpoints for the web-based login flow | +| `src/auth_store.py` | SQLite persistence for link tokens and stored credentials | + +--- + +## Flow: Discord User Links Jellyfin + +``` +Discord DM Web Browser Jellyfin Server + │ │ │ + │ 1. /login jellyfin │ │ + │ ──────────────────────────────▶│ │ + │ Bot creates link token in │ │ + │ SQLite, DMs the user a URL │ │ + │ │ │ + │ 2. User clicks link │ │ + │ ◀─────────────────────────────▶│ │ + │ │ GET /api/v1/auth/login │ + │ │ ?service=jellyfin │ + │ │ &token=xxx&discord_id=123 │ + │ │ │ + │ │ 3. Serve Quick Connect form │ + │ │ ◀──────────────────────────── │ + │ │ │ + │ │ 4. Initiate Quick Connect │ + │ │ ─────────────────────────────▶│ + │ │ POST /QuickConnect/Initiate │ + │ │ ◀── { Code: "ABC123" } │ + │ │ │ + │ 5. User enters code in │ │ + │ Jellyfin app │ │ + │ │ │ + │ │ 6. Poll: is it authorized? │ + │ │ ─────────────────────────────▶│ + │ │ GET /QuickConnect/Connect │ + │ │ ◀── Authenticated + Token │ + │ │ │ + │ 7. auth_store saves: │ │ + │ (discord_id, jellyfin, │ │ + │ AccessToken, username) │ │ + │ │ │ + │ 8. "✅ Linked to Jellyfin!" │ │ + │ ◀───────────────────────────── │ │ +``` + +--- + +## AuthService Base Class + +```python +class AuthService(ABC): + name: str # "jellyfin" + display_name: str # "Jellyfin" + + def render_login_form(token, discord_id) -> str: ... + async def authenticate(form_data) -> AuthResult: ... +``` + +Add a new service (e.g. Plex, Seerr) by subclassing `AuthService`, dropping +the module in `gateway/auth/`, and calling `register_auth_service()` at import +time. The REST endpoints and auth store work generically — no changes needed. + +--- + +## Current Implementation: Jellyfin + +`gateway/auth/jellyfin.py` supports two flows: + +| Method | How it works | +|---|---| +| **Quick Connect** (primary) | Calls Jellyfin's `/QuickConnect/Initiate` → polls `/QuickConnect/Connect` → stores the `AccessToken` | +| **Username/Password** (fallback) | Renders an HTML form → user submits credentials → calls `/Users/AuthenticateByName` → stores the `AccessToken` | + +The stored credentials include: +- `external_user_id` — Jellyfin user ID +- `external_name` — Jellyfin username +- `credentials` dict — `{"AccessToken": "...", "ServerURL": "..."}` + +--- + +## Auth Store (SQLite) + +Two tables in `data/auth.db`: + +```sql +-- One-time tokens for the web login flow (expire after 10 min) +CREATE TABLE link_tokens ( + token TEXT PRIMARY KEY, + discord_id INTEGER NOT NULL, + service TEXT NOT NULL, + created_at TEXT NOT NULL, + used INTEGER DEFAULT 0 +); + +-- Per-user, per-service stored credentials +CREATE TABLE user_auth ( + discord_id INTEGER NOT NULL, + service TEXT NOT NULL, + external_user_id TEXT, + external_name TEXT, + credentials TEXT, -- JSON + created_at TEXT NOT NULL, + PRIMARY KEY (discord_id, service) +); +``` + +--- + +## Skill-Level Auth Gating + +Skills can declare `requires_auth=["jellyfin"]`. When a tool is executed, +the skill system checks the auth store. If the user isn't linked: + +1. The tool returns `ToolResult.fail("Please login first using /login jellyfin")` +2. The LLM relays this message to the user in Discord +3. The user types `/login jellyfin` → Quick Connect flow → re-linked → try again diff --git a/gateway/discord/discord.md b/gateway/discord/discord.md new file mode 100644 index 0000000..7afbe2f --- /dev/null +++ b/gateway/discord/discord.md @@ -0,0 +1,73 @@ +# Discord — Connector + +The Discord module embeds a Discord bot **in-process** alongside FastAPI. +It uses the same LangGraph graphs and LLM client as the REST API — there is +no HTTP loopback, no separate process, and no code duplication. + +--- + +## Files + +| File | Purpose | +|---|---| +| `bot.py` | Discord `Client` subclass (`AgentBot`) — DM handler, command parser, graph invoker, Quick Connect orchestrator | +| `conversation.py` | In-memory conversation history store, keyed by Discord user ID | + +--- + +## Architecture + +``` +Discord Gateway (websocket) + │ DM: "What's trending?" + ▼ +discord.Client.on_message() + │ 1. Check: is this a DM? shares a guild? not a command? + │ 2. Build message history from ConversationStore + │ 3. Append user message + ▼ +_create_agent_graph(agent_id="media-agent") + │ Uses the exact same create_agent_graph() from src/graph.py + │ as the REST API — same LLM client, same tools, same cache. + ▼ +graph.ainvoke({"messages": [...]}) + │ LangGraph runs agent_node → tool_node → agent_node → END + ▼ +Response text → split into ≤2000-char Discord messages → sent to user +``` + +--- + +## Commands + +Commands are DMs that start with `/`. The bot parses them before hitting the +LLM: + +| Command | Action | +|---|---| +| `/login ` | Generate a one-time auth link, DM it to the user | +| `/jellyfin login` | Alias for `/login jellyfin` | +| `/help` | Show available agents and commands | +| `/` | Switch to a different agent for future messages | + +--- + +## Auth Flow (Quick Connect) + +When a user types `/login jellyfin`: + +1. Bot generates a one-time token via `auth_store` +2. Bot calls `auth_store.create_link_token(discord_id, "jellyfin")` +3. Bot DMs the user: `https:///api/v1/auth/login?service=jellyfin&token=...&discord_id=...` +4. User clicks the link → FastAPI serves the Jellyfin login form (or Quick Connect prompt) +5. User authenticates → credentials stored in `auth_store` +6. Future tool calls (e.g. `watch_history`) automatically use the stored Jellyfin session + +--- + +## Conversation Persistence + +- Per-user history stored in `ConversationStore` (in-memory dict) +- Max history length configurable via `DISCORD_MAX_HISTORY` env var (default: 7) +- Oldest messages are silently dropped when the limit is exceeded +- History is NOT persisted across restarts (future: could use SQLite) diff --git a/gateway/v1/v1.md b/gateway/v1/v1.md new file mode 100644 index 0000000..a627a02 --- /dev/null +++ b/gateway/v1/v1.md @@ -0,0 +1,106 @@ +# V1 — Chat & Agent API Endpoints + +This is the primary HTTP API surface for the chatbot agent system. It exposes +both a custom streaming chat endpoint and an OpenAI-compatible +`/chat/completions` endpoint so it works as a drop-in backend for OpenWebUI, +LibreChat, or any OpenAI-compatible client. + +--- + +## Endpoints + +| Method | Path | Description | +|---|---|---| +| `GET ` | `/v1/` | Health check — returns `{"status": "ok"}` | +| `GET ` | `/v1/agents` | List all registered agents (id + description) | +| `GET ` | `/v1/models` | OpenAI-compatible model list (one entry per agent) | +| `POST` | `/v1/chat` | Chat with an agent — streaming (SSE) | +| `POST` | `/v1/chat/sync` | Chat with an agent — non-streaming | +| `POST` | `/v1/chat/completions` | OpenAI-compatible chat completions (supports `stream: true`) | + +All `/v1/*` endpoints are mounted by `main.py` via: + +```python +app.include_router(v1_router, prefix="/v1") +``` + +--- + +## Agent Resolution + +Each request can target a specific agent. The resolution order is: + +1. **Explicit `agent_id`** field in the request body +2. **OpenAI `model` field** (OpenWebUI sends this — mapped to `agent_id` if a matching agent is registered) +3. **Fallback** to the `"naked"` agent (a plain LLM with no tools) + +This means an OpenWebUI client can simply set `model: "media-agent"` and get +the full Media Agent with Seerr tools. + +--- + +## Request Flow + +``` +Client (OpenWebUI / HTTP) + │ POST /v1/chat/completions + │ { model: "media-agent", messages: [...], stream: true/false } + ▼ +chat_completions() + │ 1. _resolve_agent(req.model) → Agent(id="media-agent", skills=[...]) + │ 2. get_agent_graph("media-agent", request) + │ → lazy-compiled LangGraph StateGraph, cached on app.state + │ 3. stream=True → _stream_graph(graph, messages) → SSE token stream + │ stream=False → _invoke_graph(graph, messages) → plain response + ▼ +LangGraph StateGraph (src/graph.py) + │ + ├── agent_node: calls LLM with system prompt + tool definitions + │ └── LLM returns text OR tool_calls + │ + ├── _should_continue: if tool_calls → tool_node, else → END + │ + └── tool_node: executes tool via agents/skills system → ToolMessage + └── loops back to agent_node with the result +``` + +For a detailed walkthrough, see [api.md](../api.md). + +--- + +## Streaming + +Two streaming modes exist: + +### SSE (Server-Sent Events) — `/v1/chat` +``` +data: {"token": "Here"} +data: {"token": " are"} +data: {"token": " the"} +... +data: [DONE] +``` + +The graph runs to completion (tools execute silently), then the final text is +yielded token-by-token as SSE events. + +### OpenAI-compatible — `/v1/chat/completions` with `stream: true` +``` +data: {"id":"...","object":"chat.completion.chunk","choices":[{"delta":{"content":"Hello"}}]} +data: {"id":"...","object":"chat.completion.chunk","choices":[{"delta":{"content":"!"}}]} +data: [DONE] +``` + +> **Future improvement:** true token-level streaming (tokens appear as the LLM +> generates them) would require using `langchain-openai`'s `ChatOpenAI` in +> place of the raw `openai` client. The current approach avoids adding that +> dependency. + +--- + +## Dependencies + +Endpoints receive shared singletons via FastAPI `Depends`: + +- **`get_llm_client(request)`** → returns `request.app.state.llm_client` (OpenAI client singleton, created once in `main.py`) +- **`get_agent_graph(agent_id, request)`** → returns a lazy-compiled LangGraph from `request.app.state.agent_graphs` -- 2.52.0 From 0151c8210ebf39a29082a457eaca7ce074d8414f Mon Sep 17 00:00:00 2001 From: TimHoogervorst Date: Mon, 25 May 2026 13:54:30 +0200 Subject: [PATCH 5/5] implement JellyStat API for watch history, genre summary, and user summary; add PostgreSQL connection pool and update requirements --- .env.example | 9 + agents/skills/__init__.py | 7 +- agents/skills/watch_history.py | 255 +++++++++++++++++++++--- gateway/jellystat/__init__.py | 0 gateway/jellystat/api.py | 106 ++++++++++ gateway/jellystat/db.py | 130 ++++++++++++ gateway/jellystat/models.py | 36 ++++ gateway/jellystat/startup-functions.sql | 224 +++++++++++++++++++++ main.py | 8 +- requirements.txt | 3 +- 10 files changed, 743 insertions(+), 35 deletions(-) create mode 100644 gateway/jellystat/__init__.py create mode 100644 gateway/jellystat/api.py create mode 100644 gateway/jellystat/db.py create mode 100644 gateway/jellystat/models.py create mode 100644 gateway/jellystat/startup-functions.sql diff --git a/.env.example b/.env.example index 8aa184b..d72c0f8 100644 --- a/.env.example +++ b/.env.example @@ -39,3 +39,12 @@ BASE_URL=http://localhost:8000 # Link token expiry in minutes (default 10) # AUTH_TOKEN_EXPIRY=10 + +# --------------------------------------------------------------------------- +# JellyStat — PostgreSQL watch-history database +# --------------------------------------------------------------------------- +JELLYSTAT_DB_HOST=localhost +JELLYSTAT_DB_PORT=5432 +JELLYSTAT_DB_USER=postgres +JELLYSTAT_DB_PASSWORD= +JELLYSTAT_DB_NAME=jfstat diff --git a/agents/skills/__init__.py b/agents/skills/__init__.py index 1e37b4d..bdd3a83 100644 --- a/agents/skills/__init__.py +++ b/agents/skills/__init__.py @@ -118,8 +118,8 @@ async def execute_tool( if t.get("function", {}).get("name") == tool_name: # --- Auth gate --- if s.requires_auth and discord_user_id is not None: - from core import auth_store - from auth import get_auth_service + from src import auth_store + from gateway.auth import get_auth_service missing: list[str] = [] for svc in s.requires_auth: if not auth_store.is_authenticated(discord_user_id, svc): @@ -134,6 +134,9 @@ async def execute_tool( + " ".join(f"Send `/login {m}` in a DM to get started." for m in missing) ) # --- End auth gate --- + # Inject discord_user_id so skills can resolve external user IDs + if discord_user_id is not None: + args = {**args, "_discord_user_id": discord_user_id} try: result = await s.execute(tool_name, args) if not result.success: diff --git a/agents/skills/watch_history.py b/agents/skills/watch_history.py index ff414fb..a7c75c8 100644 --- a/agents/skills/watch_history.py +++ b/agents/skills/watch_history.py @@ -1,14 +1,29 @@ """ -Watch History skill — fetch the user's Jellyfin watch history. +Watch History skill — fetch the user's Jellyfin watch history via JellyStat API. -Currently a placeholder — returns a "coming soon" message. -The auth gate (`requires_auth=["jellyfin"]`) is already active: -users who haven't linked Jellyfin will be prompted to /login first. +Requires the user to have linked Jellyfin via `/login jellyfin` in Discord. +The auth gate (`requires_auth=["jellyfin"]`) is already active — users who +haven't linked Jellyfin will be prompted to /login first. + +Architecture +------------ +This skill calls the JellyStat REST API (same FastAPI process, via HTTP) +rather than accessing the PostgreSQL database directly. This keeps the +bot isolated from database credentials. """ from __future__ import annotations +import httpx + from agents.skills import Skill, register, ToolResult +from src import auth_store +from src.config import get_config + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- +BASE_URL = (get_config("BASE_URL") or "http://localhost:8000").rstrip("/") # --------------------------------------------------------------------------- # Tool definitions @@ -20,59 +35,237 @@ TOOLS = [ "function": { "name": "watch_history", "description": ( - "Get the user's recent Jellyfin watch history — movies and TV " - "episodes they have watched, sorted by most recent. " - "Call this when a user asks about their watching activity." + "Get the user's Jellyfin watch history — titles grouped by total " + "watch time in a configurable time window. Use this when a user " + "asks what they've watched, what they've been watching recently, " + "or wants to see their viewing activity." ), "parameters": { "type": "object", "properties": { "limit": { "type": "integer", - "description": "How many items to return (default 10, max 20)", - } + "description": "How many titles to return (default 10, max 20).", + }, + "minutes": { + "type": "integer", + "description": ( + "Time window in minutes. Default 10080 (7 days). " + "Use a large number like 525600 for 'all time' (1 year)." + ), + }, }, }, }, - } + }, + { + "type": "function", + "function": { + "name": "watch_genres", + "description": ( + "Get the user's most-watched genres from Jellyfin, ranked by " + "total watch time. Use this when a user asks what kinds of " + "content they watch most, their favourite genres, or what " + "categories dominate their viewing." + ), + "parameters": { + "type": "object", + "properties": { + "minutes": { + "type": "integer", + "description": ( + "Time window in minutes. Default 10080 (7 days). " + "Use a large number like 525600 for 'all time'." + ), + }, + }, + }, + }, + }, + { + "type": "function", + "function": { + "name": "watch_summary", + "description": ( + "Get an all-time Jellyfin watch summary — total watch time, " + "most-watched series, most-watched movie, 30-day and 7-day " + "activity, and top 3 genres. Use this when a user asks for " + "their overall stats, a dashboard, or 'how much have I watched?'." + ), + "parameters": {"type": "object", "properties": {}}, + }, + }, ] # --------------------------------------------------------------------------- -# Executor (placeholder) +# Helpers # --------------------------------------------------------------------------- + +def _resolve_jellyfin_id(args: dict) -> str | None: + """Extract the Jellyfin user ID from auth_store using the injected Discord ID.""" + discord_user_id = args.pop("_discord_user_id", None) + if discord_user_id is None: + return None # not called from Discord — shouldn't happen with auth gate + + auth = auth_store.get_auth(discord_user_id, "jellyfin") + if auth is None or not auth.get("external_user_id"): + return None + + return auth["external_user_id"] + + +async def _fetch_json(url: str) -> dict: + """GET *url* and return the parsed JSON body, or {} on failure.""" + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.get(url) + resp.raise_for_status() + return resp.json() + + +def _format_seconds(total: float) -> str: + """Convert seconds to a human-friendly string.""" + total = max(total, 0) + hours = int(total // 3600) + minutes = int((total % 3600) // 60) + if hours and minutes: + return f"{hours}h {minutes}m" + if hours: + return f"{hours}h" + if minutes: + return f"{minutes}m" + return f"{int(total)}s" + + +def _format_history(data: dict, limit: int) -> ToolResult: + """Format a watch-history API response for the LLM.""" + items = data.get("items", [])[:limit] + if not items: + return ToolResult.ok("You haven't watched anything in this time window.") + + lines = [f"**Watch History** (last {data.get('window_minutes', '?')} minutes):"] + for i, item in enumerate(items, 1): + duration = _format_seconds(item["watch_time_sec"]) + icon = "📺" if item["media_type"] == "series" else "🎬" + lines.append(f"{i}. {icon} **{item['title']}** — {duration}") + + return ToolResult.ok("\n".join(lines)) + + +def _format_genres(data: dict) -> ToolResult: + """Format a genre-summary API response for the LLM.""" + genres = data.get("genres", []) + if not genres: + return ToolResult.ok("No genre data available for this time window.") + + lines = [f"**Top Genres** (last {data.get('window_minutes', '?')} minutes):"] + for i, g in enumerate(genres, 1): + duration = _format_seconds(g["watch_time_sec"]) + lines.append(f"{i}. **{g['genre']}** — {duration}") + + return ToolResult.ok("\n".join(lines)) + + +def _format_summary(data: dict) -> ToolResult: + """Format a user-summary API response for the LLM.""" + total = _format_seconds(data.get("total_watch_time_sec", 0)) + last_30 = _format_seconds(data.get("total_last_30d_sec", 0)) + last_7 = _format_seconds(data.get("total_last_7d_sec", 0)) + + top_series = data.get("most_watched_series") or "—" + top_movie = data.get("most_watched_movie") or "—" + top_genres = data.get("top_genres", []) + genres_str = ", ".join(top_genres) if top_genres else "—" + + lines = [ + "**Your Jellyfin Summary** (all time):", + f"⏱️ Total watch time: **{total}**", + f"📺 Most-watched series: **{top_series}**", + f"🎬 Most-watched movie: **{top_movie}**", + f"📅 Last 30 days: **{last_30}**", + f"📅 Last 7 days: **{last_7}**", + f"🏷️ Top genres: {genres_str}", + ] + return ToolResult.ok("\n".join(lines)) + + +# --------------------------------------------------------------------------- +# Executor +# --------------------------------------------------------------------------- + + async def _execute(tool_name: str, args: dict) -> ToolResult: - if tool_name == "watch_history": - return ToolResult.ok( - "👷 **Watch History — Coming Soon!**\n\n" - "This feature is currently being built. Soon you'll be able to " - "see your recently watched movies and TV episodes right here.\n\n" - "In the meantime, you can check your watch history directly in Jellyfin." + # 1. Resolve Jellyfin user ID + jellyfin_id = _resolve_jellyfin_id(args) + if jellyfin_id is None: + return ToolResult.fail( + "Your Jellyfin account is not linked. Use `/login jellyfin` in a DM to connect." + ) + + # 2. Route to the right JellyStat endpoint + try: + match tool_name: + case "watch_history": + limit = args.get("limit", 10) + minutes = args.get("minutes", 10080) + url = f"{BASE_URL}/jellystat/history/{jellyfin_id}?minutes={minutes}" + data = await _fetch_json(url) + return _format_history(data, limit) + + case "watch_genres": + minutes = args.get("minutes", 10080) + url = f"{BASE_URL}/jellystat/genres/{jellyfin_id}?minutes={minutes}" + data = await _fetch_json(url) + return _format_genres(data) + + case "watch_summary": + url = f"{BASE_URL}/jellystat/summary/{jellyfin_id}" + data = await _fetch_json(url) + return _format_summary(data) + + case _: + return ToolResult.fail(f"Unknown tool: {tool_name}") + + except httpx.HTTPError: + return ToolResult.fail( + "Could not reach the watch-history service right now. " + "Please try again in a moment." ) - return ToolResult.fail(f"Unknown tool: {tool_name}") # --------------------------------------------------------------------------- # Skill registration # --------------------------------------------------------------------------- +_PROMPT = ( + "## Watch History\n" + "\n" + "You have THREE tools to answer questions about the user's Jellyfin watch activity:\n" + "\n" + "1. **`watch_history`** — per-title watch time in a time window (default: 7 days).\n" + " Use when a user asks what they've watched, to show their history,\n" + " or what they watched this week or yesterday.\n" + "\n" + "2. **`watch_genres`** — watch time broken down by genre.\n" + " Use when a user asks what genres they watch, whether they watch more\n" + " comedy than drama, or what their most-watched genre is.\n" + "\n" + "3. **`watch_summary`** — all-time dashboard: total watch time, most-watched\n" + " series and movie, 30-day and 7-day activity, and top 3 genres.\n" + " Use when a user asks for their stats, how much they've watched in\n" + " total, or what their favourites are.\n" + "\n" + "Always call the appropriate tool before answering — NEVER guess at watch data.\n" + "Format watch times in a human-readable way (hours and minutes), but keep the\n" + "raw data visible too." +) + watch_history_skill = Skill( name="watch_history", - description="User's Jellyfin watch history (coming soon)", + description="User's Jellyfin watch history, genres, and summary stats", requires_auth=["jellyfin"], - prompt_fragment="""## Watch History - -You can fetch the user's Jellyfin watch history with the `watch_history` tool. -Call it when users ask things like: -- "what have I watched?" -- "show my watch history" -- "what did I watch recently?" -- "what was the last movie I saw?" -- "what TV shows have I been watching?" - -The tool is currently a **placeholder** — it returns a "coming soon" message. -Tell the user this feature is being worked on and will be available soon.""", + prompt_fragment=_PROMPT, tools=TOOLS, execute=_execute, ) diff --git a/gateway/jellystat/__init__.py b/gateway/jellystat/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gateway/jellystat/api.py b/gateway/jellystat/api.py new file mode 100644 index 0000000..dd60310 --- /dev/null +++ b/gateway/jellystat/api.py @@ -0,0 +1,106 @@ +"""JellyStat REST API — watch history, genre summary, and user summary.""" + +from __future__ import annotations + +import asyncpg +from fastapi import APIRouter, Depends, Query + +from gateway.jellystat.db import get_pool +from gateway.jellystat.models import ( + GenreSummaryResponse, + UserSummaryResponse, + WatchHistoryResponse, +) + +router = APIRouter(prefix="/jellystat", tags=["jellystat"]) + +DEFAULT_WINDOW_MINUTES = 10080 # 7 days + + +# --------------------------------------------------------------------------- +# GET /jellystat/history/{user_id} +# --------------------------------------------------------------------------- + + +@router.get("/history/{user_id}", response_model=WatchHistoryResponse) +async def get_watch_history( + user_id: str, + minutes: int = Query( + default=DEFAULT_WINDOW_MINUTES, ge=1, description="Time window in minutes" + ), + pool: asyncpg.Pool = Depends(get_pool), +): + """Return watch history grouped by title, ordered by most-watched first.""" + rows = await pool.fetch( + "SELECT * FROM fn_user_watch_history($1, $2)", user_id, minutes + ) + return WatchHistoryResponse( + user_id=user_id, + window_minutes=minutes, + items=[ + { + "title": r["title"], + "watch_time_sec": float(r["watch_time_sec"]), + "media_type": r["media_type"], + } + for r in rows + ], + ) + + +# --------------------------------------------------------------------------- +# GET /jellystat/genres/{user_id} +# --------------------------------------------------------------------------- + + +@router.get("/genres/{user_id}", response_model=GenreSummaryResponse) +async def get_genre_summary( + user_id: str, + minutes: int = Query( + default=DEFAULT_WINDOW_MINUTES, ge=1, description="Time window in minutes" + ), + pool: asyncpg.Pool = Depends(get_pool), +): + """Return total watch time per genre, ordered by most-watched first.""" + rows = await pool.fetch( + "SELECT * FROM fn_user_genre_summary($1, $2)", user_id, minutes + ) + return GenreSummaryResponse( + user_id=user_id, + window_minutes=minutes, + genres=[ + {"genre": r["genre"], "watch_time_sec": float(r["watch_time_sec"])} + for r in rows + ], + ) + + +# --------------------------------------------------------------------------- +# GET /jellystat/summary/{user_id} +# --------------------------------------------------------------------------- + + +@router.get("/summary/{user_id}", response_model=UserSummaryResponse) +async def get_user_summary( + user_id: str, + pool: asyncpg.Pool = Depends(get_pool), +): + """Return all-time summary: total watch time, most-watched titles, top genres.""" + rows = await pool.fetch("SELECT * FROM fn_user_summary($1)", user_id) + + # fn_user_summary returns key-value rows — build a dict + # asyncpg already deserialises JSONB → Python objects + metrics: dict[str, object] = {r["metric"]: r["value"] for r in rows} + + top_genres_raw = metrics.get("top_genres", []) + top_genres: list[str] = top_genres_raw if isinstance(top_genres_raw, list) else [] + + return UserSummaryResponse( + user_id=user_id, + total_watch_time_sec=float(metrics.get("total_watch_time", 0)), + most_watched_series=metrics.get("most_watched_series"), + most_watched_movie=metrics.get("most_watched_movie"), + total_last_30d_sec=float(metrics.get("total_last_30d", 0)), + total_last_7d_sec=float(metrics.get("total_last_7d", 0)), + top_genres=top_genres, + ) diff --git a/gateway/jellystat/db.py b/gateway/jellystat/db.py new file mode 100644 index 0000000..09b5b7a --- /dev/null +++ b/gateway/jellystat/db.py @@ -0,0 +1,130 @@ +"""PostgreSQL connection pool for the JellyStat database.""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import asyncpg +from fastapi import FastAPI, Request + +from src.config import get_config + +logger = logging.getLogger("gateway.jellystat") + +# --------------------------------------------------------------------------- +# DSN builder +# --------------------------------------------------------------------------- + + +def _build_dsn() -> str: + """Build a PostgreSQL DSN from individual environment variables.""" + host = get_config("JELLYSTAT_DB_HOST", "localhost") + port = get_config("JELLYSTAT_DB_PORT", "5432") + user = get_config("JELLYSTAT_DB_USER", "postgres") + password = get_config("JELLYSTAT_DB_PASSWORD", "") + dbname = get_config("JELLYSTAT_DB_NAME", "jfstat") + return f"postgresql://{user}:{password}@{host}:{port}/{dbname}" + + +# --------------------------------------------------------------------------- +# Pool lifecycle (called from main.py lifespan) +# --------------------------------------------------------------------------- + + +async def init_pool(app: FastAPI) -> None: + """Create the connection pool and store it on app.state.""" + dsn = _build_dsn() + safe = dsn.split("@")[1] if "@" in dsn else dsn + logger.info("Connecting to JellyStat database at %s", safe) + + pool = await asyncpg.create_pool(dsn, min_size=1, max_size=5) + app.state.jellystat_pool = pool + + # Deploy functions on every startup (CREATE OR REPLACE is idempotent) + await _ensure_functions(pool) + + +async def close_pool(app: FastAPI) -> None: + """Close the pool on shutdown.""" + pool: asyncpg.Pool | None = getattr(app.state, "jellystat_pool", None) + if pool: + await pool.close() + logger.info("JellyStat pool closed") + + +# --------------------------------------------------------------------------- +# FastAPI dependency +# --------------------------------------------------------------------------- + + +async def get_pool(request: Request) -> asyncpg.Pool: + """Return the JellyStat connection pool from app state.""" + return request.app.state.jellystat_pool + + +# --------------------------------------------------------------------------- +# Function deployment +# --------------------------------------------------------------------------- + + +async def _ensure_functions(pool: asyncpg.Pool) -> None: + """Run startup-functions.sql to create or replace all JellyStat functions.""" + sql_path = Path(__file__).parent / "startup-functions.sql" + if not sql_path.exists(): + logger.warning("startup-functions.sql not found — skipping function deployment") + return + + sql = sql_path.read_text() + statements = _split_sql(sql) + + async with pool.acquire() as conn: + for stmt in statements: + try: + await conn.execute(stmt) + except Exception: + # Log but don't crash — functions might already exist + logger.exception("Failed to deploy SQL statement — continuing") + + logger.info("JellyStat functions deployed (%d statements)", len(statements)) + + +def _split_sql(sql: str) -> list[str]: + """ + Split a multi-statement SQL string into individual statements. + + Respects $$ dollar-quoting so that semicolons inside function bodies + don't cause premature splits. Pure comment lines (starting with ``--``) + outside dollar-quoted blocks are stripped. + """ + statements: list[str] = [] + current: list[str] = [] + in_dollar_quote = False + + for line in sql.split("\n"): + stripped = line.strip() + + # Skip pure comment lines outside of dollar-quoted blocks + if not in_dollar_quote and stripped.startswith("--"): + continue + + # Toggle dollar-quote state whenever we see $$ + if "$$" in line: + in_dollar_quote = not in_dollar_quote + + current.append(line) + + # Statement terminator: semicolon at end of line, outside $$ block + if not in_dollar_quote and line.rstrip().endswith(";"): + stmt = "\n".join(current).strip() + if stmt: + statements.append(stmt) + current = [] + + # Catch any trailing statement that wasn't terminated by a semicolon + if current: + stmt = "\n".join(current).strip() + if stmt: + statements.append(stmt) + + return statements diff --git a/gateway/jellystat/models.py b/gateway/jellystat/models.py new file mode 100644 index 0000000..f59d7f0 --- /dev/null +++ b/gateway/jellystat/models.py @@ -0,0 +1,36 @@ +"""Pydantic response models for the JellyStat API.""" + +from pydantic import BaseModel + + +class WatchHistoryItem(BaseModel): + title: str + watch_time_sec: float + media_type: str + + +class WatchHistoryResponse(BaseModel): + user_id: str + window_minutes: int + items: list[WatchHistoryItem] + + +class GenreSummaryItem(BaseModel): + genre: str + watch_time_sec: float + + +class GenreSummaryResponse(BaseModel): + user_id: str + window_minutes: int + genres: list[GenreSummaryItem] + + +class UserSummaryResponse(BaseModel): + user_id: str + total_watch_time_sec: float + most_watched_series: str | None + most_watched_movie: str | None + total_last_30d_sec: float + total_last_7d_sec: float + top_genres: list[str] diff --git a/gateway/jellystat/startup-functions.sql b/gateway/jellystat/startup-functions.sql new file mode 100644 index 0000000..f741482 --- /dev/null +++ b/gateway/jellystat/startup-functions.sql @@ -0,0 +1,224 @@ +-- ============================================================================ +-- JellyStat API Functions +-- Parameterized database functions callable by the API layer as: +-- SELECT * FROM fn_user_watch_history('user_id_here', 10080); +-- SELECT * FROM fn_user_genre_summary('user_id_here', 10080); +-- SELECT * FROM fn_user_summary('user_id_here'); +-- ============================================================================ + +-- ---------------------------------------------------------------------------- +-- 1. User Watch History +-- Returns every distinct title watched in the last N minutes, +-- grouped and summed by title, ordered by most-watched first. +-- ---------------------------------------------------------------------------- +CREATE OR REPLACE FUNCTION public.fn_user_watch_history( + p_user_id TEXT, + p_minutes INTEGER DEFAULT 10080 -- 7 days in minutes +) +RETURNS TABLE( + title TEXT, + watch_time_sec NUMERIC, + media_type TEXT +) +LANGUAGE sql +STABLE +AS $$ + SELECT + COALESCE(a."SeriesName", a."NowPlayingItemName") AS title, + SUM(a."PlaybackDuration")::NUMERIC AS watch_time_sec, + CASE + WHEN a."SeriesName" IS NOT NULL THEN 'series' + ELSE 'movie' + END AS media_type + FROM jf_playback_activity a + WHERE a."UserId" = p_user_id + AND a."ActivityDateInserted" + >= NOW() - (p_minutes * INTERVAL '1 minute') + GROUP BY + COALESCE(a."SeriesName", a."NowPlayingItemName"), + CASE WHEN a."SeriesName" IS NOT NULL THEN 'series' ELSE 'movie' END + ORDER BY watch_time_sec DESC; +$$; + +-- ---------------------------------------------------------------------------- +-- 2. Genre Summary +-- Returns total watch time per genre for a user over the last N minutes. +-- Resolves genres for both movies (directly on the item) and series +-- episodes (via jf_library_episodes → jf_library_items chain). +-- ---------------------------------------------------------------------------- +CREATE OR REPLACE FUNCTION public.fn_user_genre_summary( + p_user_id TEXT, + p_minutes INTEGER DEFAULT 10080 +) +RETURNS TABLE( + genre TEXT, + watch_time_sec NUMERIC +) +LANGUAGE sql +STABLE +AS $$ + WITH movie_genres AS ( + -- Movies: join playback directly to library_items on NowPlayingItemId + SELECT + genre_item.value AS genre, + SUM(a."PlaybackDuration") AS watch_time_sec + FROM jf_playback_activity a + JOIN jf_library_items i + ON i."Id" = a."NowPlayingItemId" + CROSS JOIN LATERAL jsonb_array_elements_text(i."Genres") AS genre_item(value) + WHERE a."UserId" = p_user_id + AND a."SeriesName" IS NULL -- movies only + AND a."ActivityDateInserted" + >= NOW() - (p_minutes * INTERVAL '1 minute') + AND i."Genres" IS NOT NULL + AND jsonb_array_length(i."Genres") > 0 + GROUP BY genre_item.value + ), + series_genres AS ( + -- Series: playback → episodes → series item → genres + SELECT + genre_item.value AS genre, + SUM(a."PlaybackDuration") AS watch_time_sec + FROM jf_playback_activity a + JOIN jf_library_episodes e + ON e."EpisodeId" = a."EpisodeId" + JOIN jf_library_items i + ON i."Id" = e."SeriesId" + CROSS JOIN LATERAL jsonb_array_elements_text(i."Genres") AS genre_item(value) + WHERE a."UserId" = p_user_id + AND a."SeriesName" IS NOT NULL -- TV episodes only + AND a."ActivityDateInserted" + >= NOW() - (p_minutes * INTERVAL '1 minute') + AND i."Genres" IS NOT NULL + AND jsonb_array_length(i."Genres") > 0 + GROUP BY genre_item.value + ), + combined AS ( + SELECT genre, watch_time_sec FROM movie_genres + UNION ALL + SELECT genre, watch_time_sec FROM series_genres + ) + SELECT + genre, + SUM(watch_time_sec)::NUMERIC AS watch_time_sec + FROM combined + GROUP BY genre + ORDER BY watch_time_sec DESC; +$$; + +-- ---------------------------------------------------------------------------- +-- 3. User Summary +-- One-shot dashboard: all-time stats + recent windows + top genres. +-- Returns key-value rows that the API trivially converts to a JSON object +-- with Object.fromEntries() or similar. +-- ---------------------------------------------------------------------------- +CREATE OR REPLACE FUNCTION public.fn_user_summary( + p_user_id TEXT +) +RETURNS TABLE( + metric TEXT, + value JSONB +) +LANGUAGE sql +STABLE +AS $$ + -- total_watch_time (all time) + SELECT 'total_watch_time'::TEXT AS metric, + to_jsonb(COALESCE(SUM("PlaybackDuration"), 0)::NUMERIC) AS value + FROM jf_playback_activity + WHERE "UserId" = p_user_id + + UNION ALL + + -- most_watched_series (by total watch time) + SELECT 'most_watched_series'::TEXT AS metric, + COALESCE( + (SELECT to_jsonb("SeriesName") + FROM jf_playback_activity + WHERE "UserId" = p_user_id + AND "SeriesName" IS NOT NULL + GROUP BY "SeriesName" + ORDER BY SUM("PlaybackDuration") DESC + LIMIT 1), + 'null'::JSONB + ) AS value + + UNION ALL + + -- most_watched_movie (by total watch time) + SELECT 'most_watched_movie'::TEXT AS metric, + COALESCE( + (SELECT to_jsonb("NowPlayingItemName") + FROM jf_playback_activity + WHERE "UserId" = p_user_id + AND "SeriesName" IS NULL + GROUP BY "NowPlayingItemName" + ORDER BY SUM("PlaybackDuration") DESC + LIMIT 1), + 'null'::JSONB + ) AS value + + UNION ALL + + -- total_watch_time_last_month (last 30 days) + SELECT 'total_last_30d'::TEXT AS metric, + to_jsonb(COALESCE(SUM("PlaybackDuration"), 0)::NUMERIC) AS value + FROM jf_playback_activity + WHERE "UserId" = p_user_id + AND "ActivityDateInserted" >= NOW() - INTERVAL '30 days' + + UNION ALL + + -- total_watch_time_last_week (last 7 days) + SELECT 'total_last_7d'::TEXT AS metric, + to_jsonb(COALESCE(SUM("PlaybackDuration"), 0)::NUMERIC) AS value + FROM jf_playback_activity + WHERE "UserId" = p_user_id + AND "ActivityDateInserted" >= NOW() - INTERVAL '7 days' + + UNION ALL + + -- top_genres (top 3 all-time, as a JSON array) + SELECT 'top_genres'::TEXT AS metric, + COALESCE( + (SELECT jsonb_agg(genre ORDER BY watch_time_sec DESC) + FROM ( + SELECT genre, SUM(watch_time_sec) AS watch_time_sec + FROM ( + -- movies + SELECT + genre_item.value AS genre, + SUM(a."PlaybackDuration") AS watch_time_sec + FROM jf_playback_activity a + JOIN jf_library_items i ON i."Id" = a."NowPlayingItemId" + CROSS JOIN LATERAL jsonb_array_elements_text(i."Genres") AS genre_item(value) + WHERE a."UserId" = p_user_id + AND a."SeriesName" IS NULL + AND i."Genres" IS NOT NULL + AND jsonb_array_length(i."Genres") > 0 + GROUP BY genre_item.value + + UNION ALL + + -- series + SELECT + genre_item.value AS genre, + SUM(a."PlaybackDuration") AS watch_time_sec + FROM jf_playback_activity a + JOIN jf_library_episodes e ON e."EpisodeId" = a."EpisodeId" + JOIN jf_library_items i ON i."Id" = e."SeriesId" + CROSS JOIN LATERAL jsonb_array_elements_text(i."Genres") AS genre_item(value) + WHERE a."UserId" = p_user_id + AND a."SeriesName" IS NOT NULL + AND i."Genres" IS NOT NULL + AND jsonb_array_length(i."Genres") > 0 + GROUP BY genre_item.value + ) combined + GROUP BY genre + ORDER BY SUM(watch_time_sec) DESC + LIMIT 3 + ) top3 + ), + '[]'::JSONB + ) AS value; +$$; \ No newline at end of file diff --git a/main.py b/main.py index 586ca45..67a4de9 100644 --- a/main.py +++ b/main.py @@ -6,6 +6,7 @@ from fastapi.middleware.cors import CORSMiddleware from gateway.v1.auth import router as auth_router from gateway.v1.chat import router as v1_router +from gateway.jellystat.api import router as jellystat_router from src.config import DEEPSEEK_API_KEY, get_config from src.llm import create_client @@ -33,11 +34,15 @@ import gateway.auth.jellyfin # noqa: E402 — self-registers JellyfinAuth @asynccontextmanager async def lifespan(app: FastAPI): from gateway.discord.bot import start_in_background # noqa: E402 + from gateway.jellystat.db import init_pool, close_pool # noqa: E402 + await init_pool(app) start_in_background() yield + await close_pool(app) + # --------------------------------------------------------------------------- # App @@ -64,4 +69,5 @@ app.state.agent_graphs: dict = {} # Routers # --------------------------------------------------------------------------- app.include_router(v1_router, prefix="/v1") -app.include_router(auth_router) \ No newline at end of file +app.include_router(auth_router) +app.include_router(jellystat_router) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b62cbf6..f1afa9f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ httpx langgraph langgraph-checkpoint discord.py -python-multipart \ No newline at end of file +python-multipart +asyncpg \ No newline at end of file -- 2.52.0