diff --git a/.env.example b/.env.example index fc20b32..8aa184b 100644 --- a/.env.example +++ b/.env.example @@ -1,16 +1,19 @@ -# --------------------------------------------------------------------------- -# Agent Backend — Environment Variables -# Copy this to .env and fill in your values. -# --------------------------------------------------------------------------- +# ============================================================================= +# Agent Bot — Environment Configuration +# ============================================================================= +# Copy this file to .env and fill in your values. +# .env is git-ignored — never commit real secrets. +# --------------------------------------------------------------------------- # LLM — DeepSeek (OpenAI-compatible) +# --------------------------------------------------------------------------- DEEPSEEK_API_KEY=sk-your-deepseek-api-key # --------------------------------------------------------------------------- # Discord Bot # --------------------------------------------------------------------------- DISCORD_BOT_TOKEN=your-discord-bot-token-here -# DISCORD_MAX_HISTORY=7 # optional, defaults to 7 (max past messages per user) +# DISCORD_MAX_HISTORY=7 # optional, defaults to 7 (max past messages per user) # DISCORD_DEFAULT_AGENT=media-agent # optional, which agent the DM bot uses # --------------------------------------------------------------------------- @@ -18,4 +21,21 @@ DISCORD_BOT_TOKEN=your-discord-bot-token-here # --------------------------------------------------------------------------- SEERR_URL=https://seerr.example.com SEERR_API_KEY=your-seerr-api-key -# SEERR_TIMEOUT=30 # optional, defaults to 30 seconds +# SEERR_USERNAME=your-username # alternative: username+password auth +# SEERR_PASSWORD=your-password +# SEERR_TIMEOUT=30 # optional, defaults to 30 seconds + +# --------------------------------------------------------------------------- +# Auth System (Discord ↔ external services) +# --------------------------------------------------------------------------- +# The public-facing URL where users reach this bot's web API. +# Used to build the "Click here to link" URLs sent via Discord DM. +# For local dev: http://localhost:8000 +# For production behind a reverse proxy: https://bot.yourdomain.com +BASE_URL=http://localhost:8000 + +# Where the auth SQLite database lives (relative to project root) +# AUTH_DB_PATH=data/auth.db + +# Link token expiry in minutes (default 10) +# AUTH_TOKEN_EXPIRY=10 diff --git a/.gitignore b/.gitignore index 089e111..c84e18a 100644 --- a/.gitignore +++ b/.gitignore @@ -174,4 +174,5 @@ cython_debug/ # PyPI configuration file .pypirc -.docs/ \ No newline at end of file +.docs/ +data/ \ No newline at end of file diff --git a/agents/__init__.py b/agents/__init__.py index 48ae91a..76df02e 100644 --- a/agents/__init__.py +++ b/agents/__init__.py @@ -65,3 +65,4 @@ def load_all_agents() -> None: import skills.seerr # noqa: F401 import skills.triage # noqa: F401 import skills.easter_eggs # noqa: F401 + import skills.watch_history # noqa: F401 diff --git a/agents/media_agent.py b/agents/media_agent.py index a599b63..c26ab73 100644 --- a/agents/media_agent.py +++ b/agents/media_agent.py @@ -14,11 +14,16 @@ media_agent = Agent( agent_id="media-agent", description="Media assistant — handles movie/TV/subtitle/ticket requests " "via Seerr, Jellyfin, Sonarr, etc.", - skills=["media_info", "seerr", "triage", "easter_eggs"], + skills=["media_info", "seerr", "triage", "easter_eggs", "watch_history"], base_prompt=( "You are a media assistant connected to Seerr and other media services. " "Help users discover, request, and troubleshoot their media library. " - "Use the tools provided to perform real actions." + "Use the tools provided to perform real actions.\n\n" + "## Authentication\n" + "If a tool returns a message saying the user needs to log in first, " + "tell the user to type `/login ` in their DM (e.g. `/login jellyfin`). " + "This opens Quick Connect on their Jellyfin app so they can link their account. " + "Do NOT tell the user you 'can't connect' or 'don't have access' — just relay the login instructions." ), ) diff --git a/api/v1/auth.py b/api/v1/auth.py new file mode 100644 index 0000000..7c2d1a4 --- /dev/null +++ b/api/v1/auth.py @@ -0,0 +1,189 @@ +""" +Auth API — generic endpoints for linking Discord users to external services. + +GET /api/v1/auth/login?service=X&token=Y&discord_id=Z + Validates the link token and serves a service-specific login form. + +POST /api/v1/auth/login + Accepts the form submission, validates credentials against the service, + stores the session, and returns a result page. + +GET /api/v1/auth/status?discord_id=Z + Returns which services are linked for this Discord user. +""" + +from __future__ import annotations + +import logging + +from fastapi import APIRouter, Form, HTTPException, Request +from fastapi.responses import HTMLResponse + +from auth import get_auth_service, list_auth_services +from core import auth_store + +logger = logging.getLogger("api.auth") + +router = APIRouter(prefix="/api/v1/auth", tags=["auth"]) + + +# --------------------------------------------------------------------------- +# GET /auth/login — serve the login form +# --------------------------------------------------------------------------- + +@router.get("/login") +async def login_form( + service: str, + token: str, + discord_id: int, +): + """Validate the one-time link token and return a service-specific login form.""" + + # Validate the token WITHOUT consuming it (the POST will consume it) + result = auth_store.validate_token(token) + if result is None: + raise HTTPException(status_code=400, detail="Invalid or expired link token.") + + uid, svc = result + if uid != discord_id or svc != service: + raise HTTPException(status_code=400, detail="Token does not match the request.") + + # Look up the AuthService + svc_obj = get_auth_service(service) + if svc_obj is None: + raise HTTPException(status_code=404, detail=f"Unknown service: {service}") + + logger.info("Serving login form: user=%s service=%s", discord_id, service) + return HTMLResponse(svc_obj.render_login_form(token, discord_id)) + + +# --------------------------------------------------------------------------- +# POST /auth/login — handle form submission +# --------------------------------------------------------------------------- + +@router.post("/login") +async def login_submit(request: Request): + """Handle the login form POST: validate credentials, store auth, show result.""" + + # Parse form data + form = await request.form() + token = form.get("token", "") + discord_id_str = form.get("discord_id", "") + service = form.get("service", "") + + if not token or not discord_id_str or not service: + raise HTTPException(status_code=400, detail="Missing required fields.") + + try: + discord_id = int(discord_id_str) + except (ValueError, TypeError): + raise HTTPException(status_code=400, detail="Invalid discord_id.") + + # Consume the token on POST (the GET only validated, didn't consume) + result = auth_store.consume_token(token) + if result is None: + raise HTTPException(status_code=400, detail="Invalid or expired link token.") + + # Look up the AuthService + svc_obj = get_auth_service(service) + if svc_obj is None: + raise HTTPException(status_code=404, detail=f"Unknown service: {service}") + + # Collect service-specific form fields (everything except token, discord_id, service) + form_data: dict[str, str] = {} + for key, value in form.items(): + if key not in ("token", "discord_id", "service"): + form_data[key] = str(value) + + # Authenticate against the service + auth_result = await svc_obj.authenticate(form_data) + + if not auth_result.success: + return HTMLResponse( + status_code=401, + content=f""" +Login Failed + +

