added quick connect auth from jellyfin, still needs to have some more cleaning before push to prod #2
+24
-4
@@ -1,9 +1,12 @@
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent Backend — Environment Variables
|
||||
# Copy this to .env and fill in your values.
|
||||
# ---------------------------------------------------------------------------
|
||||
# =============================================================================
|
||||
# Agent Bot — Environment Configuration
|
||||
# =============================================================================
|
||||
# Copy this file to .env and fill in your values.
|
||||
# .env is git-ignored — never commit real secrets.
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM — DeepSeek (OpenAI-compatible)
|
||||
# ---------------------------------------------------------------------------
|
||||
DEEPSEEK_API_KEY=sk-your-deepseek-api-key
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -18,4 +21,21 @@ DISCORD_BOT_TOKEN=your-discord-bot-token-here
|
||||
# ---------------------------------------------------------------------------
|
||||
SEERR_URL=https://seerr.example.com
|
||||
SEERR_API_KEY=your-seerr-api-key
|
||||
# SEERR_USERNAME=your-username # alternative: username+password auth
|
||||
# SEERR_PASSWORD=your-password
|
||||
# SEERR_TIMEOUT=30 # optional, defaults to 30 seconds
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth System (Discord ↔ external services)
|
||||
# ---------------------------------------------------------------------------
|
||||
# The public-facing URL where users reach this bot's web API.
|
||||
# Used to build the "Click here to link" URLs sent via Discord DM.
|
||||
# For local dev: http://localhost:8000
|
||||
# For production behind a reverse proxy: https://bot.yourdomain.com
|
||||
BASE_URL=http://localhost:8000
|
||||
|
||||
# Where the auth SQLite database lives (relative to project root)
|
||||
# AUTH_DB_PATH=data/auth.db
|
||||
|
||||
# Link token expiry in minutes (default 10)
|
||||
# AUTH_TOKEN_EXPIRY=10
|
||||
|
||||
@@ -175,3 +175,4 @@ cython_debug/
|
||||
.pypirc
|
||||
|
||||
.docs/
|
||||
data/
|
||||
@@ -65,3 +65,4 @@ def load_all_agents() -> None:
|
||||
import skills.seerr # noqa: F401
|
||||
import skills.triage # noqa: F401
|
||||
import skills.easter_eggs # noqa: F401
|
||||
import skills.watch_history # noqa: F401
|
||||
|
||||
@@ -14,11 +14,16 @@ media_agent = Agent(
|
||||
agent_id="media-agent",
|
||||
description="Media assistant — handles movie/TV/subtitle/ticket requests "
|
||||
"via Seerr, Jellyfin, Sonarr, etc.",
|
||||
skills=["media_info", "seerr", "triage", "easter_eggs"],
|
||||
skills=["media_info", "seerr", "triage", "easter_eggs", "watch_history"],
|
||||
base_prompt=(
|
||||
"You are a media assistant connected to Seerr and other media services. "
|
||||
"Help users discover, request, and troubleshoot their media library. "
|
||||
"Use the tools provided to perform real actions."
|
||||
"Use the tools provided to perform real actions.\n\n"
|
||||
"## Authentication\n"
|
||||
"If a tool returns a message saying the user needs to log in first, "
|
||||
"tell the user to type `/login <service>` in their DM (e.g. `/login jellyfin`). "
|
||||
"This opens Quick Connect on their Jellyfin app so they can link their account. "
|
||||
"Do NOT tell the user you 'can't connect' or 'don't have access' — just relay the login instructions."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
+189
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
Auth API — generic endpoints for linking Discord users to external services.
|
||||
|
||||
GET /api/v1/auth/login?service=X&token=Y&discord_id=Z
|
||||
Validates the link token and serves a service-specific login form.
|
||||
|
||||
POST /api/v1/auth/login
|
||||
Accepts the form submission, validates credentials against the service,
|
||||
stores the session, and returns a result page.
|
||||
|
||||
GET /api/v1/auth/status?discord_id=Z
|
||||
Returns which services are linked for this Discord user.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Form, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from auth import get_auth_service, list_auth_services
|
||||
from core import auth_store
|
||||
|
||||
logger = logging.getLogger("api.auth")
|
||||
|
||||
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /auth/login — serve the login form
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/login")
|
||||
async def login_form(
|
||||
service: str,
|
||||
token: str,
|
||||
discord_id: int,
|
||||
):
|
||||
"""Validate the one-time link token and return a service-specific login form."""
|
||||
|
||||
# Validate the token WITHOUT consuming it (the POST will consume it)
|
||||
result = auth_store.validate_token(token)
|
||||
if result is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired link token.")
|
||||
|
||||
uid, svc = result
|
||||
if uid != discord_id or svc != service:
|
||||
raise HTTPException(status_code=400, detail="Token does not match the request.")
|
||||
|
||||
# Look up the AuthService
|
||||
svc_obj = get_auth_service(service)
|
||||
if svc_obj is None:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown service: {service}")
|
||||
|
||||
logger.info("Serving login form: user=%s service=%s", discord_id, service)
|
||||
return HTMLResponse(svc_obj.render_login_form(token, discord_id))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /auth/login — handle form submission
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/login")
|
||||
async def login_submit(request: Request):
|
||||
"""Handle the login form POST: validate credentials, store auth, show result."""
|
||||
|
||||
# Parse form data
|
||||
form = await request.form()
|
||||
token = form.get("token", "")
|
||||
discord_id_str = form.get("discord_id", "")
|
||||
service = form.get("service", "")
|
||||
|
||||
if not token or not discord_id_str or not service:
|
||||
raise HTTPException(status_code=400, detail="Missing required fields.")
|
||||
|
||||
try:
|
||||
discord_id = int(discord_id_str)
|
||||
except (ValueError, TypeError):
|
||||
raise HTTPException(status_code=400, detail="Invalid discord_id.")
|
||||
|
||||
# Consume the token on POST (the GET only validated, didn't consume)
|
||||
result = auth_store.consume_token(token)
|
||||
if result is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired link token.")
|
||||
|
||||
# Look up the AuthService
|
||||
svc_obj = get_auth_service(service)
|
||||
if svc_obj is None:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown service: {service}")
|
||||
|
||||
# Collect service-specific form fields (everything except token, discord_id, service)
|
||||
form_data: dict[str, str] = {}
|
||||
for key, value in form.items():
|
||||
if key not in ("token", "discord_id", "service"):
|
||||
form_data[key] = str(value)
|
||||
|
||||
# Authenticate against the service
|
||||
auth_result = await svc_obj.authenticate(form_data)
|
||||
|
||||
if not auth_result.success:
|
||||
return HTMLResponse(
|
||||
status_code=401,
|
||||
content=f"""<!DOCTYPE html>
|
||||
<html><head><meta charset="utf-8"><title>Login Failed</title>
|
||||
<style>
|
||||
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; }}
|
||||
h2 {{ color: #d32f2f; }}
|
||||
a {{ color: #aa5cc3; }}
|
||||
</style></head><body>
|
||||
<h2>❌ Login Failed</h2>
|
||||
<p>{auth_result.error_message or "Authentication failed. Please try again."}</p>
|
||||
<p><a href="javascript:history.back()">← Go back and try again</a></p>
|
||||
</body></html>""",
|
||||
)
|
||||
|
||||
# Store the successful auth
|
||||
auth_store.store_auth(
|
||||
discord_user_id=discord_id,
|
||||
service=service,
|
||||
external_user_id=auth_result.external_user_id or "",
|
||||
external_name=auth_result.external_name or "",
|
||||
credentials=auth_result.credentials,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Auth linked: discord=%s → %s (%s)",
|
||||
discord_id,
|
||||
service,
|
||||
auth_result.external_name,
|
||||
)
|
||||
|
||||
return HTMLResponse(f"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>Account Linked</title>
|
||||
<style>
|
||||
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; text-align: center; }}
|
||||
h1 {{ color: #388e3c; }}
|
||||
.name {{ font-weight: bold; color: #aa5cc3; }}
|
||||
p {{ color: #666; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>✅ Account Linked!</h1>
|
||||
<p>Logged in as <span class="name">{auth_result.external_name}</span> on <strong>{svc_obj.display_name}</strong>.</p>
|
||||
<p>You can close this page and return to Discord.</p>
|
||||
</body>
|
||||
</html>""")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /auth/status — check which services are linked
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/status")
|
||||
async def auth_status(discord_id: int):
|
||||
"""Return which services this Discord user has linked."""
|
||||
services: dict[str, bool] = {}
|
||||
for svc_name in list_auth_services():
|
||||
services[svc_name] = auth_store.is_authenticated(discord_id, svc_name)
|
||||
return {"discord_id": discord_id, "services": services}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /auth/reset — wipe auth store (DEV ONLY)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from core.config import get_config # noqa: E402
|
||||
|
||||
@router.post("/reset")
|
||||
async def reset_auth():
|
||||
"""
|
||||
Reset the entire auth store — clears all link tokens and user auth records.
|
||||
|
||||
Only enabled when ALLOW_AUTH_RESET=true in the environment.
|
||||
Returns 403 in production.
|
||||
"""
|
||||
if get_config("ALLOW_AUTH_RESET", "false").lower() != "true":
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Auth reset is disabled. Set ALLOW_AUTH_RESET=true to enable (dev only).",
|
||||
)
|
||||
|
||||
auth_store.reset_all()
|
||||
logger.warning("Auth store reset via API endpoint.")
|
||||
return {"status": "ok", "message": "Auth store cleared — all tokens and auth records removed."}
|
||||
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
Auth Service registry — generic, pluggable authentication for any service.
|
||||
|
||||
Add a new service (Plex, Seerr, etc.) by:
|
||||
1. Subclassing AuthService
|
||||
2. Dropping the module in this package
|
||||
3. Calling register_auth_service() at import time
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AuthResult — returned by AuthService.authenticate()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class AuthResult:
|
||||
"""Outcome of a credential validation attempt."""
|
||||
success: bool
|
||||
external_user_id: Optional[str] = None
|
||||
external_name: Optional[str] = None
|
||||
credentials: Optional[dict] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AuthService — abstract base class
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AuthService(ABC):
|
||||
"""A service that users can authenticate against (Jellyfin, Seerr, Plex, etc.)
|
||||
|
||||
Subclasses must implement:
|
||||
- name : unique identifier used in URLs and DB keys
|
||||
- display_name : human-readable label shown in Discord
|
||||
- render_login_form(token, discord_id) → HTML string
|
||||
- authenticate(form_data) → AuthResult
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Unique service name: "jellyfin", "seerr", etc."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def display_name(self) -> str:
|
||||
"""Human-readable: "Jellyfin", "Seerr", "Plex" """
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def render_login_form(self, token: str, discord_id: int) -> str:
|
||||
"""Return HTML string with a login form for this service.
|
||||
|
||||
The form MUST include these hidden fields:
|
||||
<input type="hidden" name="token" value="{token}">
|
||||
<input type="hidden" name="discord_id" value="{discord_id}">
|
||||
<input type="hidden" name="service" value="{self.name}">
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def authenticate(self, form_data: dict) -> AuthResult:
|
||||
"""Validate credentials against the service. Return AuthResult."""
|
||||
...
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Global registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_registry: dict[str, AuthService] = {}
|
||||
|
||||
|
||||
def register_auth_service(svc: AuthService) -> None:
|
||||
"""Register an AuthService so it can be looked up by name."""
|
||||
_registry[svc.name] = svc
|
||||
|
||||
|
||||
def get_auth_service(name: str) -> AuthService | None:
|
||||
"""Look up a registered AuthService by name."""
|
||||
return _registry.get(name)
|
||||
|
||||
|
||||
def list_auth_services() -> list[str]:
|
||||
"""Return names of all registered auth services."""
|
||||
return list(_registry.keys())
|
||||
@@ -0,0 +1,401 @@
|
||||
"""
|
||||
Jellyfin AuthService — validates Jellyfin credentials and stores the session token.
|
||||
|
||||
Two authentication flows:
|
||||
1. Quick Connect (primary): user enters a short code on their Jellyfin app.
|
||||
- initiate_quick_connect() → {code, secret}
|
||||
- poll_quick_connect(secret) → "Active" | "Authorized" | "Expired"
|
||||
- authenticate_quick_connect(secret) → AuthResult with token
|
||||
|
||||
2. Username/password (legacy): renders an HTML form, called via the REST API.
|
||||
- render_login_form(token, discord_id) → HTML string
|
||||
- authenticate(form_data) → AuthResult
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from auth import AuthService, AuthResult, register_auth_service
|
||||
from core.config import get_config
|
||||
|
||||
logger = logging.getLogger("auth.jellyfin")
|
||||
|
||||
# Emby-style authorization header required by Jellyfin's AuthenticateByName
|
||||
_EMBY_HEADER = (
|
||||
'MediaBrowser Client="AgentBot",'
|
||||
'Device="DiscordBot",'
|
||||
'DeviceId="agent-bot",'
|
||||
'Version="1.0"'
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuickConnectResult:
|
||||
"""Result of a Quick Connect initiation."""
|
||||
secret: str
|
||||
code: str
|
||||
device_id: str
|
||||
device_name: str
|
||||
|
||||
|
||||
class JellyfinAuth(AuthService):
|
||||
name = "jellyfin"
|
||||
display_name = "Jellyfin"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Quick Connect helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _qc_headers(self) -> dict[str, str]:
|
||||
"""Return headers used by all Quick Connect API calls."""
|
||||
return {
|
||||
"X-Emby-Authorization": (
|
||||
'MediaBrowser Client="AgentBot",'
|
||||
'Device="DiscordBot",'
|
||||
'DeviceId="agent-bot-qc",'
|
||||
'Version="1.0"'
|
||||
)
|
||||
}
|
||||
|
||||
async def _resolve_url(self) -> str | None:
|
||||
"""
|
||||
Resolve the Jellyfin server URL.
|
||||
1. Check JELLYFIN_URL env var (used in deployment).
|
||||
2. Check if user already has a stored auth with a URL (from legacy login).
|
||||
Returns None if no URL is configured.
|
||||
"""
|
||||
# First: explicit env var
|
||||
env_url = get_config("JELLYFIN_URL")
|
||||
if env_url:
|
||||
return env_url.strip().rstrip("/")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 1a: initiate Quick Connect
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def initiate_quick_connect(self, url: str | None = None) -> QuickConnectResult | None:
|
||||
"""
|
||||
Call Jellyfin's POST /QuickConnect/Initiate.
|
||||
Returns a QuickConnectResult with {secret, code} or None on failure.
|
||||
|
||||
The *code* is what the user enters on their Jellyfin page.
|
||||
The *secret* is used internally to poll/authenticate.
|
||||
"""
|
||||
base_url = url or await self._resolve_url()
|
||||
if not base_url:
|
||||
logger.error("QuickConnect failed — no JELLYFIN_URL configured.")
|
||||
return None
|
||||
|
||||
logger.info("Initiating Quick Connect on %s", base_url)
|
||||
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{base_url}/QuickConnect/Initiate",
|
||||
headers=self._qc_headers(),
|
||||
json={},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
"QuickConnect init failed: HTTP %s — %s",
|
||||
resp.status_code, resp.text[:200],
|
||||
)
|
||||
return None
|
||||
|
||||
data = resp.json()
|
||||
secret = data.get("Secret", "")
|
||||
code = data.get("Code", "")
|
||||
device_id = data.get("DeviceId", "")
|
||||
device_name = data.get("DeviceName", "")
|
||||
|
||||
if not secret or not code:
|
||||
logger.warning("QuickConnect init returned unexpected payload: %s", data)
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
"Quick Connect initiated: code=%s device=%s",
|
||||
code, device_name,
|
||||
)
|
||||
return QuickConnectResult(
|
||||
secret=secret,
|
||||
code=code,
|
||||
device_id=device_id,
|
||||
device_name=device_name,
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("QuickConnect init timed out reaching %s", base_url)
|
||||
return None
|
||||
except httpx.ConnectError:
|
||||
logger.error("QuickConnect init — cannot connect to %s", base_url)
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Unexpected error during QuickConnect init")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 1b: poll Quick Connect status
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def poll_quick_connect(self, secret: str, url: str | None = None) -> str:
|
||||
"""
|
||||
Call Jellyfin's GET /QuickConnect/Connect?secret=<secret>.
|
||||
Returns one of:
|
||||
- "Active" → user hasn't entered the code yet
|
||||
- "Authorized" → user entered code AND approved
|
||||
- "Expired" → code expired / unknown secret
|
||||
- "Error" → network or unexpected failure
|
||||
"""
|
||||
base_url = url or await self._resolve_url()
|
||||
if not base_url:
|
||||
logger.error("QuickConnect poll failed — no JELLYFIN_URL configured.")
|
||||
return "Error"
|
||||
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
try:
|
||||
resp = await client.get(
|
||||
f"{base_url}/QuickConnect/Connect",
|
||||
params={"secret": secret},
|
||||
headers=self._qc_headers(),
|
||||
)
|
||||
if resp.status_code == 404:
|
||||
return "Expired"
|
||||
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
# Jellyfin returns "Authenticated" (not "Authorized")
|
||||
if data.get("Authenticated") is True:
|
||||
return "Authorized"
|
||||
# "Authenticated" is false, missing, or null → still active
|
||||
return "Active"
|
||||
|
||||
logger.warning(
|
||||
"QuickConnect poll unexpected: HTTP %s — %s",
|
||||
resp.status_code, resp.text[:200],
|
||||
)
|
||||
return "Error"
|
||||
|
||||
except (httpx.TimeoutException, httpx.ConnectError):
|
||||
logger.warning("QuickConnect poll network error")
|
||||
return "Error"
|
||||
except Exception:
|
||||
logger.exception("Unexpected error during QuickConnect poll")
|
||||
return "Error"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 1c: exchange secret for token
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def authenticate_quick_connect(
|
||||
self, secret: str, url: str | None = None
|
||||
) -> AuthResult:
|
||||
"""
|
||||
After poll_quick_connect returns "Authorized", call
|
||||
POST /Users/AuthenticateWithQuickConnect to exchange the secret
|
||||
for a real access token.
|
||||
|
||||
Returns AuthResult with token, user_id, username on success.
|
||||
"""
|
||||
base_url = url or await self._resolve_url()
|
||||
if not base_url:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message="No Jellyfin server URL configured.",
|
||||
)
|
||||
|
||||
logger.info("Exchanging QuickConnect secret for token on %s", base_url)
|
||||
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{base_url}/Users/AuthenticateWithQuickConnect",
|
||||
json={"Secret": secret},
|
||||
headers=self._qc_headers(),
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
"QuickConnect auth exchange failed: HTTP %s",
|
||||
resp.status_code,
|
||||
)
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message="Quick Connect authentication failed. The code may have expired.",
|
||||
)
|
||||
|
||||
data = resp.json()
|
||||
user = data.get("User", {})
|
||||
token = data.get("AccessToken", "")
|
||||
|
||||
if not token:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message="Jellyfin returned an unexpected response.",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"QuickConnect linked: user=%s (%s)",
|
||||
user.get("Name", "?"),
|
||||
user.get("Id", "?"),
|
||||
)
|
||||
|
||||
return AuthResult(
|
||||
success=True,
|
||||
external_user_id=user.get("Id", ""),
|
||||
external_name=user.get("Name", "?"),
|
||||
credentials={
|
||||
"token": token,
|
||||
"url": base_url,
|
||||
"user_id": user.get("Id", ""),
|
||||
},
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message=f"Could not reach {base_url} — connection timed out.",
|
||||
)
|
||||
except httpx.ConnectError:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message=f"Could not connect to {base_url}. Is the server running?",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Unexpected error during QuickConnect auth exchange")
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message="An unexpected error occurred during authentication.",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Login form (legacy — used by the REST API)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def render_login_form(self, token: str, discord_id: int) -> str:
|
||||
return f"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>Link Jellyfin</title>
|
||||
<style>
|
||||
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; }}
|
||||
h2 {{ margin-bottom: 4px; }}
|
||||
.sub {{ color: #666; margin-bottom: 24px; }}
|
||||
label {{ display: block; margin-top: 16px; font-weight: 600; }}
|
||||
input {{ width: 100%; padding: 10px; margin-top: 4px; border: 1px solid #ccc; border-radius: 6px; box-sizing: border-box; }}
|
||||
button {{ margin-top: 24px; width: 100%; padding: 12px; background: #aa5cc3; color: #fff; border: none; border-radius: 6px; font-size: 16px; cursor: pointer; }}
|
||||
button:hover {{ background: #9448b0; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h2>🔗 Link Jellyfin to Discord</h2>
|
||||
<p class="sub">Enter your Jellyfin server URL and credentials to link your account.</p>
|
||||
|
||||
<form method="POST" action="/api/v1/auth/login">
|
||||
<input type="hidden" name="token" value="{token}">
|
||||
<input type="hidden" name="discord_id" value="{discord_id}">
|
||||
<input type="hidden" name="service" value="jellyfin">
|
||||
|
||||
<label for="jellyfin_url">Jellyfin Server URL</label>
|
||||
<input id="jellyfin_url" name="jellyfin_url" type="url"
|
||||
placeholder="https://jellyfin.example.com" required>
|
||||
|
||||
<label for="username">Username</label>
|
||||
<input id="username" name="username" type="text"
|
||||
placeholder="Your Jellyfin username" required autofocus>
|
||||
|
||||
<label for="password">Password</label>
|
||||
<input id="password" name="password" type="password"
|
||||
placeholder="Your Jellyfin password" required>
|
||||
|
||||
<button type="submit">Link Account</button>
|
||||
</form>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Authentication
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def authenticate(self, form_data: dict) -> AuthResult:
|
||||
url = form_data.get("jellyfin_url", "").strip().rstrip("/")
|
||||
username = form_data.get("username", "").strip()
|
||||
password = form_data.get("password", "").strip()
|
||||
|
||||
if not url or not username or not password:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message="All fields are required (URL, username, password).",
|
||||
)
|
||||
|
||||
logger.info("Attempting Jellyfin login for '%s' on %s", username, url)
|
||||
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{url}/Users/AuthenticateByName",
|
||||
json={"Username": username, "Pw": password},
|
||||
headers={"X-Emby-Authorization": _EMBY_HEADER},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
"Jellyfin login failed for '%s': HTTP %s", username, resp.status_code
|
||||
)
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message=f"Login failed — check your server URL and credentials.",
|
||||
)
|
||||
|
||||
data = resp.json()
|
||||
user = data.get("User", {})
|
||||
token = data.get("AccessToken", "")
|
||||
|
||||
if not token:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message="Jellyfin returned an unexpected response.",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Jellyfin login OK: user=%s (%s)",
|
||||
user.get("Name", "?"),
|
||||
user.get("Id", "?"),
|
||||
)
|
||||
|
||||
return AuthResult(
|
||||
success=True,
|
||||
external_user_id=user.get("Id", ""),
|
||||
external_name=user.get("Name", username),
|
||||
credentials={
|
||||
"token": token,
|
||||
"url": url,
|
||||
"user_id": user.get("Id", ""),
|
||||
},
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message=f"Could not reach {url} — connection timed out. Check the URL.",
|
||||
)
|
||||
except httpx.ConnectError:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message=f"Could not connect to {url}. Is the server running?",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Unexpected error during Jellyfin login")
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message=f"An unexpected error occurred. Please try again.",
|
||||
)
|
||||
|
||||
|
||||
# Self-register at import time
|
||||
register_auth_service(JellyfinAuth())
|
||||
+144
-2
@@ -27,6 +27,8 @@ from bot.conversation import ConversationStore
|
||||
from core.config import DEEPSEEK_API_KEY, get_config
|
||||
from core.graph import create_agent_graph
|
||||
from core.llm import create_client
|
||||
from core import auth_store
|
||||
from auth import list_auth_services, get_auth_service
|
||||
|
||||
logger = logging.getLogger("bot.discord")
|
||||
|
||||
@@ -36,6 +38,7 @@ logger = logging.getLogger("bot.discord")
|
||||
DISCORD_BOT_TOKEN = get_config("DISCORD_BOT_TOKEN") or ""
|
||||
DISCORD_MAX_HISTORY = int(get_config("DISCORD_MAX_HISTORY", "7"))
|
||||
DISCORD_DEFAULT_AGENT = get_config("DISCORD_DEFAULT_AGENT", "media-agent")
|
||||
BASE_URL = get_config("BASE_URL", "http://localhost:8000").rstrip("/")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM client shared by all agents (same as the REST API uses)
|
||||
@@ -138,6 +141,12 @@ class AgentBot(discord.Client):
|
||||
# |--------------------------------------------------------------|
|
||||
|
||||
user_id = message.author.id
|
||||
content = message.content.strip()
|
||||
|
||||
# |-- Bot commands — handled directly, never sent to the LLM --|
|
||||
if await self._handle_command(message, user_id, content):
|
||||
return
|
||||
# |--------------------------------------------------------------|
|
||||
|
||||
# Show typing indicator while the graph runs
|
||||
async with message.channel.typing():
|
||||
@@ -154,6 +163,140 @@ class AgentBot(discord.Client):
|
||||
"Please try again in a moment."
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Bot commands
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _handle_command(
|
||||
self, message: discord.Message, user_id: int, content: str
|
||||
) -> bool:
|
||||
"""Handle bot commands (/login, /logout). Returns True if handled."""
|
||||
lower = content.lower()
|
||||
|
||||
# --- /login [service] ---
|
||||
if lower.startswith("/login"):
|
||||
parts = content.split()
|
||||
service = parts[1].lower() if len(parts) > 1 else None
|
||||
|
||||
available = list_auth_services()
|
||||
if not available:
|
||||
await message.channel.send("No auth services are configured yet.")
|
||||
return True
|
||||
|
||||
if service is None:
|
||||
svc_list = ", ".join(available)
|
||||
await message.channel.send(
|
||||
f"Available services to link: **{svc_list}**\n"
|
||||
f"Use `/login <service>` — e.g. `/login jellyfin`"
|
||||
)
|
||||
return True
|
||||
|
||||
if service not in available:
|
||||
await message.channel.send(
|
||||
f"Unknown service '{service}'. Available: {', '.join(available)}"
|
||||
)
|
||||
return True
|
||||
|
||||
if auth_store.is_authenticated(user_id, service):
|
||||
svc_display = (get_auth_service(service) and get_auth_service(service).display_name) or service
|
||||
await message.channel.send(
|
||||
f"You're already linked to **{svc_display}**! "
|
||||
f"Use `/logout {service}` to unlink."
|
||||
)
|
||||
return True
|
||||
|
||||
# --- Quick Connect flow ---
|
||||
svc = get_auth_service(service)
|
||||
if svc is None:
|
||||
await message.channel.send(f"Unknown service: {service}")
|
||||
return True
|
||||
|
||||
await message.channel.send(f"🔑 Starting **{svc.display_name}** Quick Connect…")
|
||||
|
||||
qc_result = await svc.initiate_quick_connect()
|
||||
if qc_result is None:
|
||||
await message.channel.send(
|
||||
f"❌ Could not start Quick Connect for **{svc.display_name}**.\n"
|
||||
"Check that `JELLYFIN_URL` is configured and the server is reachable."
|
||||
)
|
||||
return True
|
||||
|
||||
await message.channel.send(
|
||||
f"Open **{svc.display_name}** → **Quick Connect** and enter this code:\n\n"
|
||||
f"**`{qc_result.code}`**\n\n"
|
||||
f"⏳ Waiting for you to approve…"
|
||||
)
|
||||
|
||||
# Poll for authorization
|
||||
async with message.channel.typing():
|
||||
for attempt in range(24): # 24 × 5s = 2 minutes
|
||||
await asyncio.sleep(5)
|
||||
status = await svc.poll_quick_connect(qc_result.secret)
|
||||
|
||||
if status == "Authorized":
|
||||
auth_result = await svc.authenticate_quick_connect(qc_result.secret)
|
||||
if auth_result.success:
|
||||
auth_store.store_auth(
|
||||
discord_user_id=user_id,
|
||||
service=service,
|
||||
external_user_id=auth_result.external_user_id or "",
|
||||
external_name=auth_result.external_name or "",
|
||||
credentials=auth_result.credentials,
|
||||
)
|
||||
await message.channel.send(
|
||||
f"✅ Linked to **{svc.display_name}** as "
|
||||
f"**{auth_result.external_name}**!"
|
||||
)
|
||||
else:
|
||||
await message.channel.send(
|
||||
f"❌ Authentication failed: "
|
||||
f"{auth_result.error_message or 'Unknown error'}"
|
||||
)
|
||||
return True
|
||||
|
||||
elif status == "Expired":
|
||||
await message.channel.send(
|
||||
"⌛ The Quick Connect code expired. "
|
||||
f"Use `/login {service}` to try again."
|
||||
)
|
||||
return True
|
||||
|
||||
# else: still "Active" — keep polling
|
||||
|
||||
await message.channel.send(
|
||||
"⌛ Timed out waiting for Quick Connect approval. "
|
||||
f"Use `/login {service}` to try again."
|
||||
)
|
||||
return True
|
||||
|
||||
# --- /logout [service] ---
|
||||
if lower.startswith("/logout"):
|
||||
parts = content.split()
|
||||
service = parts[1].lower() if len(parts) > 1 else None
|
||||
|
||||
if service is None:
|
||||
linked = auth_store.list_services(user_id)
|
||||
if not linked:
|
||||
await message.channel.send("You don't have any linked services.")
|
||||
else:
|
||||
svc_list = ", ".join(linked)
|
||||
await message.channel.send(
|
||||
f"Linked services: **{svc_list}**\n"
|
||||
f"Use `/logout <service>` to unlink."
|
||||
)
|
||||
return True
|
||||
|
||||
if not auth_store.is_authenticated(user_id, service):
|
||||
await message.channel.send(f"You're not linked to **{service}**.")
|
||||
return True
|
||||
|
||||
auth_store.revoke(user_id, service)
|
||||
svc_display = (get_auth_service(service) and get_auth_service(service).display_name) or service
|
||||
await message.channel.send(f"Unlinked from **{svc_display}**. Use `/login {service}` to re-link.")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Agent invocation
|
||||
# ------------------------------------------------------------------
|
||||
@@ -163,7 +306,6 @@ class AgentBot(discord.Client):
|
||||
reply, and return the assistant's final text."""
|
||||
|
||||
# 1. Pick agent — defaults to DISCORD_DEFAULT_AGENT env var.
|
||||
# Change DISCORD_DEFAULT_AGENT in .env to switch agents.
|
||||
agent_id = DISCORD_DEFAULT_AGENT
|
||||
|
||||
# 2. Build message list from stored history + new user message
|
||||
@@ -172,7 +314,7 @@ class AgentBot(discord.Client):
|
||||
|
||||
# 3. Run the LangGraph (tools execute inline if needed)
|
||||
graph = _get_graph(agent_id)
|
||||
state = {"messages": messages}
|
||||
state = {"messages": messages, "discord_user_id": user_id}
|
||||
result = await graph.ainvoke(state)
|
||||
|
||||
last_msg = result["messages"][-1]
|
||||
|
||||
@@ -0,0 +1,316 @@
|
||||
"""
|
||||
Auth Store — SQLite-backed persistence for Discord-to-service authentication.
|
||||
|
||||
Two tables:
|
||||
- link_tokens : one-time tokens sent via Discord DM to initiate login
|
||||
- user_auth : per-user, per-service credentials (Jellyfin token, etc.)
|
||||
|
||||
Thread-safe via WAL mode and a shared lock. No passwords are ever stored
|
||||
— only opaque service tokens (e.g. Jellyfin AccessToken).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from core.config import get_config
|
||||
|
||||
logger = logging.getLogger("auth_store")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
AUTH_DB_PATH = get_config("AUTH_DB_PATH", "data/auth.db")
|
||||
TOKEN_EXPIRY_MINUTES = int(get_config("AUTH_TOKEN_EXPIRY", "10"))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Singleton handle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_db_path: Path | None = None
|
||||
_db_lock = threading.Lock()
|
||||
|
||||
|
||||
def _resolve_path() -> Path:
|
||||
"""Turn AUTH_DB_PATH into an absolute path, creating parent dirs."""
|
||||
global _db_path
|
||||
if _db_path is not None:
|
||||
return _db_path
|
||||
p = Path(AUTH_DB_PATH)
|
||||
if not p.is_absolute():
|
||||
# Relative to the project root (two levels above this file)
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
p = project_root / p
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
_db_path = p
|
||||
return p
|
||||
|
||||
|
||||
def _get_conn() -> sqlite3.Connection:
|
||||
"""Return a thread-local connection to the auth database."""
|
||||
import sqlite3
|
||||
|
||||
conn = sqlite3.connect(str(_resolve_path()), check_same_thread=False)
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS link_tokens (
|
||||
token TEXT PRIMARY KEY,
|
||||
discord_user_id INTEGER NOT NULL,
|
||||
service TEXT NOT NULL,
|
||||
expires_at TEXT NOT NULL,
|
||||
used INTEGER DEFAULT 0,
|
||||
created_at TEXT DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS user_auth (
|
||||
discord_user_id INTEGER NOT NULL,
|
||||
service TEXT NOT NULL,
|
||||
external_user_id TEXT,
|
||||
external_name TEXT,
|
||||
credentials TEXT,
|
||||
linked_at TEXT DEFAULT (datetime('now')),
|
||||
is_active INTEGER DEFAULT 1,
|
||||
PRIMARY KEY (discord_user_id, service)
|
||||
);
|
||||
"""
|
||||
|
||||
_initialized = False
|
||||
|
||||
|
||||
def _ensure_schema() -> None:
|
||||
global _initialized
|
||||
if _initialized:
|
||||
return
|
||||
with _db_lock:
|
||||
if _initialized:
|
||||
return
|
||||
conn = _get_conn()
|
||||
conn.executescript(_SCHEMA)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
_initialized = True
|
||||
logger.info("Auth store initialized at %s", _resolve_path())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API — Link Tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_token(discord_user_id: int, service: str) -> str:
|
||||
"""Generate a one-time link token. Expires after TOKEN_EXPIRY_MINUTES."""
|
||||
_ensure_schema()
|
||||
token = secrets.token_urlsafe(32)
|
||||
expires = (datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRY_MINUTES)).isoformat()
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"INSERT INTO link_tokens (token, discord_user_id, service, expires_at) VALUES (?, ?, ?, ?)",
|
||||
(token, discord_user_id, service, expires),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
logger.info("Created link token for user %s / service %s", discord_user_id, service)
|
||||
return token
|
||||
|
||||
|
||||
def validate_token(token: str) -> tuple[int, str] | None:
|
||||
"""Read-only validation — does NOT consume the token.
|
||||
|
||||
Returns (discord_user_id, service) if the token exists, is unused,
|
||||
and has not expired. Returns None otherwise.
|
||||
"""
|
||||
_ensure_schema()
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?",
|
||||
(token,),
|
||||
).fetchone()
|
||||
conn.close()
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
if row["used"]:
|
||||
return None
|
||||
|
||||
expires = datetime.fromisoformat(row["expires_at"])
|
||||
if datetime.now(timezone.utc) > expires:
|
||||
return None
|
||||
|
||||
return (row["discord_user_id"], row["service"])
|
||||
|
||||
|
||||
def consume_token(token: str) -> tuple[int, str] | None:
|
||||
"""Validate and consume a link token. Returns (discord_user_id, service) or None.
|
||||
|
||||
A token is valid if:
|
||||
- It exists
|
||||
- It has not been used
|
||||
- It has not expired
|
||||
"""
|
||||
_ensure_schema()
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?",
|
||||
(token,),
|
||||
).fetchone()
|
||||
|
||||
if row is None:
|
||||
conn.close()
|
||||
return None
|
||||
|
||||
if row["used"]:
|
||||
conn.close()
|
||||
logger.warning("Token already used: %s…", token[:8])
|
||||
return None
|
||||
|
||||
expires = datetime.fromisoformat(row["expires_at"])
|
||||
if datetime.now(timezone.utc) > expires:
|
||||
conn.close()
|
||||
logger.warning("Token expired: %s…", token[:8])
|
||||
return None
|
||||
|
||||
conn.execute("UPDATE link_tokens SET used = 1 WHERE token = ?", (token,))
|
||||
conn.commit()
|
||||
result = (row["discord_user_id"], row["service"])
|
||||
conn.close()
|
||||
logger.info("Token consumed: %s… → user=%s service=%s", token[:8], result[0], result[1])
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API — User Auth
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def store_auth(
|
||||
discord_user_id: int,
|
||||
service: str,
|
||||
*,
|
||||
external_user_id: str = "",
|
||||
external_name: str = "",
|
||||
credentials: dict | None = None,
|
||||
) -> None:
|
||||
"""Store or update authentication for a user on a service."""
|
||||
_ensure_schema()
|
||||
import json
|
||||
|
||||
creds_json = json.dumps(credentials) if credentials else "{}"
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"""INSERT INTO user_auth (discord_user_id, service, external_user_id, external_name, credentials, linked_at)
|
||||
VALUES (?, ?, ?, ?, ?, datetime('now'))
|
||||
ON CONFLICT(discord_user_id, service) DO UPDATE SET
|
||||
external_user_id = excluded.external_user_id,
|
||||
external_name = excluded.external_name,
|
||||
credentials = excluded.credentials,
|
||||
linked_at = datetime('now'),
|
||||
is_active = 1""",
|
||||
(discord_user_id, service, external_user_id, external_name, creds_json),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
logger.info("Stored auth for user %s on %s as %s", discord_user_id, service, external_name)
|
||||
|
||||
|
||||
def get_auth(discord_user_id: int, service: str) -> dict | None:
|
||||
"""Retrieve stored auth for a user on a service. Returns None if not linked."""
|
||||
_ensure_schema()
|
||||
import json
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT * FROM user_auth WHERE discord_user_id = ? AND service = ? AND is_active = 1",
|
||||
(discord_user_id, service),
|
||||
).fetchone()
|
||||
conn.close()
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
credentials = json.loads(row["credentials"]) if row["credentials"] else {}
|
||||
return {
|
||||
"discord_user_id": row["discord_user_id"],
|
||||
"service": row["service"],
|
||||
"external_user_id": row["external_user_id"],
|
||||
"external_name": row["external_name"],
|
||||
"credentials": credentials,
|
||||
"linked_at": row["linked_at"],
|
||||
}
|
||||
|
||||
|
||||
def is_authenticated(discord_user_id: int, service: str) -> bool:
|
||||
"""Quick check: is this user linked to this service?"""
|
||||
return get_auth(discord_user_id, service) is not None
|
||||
|
||||
|
||||
def list_services(discord_user_id: int) -> list[str]:
|
||||
"""Return list of service names this user has linked."""
|
||||
_ensure_schema()
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT service FROM user_auth WHERE discord_user_id = ? AND is_active = 1",
|
||||
(discord_user_id,),
|
||||
).fetchall()
|
||||
conn.close()
|
||||
|
||||
return [r["service"] for r in rows]
|
||||
|
||||
|
||||
def revoke(discord_user_id: int, service: str) -> None:
|
||||
"""Unlink a user from a service."""
|
||||
_ensure_schema()
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"UPDATE user_auth SET is_active = 0 WHERE discord_user_id = ? AND service = ?",
|
||||
(discord_user_id, service),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
logger.info("Revoked auth for user %s on %s", discord_user_id, service)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dev / testing — reset the entire store
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def reset_all() -> None:
|
||||
"""Truncate all auth tables — for development and testing only."""
|
||||
_ensure_schema()
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
conn.execute("DELETE FROM link_tokens")
|
||||
conn.execute("DELETE FROM user_auth")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
logger.warning("Auth store RESET — all tokens and auth records cleared.")
|
||||
+13
-29
@@ -1,12 +1,13 @@
|
||||
"""
|
||||
LangGraph agent graph factory.
|
||||
|
||||
Builds a StateGraph that replaces the manual tool-calling loop in api/v1/chat.py.
|
||||
The graph has two nodes:
|
||||
Builds a StateGraph with two nodes:
|
||||
- agent_node : calls the LLM (with system prompt + tool definitions)
|
||||
- tool_node : executes tool calls via the existing skill system
|
||||
|
||||
A conditional edge routes tool_calls back to the agent, or ends the run.
|
||||
When a tool fails due to missing authentication, the failure message is
|
||||
relayed to the LLM, which tells the user to use /login.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -97,18 +98,14 @@ def _make_agent_node(
|
||||
full: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
|
||||
for m in messages:
|
||||
if isinstance(m, dict):
|
||||
# Already a plain dict — pass through.
|
||||
# But fix tool_calls if they're in LangChain format.
|
||||
d = dict(m)
|
||||
tc = d.get("tool_calls")
|
||||
if tc and isinstance(tc, list) and tc and isinstance(tc[0], dict) and "function" not in tc[0]:
|
||||
d["tool_calls"] = _langchain_tc_to_openai(tc)
|
||||
full.append(d)
|
||||
else:
|
||||
# LangChain message object → OpenAI-compatible dict
|
||||
role = _lc_role_to_openai(getattr(m, "type", "user"))
|
||||
d: dict[str, Any] = {"role": role, "content": getattr(m, "content", "")}
|
||||
# Serialize tool_calls back to OpenAI format (if this is an AI msg)
|
||||
tc = getattr(m, "tool_calls", None)
|
||||
if tc:
|
||||
d["tool_calls"] = _langchain_tc_to_openai(tc)
|
||||
@@ -125,7 +122,6 @@ def _make_agent_node(
|
||||
)
|
||||
choice = resp.choices[0]
|
||||
|
||||
# Convert OpenAI tool_calls to the dict format LangChain expects.
|
||||
raw_tool_calls = list(choice.message.tool_calls) if choice.message.tool_calls else []
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
for tc in raw_tool_calls:
|
||||
@@ -153,9 +149,9 @@ def _make_tool_node(skill_names: list[str]):
|
||||
"""
|
||||
Return a callable that executes tool_calls from the last AI message.
|
||||
|
||||
This replaces LangGraph's built-in ToolNode — we call our own
|
||||
`execute_tool()` pipeline so that skill-level auth, httpx sessions,
|
||||
and ToolResult handling are fully preserved.
|
||||
If a tool fails because the user isn't authenticated, the failure
|
||||
message (which tells the user to /login) is returned to the LLM.
|
||||
The LLM naturally relays the instructions to the user.
|
||||
"""
|
||||
|
||||
async def tool_node(state: AgentState) -> dict[str, list]:
|
||||
@@ -164,18 +160,16 @@ def _make_tool_node(skill_names: list[str]):
|
||||
if not tool_calls:
|
||||
return {"messages": []}
|
||||
|
||||
discord_user_id = state.get("discord_user_id")
|
||||
|
||||
results: list[ToolMessage] = []
|
||||
for tc in tool_calls:
|
||||
# Handle both LangChain format (top-level name/args) and
|
||||
# OpenAI format (nested "function" key).
|
||||
if isinstance(tc, dict):
|
||||
if "function" in tc:
|
||||
# OpenAI format: {"id":..., "function": {"name":..., "arguments":"..."}}
|
||||
fn = tc["function"]
|
||||
fn_name = fn.get("name", "")
|
||||
fn_args_raw = fn.get("arguments", "{}")
|
||||
else:
|
||||
# LangChain format: {"name":..., "args":{...}, "id":...}
|
||||
fn_name = tc.get("name", "")
|
||||
fn_args_raw = tc.get("args", {})
|
||||
tc_id = tc.get("id", "")
|
||||
@@ -184,13 +178,15 @@ def _make_tool_node(skill_names: list[str]):
|
||||
fn_args_raw = getattr(tc, "args", {})
|
||||
tc_id = getattr(tc, "id", "")
|
||||
|
||||
# Parse args if they arrive as a JSON string
|
||||
if isinstance(fn_args_raw, str):
|
||||
fn_args = json.loads(fn_args_raw)
|
||||
else:
|
||||
fn_args = fn_args_raw
|
||||
|
||||
tr = await execute_tool(skill_names, fn_name, fn_args)
|
||||
tr = await execute_tool(
|
||||
skill_names, fn_name, fn_args,
|
||||
discord_user_id=discord_user_id,
|
||||
)
|
||||
content = tr.content if tr else f"Tool '{fn_name}' is not available."
|
||||
results.append(ToolMessage(content=content, tool_call_id=tc_id))
|
||||
|
||||
@@ -224,27 +220,16 @@ def create_agent_graph(
|
||||
) -> StateGraph:
|
||||
"""
|
||||
Build and compile a LangGraph StateGraph for a single agent.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client : The OpenAI-compatible client (already authenticated).
|
||||
agent_skills : Skill names assigned to the agent (e.g. ["seerr", "triage"]).
|
||||
system_prompt : The fully-built system prompt (base + skill fragments).
|
||||
model_name : Model identifier sent to the LLM provider.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A compiled LangGraph graph ready for `.ainvoke()` or `.astream()`.
|
||||
"""
|
||||
tool_defs = get_all_tools(agent_skills)
|
||||
|
||||
graph = StateGraph(AgentState)
|
||||
|
||||
# Nodes
|
||||
graph.add_node(
|
||||
"agent_node",
|
||||
_make_agent_node(client, system_prompt, tool_defs, model_name),
|
||||
)
|
||||
|
||||
if tool_defs:
|
||||
graph.add_node("tool_node", _make_tool_node(agent_skills))
|
||||
graph.add_conditional_edges("agent_node", _should_continue, {
|
||||
@@ -253,7 +238,6 @@ def create_agent_graph(
|
||||
})
|
||||
graph.add_edge("tool_node", "agent_node")
|
||||
else:
|
||||
# No tools — agent responds once and finishes
|
||||
graph.add_edge("agent_node", END)
|
||||
|
||||
graph.set_entry_point("agent_node")
|
||||
|
||||
@@ -18,3 +18,4 @@ class AgentState(TypedDict):
|
||||
"""
|
||||
|
||||
messages: Annotated[list, add_messages]
|
||||
discord_user_id: int | None # set by the Discord bot, None for REST API calls
|
||||
|
||||
@@ -4,8 +4,9 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from api.v1.auth import router as auth_router
|
||||
from api.v1.chat import router as v1_router
|
||||
from core.config import DEEPSEEK_API_KEY
|
||||
from core.config import DEEPSEEK_API_KEY, get_config
|
||||
from core.llm import create_client
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -18,12 +19,14 @@ logging.basicConfig(
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Load all agents & skills so they self-register at startup
|
||||
# Load all agents, skills, AND auth services so they self-register at startup
|
||||
# ---------------------------------------------------------------------------
|
||||
from agents import load_all_agents # noqa: E402
|
||||
|
||||
load_all_agents()
|
||||
|
||||
import auth.jellyfin # noqa: E402 — self-registers JellyfinAuth
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lifespan
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -61,3 +64,4 @@ app.state.agent_graphs: dict = {}
|
||||
# Routers
|
||||
# ---------------------------------------------------------------------------
|
||||
app.include_router(v1_router, prefix="/v1")
|
||||
app.include_router(auth_router)
|
||||
@@ -6,3 +6,4 @@ httpx
|
||||
langgraph
|
||||
langgraph-checkpoint
|
||||
discord.py
|
||||
python-multipart
|
||||
+26
-1
@@ -47,6 +47,7 @@ class Skill:
|
||||
prompt_fragment: str = ""
|
||||
tools: List[Dict[str, Any]] = field(default_factory=list)
|
||||
execute: Optional[ToolExecutor] = None
|
||||
requires_auth: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -96,9 +97,15 @@ def get_all_tools(skill_names: list[str]) -> List[Dict[str, Any]]:
|
||||
|
||||
|
||||
async def execute_tool(
|
||||
skill_names: list[str], tool_name: str, args: dict
|
||||
skill_names: list[str], tool_name: str, args: dict,
|
||||
discord_user_id: int | None = None,
|
||||
) -> ToolResult | None:
|
||||
"""Find the skill that owns *tool_name* and run its executor.
|
||||
|
||||
If *discord_user_id* is provided, also checks whether the owning skill
|
||||
requires authentication for any services. If auth is missing, returns
|
||||
a friendly ToolResult.fail(...) telling the user how to log in.
|
||||
|
||||
Only logs failures to the console — successful calls are silent.
|
||||
"""
|
||||
import logging
|
||||
@@ -109,6 +116,24 @@ async def execute_tool(
|
||||
if s and s.execute:
|
||||
for t in s.tools:
|
||||
if t.get("function", {}).get("name") == tool_name:
|
||||
# --- Auth gate ---
|
||||
if s.requires_auth and discord_user_id is not None:
|
||||
from core import auth_store
|
||||
from auth import get_auth_service
|
||||
missing: list[str] = []
|
||||
for svc in s.requires_auth:
|
||||
if not auth_store.is_authenticated(discord_user_id, svc):
|
||||
missing.append(svc)
|
||||
if missing:
|
||||
svc_displays = ", ".join(
|
||||
(get_auth_service(m) and get_auth_service(m).display_name) or m
|
||||
for m in missing
|
||||
)
|
||||
return ToolResult.fail(
|
||||
f"You need to log in to {svc_displays} first. "
|
||||
+ " ".join(f"Send `/login {m}` in a DM to get started." for m in missing)
|
||||
)
|
||||
# --- End auth gate ---
|
||||
try:
|
||||
result = await s.execute(tool_name, args)
|
||||
if not result.success:
|
||||
|
||||
@@ -23,6 +23,16 @@ When responding:
|
||||
suggest submitting a ticket if there's a problem.
|
||||
- Always confirm successful actions and warn about failures.
|
||||
|
||||
## Jellyfin & Authentication
|
||||
|
||||
You are connected to the user's Jellyfin server. If a user asks you to
|
||||
"connect to Jellyfin", "link my Jellyfin", or asks about their watch history,
|
||||
simply call the `watch_history` tool. The system will automatically handle
|
||||
authentication — if the user isn't linked yet, they'll be guided through
|
||||
Quick Connect seamlessly. NEVER tell a user you "don't have access to
|
||||
Jellyfin" or "can't connect" — always try the tool first and let the system
|
||||
sort it out.
|
||||
|
||||
This is the base media assistant persona. Real API capabilities come from the
|
||||
attached skills (seerr, triage, etc.).""",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Watch History skill — fetch the user's Jellyfin watch history.
|
||||
|
||||
Currently a placeholder — returns a "coming soon" message.
|
||||
The auth gate (`requires_auth=["jellyfin"]`) is already active:
|
||||
users who haven't linked Jellyfin will be prompted to /login first.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from skills import Skill, register, ToolResult
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool definitions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "watch_history",
|
||||
"description": (
|
||||
"Get the user's recent Jellyfin watch history — movies and TV "
|
||||
"episodes they have watched, sorted by most recent. "
|
||||
"Call this when a user asks about their watching activity."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "How many items to return (default 10, max 20)",
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Executor (placeholder)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _execute(tool_name: str, args: dict) -> ToolResult:
|
||||
if tool_name == "watch_history":
|
||||
return ToolResult.ok(
|
||||
"👷 **Watch History — Coming Soon!**\n\n"
|
||||
"This feature is currently being built. Soon you'll be able to "
|
||||
"see your recently watched movies and TV episodes right here.\n\n"
|
||||
"In the meantime, you can check your watch history directly in Jellyfin."
|
||||
)
|
||||
return ToolResult.fail(f"Unknown tool: {tool_name}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Skill registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
watch_history_skill = Skill(
|
||||
name="watch_history",
|
||||
description="User's Jellyfin watch history (coming soon)",
|
||||
requires_auth=["jellyfin"],
|
||||
prompt_fragment="""## Watch History
|
||||
|
||||
You can fetch the user's Jellyfin watch history with the `watch_history` tool.
|
||||
Call it when users ask things like:
|
||||
- "what have I watched?"
|
||||
- "show my watch history"
|
||||
- "what did I watch recently?"
|
||||
- "what was the last movie I saw?"
|
||||
- "what TV shows have I been watching?"
|
||||
|
||||
The tool is currently a **placeholder** — it returns a "coming soon" message.
|
||||
Tell the user this feature is being worked on and will be available soon.""",
|
||||
tools=TOOLS,
|
||||
execute=_execute,
|
||||
)
|
||||
|
||||
register(watch_history_skill)
|
||||
Reference in New Issue
Block a user