Compare commits

2 Commits

Author SHA1 Message Date
TimHoogervorst f21676eafd removed bunch of old code, added /init bunch of cleanups
Build and Push Agent API / build (push) Successful in 4s
2026-05-25 18:07:01 +02:00
TimHoogervorst f28f0d41ec Merge pull request 'added quick connect auth from jellyfin, still needs to have some more cleaning before push to prod' (#2) from auth into main
Build and Push Agent API / build (push) Successful in 24s
Reviewed-on: #2
2026-05-25 14:21:33 +00:00
13 changed files with 119 additions and 792 deletions
+8
View File
@@ -0,0 +1,8 @@
.env
.venv/
__pycache__/
.git/
.gitea/
*.pyc
*.pyo
data/
+78
View File
@@ -0,0 +1,78 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Commands
```bash
# Run the server (reads .env for config)
uvicorn main:app --host 0.0.0.0 --port 8000
# Verify code compiles and imports cleanly
python -c "import main; print('OK')"
# Syntax check a specific file
python -m py_compile path/to/file.py
# Docker build
docker build -t agents-api -f docker/Dockerfile .
```
There is no test suite or linting setup yet.
## Architecture
This is a **Discord bot** powered by a LangGraph agent with a pluggable skill system. The only interaction surface is Discord DMs — there is no web chat UI.
```
Discord DM → bot.py → LangGraph StateGraph → skills (tools) → external APIs
REST API (auth status, JellyStat)
```
### Agent / skill system (`agents/`)
- **Agents** (`agents/__init__.py`) are thin wrappers pairing a system prompt with a list of skill names.
- **Skills** (`agents/skills/__init__.py`) provide prompt fragments, OpenAI tool definitions, and an async `execute` callback. Each skill self-registers via `register()` at import time.
- Agents and skills are loaded at startup by `load_all_agents()` in `main.py`, which triggers side-effecting imports of all agent/skill modules.
The media-agent (`agents/media_agent.py`) is the primary agent. Skills attached:
- `seerr` — search, trending, discover, request, submit issues via Seerr API
- `watch_history` — Jellyfin watch stats via the internal JellyStat API
- `media_info` — base persona (prompt-only)
- `triage` — fallback rules for unsupported actions (prompt-only)
- `easter_eggs` — theme-aware persona flavours (prompt-only)
### LangGraph graph (`src/graph.py`)
Two-node StateGraph: `agent_node → tool_node → agent_node`. The agent node calls DeepSeek (OpenAI-compatible) with system prompt + tool defs. The tool node executes tool calls through the skill system via `execute_tool()`. A custom tool node is used — not LangChain's ToolNode.
State is `AgentState` (`src/state.py`): a TypedDict with `messages` (LangGraph `add_messages` reducer) and `discord_user_id`.
### Discord bot (`gateway/discord/bot.py`)
Runs in a background daemon thread with its own asyncio event loop (separate from FastAPI's). DMs only — no server/channel messages. Requires users to share a guild with the bot. Maintains per-user conversation history via `ConversationStore` (in-memory dict, last N exchanges). Supports `/login <service>` (Quick Connect) and `/logout <service>`.
### Auth system (`src/auth_store.py`, `gateway/auth/`)
SQLite-backed (WAL mode) with a single `user_auth` table keyed on `(discord_user_id, service)`. Stores opaque service tokens, never passwords. `AuthService` is a minimal ABC with `name`/`display_name` properties; `JellyfinAuth` is the only implementation, using Jellyfin Quick Connect (initiate → poll → exchange secret for token). The auth gate in `execute_tool()` checks `skill.requires_auth` before executing tools.
### REST API (`main.py`)
Minimal — only two routers remain:
- `gateway/v1/auth.py``GET /api/v1/auth/Discord/status` (linked services lookup) and `POST /api/v1/auth/reset` (dev only)
- `gateway/jellystat/api.py``GET /jellystat/{history,genres,summary}/{user_id}` called internally by the `watch_history` skill
### JellyStat (`gateway/jellystat/`)
PostgreSQL connection pool stored on `app.state.jellystat_pool`. Database functions (`startup-functions.sql`) are deployed on startup via `CREATE OR REPLACE FUNCTION`. The watch_history skill calls these endpoints over HTTP (localhost) rather than querying the DB directly, keeping DB credentials isolated from the skill layer.
### Seerr session caching (`agents/skills/seerr.py`)
`httpx.AsyncClient` instances are cached per event loop (the Discord bot thread has its own loop separate from FastAPI). Cookie-based auth for Seerr is obtained once at startup via a sync login (thread-safe with double-check locking), then reused across all event-loop-specific clients.
## Key patterns
- **Self-registration**: agents, skills, and auth services all register at import time via module-level function calls. New modules just need to be imported once (see `load_all_agents()` and `import gateway.auth.jellyfin` in `main.py`).
- **Auth gate**: skills declare `requires_auth=["jellyfin"]` and the framework checks credentials before tool execution. Tools receive `_discord_user_id` injected into their args dict.
- **TMDb IDs as source of truth**: media tools display `[tmdb:123456]` tags and prefer IDs over title matching. The system prompt instructs the LLM to always show and use these IDs.
+6 -6
View File
@@ -84,9 +84,9 @@ THEMES = {
"airplane", "plane", "flight", "pilot", "cockpit", "turbulence" "airplane", "plane", "flight", "pilot", "cockpit", "turbulence"
], ],
"persona": ( "persona": (
"mention airplane noises, and say that Erwin will be terrified " "Mention airplane noises, and say that Erwin will be terrified. "
"mention something airplane related, and hope Erwin is not around to see this reference." "Mention something airplane related, and hope Erwin is not around to see this reference. "
"Stay fully functional — carry out all requested actions normally, but only if something is requested" "Stay fully functional — carry out all requested actions normally, but only if something is requested."
) )
}, },
@@ -95,9 +95,9 @@ THEMES = {
"unable to read", "dyslexia", "dislexia", "dislexic", "dyslexic", "typo", "trouble reading", "misspelled", "misspelling" "unable to read", "dyslexia", "dislexia", "dislexic", "dyslexic", "typo", "trouble reading", "misspelled", "misspelling"
], ],
"persona": ( "persona": (
"mention that tim is dyslexic and has trouble reading in the response " "Mention that Tim is dyslexic and has trouble reading in the response. "
"mention it is sometimes difficult to read, since the creator (which is Tim) of this agent is dyslexic. " "Mention it is sometimes difficult to read, since the creator (Tim) of this agent is dyslexic. "
"Stay fully functional — carry out all requested actions normally, but only if something is requested" "Stay fully functional — carry out all requested actions normally, but only if something is requested."
) )
} }
} }
+6 -25
View File
@@ -565,26 +565,12 @@ async def _discover(args: dict) -> ToolResult:
r.raise_for_status() r.raise_for_status()
results = r.json().get("results", []) results = r.json().get("results", [])
desc = genre desc = genre
elif studio: elif studio or keyword:
# /discover/studio/{studioId} requires a numeric TMDB studio ID. # Both use /search with free-text query (studio needs a numeric TMDB
# Fall back to searching by name via /search. # studio ID not available here, so fall back to name search).
desc = studio desc = studio or keyword
search_query = studio
endpoint = "/api/v1/search" endpoint = "/api/v1/search"
params["query"] = search_query params["query"] = desc
if language:
params["language"] = language
async with _client() as c:
r = await c.get(endpoint, params=params)
r.raise_for_status()
results = r.json().get("results", [])
# Filter to requested media type
results = [item for item in results if item.get("mediaType") == kind]
elif keyword:
# Free-text keyword → use /search, filtered by mediaType
desc = keyword
endpoint = "/api/v1/search"
params["query"] = keyword
if language: if language:
params["language"] = language params["language"] = language
async with _client() as c: async with _client() as c:
@@ -616,7 +602,6 @@ async def _request_media(args: dict) -> ToolResult:
async with _client() as c: async with _client() as c:
# --- Fast-path: TMDb ID known — confirm the title and request directly --- # --- Fast-path: TMDb ID known — confirm the title and request directly ---
if tmdb_id: if tmdb_id:
# Quick lookup to get the correct title for the confirmation message
detail_r = await c.get(f"/api/v1/{kind}/{tmdb_id}") detail_r = await c.get(f"/api/v1/{kind}/{tmdb_id}")
if detail_r.status_code == 200: if detail_r.status_code == 200:
detail = detail_r.json() detail = detail_r.json()
@@ -626,11 +611,6 @@ async def _request_media(args: dict) -> ToolResult:
or detail.get("firstAirDate", "")[:4] or detail.get("firstAirDate", "")[:4]
or "?" or "?"
) )
else:
# Detail lookup failed — fall back to title search
pass
if detail_r.status_code == 200:
# Submit directly with the known TMDb ID # Submit directly with the known TMDb ID
request_body: dict = {"mediaType": kind, "mediaId": tmdb_id} request_body: dict = {"mediaType": kind, "mediaId": tmdb_id}
if kind == "tv": if kind == "tv":
@@ -651,6 +631,7 @@ async def _request_media(args: dict) -> ToolResult:
f"❌ Failed to request **{media_title}** ({media_year}). " f"❌ Failed to request **{media_title}** ({media_year}). "
f"Seerr responded with status {req_r.status_code}: {req_r.text[:500]}" f"Seerr responded with status {req_r.status_code}: {req_r.text[:500]}"
) )
# Detail lookup failed — fall through to slow-path title search
# --- Slow-path: search by title --- # --- Slow-path: search by title ---
r = await c.get("/api/v1/search", params={"query": quote(title), "page": 1}) r = await c.get("/api/v1/search", params={"query": quote(title), "page": 1})
+3 -21
View File
@@ -10,17 +10,17 @@ Add a new service (Plex, Seerr, etc.) by:
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import Optional from typing import Optional
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# AuthResult — returned by AuthService.authenticate() # AuthResult — returned by AuthService authentication
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@dataclass @dataclass
class AuthResult: class AuthResult:
"""Outcome of a credential validation attempt.""" """Outcome of an authentication attempt."""
success: bool success: bool
external_user_id: Optional[str] = None external_user_id: Optional[str] = None
external_name: Optional[str] = None external_name: Optional[str] = None
@@ -38,8 +38,6 @@ class AuthService(ABC):
Subclasses must implement: Subclasses must implement:
- name : unique identifier used in URLs and DB keys - name : unique identifier used in URLs and DB keys
- display_name : human-readable label shown in Discord - display_name : human-readable label shown in Discord
- render_login_form(token, discord_id) → HTML string
- authenticate(form_data) → AuthResult
""" """
@property @property
@@ -54,22 +52,6 @@ class AuthService(ABC):
"""Human-readable: "Jellyfin", "Seerr", "Plex" """ """Human-readable: "Jellyfin", "Seerr", "Plex" """
... ...
@abstractmethod
def render_login_form(self, token: str, discord_id: int) -> str:
"""Return HTML string with a login form for this service.
The form MUST include these hidden fields:
<input type="hidden" name="token" value="{token}">
<input type="hidden" name="discord_id" value="{discord_id}">
<input type="hidden" name="service" value="{self.name}">
"""
...
@abstractmethod
async def authenticate(self, form_data: dict) -> AuthResult:
"""Validate credentials against the service. Return AuthResult."""
...
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Global registry # Global registry
+6 -144
View File
@@ -1,15 +1,10 @@
""" """
Jellyfin AuthService — validates Jellyfin credentials and stores the session token. Jellyfin AuthService — authenticates users via Jellyfin Quick Connect.
Two authentication flows: Flow:
1. Quick Connect (primary): user enters a short code on their Jellyfin app. 1. initiate_quick_connect() → {code, secret}
- initiate_quick_connect() {code, secret} 2. poll_quick_connect(secret) → "Active" | "Authorized" | "Expired"
- poll_quick_connect(secret) → "Active" | "Authorized" | "Expired" 3. authenticate_quick_connect(secret) → AuthResult with token
- authenticate_quick_connect(secret) → AuthResult with token
2. Username/password (legacy): renders an HTML form, called via the REST API.
- render_login_form(token, discord_id) → HTML string
- authenticate(form_data) → AuthResult
""" """
from __future__ import annotations from __future__ import annotations
@@ -25,14 +20,6 @@ from src.config import get_config
logger = logging.getLogger("auth.jellyfin") logger = logging.getLogger("auth.jellyfin")
# Emby-style authorization header required by Jellyfin's AuthenticateByName
_EMBY_HEADER = (
'MediaBrowser Client="AgentBot",'
'Device="DiscordBot",'
'DeviceId="agent-bot",'
'Version="1.0"'
)
@dataclass @dataclass
class QuickConnectResult: class QuickConnectResult:
@@ -64,9 +51,7 @@ class JellyfinAuth(AuthService):
async def _resolve_url(self) -> str | None: async def _resolve_url(self) -> str | None:
""" """
Resolve the Jellyfin server URL. Resolve the Jellyfin server URL from the JELLYFIN_URL env var.
1. Check JELLYFIN_URL env var (used in deployment).
2. Check if user already has a stored auth with a URL (from legacy login).
Returns None if no URL is configured. Returns None if no URL is configured.
""" """
# First: explicit env var # First: explicit env var
@@ -272,129 +257,6 @@ class JellyfinAuth(AuthService):
error_message="An unexpected error occurred during authentication.", error_message="An unexpected error occurred during authentication.",
) )
# ------------------------------------------------------------------
# Login form (legacy — used by the REST API)
# ------------------------------------------------------------------
def render_login_form(self, token: str, discord_id: int) -> str:
return f"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Link Jellyfin</title>
<style>
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; }}
h2 {{ margin-bottom: 4px; }}
.sub {{ color: #666; margin-bottom: 24px; }}
label {{ display: block; margin-top: 16px; font-weight: 600; }}
input {{ width: 100%; padding: 10px; margin-top: 4px; border: 1px solid #ccc; border-radius: 6px; box-sizing: border-box; }}
button {{ margin-top: 24px; width: 100%; padding: 12px; background: #aa5cc3; color: #fff; border: none; border-radius: 6px; font-size: 16px; cursor: pointer; }}
button:hover {{ background: #9448b0; }}
</style>
</head>
<body>
<h2>🔗 Link Jellyfin to Discord</h2>
<p class="sub">Enter your Jellyfin server URL and credentials to link your account.</p>
<form method="POST" action="/api/v1/auth/login">
<input type="hidden" name="token" value="{token}">
<input type="hidden" name="discord_id" value="{discord_id}">
<input type="hidden" name="service" value="jellyfin">
<label for="jellyfin_url">Jellyfin Server URL</label>
<input id="jellyfin_url" name="jellyfin_url" type="url"
placeholder="https://jellyfin.example.com" required>
<label for="username">Username</label>
<input id="username" name="username" type="text"
placeholder="Your Jellyfin username" required autofocus>
<label for="password">Password</label>
<input id="password" name="password" type="password"
placeholder="Your Jellyfin password" required>
<button type="submit">Link Account</button>
</form>
</body>
</html>"""
# ------------------------------------------------------------------
# Authentication
# ------------------------------------------------------------------
async def authenticate(self, form_data: dict) -> AuthResult:
url = form_data.get("jellyfin_url", "").strip().rstrip("/")
username = form_data.get("username", "").strip()
password = form_data.get("password", "").strip()
if not url or not username or not password:
return AuthResult(
success=False,
error_message="All fields are required (URL, username, password).",
)
logger.info("Attempting Jellyfin login for '%s' on %s", username, url)
async with httpx.AsyncClient(timeout=10) as client:
try:
resp = await client.post(
f"{url}/Users/AuthenticateByName",
json={"Username": username, "Pw": password},
headers={"X-Emby-Authorization": _EMBY_HEADER},
)
if resp.status_code != 200:
logger.warning(
"Jellyfin login failed for '%s': HTTP %s", username, resp.status_code
)
return AuthResult(
success=False,
error_message=f"Login failed — check your server URL and credentials.",
)
data = resp.json()
user = data.get("User", {})
token = data.get("AccessToken", "")
if not token:
return AuthResult(
success=False,
error_message="Jellyfin returned an unexpected response.",
)
logger.info(
"Jellyfin login OK: user=%s (%s)",
user.get("Name", "?"),
user.get("Id", "?"),
)
return AuthResult(
success=True,
external_user_id=user.get("Id", ""),
external_name=user.get("Name", username),
credentials={
"token": token,
"url": url,
"user_id": user.get("Id", ""),
},
)
except httpx.TimeoutException:
return AuthResult(
success=False,
error_message=f"Could not reach {url} — connection timed out. Check the URL.",
)
except httpx.ConnectError:
return AuthResult(
success=False,
error_message=f"Could not connect to {url}. Is the server running?",
)
except Exception as exc:
logger.exception("Unexpected error during Jellyfin login")
return AuthResult(
success=False,
error_message=f"An unexpected error occurred. Please try again.",
)
# Self-register at import time # Self-register at import time
-36
View File
@@ -1,36 +0,0 @@
from fastapi import Request
from openai import OpenAI
from src.graph import create_agent_graph
def get_llm_client(request: Request) -> OpenAI:
"""FastAPI dependency — returns the singleton OpenAI client from app.state."""
return request.app.state.llm_client
def get_agent_graph(agent_id: str, request: Request):
"""
FastAPI dependency — returns the compiled LangGraph graph for *agent_id*.
Graphs are lazily compiled on first use and cached on app.state so each
agent's graph is only built once per process lifetime.
"""
cache: dict = request.app.state.agent_graphs
if agent_id not in cache:
from agents import get as get_agent
agent = get_agent(agent_id)
if agent is None:
# Fall back to the naked agent if the requested one doesn't exist
agent_id = "naked"
agent = get_agent(agent_id)
cache[agent_id] = create_agent_graph(
client=request.app.state.llm_client,
agent_skills=agent.skills,
system_prompt=agent.build_system_prompt(),
)
return cache[agent_id]
-3
View File
@@ -207,9 +207,6 @@ class AgentBot(discord.Client):
# --- Quick Connect flow --- # --- Quick Connect flow ---
svc = get_auth_service(service) svc = get_auth_service(service)
if svc is None:
await message.channel.send(f"Unknown service: {service}")
return True
await message.channel.send(f"🔑 Starting **{svc.display_name}** Quick Connect…") await message.channel.send(f"🔑 Starting **{svc.display_name}** Quick Connect…")
+9 -140
View File
@@ -1,156 +1,27 @@
""" """
Auth API — generic endpoints for linking Discord users to external services. Auth API — endpoints for checking linked services and dev operations.
GET /api/v1/auth/login?service=X&token=Y&discord_id=Z GET /api/v1/auth/Discord/status?discord_id=Z
Validates the link token and serves a service-specific login form.
POST /api/v1/auth/login
Accepts the form submission, validates credentials against the service,
stores the session, and returns a result page.
GET /api/v1/auth/status?discord_id=Z
Returns which services are linked for this Discord user. Returns which services are linked for this Discord user.
POST /api/v1/auth/reset
Wipes the auth store (dev only — requires ALLOW_AUTH_RESET=true).
""" """
from __future__ import annotations from __future__ import annotations
import logging import logging
from fastapi import APIRouter, Form, HTTPException, Request from fastapi import APIRouter, HTTPException
from fastapi.responses import HTMLResponse
from gateway.auth import get_auth_service, list_auth_services
from src import auth_store from src import auth_store
from src.config import get_config
logger = logging.getLogger("gateway.auth") logger = logging.getLogger("gateway.auth")
router = APIRouter(prefix="/api/v1/auth", tags=["auth"]) router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
# ---------------------------------------------------------------------------
# GET /auth/login — serve the login form
# ---------------------------------------------------------------------------
@router.get("/login")
async def login_form(
service: str,
token: str,
discord_id: int,
):
"""Validate the one-time link token and return a service-specific login form."""
# Validate the token WITHOUT consuming it (the POST will consume it)
result = auth_store.validate_token(token)
if result is None:
raise HTTPException(status_code=400, detail="Invalid or expired link token.")
uid, svc = result
if uid != discord_id or svc != service:
raise HTTPException(status_code=400, detail="Token does not match the request.")
# Look up the AuthService
svc_obj = get_auth_service(service)
if svc_obj is None:
raise HTTPException(status_code=404, detail=f"Unknown service: {service}")
logger.info("Serving login form: user=%s service=%s", discord_id, service)
return HTMLResponse(svc_obj.render_login_form(token, discord_id))
# ---------------------------------------------------------------------------
# POST /auth/login — handle form submission
# ---------------------------------------------------------------------------
@router.post("/login")
async def login_submit(request: Request):
"""Handle the login form POST: validate credentials, store auth, show result."""
# Parse form data
form = await request.form()
token = form.get("token", "")
discord_id_str = form.get("discord_id", "")
service = form.get("service", "")
if not token or not discord_id_str or not service:
raise HTTPException(status_code=400, detail="Missing required fields.")
try:
discord_id = int(discord_id_str)
except (ValueError, TypeError):
raise HTTPException(status_code=400, detail="Invalid discord_id.")
# Consume the token on POST (the GET only validated, didn't consume)
result = auth_store.consume_token(token)
if result is None:
raise HTTPException(status_code=400, detail="Invalid or expired link token.")
# Look up the AuthService
svc_obj = get_auth_service(service)
if svc_obj is None:
raise HTTPException(status_code=404, detail=f"Unknown service: {service}")
# Collect service-specific form fields (everything except token, discord_id, service)
form_data: dict[str, str] = {}
for key, value in form.items():
if key not in ("token", "discord_id", "service"):
form_data[key] = str(value)
# Authenticate against the service
auth_result = await svc_obj.authenticate(form_data)
if not auth_result.success:
return HTMLResponse(
status_code=401,
content=f"""<!DOCTYPE html>
<html><head><meta charset="utf-8"><title>Login Failed</title>
<style>
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; }}
h2 {{ color: #d32f2f; }}
a {{ color: #aa5cc3; }}
</style></head><body>
<h2>❌ Login Failed</h2>
<p>{auth_result.error_message or "Authentication failed. Please try again."}</p>
<p><a href="javascript:history.back()">← Go back and try again</a></p>
</body></html>""",
)
# Store the successful auth
auth_store.store_auth(
discord_user_id=discord_id,
service=service,
external_user_id=auth_result.external_user_id or "",
external_name=auth_result.external_name or "",
credentials=auth_result.credentials,
)
logger.info(
"Auth linked: discord=%s%s (%s)",
discord_id,
service,
auth_result.external_name,
)
return HTMLResponse(f"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Account Linked</title>
<style>
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; text-align: center; }}
h1 {{ color: #388e3c; }}
.name {{ font-weight: bold; color: #aa5cc3; }}
p {{ color: #666; }}
</style>
</head>
<body>
<h1>✅ Account Linked!</h1>
<p>Logged in as <span class="name">{auth_result.external_name}</span> on <strong>{svc_obj.display_name}</strong>.</p>
<p>You can close this page and return to Discord.</p>
</body>
</html>""")
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# GET /auth/status — get all linked services for a Discord user # GET /auth/status — get all linked services for a Discord user
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -199,12 +70,10 @@ async def auth_status(discord_id: int):
# POST /auth/reset — wipe auth store (DEV ONLY) # POST /auth/reset — wipe auth store (DEV ONLY)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
from src.config import get_config # noqa: E402
@router.post("/reset") @router.post("/reset")
async def reset_auth(): async def reset_auth():
""" """
Reset the entire auth store — clears all link tokens and user auth records. Reset the entire auth store — clears all user auth records.
Only enabled when ALLOW_AUTH_RESET=true in the environment. Only enabled when ALLOW_AUTH_RESET=true in the environment.
Returns 403 in production. Returns 403 in production.
@@ -217,4 +86,4 @@ async def reset_auth():
auth_store.reset_all() auth_store.reset_all()
logger.warning("Auth store reset via API endpoint.") logger.warning("Auth store reset via API endpoint.")
return {"status": "ok", "message": "Auth store cleared — all tokens and auth records removed."} return {"status": "ok", "message": "Auth store cleared — all auth records removed."}
-241
View File
@@ -1,241 +0,0 @@
from fastapi import APIRouter, Depends, Request
from fastapi.responses import StreamingResponse
from openai import OpenAI
from pydantic import BaseModel
import json
from gateway.dependencies import get_llm_client, get_agent_graph
from agents import get as get_agent, list_all as list_all_agents
from src.state import AgentState
router = APIRouter()
class ChatRequest(BaseModel):
message: str
session_id: str | None = None
agent_id: str | None = None
class ChatCompletionRequest(BaseModel):
messages: list[dict]
stream: bool = False
model: str = "deepseek-chat"
# ---------------------------------------------------------------------------
# Agent resolution
# ---------------------------------------------------------------------------
def _resolve_agent(agent_id: str | None = None, model: str | None = None):
"""
1. explicit agent_id
2. model field (OpenWebUI sends this — maps to agent_id if registered)
3. fallback to "naked"
"""
lookup = agent_id or model
if lookup is None:
return get_agent("naked")
agent = get_agent(lookup)
return agent if agent else get_agent("naked")
# ---------------------------------------------------------------------------
# LangGraph helpers
# ---------------------------------------------------------------------------
async def _invoke_graph(graph, messages: list[dict]) -> str:
"""Run the graph synchronously (non-streaming) and return the final text."""
state: AgentState = {"messages": messages}
result = await graph.ainvoke(state)
last_msg = result["messages"][-1]
return last_msg.content or ""
async def _stream_graph(graph, messages: list[dict]):
"""
Run the graph and stream the final response token-by-token.
LangGraph's astream_events would require langchain-openai's ChatOpenAI
to intercept LLM chunks. Instead we run the graph to completion (tools
execute silently) and then stream the final text content character by
character — this gives the client a real SSE stream without adding new
dependencies.
"""
state: AgentState = {"messages": messages}
result = await graph.ainvoke(state)
content = result["messages"][-1].content or ""
# Yield token-by-token so the SSE client sees incremental output
for token in content:
yield token
# ---------------------------------------------------------------------------
# Non-streaming run (kept for /chat/sync and sync completions)
# ---------------------------------------------------------------------------
async def run_agent_with_tools(
request: Request,
messages: list[dict],
agent_id: str | None = None,
model: str | None = None,
) -> str:
"""Send messages through the agent's LangGraph. Non-streaming."""
agent = _resolve_agent(agent_id, model)
graph = get_agent_graph(agent.agent_id, request)
return await _invoke_graph(graph, messages)
# ---------------------------------------------------------------------------
# Streaming generator (kept for /chat and stream completions)
# ---------------------------------------------------------------------------
async def run_agent_stream(
request: Request,
messages: list[dict],
agent_id: str | None = None,
model: str | None = None,
):
"""Async generator — yields tokens via the agent's LangGraph."""
agent = _resolve_agent(agent_id, model)
graph = get_agent_graph(agent.agent_id, request)
async for token in _stream_graph(graph, messages):
yield token
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.get("/")
def root():
return {"status": "ok"}
@router.post("/chat")
async def chat(
req: ChatRequest,
request: Request,
client: OpenAI = Depends(get_llm_client),
):
"""Streaming chat — single message, no history."""
messages = [{"role": "user", "content": req.message}]
async def event_stream():
async for token in run_agent_stream(request, messages, req.agent_id):
payload = json.dumps({"token": token, "session_id": req.session_id})
yield f"data: {payload}\n\n"
yield f"data: {json.dumps({'done': True, 'session_id': req.session_id})}\n\n"
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.post("/chat/sync")
async def chat_sync(
req: ChatRequest,
request: Request,
client: OpenAI = Depends(get_llm_client),
):
"""Non-streaming chat — single message."""
messages = [{"role": "user", "content": req.message}]
response = await run_agent_with_tools(request, messages, req.agent_id)
return {"response": response, "session_id": req.session_id}
@router.get("/agents")
def list_agents():
"""Return all registered agents."""
return {
"agents": [
{
"agent_id": a.agent_id,
"description": a.description,
"skills": a.skills,
}
for a in list_all_agents().values()
]
}
@router.get("/models")
def list_models():
"""Return agents as selectable models for OpenWebUI."""
return {
"object": "list",
"data": [
{
"id": a.agent_id,
"object": "model",
"created": 0,
"owned_by": "local-agent",
}
for a in list_all_agents().values()
],
}
@router.post("/chat/completions")
async def chat_completions(
req: ChatCompletionRequest,
request: Request,
client: OpenAI = Depends(get_llm_client),
):
"""OpenAI-compatible /chat/completions — supports stream=True.
Multi-turn: req.messages contains the FULL conversation history.
Agent resolved from the model field (OpenWebUI sends this).
"""
agent = _resolve_agent(model=req.model)
if req.stream:
async def sse_stream():
async for token in run_agent_stream(
request, req.messages, agent_id=agent.agent_id,
):
chunk = {
"id": "chatcmpl-local",
"object": "chat.completion.chunk",
"choices": [
{"index": 0, "delta": {"content": token}, "finish_reason": None}
],
}
yield f"data: {json.dumps(chunk)}\n\n"
final_chunk = {
"id": "chatcmpl-local",
"object": "chat.completion.chunk",
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
yield f"data: {json.dumps(final_chunk)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
sse_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
# Non-streaming — full history, LangGraph agent
response = await run_agent_with_tools(
request, req.messages, agent_id=agent.agent_id,
)
return {
"id": "chatcmpl-local",
"object": "chat.completion",
"created": 0,
"model": req.model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": response},
"finish_reason": "stop",
}
],
}
-12
View File
@@ -5,10 +5,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from gateway.v1.auth import router as auth_router from gateway.v1.auth import router as auth_router
from gateway.v1.chat import router as v1_router
from gateway.jellystat.api import router as jellystat_router from gateway.jellystat.api import router as jellystat_router
from src.config import DEEPSEEK_API_KEY, get_config
from src.llm import create_client
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Logging — tool calls will appear in the uvicorn console # Logging — tool calls will appear in the uvicorn console
@@ -57,17 +54,8 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
# ---------------------------------------------------------------------------
# Singletons (stored on app.state so every module can reach them via Depends)
# ---------------------------------------------------------------------------
app.state.llm_client = create_client(DEEPSEEK_API_KEY)
# Lazy-compiled LangGraph graphs — populated on first use per agent
app.state.agent_graphs: dict = {}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Routers # Routers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
app.include_router(v1_router, prefix="/v1")
app.include_router(auth_router) app.include_router(auth_router)
app.include_router(jellystat_router) app.include_router(jellystat_router)
+3 -113
View File
@@ -1,23 +1,17 @@
""" """
Auth Store — SQLite-backed persistence for Discord-to-service authentication. Auth Store — SQLite-backed persistence for Discord-to-service authentication.
Two tables: Stores per-user, per-service credentials (Jellyfin access tokens, etc.).
- link_tokens : one-time tokens sent via Discord DM to initiate login
- user_auth : per-user, per-service credentials (Jellyfin token, etc.)
Thread-safe via WAL mode and a shared lock. No passwords are ever stored Thread-safe via WAL mode and a shared lock. No passwords are ever stored
— only opaque service tokens (e.g. Jellyfin AccessToken). — only opaque service tokens.
""" """
from __future__ import annotations from __future__ import annotations
import logging import logging
import secrets
import sqlite3 import sqlite3
import threading import threading
from datetime import datetime, timedelta, timezone
from pathlib import Path from pathlib import Path
from typing import Optional
from src.config import get_config from src.config import get_config
@@ -27,7 +21,6 @@ logger = logging.getLogger("auth_store")
# Config # Config
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
AUTH_DB_PATH = get_config("AUTH_DB_PATH", "data/auth.db") AUTH_DB_PATH = get_config("AUTH_DB_PATH", "data/auth.db")
TOKEN_EXPIRY_MINUTES = int(get_config("AUTH_TOKEN_EXPIRY", "10"))
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Singleton handle # Singleton handle
@@ -54,8 +47,6 @@ def _resolve_path() -> Path:
def _get_conn() -> sqlite3.Connection: def _get_conn() -> sqlite3.Connection:
"""Return a thread-local connection to the auth database.""" """Return a thread-local connection to the auth database."""
import sqlite3
conn = sqlite3.connect(str(_resolve_path()), check_same_thread=False) conn = sqlite3.connect(str(_resolve_path()), check_same_thread=False)
conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON") conn.execute("PRAGMA foreign_keys=ON")
@@ -68,15 +59,6 @@ def _get_conn() -> sqlite3.Connection:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
_SCHEMA = """ _SCHEMA = """
CREATE TABLE IF NOT EXISTS link_tokens (
token TEXT PRIMARY KEY,
discord_user_id INTEGER NOT NULL,
service TEXT NOT NULL,
expires_at TEXT NOT NULL,
used INTEGER DEFAULT 0,
created_at TEXT DEFAULT (datetime('now'))
);
CREATE TABLE IF NOT EXISTS user_auth ( CREATE TABLE IF NOT EXISTS user_auth (
discord_user_id INTEGER NOT NULL, discord_user_id INTEGER NOT NULL,
service TEXT NOT NULL, service TEXT NOT NULL,
@@ -107,97 +89,6 @@ def _ensure_schema() -> None:
logger.info("Auth store initialized at %s", _resolve_path()) logger.info("Auth store initialized at %s", _resolve_path())
# ---------------------------------------------------------------------------
# Public API — Link Tokens
# ---------------------------------------------------------------------------
def create_token(discord_user_id: int, service: str) -> str:
"""Generate a one-time link token. Expires after TOKEN_EXPIRY_MINUTES."""
_ensure_schema()
token = secrets.token_urlsafe(32)
expires = (datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRY_MINUTES)).isoformat()
with _db_lock:
conn = _get_conn()
conn.execute(
"INSERT INTO link_tokens (token, discord_user_id, service, expires_at) VALUES (?, ?, ?, ?)",
(token, discord_user_id, service, expires),
)
conn.commit()
conn.close()
logger.info("Created link token for user %s / service %s", discord_user_id, service)
return token
def validate_token(token: str) -> tuple[int, str] | None:
"""Read-only validation — does NOT consume the token.
Returns (discord_user_id, service) if the token exists, is unused,
and has not expired. Returns None otherwise.
"""
_ensure_schema()
with _db_lock:
conn = _get_conn()
row = conn.execute(
"SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?",
(token,),
).fetchone()
conn.close()
if row is None:
return None
if row["used"]:
return None
expires = datetime.fromisoformat(row["expires_at"])
if datetime.now(timezone.utc) > expires:
return None
return (row["discord_user_id"], row["service"])
def consume_token(token: str) -> tuple[int, str] | None:
"""Validate and consume a link token. Returns (discord_user_id, service) or None.
A token is valid if:
- It exists
- It has not been used
- It has not expired
"""
_ensure_schema()
with _db_lock:
conn = _get_conn()
row = conn.execute(
"SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?",
(token,),
).fetchone()
if row is None:
conn.close()
return None
if row["used"]:
conn.close()
logger.warning("Token already used: %s", token[:8])
return None
expires = datetime.fromisoformat(row["expires_at"])
if datetime.now(timezone.utc) > expires:
conn.close()
logger.warning("Token expired: %s", token[:8])
return None
conn.execute("UPDATE link_tokens SET used = 1 WHERE token = ?", (token,))
conn.commit()
result = (row["discord_user_id"], row["service"])
conn.close()
logger.info("Token consumed: %s… → user=%s service=%s", token[:8], result[0], result[1])
return result
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Public API — User Auth # Public API — User Auth
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -350,9 +241,8 @@ def reset_all() -> None:
with _db_lock: with _db_lock:
conn = _get_conn() conn = _get_conn()
conn.execute("DELETE FROM link_tokens")
conn.execute("DELETE FROM user_auth") conn.execute("DELETE FROM user_auth")
conn.commit() conn.commit()
conn.close() conn.close()
logger.warning("Auth store RESET — all tokens and auth records cleared.") logger.warning("Auth store RESET — all auth records cleared.")
-51
View File
@@ -1,51 +0,0 @@
"""
Tools adapter — bridges the existing skill/tool system with LangGraph's ToolNode.
LangGraph's ToolNode expects callable tools (typically @tool-decorated functions).
This module wraps our skill-based tool definitions and async executors so
ToolNode can invoke them without any changes to the skills/ layer.
"""
from __future__ import annotations
import json
from typing import Any
from langchain_core.tools import tool
from agents.skills import get_all_tools, execute_tool
def build_langgraph_tools(skill_names: list[str]) -> list:
"""
Convert the registered skill tool definitions into LangChain-compatible
@tool-decorated functions that ToolNode can call.
Each tool wraps the existing `execute_tool()` pipeline, so the skill
system's ToolResult + httpx session handling is fully preserved.
"""
tool_defs = get_all_tools(skill_names)
wrapped: list = []
for td in tool_defs:
fn_def = td.get("function", {})
fn_name = fn_def.get("name", "")
fn_desc = fn_def.get("description", "")
# Create a unique factory so each closure captures the right fn_name
def _make_tool(name: str, desc: str, skills: list[str]):
@tool(name, description=desc)
async def _wrapped(**kwargs: Any) -> str:
"""Execute the tool via the skill system and return its content."""
result = await execute_tool(skills, name, kwargs)
if result is None:
return f"Tool '{name}' is not available."
return result.content
# Stash the original OpenAI schema so LangGraph can use it
_wrapped.metadata = fn_def
return _wrapped
wrapped.append(_make_tool(fn_name, fn_desc, skill_names))
return wrapped