❌ Login Failed

+

{auth_result.error_message or "Authentication failed. Please try again."}

+

← Go back and try again

+""", + ) + + # Store the successful auth + auth_store.store_auth( + discord_user_id=discord_id, + service=service, + external_user_id=auth_result.external_user_id or "", + external_name=auth_result.external_name or "", + credentials=auth_result.credentials, + ) + + logger.info( + "Auth linked: discord=%s → %s (%s)", + discord_id, + service, + auth_result.external_name, + ) + + return HTMLResponse(f""" + + + + +Account Linked + + + +

✅ Account Linked!

+

Logged in as {auth_result.external_name} on {svc_obj.display_name}.

+

You can close this page and return to Discord.

+ +""") + + +# --------------------------------------------------------------------------- +# GET /auth/status — check which services are linked +# --------------------------------------------------------------------------- + +@router.get("/status") +async def auth_status(discord_id: int): + """Return which services this Discord user has linked.""" + services: dict[str, bool] = {} + for svc_name in list_auth_services(): + services[svc_name] = auth_store.is_authenticated(discord_id, svc_name) + return {"discord_id": discord_id, "services": services} + + +# --------------------------------------------------------------------------- +# POST /auth/reset — wipe auth store (DEV ONLY) +# --------------------------------------------------------------------------- + +from core.config import get_config # noqa: E402 + +@router.post("/reset") +async def reset_auth(): + """ + Reset the entire auth store — clears all link tokens and user auth records. + + Only enabled when ALLOW_AUTH_RESET=true in the environment. + Returns 403 in production. + """ + if get_config("ALLOW_AUTH_RESET", "false").lower() != "true": + raise HTTPException( + status_code=403, + detail="Auth reset is disabled. Set ALLOW_AUTH_RESET=true to enable (dev only).", + ) + + auth_store.reset_all() + logger.warning("Auth store reset via API endpoint.") + return {"status": "ok", "message": "Auth store cleared — all tokens and auth records removed."} diff --git a/auth/__init__.py b/auth/__init__.py new file mode 100644 index 0000000..2344e47 --- /dev/null +++ b/auth/__init__.py @@ -0,0 +1,93 @@ +""" +Auth Service registry — generic, pluggable authentication for any service. + +Add a new service (Plex, Seerr, etc.) by: +1. Subclassing AuthService +2. Dropping the module in this package +3. Calling register_auth_service() at import time +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Optional + + +# --------------------------------------------------------------------------- +# AuthResult — returned by AuthService.authenticate() +# --------------------------------------------------------------------------- + +@dataclass +class AuthResult: + """Outcome of a credential validation attempt.""" + success: bool + external_user_id: Optional[str] = None + external_name: Optional[str] = None + credentials: Optional[dict] = None + error_message: Optional[str] = None + + +# --------------------------------------------------------------------------- +# AuthService — abstract base class +# --------------------------------------------------------------------------- + +class AuthService(ABC): + """A service that users can authenticate against (Jellyfin, Seerr, Plex, etc.) + + Subclasses must implement: + - name : unique identifier used in URLs and DB keys + - display_name : human-readable label shown in Discord + - render_login_form(token, discord_id) → HTML string + - authenticate(form_data) → AuthResult + """ + + @property + @abstractmethod + def name(self) -> str: + """Unique service name: "jellyfin", "seerr", etc.""" + ... + + @property + @abstractmethod + def display_name(self) -> str: + """Human-readable: "Jellyfin", "Seerr", "Plex" """ + ... + + @abstractmethod + def render_login_form(self, token: str, discord_id: int) -> str: + """Return HTML string with a login form for this service. + + The form MUST include these hidden fields: + + + + """ + ... + + @abstractmethod + async def authenticate(self, form_data: dict) -> AuthResult: + """Validate credentials against the service. Return AuthResult.""" + ... + + +# --------------------------------------------------------------------------- +# Global registry +# --------------------------------------------------------------------------- + +_registry: dict[str, AuthService] = {} + + +def register_auth_service(svc: AuthService) -> None: + """Register an AuthService so it can be looked up by name.""" + _registry[svc.name] = svc + + +def get_auth_service(name: str) -> AuthService | None: + """Look up a registered AuthService by name.""" + return _registry.get(name) + + +def list_auth_services() -> list[str]: + """Return names of all registered auth services.""" + return list(_registry.keys()) diff --git a/auth/jellyfin.py b/auth/jellyfin.py new file mode 100644 index 0000000..b2744df --- /dev/null +++ b/auth/jellyfin.py @@ -0,0 +1,401 @@ +""" +Jellyfin AuthService — validates Jellyfin credentials and stores the session token. + +Two authentication flows: + 1. Quick Connect (primary): user enters a short code on their Jellyfin app. + - initiate_quick_connect() → {code, secret} + - poll_quick_connect(secret) → "Active" | "Authorized" | "Expired" + - authenticate_quick_connect(secret) → AuthResult with token + + 2. Username/password (legacy): renders an HTML form, called via the REST API. + - render_login_form(token, discord_id) → HTML string + - authenticate(form_data) → AuthResult +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Optional + +import httpx + +from auth import AuthService, AuthResult, register_auth_service +from core.config import get_config + +logger = logging.getLogger("auth.jellyfin") + +# Emby-style authorization header required by Jellyfin's AuthenticateByName +_EMBY_HEADER = ( + 'MediaBrowser Client="AgentBot",' + 'Device="DiscordBot",' + 'DeviceId="agent-bot",' + 'Version="1.0"' +) + + +@dataclass +class QuickConnectResult: + """Result of a Quick Connect initiation.""" + secret: str + code: str + device_id: str + device_name: str + + +class JellyfinAuth(AuthService): + name = "jellyfin" + display_name = "Jellyfin" + + # ------------------------------------------------------------------ + # Quick Connect helpers + # ------------------------------------------------------------------ + + def _qc_headers(self) -> dict[str, str]: + """Return headers used by all Quick Connect API calls.""" + return { + "X-Emby-Authorization": ( + 'MediaBrowser Client="AgentBot",' + 'Device="DiscordBot",' + 'DeviceId="agent-bot-qc",' + 'Version="1.0"' + ) + } + + async def _resolve_url(self) -> str | None: + """ + Resolve the Jellyfin server URL. + 1. Check JELLYFIN_URL env var (used in deployment). + 2. Check if user already has a stored auth with a URL (from legacy login). + Returns None if no URL is configured. + """ + # First: explicit env var + env_url = get_config("JELLYFIN_URL") + if env_url: + return env_url.strip().rstrip("/") + return None + + # ------------------------------------------------------------------ + # Phase 1a: initiate Quick Connect + # ------------------------------------------------------------------ + + async def initiate_quick_connect(self, url: str | None = None) -> QuickConnectResult | None: + """ + Call Jellyfin's POST /QuickConnect/Initiate. + Returns a QuickConnectResult with {secret, code} or None on failure. + + The *code* is what the user enters on their Jellyfin page. + The *secret* is used internally to poll/authenticate. + """ + base_url = url or await self._resolve_url() + if not base_url: + logger.error("QuickConnect failed — no JELLYFIN_URL configured.") + return None + + logger.info("Initiating Quick Connect on %s", base_url) + + async with httpx.AsyncClient(timeout=10) as client: + try: + resp = await client.post( + f"{base_url}/QuickConnect/Initiate", + headers=self._qc_headers(), + json={}, + ) + if resp.status_code != 200: + logger.warning( + "QuickConnect init failed: HTTP %s — %s", + resp.status_code, resp.text[:200], + ) + return None + + data = resp.json() + secret = data.get("Secret", "") + code = data.get("Code", "") + device_id = data.get("DeviceId", "") + device_name = data.get("DeviceName", "") + + if not secret or not code: + logger.warning("QuickConnect init returned unexpected payload: %s", data) + return None + + logger.info( + "Quick Connect initiated: code=%s device=%s", + code, device_name, + ) + return QuickConnectResult( + secret=secret, + code=code, + device_id=device_id, + device_name=device_name, + ) + + except httpx.TimeoutException: + logger.error("QuickConnect init timed out reaching %s", base_url) + return None + except httpx.ConnectError: + logger.error("QuickConnect init — cannot connect to %s", base_url) + return None + except Exception: + logger.exception("Unexpected error during QuickConnect init") + return None + + # ------------------------------------------------------------------ + # Phase 1b: poll Quick Connect status + # ------------------------------------------------------------------ + + async def poll_quick_connect(self, secret: str, url: str | None = None) -> str: + """ + Call Jellyfin's GET /QuickConnect/Connect?secret=. + Returns one of: + - "Active" → user hasn't entered the code yet + - "Authorized" → user entered code AND approved + - "Expired" → code expired / unknown secret + - "Error" → network or unexpected failure + """ + base_url = url or await self._resolve_url() + if not base_url: + logger.error("QuickConnect poll failed — no JELLYFIN_URL configured.") + return "Error" + + async with httpx.AsyncClient(timeout=10) as client: + try: + resp = await client.get( + f"{base_url}/QuickConnect/Connect", + params={"secret": secret}, + headers=self._qc_headers(), + ) + if resp.status_code == 404: + return "Expired" + + if resp.status_code == 200: + data = resp.json() + # Jellyfin returns "Authenticated" (not "Authorized") + if data.get("Authenticated") is True: + return "Authorized" + # "Authenticated" is false, missing, or null → still active + return "Active" + + logger.warning( + "QuickConnect poll unexpected: HTTP %s — %s", + resp.status_code, resp.text[:200], + ) + return "Error" + + except (httpx.TimeoutException, httpx.ConnectError): + logger.warning("QuickConnect poll network error") + return "Error" + except Exception: + logger.exception("Unexpected error during QuickConnect poll") + return "Error" + + # ------------------------------------------------------------------ + # Phase 1c: exchange secret for token + # ------------------------------------------------------------------ + + async def authenticate_quick_connect( + self, secret: str, url: str | None = None + ) -> AuthResult: + """ + After poll_quick_connect returns "Authorized", call + POST /Users/AuthenticateWithQuickConnect to exchange the secret + for a real access token. + + Returns AuthResult with token, user_id, username on success. + """ + base_url = url or await self._resolve_url() + if not base_url: + return AuthResult( + success=False, + error_message="No Jellyfin server URL configured.", + ) + + logger.info("Exchanging QuickConnect secret for token on %s", base_url) + + async with httpx.AsyncClient(timeout=10) as client: + try: + resp = await client.post( + f"{base_url}/Users/AuthenticateWithQuickConnect", + json={"Secret": secret}, + headers=self._qc_headers(), + ) + if resp.status_code != 200: + logger.warning( + "QuickConnect auth exchange failed: HTTP %s", + resp.status_code, + ) + return AuthResult( + success=False, + error_message="Quick Connect authentication failed. The code may have expired.", + ) + + data = resp.json() + user = data.get("User", {}) + token = data.get("AccessToken", "") + + if not token: + return AuthResult( + success=False, + error_message="Jellyfin returned an unexpected response.", + ) + + logger.info( + "QuickConnect linked: user=%s (%s)", + user.get("Name", "?"), + user.get("Id", "?"), + ) + + return AuthResult( + success=True, + external_user_id=user.get("Id", ""), + external_name=user.get("Name", "?"), + credentials={ + "token": token, + "url": base_url, + "user_id": user.get("Id", ""), + }, + ) + + except httpx.TimeoutException: + return AuthResult( + success=False, + error_message=f"Could not reach {base_url} — connection timed out.", + ) + except httpx.ConnectError: + return AuthResult( + success=False, + error_message=f"Could not connect to {base_url}. Is the server running?", + ) + except Exception: + logger.exception("Unexpected error during QuickConnect auth exchange") + return AuthResult( + success=False, + error_message="An unexpected error occurred during authentication.", + ) + + # ------------------------------------------------------------------ + # Login form (legacy — used by the REST API) + # ------------------------------------------------------------------ + + def render_login_form(self, token: str, discord_id: int) -> str: + return f""" + + + + +Link Jellyfin + + + +

🔗 Link Jellyfin to Discord

+

Enter your Jellyfin server URL and credentials to link your account.

+ +
+ + + + + + + + + + + + + + +
+ +""" + + # ------------------------------------------------------------------ + # Authentication + # ------------------------------------------------------------------ + + async def authenticate(self, form_data: dict) -> AuthResult: + url = form_data.get("jellyfin_url", "").strip().rstrip("/") + username = form_data.get("username", "").strip() + password = form_data.get("password", "").strip() + + if not url or not username or not password: + return AuthResult( + success=False, + error_message="All fields are required (URL, username, password).", + ) + + logger.info("Attempting Jellyfin login for '%s' on %s", username, url) + + async with httpx.AsyncClient(timeout=10) as client: + try: + resp = await client.post( + f"{url}/Users/AuthenticateByName", + json={"Username": username, "Pw": password}, + headers={"X-Emby-Authorization": _EMBY_HEADER}, + ) + if resp.status_code != 200: + logger.warning( + "Jellyfin login failed for '%s': HTTP %s", username, resp.status_code + ) + return AuthResult( + success=False, + error_message=f"Login failed — check your server URL and credentials.", + ) + + data = resp.json() + user = data.get("User", {}) + token = data.get("AccessToken", "") + + if not token: + return AuthResult( + success=False, + error_message="Jellyfin returned an unexpected response.", + ) + + logger.info( + "Jellyfin login OK: user=%s (%s)", + user.get("Name", "?"), + user.get("Id", "?"), + ) + + return AuthResult( + success=True, + external_user_id=user.get("Id", ""), + external_name=user.get("Name", username), + credentials={ + "token": token, + "url": url, + "user_id": user.get("Id", ""), + }, + ) + + except httpx.TimeoutException: + return AuthResult( + success=False, + error_message=f"Could not reach {url} — connection timed out. Check the URL.", + ) + except httpx.ConnectError: + return AuthResult( + success=False, + error_message=f"Could not connect to {url}. Is the server running?", + ) + except Exception as exc: + logger.exception("Unexpected error during Jellyfin login") + return AuthResult( + success=False, + error_message=f"An unexpected error occurred. Please try again.", + ) + + +# Self-register at import time +register_auth_service(JellyfinAuth()) diff --git a/bot/discord_bot.py b/bot/discord_bot.py index 57b1d4b..3681d2a 100644 --- a/bot/discord_bot.py +++ b/bot/discord_bot.py @@ -27,6 +27,8 @@ from bot.conversation import ConversationStore from core.config import DEEPSEEK_API_KEY, get_config from core.graph import create_agent_graph from core.llm import create_client +from core import auth_store +from auth import list_auth_services, get_auth_service logger = logging.getLogger("bot.discord") @@ -36,6 +38,7 @@ logger = logging.getLogger("bot.discord") DISCORD_BOT_TOKEN = get_config("DISCORD_BOT_TOKEN") or "" DISCORD_MAX_HISTORY = int(get_config("DISCORD_MAX_HISTORY", "7")) DISCORD_DEFAULT_AGENT = get_config("DISCORD_DEFAULT_AGENT", "media-agent") +BASE_URL = get_config("BASE_URL", "http://localhost:8000").rstrip("/") # --------------------------------------------------------------------------- # LLM client shared by all agents (same as the REST API uses) @@ -138,6 +141,12 @@ class AgentBot(discord.Client): # |--------------------------------------------------------------| user_id = message.author.id + content = message.content.strip() + + # |-- Bot commands — handled directly, never sent to the LLM --| + if await self._handle_command(message, user_id, content): + return + # |--------------------------------------------------------------| # Show typing indicator while the graph runs async with message.channel.typing(): @@ -154,6 +163,140 @@ class AgentBot(discord.Client): "Please try again in a moment." ) + # ------------------------------------------------------------------ + # Bot commands + # ------------------------------------------------------------------ + + async def _handle_command( + self, message: discord.Message, user_id: int, content: str + ) -> bool: + """Handle bot commands (/login, /logout). Returns True if handled.""" + lower = content.lower() + + # --- /login [service] --- + if lower.startswith("/login"): + parts = content.split() + service = parts[1].lower() if len(parts) > 1 else None + + available = list_auth_services() + if not available: + await message.channel.send("No auth services are configured yet.") + return True + + if service is None: + svc_list = ", ".join(available) + await message.channel.send( + f"Available services to link: **{svc_list}**\n" + f"Use `/login ` — e.g. `/login jellyfin`" + ) + return True + + if service not in available: + await message.channel.send( + f"Unknown service '{service}'. Available: {', '.join(available)}" + ) + return True + + if auth_store.is_authenticated(user_id, service): + svc_display = (get_auth_service(service) and get_auth_service(service).display_name) or service + await message.channel.send( + f"You're already linked to **{svc_display}**! " + f"Use `/logout {service}` to unlink." + ) + return True + + # --- Quick Connect flow --- + svc = get_auth_service(service) + if svc is None: + await message.channel.send(f"Unknown service: {service}") + return True + + await message.channel.send(f"🔑 Starting **{svc.display_name}** Quick Connect…") + + qc_result = await svc.initiate_quick_connect() + if qc_result is None: + await message.channel.send( + f"❌ Could not start Quick Connect for **{svc.display_name}**.\n" + "Check that `JELLYFIN_URL` is configured and the server is reachable." + ) + return True + + await message.channel.send( + f"Open **{svc.display_name}** → **Quick Connect** and enter this code:\n\n" + f"**`{qc_result.code}`**\n\n" + f"⏳ Waiting for you to approve…" + ) + + # Poll for authorization + async with message.channel.typing(): + for attempt in range(24): # 24 × 5s = 2 minutes + await asyncio.sleep(5) + status = await svc.poll_quick_connect(qc_result.secret) + + if status == "Authorized": + auth_result = await svc.authenticate_quick_connect(qc_result.secret) + if auth_result.success: + auth_store.store_auth( + discord_user_id=user_id, + service=service, + external_user_id=auth_result.external_user_id or "", + external_name=auth_result.external_name or "", + credentials=auth_result.credentials, + ) + await message.channel.send( + f"✅ Linked to **{svc.display_name}** as " + f"**{auth_result.external_name}**!" + ) + else: + await message.channel.send( + f"❌ Authentication failed: " + f"{auth_result.error_message or 'Unknown error'}" + ) + return True + + elif status == "Expired": + await message.channel.send( + "⌛ The Quick Connect code expired. " + f"Use `/login {service}` to try again." + ) + return True + + # else: still "Active" — keep polling + + await message.channel.send( + "⌛ Timed out waiting for Quick Connect approval. " + f"Use `/login {service}` to try again." + ) + return True + + # --- /logout [service] --- + if lower.startswith("/logout"): + parts = content.split() + service = parts[1].lower() if len(parts) > 1 else None + + if service is None: + linked = auth_store.list_services(user_id) + if not linked: + await message.channel.send("You don't have any linked services.") + else: + svc_list = ", ".join(linked) + await message.channel.send( + f"Linked services: **{svc_list}**\n" + f"Use `/logout ` to unlink." + ) + return True + + if not auth_store.is_authenticated(user_id, service): + await message.channel.send(f"You're not linked to **{service}**.") + return True + + auth_store.revoke(user_id, service) + svc_display = (get_auth_service(service) and get_auth_service(service).display_name) or service + await message.channel.send(f"Unlinked from **{svc_display}**. Use `/login {service}` to re-link.") + return True + + return False + # ------------------------------------------------------------------ # Agent invocation # ------------------------------------------------------------------ @@ -163,7 +306,6 @@ class AgentBot(discord.Client): reply, and return the assistant's final text.""" # 1. Pick agent — defaults to DISCORD_DEFAULT_AGENT env var. - # Change DISCORD_DEFAULT_AGENT in .env to switch agents. agent_id = DISCORD_DEFAULT_AGENT # 2. Build message list from stored history + new user message @@ -172,7 +314,7 @@ class AgentBot(discord.Client): # 3. Run the LangGraph (tools execute inline if needed) graph = _get_graph(agent_id) - state = {"messages": messages} + state = {"messages": messages, "discord_user_id": user_id} result = await graph.ainvoke(state) last_msg = result["messages"][-1] diff --git a/core/auth_store.py b/core/auth_store.py new file mode 100644 index 0000000..cb57cb4 --- /dev/null +++ b/core/auth_store.py @@ -0,0 +1,316 @@ +""" +Auth Store — SQLite-backed persistence for Discord-to-service authentication. + +Two tables: + - link_tokens : one-time tokens sent via Discord DM to initiate login + - user_auth : per-user, per-service credentials (Jellyfin token, etc.) + +Thread-safe via WAL mode and a shared lock. No passwords are ever stored +— only opaque service tokens (e.g. Jellyfin AccessToken). +""" + +from __future__ import annotations + +import logging +import secrets +import sqlite3 +import threading +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Optional + +from core.config import get_config + +logger = logging.getLogger("auth_store") + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- +AUTH_DB_PATH = get_config("AUTH_DB_PATH", "data/auth.db") +TOKEN_EXPIRY_MINUTES = int(get_config("AUTH_TOKEN_EXPIRY", "10")) + +# --------------------------------------------------------------------------- +# Singleton handle +# --------------------------------------------------------------------------- + +_db_path: Path | None = None +_db_lock = threading.Lock() + + +def _resolve_path() -> Path: + """Turn AUTH_DB_PATH into an absolute path, creating parent dirs.""" + global _db_path + if _db_path is not None: + return _db_path + p = Path(AUTH_DB_PATH) + if not p.is_absolute(): + # Relative to the project root (two levels above this file) + project_root = Path(__file__).resolve().parent.parent + p = project_root / p + p.parent.mkdir(parents=True, exist_ok=True) + _db_path = p + return p + + +def _get_conn() -> sqlite3.Connection: + """Return a thread-local connection to the auth database.""" + import sqlite3 + + conn = sqlite3.connect(str(_resolve_path()), check_same_thread=False) + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA foreign_keys=ON") + conn.row_factory = sqlite3.Row + return conn + + +# --------------------------------------------------------------------------- +# Schema +# --------------------------------------------------------------------------- + +_SCHEMA = """ +CREATE TABLE IF NOT EXISTS link_tokens ( + token TEXT PRIMARY KEY, + discord_user_id INTEGER NOT NULL, + service TEXT NOT NULL, + expires_at TEXT NOT NULL, + used INTEGER DEFAULT 0, + created_at TEXT DEFAULT (datetime('now')) +); + +CREATE TABLE IF NOT EXISTS user_auth ( + discord_user_id INTEGER NOT NULL, + service TEXT NOT NULL, + external_user_id TEXT, + external_name TEXT, + credentials TEXT, + linked_at TEXT DEFAULT (datetime('now')), + is_active INTEGER DEFAULT 1, + PRIMARY KEY (discord_user_id, service) +); +""" + +_initialized = False + + +def _ensure_schema() -> None: + global _initialized + if _initialized: + return + with _db_lock: + if _initialized: + return + conn = _get_conn() + conn.executescript(_SCHEMA) + conn.commit() + conn.close() + _initialized = True + logger.info("Auth store initialized at %s", _resolve_path()) + + +# --------------------------------------------------------------------------- +# Public API — Link Tokens +# --------------------------------------------------------------------------- + +def create_token(discord_user_id: int, service: str) -> str: + """Generate a one-time link token. Expires after TOKEN_EXPIRY_MINUTES.""" + _ensure_schema() + token = secrets.token_urlsafe(32) + expires = (datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRY_MINUTES)).isoformat() + + with _db_lock: + conn = _get_conn() + conn.execute( + "INSERT INTO link_tokens (token, discord_user_id, service, expires_at) VALUES (?, ?, ?, ?)", + (token, discord_user_id, service, expires), + ) + conn.commit() + conn.close() + + logger.info("Created link token for user %s / service %s", discord_user_id, service) + return token + + +def validate_token(token: str) -> tuple[int, str] | None: + """Read-only validation — does NOT consume the token. + + Returns (discord_user_id, service) if the token exists, is unused, + and has not expired. Returns None otherwise. + """ + _ensure_schema() + + with _db_lock: + conn = _get_conn() + row = conn.execute( + "SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?", + (token,), + ).fetchone() + conn.close() + + if row is None: + return None + if row["used"]: + return None + + expires = datetime.fromisoformat(row["expires_at"]) + if datetime.now(timezone.utc) > expires: + return None + + return (row["discord_user_id"], row["service"]) + + +def consume_token(token: str) -> tuple[int, str] | None: + """Validate and consume a link token. Returns (discord_user_id, service) or None. + + A token is valid if: + - It exists + - It has not been used + - It has not expired + """ + _ensure_schema() + + with _db_lock: + conn = _get_conn() + row = conn.execute( + "SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?", + (token,), + ).fetchone() + + if row is None: + conn.close() + return None + + if row["used"]: + conn.close() + logger.warning("Token already used: %s…", token[:8]) + return None + + expires = datetime.fromisoformat(row["expires_at"]) + if datetime.now(timezone.utc) > expires: + conn.close() + logger.warning("Token expired: %s…", token[:8]) + return None + + conn.execute("UPDATE link_tokens SET used = 1 WHERE token = ?", (token,)) + conn.commit() + result = (row["discord_user_id"], row["service"]) + conn.close() + logger.info("Token consumed: %s… → user=%s service=%s", token[:8], result[0], result[1]) + return result + + +# --------------------------------------------------------------------------- +# Public API — User Auth +# --------------------------------------------------------------------------- + +def store_auth( + discord_user_id: int, + service: str, + *, + external_user_id: str = "", + external_name: str = "", + credentials: dict | None = None, +) -> None: + """Store or update authentication for a user on a service.""" + _ensure_schema() + import json + + creds_json = json.dumps(credentials) if credentials else "{}" + + with _db_lock: + conn = _get_conn() + conn.execute( + """INSERT INTO user_auth (discord_user_id, service, external_user_id, external_name, credentials, linked_at) + VALUES (?, ?, ?, ?, ?, datetime('now')) + ON CONFLICT(discord_user_id, service) DO UPDATE SET + external_user_id = excluded.external_user_id, + external_name = excluded.external_name, + credentials = excluded.credentials, + linked_at = datetime('now'), + is_active = 1""", + (discord_user_id, service, external_user_id, external_name, creds_json), + ) + conn.commit() + conn.close() + + logger.info("Stored auth for user %s on %s as %s", discord_user_id, service, external_name) + + +def get_auth(discord_user_id: int, service: str) -> dict | None: + """Retrieve stored auth for a user on a service. Returns None if not linked.""" + _ensure_schema() + import json + + with _db_lock: + conn = _get_conn() + row = conn.execute( + "SELECT * FROM user_auth WHERE discord_user_id = ? AND service = ? AND is_active = 1", + (discord_user_id, service), + ).fetchone() + conn.close() + + if row is None: + return None + + credentials = json.loads(row["credentials"]) if row["credentials"] else {} + return { + "discord_user_id": row["discord_user_id"], + "service": row["service"], + "external_user_id": row["external_user_id"], + "external_name": row["external_name"], + "credentials": credentials, + "linked_at": row["linked_at"], + } + + +def is_authenticated(discord_user_id: int, service: str) -> bool: + """Quick check: is this user linked to this service?""" + return get_auth(discord_user_id, service) is not None + + +def list_services(discord_user_id: int) -> list[str]: + """Return list of service names this user has linked.""" + _ensure_schema() + + with _db_lock: + conn = _get_conn() + rows = conn.execute( + "SELECT service FROM user_auth WHERE discord_user_id = ? AND is_active = 1", + (discord_user_id,), + ).fetchall() + conn.close() + + return [r["service"] for r in rows] + + +def revoke(discord_user_id: int, service: str) -> None: + """Unlink a user from a service.""" + _ensure_schema() + + with _db_lock: + conn = _get_conn() + conn.execute( + "UPDATE user_auth SET is_active = 0 WHERE discord_user_id = ? AND service = ?", + (discord_user_id, service), + ) + conn.commit() + conn.close() + + logger.info("Revoked auth for user %s on %s", discord_user_id, service) + + +# --------------------------------------------------------------------------- +# Dev / testing — reset the entire store +# --------------------------------------------------------------------------- + +def reset_all() -> None: + """Truncate all auth tables — for development and testing only.""" + _ensure_schema() + + with _db_lock: + conn = _get_conn() + conn.execute("DELETE FROM link_tokens") + conn.execute("DELETE FROM user_auth") + conn.commit() + conn.close() + + logger.warning("Auth store RESET — all tokens and auth records cleared.") diff --git a/core/graph.py b/core/graph.py index 70f755a..4bb2675 100644 --- a/core/graph.py +++ b/core/graph.py @@ -1,12 +1,13 @@ """ LangGraph agent graph factory. -Builds a StateGraph that replaces the manual tool-calling loop in api/v1/chat.py. -The graph has two nodes: +Builds a StateGraph with two nodes: - agent_node : calls the LLM (with system prompt + tool definitions) - tool_node : executes tool calls via the existing skill system A conditional edge routes tool_calls back to the agent, or ends the run. +When a tool fails due to missing authentication, the failure message is +relayed to the LLM, which tells the user to use /login. """ from __future__ import annotations @@ -97,18 +98,14 @@ def _make_agent_node( full: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}] for m in messages: if isinstance(m, dict): - # Already a plain dict — pass through. - # But fix tool_calls if they're in LangChain format. d = dict(m) tc = d.get("tool_calls") if tc and isinstance(tc, list) and tc and isinstance(tc[0], dict) and "function" not in tc[0]: d["tool_calls"] = _langchain_tc_to_openai(tc) full.append(d) else: - # LangChain message object → OpenAI-compatible dict role = _lc_role_to_openai(getattr(m, "type", "user")) d: dict[str, Any] = {"role": role, "content": getattr(m, "content", "")} - # Serialize tool_calls back to OpenAI format (if this is an AI msg) tc = getattr(m, "tool_calls", None) if tc: d["tool_calls"] = _langchain_tc_to_openai(tc) @@ -125,7 +122,6 @@ def _make_agent_node( ) choice = resp.choices[0] - # Convert OpenAI tool_calls to the dict format LangChain expects. raw_tool_calls = list(choice.message.tool_calls) if choice.message.tool_calls else [] tool_calls: list[dict[str, Any]] = [] for tc in raw_tool_calls: @@ -153,9 +149,9 @@ def _make_tool_node(skill_names: list[str]): """ Return a callable that executes tool_calls from the last AI message. - This replaces LangGraph's built-in ToolNode — we call our own - `execute_tool()` pipeline so that skill-level auth, httpx sessions, - and ToolResult handling are fully preserved. + If a tool fails because the user isn't authenticated, the failure + message (which tells the user to /login) is returned to the LLM. + The LLM naturally relays the instructions to the user. """ async def tool_node(state: AgentState) -> dict[str, list]: @@ -164,18 +160,16 @@ def _make_tool_node(skill_names: list[str]): if not tool_calls: return {"messages": []} + discord_user_id = state.get("discord_user_id") + results: list[ToolMessage] = [] for tc in tool_calls: - # Handle both LangChain format (top-level name/args) and - # OpenAI format (nested "function" key). if isinstance(tc, dict): if "function" in tc: - # OpenAI format: {"id":..., "function": {"name":..., "arguments":"..."}} fn = tc["function"] fn_name = fn.get("name", "") fn_args_raw = fn.get("arguments", "{}") else: - # LangChain format: {"name":..., "args":{...}, "id":...} fn_name = tc.get("name", "") fn_args_raw = tc.get("args", {}) tc_id = tc.get("id", "") @@ -184,13 +178,15 @@ def _make_tool_node(skill_names: list[str]): fn_args_raw = getattr(tc, "args", {}) tc_id = getattr(tc, "id", "") - # Parse args if they arrive as a JSON string if isinstance(fn_args_raw, str): fn_args = json.loads(fn_args_raw) else: fn_args = fn_args_raw - tr = await execute_tool(skill_names, fn_name, fn_args) + tr = await execute_tool( + skill_names, fn_name, fn_args, + discord_user_id=discord_user_id, + ) content = tr.content if tr else f"Tool '{fn_name}' is not available." results.append(ToolMessage(content=content, tool_call_id=tc_id)) @@ -224,27 +220,16 @@ def create_agent_graph( ) -> StateGraph: """ Build and compile a LangGraph StateGraph for a single agent. - - Parameters - ---------- - client : The OpenAI-compatible client (already authenticated). - agent_skills : Skill names assigned to the agent (e.g. ["seerr", "triage"]). - system_prompt : The fully-built system prompt (base + skill fragments). - model_name : Model identifier sent to the LLM provider. - - Returns - ------- - A compiled LangGraph graph ready for `.ainvoke()` or `.astream()`. """ tool_defs = get_all_tools(agent_skills) graph = StateGraph(AgentState) - # Nodes graph.add_node( "agent_node", _make_agent_node(client, system_prompt, tool_defs, model_name), ) + if tool_defs: graph.add_node("tool_node", _make_tool_node(agent_skills)) graph.add_conditional_edges("agent_node", _should_continue, { @@ -253,7 +238,6 @@ def create_agent_graph( }) graph.add_edge("tool_node", "agent_node") else: - # No tools — agent responds once and finishes graph.add_edge("agent_node", END) graph.set_entry_point("agent_node") diff --git a/core/state.py b/core/state.py index 3e434b2..28ec77e 100644 --- a/core/state.py +++ b/core/state.py @@ -18,3 +18,4 @@ class AgentState(TypedDict): """ messages: Annotated[list, add_messages] + discord_user_id: int | None # set by the Discord bot, None for REST API calls diff --git a/main.py b/main.py index b03c488..2800bd6 100644 --- a/main.py +++ b/main.py @@ -4,8 +4,9 @@ from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from api.v1.auth import router as auth_router from api.v1.chat import router as v1_router -from core.config import DEEPSEEK_API_KEY +from core.config import DEEPSEEK_API_KEY, get_config from core.llm import create_client # --------------------------------------------------------------------------- @@ -18,12 +19,14 @@ logging.basicConfig( ) # --------------------------------------------------------------------------- -# Load all agents & skills so they self-register at startup +# Load all agents, skills, AND auth services so they self-register at startup # --------------------------------------------------------------------------- from agents import load_all_agents # noqa: E402 load_all_agents() +import auth.jellyfin # noqa: E402 — self-registers JellyfinAuth + # --------------------------------------------------------------------------- # Lifespan # --------------------------------------------------------------------------- @@ -60,4 +63,5 @@ app.state.agent_graphs: dict = {} # --------------------------------------------------------------------------- # Routers # --------------------------------------------------------------------------- -app.include_router(v1_router, prefix="/v1") \ No newline at end of file +app.include_router(v1_router, prefix="/v1") +app.include_router(auth_router) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index d1b0c2f..b62cbf6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ python-dotenv httpx langgraph langgraph-checkpoint -discord.py \ No newline at end of file +discord.py +python-multipart \ No newline at end of file diff --git a/skills/__init__.py b/skills/__init__.py index 1a3b32f..97ded41 100644 --- a/skills/__init__.py +++ b/skills/__init__.py @@ -47,6 +47,7 @@ class Skill: prompt_fragment: str = "" tools: List[Dict[str, Any]] = field(default_factory=list) execute: Optional[ToolExecutor] = None + requires_auth: List[str] = field(default_factory=list) # --------------------------------------------------------------------------- @@ -96,9 +97,15 @@ def get_all_tools(skill_names: list[str]) -> List[Dict[str, Any]]: async def execute_tool( - skill_names: list[str], tool_name: str, args: dict + skill_names: list[str], tool_name: str, args: dict, + discord_user_id: int | None = None, ) -> ToolResult | None: """Find the skill that owns *tool_name* and run its executor. + + If *discord_user_id* is provided, also checks whether the owning skill + requires authentication for any services. If auth is missing, returns + a friendly ToolResult.fail(...) telling the user how to log in. + Only logs failures to the console — successful calls are silent. """ import logging @@ -109,6 +116,24 @@ async def execute_tool( if s and s.execute: for t in s.tools: if t.get("function", {}).get("name") == tool_name: + # --- Auth gate --- + if s.requires_auth and discord_user_id is not None: + from core import auth_store + from auth import get_auth_service + missing: list[str] = [] + for svc in s.requires_auth: + if not auth_store.is_authenticated(discord_user_id, svc): + missing.append(svc) + if missing: + svc_displays = ", ".join( + (get_auth_service(m) and get_auth_service(m).display_name) or m + for m in missing + ) + return ToolResult.fail( + f"You need to log in to {svc_displays} first. " + + " ".join(f"Send `/login {m}` in a DM to get started." for m in missing) + ) + # --- End auth gate --- try: result = await s.execute(tool_name, args) if not result.success: diff --git a/skills/media_info.py b/skills/media_info.py index 8578cdf..3365900 100644 --- a/skills/media_info.py +++ b/skills/media_info.py @@ -23,6 +23,16 @@ When responding: suggest submitting a ticket if there's a problem. - Always confirm successful actions and warn about failures. +## Jellyfin & Authentication + +You are connected to the user's Jellyfin server. If a user asks you to +"connect to Jellyfin", "link my Jellyfin", or asks about their watch history, +simply call the `watch_history` tool. The system will automatically handle +authentication — if the user isn't linked yet, they'll be guided through +Quick Connect seamlessly. NEVER tell a user you "don't have access to +Jellyfin" or "can't connect" — always try the tool first and let the system +sort it out. + This is the base media assistant persona. Real API capabilities come from the attached skills (seerr, triage, etc.).""", ) diff --git a/skills/watch_history.py b/skills/watch_history.py new file mode 100644 index 0000000..29bc0b2 --- /dev/null +++ b/skills/watch_history.py @@ -0,0 +1,80 @@ +""" +Watch History skill — fetch the user's Jellyfin watch history. + +Currently a placeholder — returns a "coming soon" message. +The auth gate (`requires_auth=["jellyfin"]`) is already active: +users who haven't linked Jellyfin will be prompted to /login first. +""" + +from __future__ import annotations + +from skills import Skill, register, ToolResult + +# --------------------------------------------------------------------------- +# Tool definitions +# --------------------------------------------------------------------------- + +TOOLS = [ + { + "type": "function", + "function": { + "name": "watch_history", + "description": ( + "Get the user's recent Jellyfin watch history — movies and TV " + "episodes they have watched, sorted by most recent. " + "Call this when a user asks about their watching activity." + ), + "parameters": { + "type": "object", + "properties": { + "limit": { + "type": "integer", + "description": "How many items to return (default 10, max 20)", + } + }, + }, + }, + } +] + + +# --------------------------------------------------------------------------- +# Executor (placeholder) +# --------------------------------------------------------------------------- + +async def _execute(tool_name: str, args: dict) -> ToolResult: + if tool_name == "watch_history": + return ToolResult.ok( + "👷 **Watch History — Coming Soon!**\n\n" + "This feature is currently being built. Soon you'll be able to " + "see your recently watched movies and TV episodes right here.\n\n" + "In the meantime, you can check your watch history directly in Jellyfin." + ) + return ToolResult.fail(f"Unknown tool: {tool_name}") + + +# --------------------------------------------------------------------------- +# Skill registration +# --------------------------------------------------------------------------- + +watch_history_skill = Skill( + name="watch_history", + description="User's Jellyfin watch history (coming soon)", + requires_auth=["jellyfin"], + prompt_fragment="""## Watch History + +You can fetch the user's Jellyfin watch history with the `watch_history` tool. +Call it when users ask things like: +- "what have I watched?" +- "show my watch history" +- "what did I watch recently?" +- "what was the last movie I saw?" +- "what TV shows have I been watching?" + +The tool is currently a **placeholder** — it returns a "coming soon" message. +Tell the user this feature is being worked on and will be available soon.""", + tools=TOOLS, + execute=_execute, +) + +register(watch_history_skill)