Compare commits
1 Commits
f28f0d41ec
..
main
| Author | SHA1 | Date | |
|---|---|---|---|
| f21676eafd |
@@ -0,0 +1,8 @@
|
|||||||
|
.env
|
||||||
|
.venv/
|
||||||
|
__pycache__/
|
||||||
|
.git/
|
||||||
|
.gitea/
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
data/
|
||||||
@@ -0,0 +1,78 @@
|
|||||||
|
# CLAUDE.md
|
||||||
|
|
||||||
|
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||||
|
|
||||||
|
## Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run the server (reads .env for config)
|
||||||
|
uvicorn main:app --host 0.0.0.0 --port 8000
|
||||||
|
|
||||||
|
# Verify code compiles and imports cleanly
|
||||||
|
python -c "import main; print('OK')"
|
||||||
|
|
||||||
|
# Syntax check a specific file
|
||||||
|
python -m py_compile path/to/file.py
|
||||||
|
|
||||||
|
# Docker build
|
||||||
|
docker build -t agents-api -f docker/Dockerfile .
|
||||||
|
```
|
||||||
|
|
||||||
|
There is no test suite or linting setup yet.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
This is a **Discord bot** powered by a LangGraph agent with a pluggable skill system. The only interaction surface is Discord DMs — there is no web chat UI.
|
||||||
|
|
||||||
|
```
|
||||||
|
Discord DM → bot.py → LangGraph StateGraph → skills (tools) → external APIs
|
||||||
|
│
|
||||||
|
REST API (auth status, JellyStat)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Agent / skill system (`agents/`)
|
||||||
|
|
||||||
|
- **Agents** (`agents/__init__.py`) are thin wrappers pairing a system prompt with a list of skill names.
|
||||||
|
- **Skills** (`agents/skills/__init__.py`) provide prompt fragments, OpenAI tool definitions, and an async `execute` callback. Each skill self-registers via `register()` at import time.
|
||||||
|
- Agents and skills are loaded at startup by `load_all_agents()` in `main.py`, which triggers side-effecting imports of all agent/skill modules.
|
||||||
|
|
||||||
|
The media-agent (`agents/media_agent.py`) is the primary agent. Skills attached:
|
||||||
|
- `seerr` — search, trending, discover, request, submit issues via Seerr API
|
||||||
|
- `watch_history` — Jellyfin watch stats via the internal JellyStat API
|
||||||
|
- `media_info` — base persona (prompt-only)
|
||||||
|
- `triage` — fallback rules for unsupported actions (prompt-only)
|
||||||
|
- `easter_eggs` — theme-aware persona flavours (prompt-only)
|
||||||
|
|
||||||
|
### LangGraph graph (`src/graph.py`)
|
||||||
|
|
||||||
|
Two-node StateGraph: `agent_node → tool_node → agent_node`. The agent node calls DeepSeek (OpenAI-compatible) with system prompt + tool defs. The tool node executes tool calls through the skill system via `execute_tool()`. A custom tool node is used — not LangChain's ToolNode.
|
||||||
|
|
||||||
|
State is `AgentState` (`src/state.py`): a TypedDict with `messages` (LangGraph `add_messages` reducer) and `discord_user_id`.
|
||||||
|
|
||||||
|
### Discord bot (`gateway/discord/bot.py`)
|
||||||
|
|
||||||
|
Runs in a background daemon thread with its own asyncio event loop (separate from FastAPI's). DMs only — no server/channel messages. Requires users to share a guild with the bot. Maintains per-user conversation history via `ConversationStore` (in-memory dict, last N exchanges). Supports `/login <service>` (Quick Connect) and `/logout <service>`.
|
||||||
|
|
||||||
|
### Auth system (`src/auth_store.py`, `gateway/auth/`)
|
||||||
|
|
||||||
|
SQLite-backed (WAL mode) with a single `user_auth` table keyed on `(discord_user_id, service)`. Stores opaque service tokens, never passwords. `AuthService` is a minimal ABC with `name`/`display_name` properties; `JellyfinAuth` is the only implementation, using Jellyfin Quick Connect (initiate → poll → exchange secret for token). The auth gate in `execute_tool()` checks `skill.requires_auth` before executing tools.
|
||||||
|
|
||||||
|
### REST API (`main.py`)
|
||||||
|
|
||||||
|
Minimal — only two routers remain:
|
||||||
|
- `gateway/v1/auth.py` — `GET /api/v1/auth/Discord/status` (linked services lookup) and `POST /api/v1/auth/reset` (dev only)
|
||||||
|
- `gateway/jellystat/api.py` — `GET /jellystat/{history,genres,summary}/{user_id}` called internally by the `watch_history` skill
|
||||||
|
|
||||||
|
### JellyStat (`gateway/jellystat/`)
|
||||||
|
|
||||||
|
PostgreSQL connection pool stored on `app.state.jellystat_pool`. Database functions (`startup-functions.sql`) are deployed on startup via `CREATE OR REPLACE FUNCTION`. The watch_history skill calls these endpoints over HTTP (localhost) rather than querying the DB directly, keeping DB credentials isolated from the skill layer.
|
||||||
|
|
||||||
|
### Seerr session caching (`agents/skills/seerr.py`)
|
||||||
|
|
||||||
|
`httpx.AsyncClient` instances are cached per event loop (the Discord bot thread has its own loop separate from FastAPI). Cookie-based auth for Seerr is obtained once at startup via a sync login (thread-safe with double-check locking), then reused across all event-loop-specific clients.
|
||||||
|
|
||||||
|
## Key patterns
|
||||||
|
|
||||||
|
- **Self-registration**: agents, skills, and auth services all register at import time via module-level function calls. New modules just need to be imported once (see `load_all_agents()` and `import gateway.auth.jellyfin` in `main.py`).
|
||||||
|
- **Auth gate**: skills declare `requires_auth=["jellyfin"]` and the framework checks credentials before tool execution. Tools receive `_discord_user_id` injected into their args dict.
|
||||||
|
- **TMDb IDs as source of truth**: media tools display `[tmdb:123456]` tags and prefer IDs over title matching. The system prompt instructs the LLM to always show and use these IDs.
|
||||||
@@ -84,9 +84,9 @@ THEMES = {
|
|||||||
"airplane", "plane", "flight", "pilot", "cockpit", "turbulence"
|
"airplane", "plane", "flight", "pilot", "cockpit", "turbulence"
|
||||||
],
|
],
|
||||||
"persona": (
|
"persona": (
|
||||||
"mention airplane noises, and say that Erwin will be terrified "
|
"Mention airplane noises, and say that Erwin will be terrified. "
|
||||||
"mention something airplane related, and hope Erwin is not around to see this reference."
|
"Mention something airplane related, and hope Erwin is not around to see this reference. "
|
||||||
"Stay fully functional — carry out all requested actions normally, but only if something is requested"
|
"Stay fully functional — carry out all requested actions normally, but only if something is requested."
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -95,9 +95,9 @@ THEMES = {
|
|||||||
"unable to read", "dyslexia", "dislexia", "dislexic", "dyslexic", "typo", "trouble reading", "misspelled", "misspelling"
|
"unable to read", "dyslexia", "dislexia", "dislexic", "dyslexic", "typo", "trouble reading", "misspelled", "misspelling"
|
||||||
],
|
],
|
||||||
"persona": (
|
"persona": (
|
||||||
"mention that tim is dyslexic and has trouble reading in the response "
|
"Mention that Tim is dyslexic and has trouble reading in the response. "
|
||||||
"mention it is sometimes difficult to read, since the creator (which is Tim) of this agent is dyslexic. "
|
"Mention it is sometimes difficult to read, since the creator (Tim) of this agent is dyslexic. "
|
||||||
"Stay fully functional — carry out all requested actions normally, but only if something is requested"
|
"Stay fully functional — carry out all requested actions normally, but only if something is requested."
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+6
-25
@@ -565,26 +565,12 @@ async def _discover(args: dict) -> ToolResult:
|
|||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
results = r.json().get("results", [])
|
results = r.json().get("results", [])
|
||||||
desc = genre
|
desc = genre
|
||||||
elif studio:
|
elif studio or keyword:
|
||||||
# /discover/studio/{studioId} requires a numeric TMDB studio ID.
|
# Both use /search with free-text query (studio needs a numeric TMDB
|
||||||
# Fall back to searching by name via /search.
|
# studio ID not available here, so fall back to name search).
|
||||||
desc = studio
|
desc = studio or keyword
|
||||||
search_query = studio
|
|
||||||
endpoint = "/api/v1/search"
|
endpoint = "/api/v1/search"
|
||||||
params["query"] = search_query
|
params["query"] = desc
|
||||||
if language:
|
|
||||||
params["language"] = language
|
|
||||||
async with _client() as c:
|
|
||||||
r = await c.get(endpoint, params=params)
|
|
||||||
r.raise_for_status()
|
|
||||||
results = r.json().get("results", [])
|
|
||||||
# Filter to requested media type
|
|
||||||
results = [item for item in results if item.get("mediaType") == kind]
|
|
||||||
elif keyword:
|
|
||||||
# Free-text keyword → use /search, filtered by mediaType
|
|
||||||
desc = keyword
|
|
||||||
endpoint = "/api/v1/search"
|
|
||||||
params["query"] = keyword
|
|
||||||
if language:
|
if language:
|
||||||
params["language"] = language
|
params["language"] = language
|
||||||
async with _client() as c:
|
async with _client() as c:
|
||||||
@@ -616,7 +602,6 @@ async def _request_media(args: dict) -> ToolResult:
|
|||||||
async with _client() as c:
|
async with _client() as c:
|
||||||
# --- Fast-path: TMDb ID known — confirm the title and request directly ---
|
# --- Fast-path: TMDb ID known — confirm the title and request directly ---
|
||||||
if tmdb_id:
|
if tmdb_id:
|
||||||
# Quick lookup to get the correct title for the confirmation message
|
|
||||||
detail_r = await c.get(f"/api/v1/{kind}/{tmdb_id}")
|
detail_r = await c.get(f"/api/v1/{kind}/{tmdb_id}")
|
||||||
if detail_r.status_code == 200:
|
if detail_r.status_code == 200:
|
||||||
detail = detail_r.json()
|
detail = detail_r.json()
|
||||||
@@ -626,11 +611,6 @@ async def _request_media(args: dict) -> ToolResult:
|
|||||||
or detail.get("firstAirDate", "")[:4]
|
or detail.get("firstAirDate", "")[:4]
|
||||||
or "?"
|
or "?"
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# Detail lookup failed — fall back to title search
|
|
||||||
pass
|
|
||||||
|
|
||||||
if detail_r.status_code == 200:
|
|
||||||
# Submit directly with the known TMDb ID
|
# Submit directly with the known TMDb ID
|
||||||
request_body: dict = {"mediaType": kind, "mediaId": tmdb_id}
|
request_body: dict = {"mediaType": kind, "mediaId": tmdb_id}
|
||||||
if kind == "tv":
|
if kind == "tv":
|
||||||
@@ -651,6 +631,7 @@ async def _request_media(args: dict) -> ToolResult:
|
|||||||
f"❌ Failed to request **{media_title}** ({media_year}). "
|
f"❌ Failed to request **{media_title}** ({media_year}). "
|
||||||
f"Seerr responded with status {req_r.status_code}: {req_r.text[:500]}"
|
f"Seerr responded with status {req_r.status_code}: {req_r.text[:500]}"
|
||||||
)
|
)
|
||||||
|
# Detail lookup failed — fall through to slow-path title search
|
||||||
|
|
||||||
# --- Slow-path: search by title ---
|
# --- Slow-path: search by title ---
|
||||||
r = await c.get("/api/v1/search", params={"query": quote(title), "page": 1})
|
r = await c.get("/api/v1/search", params={"query": quote(title), "page": 1})
|
||||||
|
|||||||
@@ -10,17 +10,17 @@ Add a new service (Plex, Seerr, etc.) by:
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# AuthResult — returned by AuthService.authenticate()
|
# AuthResult — returned by AuthService authentication
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AuthResult:
|
class AuthResult:
|
||||||
"""Outcome of a credential validation attempt."""
|
"""Outcome of an authentication attempt."""
|
||||||
success: bool
|
success: bool
|
||||||
external_user_id: Optional[str] = None
|
external_user_id: Optional[str] = None
|
||||||
external_name: Optional[str] = None
|
external_name: Optional[str] = None
|
||||||
@@ -38,8 +38,6 @@ class AuthService(ABC):
|
|||||||
Subclasses must implement:
|
Subclasses must implement:
|
||||||
- name : unique identifier used in URLs and DB keys
|
- name : unique identifier used in URLs and DB keys
|
||||||
- display_name : human-readable label shown in Discord
|
- display_name : human-readable label shown in Discord
|
||||||
- render_login_form(token, discord_id) → HTML string
|
|
||||||
- authenticate(form_data) → AuthResult
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -54,22 +52,6 @@ class AuthService(ABC):
|
|||||||
"""Human-readable: "Jellyfin", "Seerr", "Plex" """
|
"""Human-readable: "Jellyfin", "Seerr", "Plex" """
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def render_login_form(self, token: str, discord_id: int) -> str:
|
|
||||||
"""Return HTML string with a login form for this service.
|
|
||||||
|
|
||||||
The form MUST include these hidden fields:
|
|
||||||
<input type="hidden" name="token" value="{token}">
|
|
||||||
<input type="hidden" name="discord_id" value="{discord_id}">
|
|
||||||
<input type="hidden" name="service" value="{self.name}">
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def authenticate(self, form_data: dict) -> AuthResult:
|
|
||||||
"""Validate credentials against the service. Return AuthResult."""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Global registry
|
# Global registry
|
||||||
|
|||||||
+6
-144
@@ -1,15 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
Jellyfin AuthService — validates Jellyfin credentials and stores the session token.
|
Jellyfin AuthService — authenticates users via Jellyfin Quick Connect.
|
||||||
|
|
||||||
Two authentication flows:
|
Flow:
|
||||||
1. Quick Connect (primary): user enters a short code on their Jellyfin app.
|
1. initiate_quick_connect() → {code, secret}
|
||||||
- initiate_quick_connect() → {code, secret}
|
2. poll_quick_connect(secret) → "Active" | "Authorized" | "Expired"
|
||||||
- poll_quick_connect(secret) → "Active" | "Authorized" | "Expired"
|
3. authenticate_quick_connect(secret) → AuthResult with token
|
||||||
- authenticate_quick_connect(secret) → AuthResult with token
|
|
||||||
|
|
||||||
2. Username/password (legacy): renders an HTML form, called via the REST API.
|
|
||||||
- render_login_form(token, discord_id) → HTML string
|
|
||||||
- authenticate(form_data) → AuthResult
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -25,14 +20,6 @@ from src.config import get_config
|
|||||||
|
|
||||||
logger = logging.getLogger("auth.jellyfin")
|
logger = logging.getLogger("auth.jellyfin")
|
||||||
|
|
||||||
# Emby-style authorization header required by Jellyfin's AuthenticateByName
|
|
||||||
_EMBY_HEADER = (
|
|
||||||
'MediaBrowser Client="AgentBot",'
|
|
||||||
'Device="DiscordBot",'
|
|
||||||
'DeviceId="agent-bot",'
|
|
||||||
'Version="1.0"'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class QuickConnectResult:
|
class QuickConnectResult:
|
||||||
@@ -64,9 +51,7 @@ class JellyfinAuth(AuthService):
|
|||||||
|
|
||||||
async def _resolve_url(self) -> str | None:
|
async def _resolve_url(self) -> str | None:
|
||||||
"""
|
"""
|
||||||
Resolve the Jellyfin server URL.
|
Resolve the Jellyfin server URL from the JELLYFIN_URL env var.
|
||||||
1. Check JELLYFIN_URL env var (used in deployment).
|
|
||||||
2. Check if user already has a stored auth with a URL (from legacy login).
|
|
||||||
Returns None if no URL is configured.
|
Returns None if no URL is configured.
|
||||||
"""
|
"""
|
||||||
# First: explicit env var
|
# First: explicit env var
|
||||||
@@ -272,129 +257,6 @@ class JellyfinAuth(AuthService):
|
|||||||
error_message="An unexpected error occurred during authentication.",
|
error_message="An unexpected error occurred during authentication.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Login form (legacy — used by the REST API)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def render_login_form(self, token: str, discord_id: int) -> str:
|
|
||||||
return f"""<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="utf-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
|
||||||
<title>Link Jellyfin</title>
|
|
||||||
<style>
|
|
||||||
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; }}
|
|
||||||
h2 {{ margin-bottom: 4px; }}
|
|
||||||
.sub {{ color: #666; margin-bottom: 24px; }}
|
|
||||||
label {{ display: block; margin-top: 16px; font-weight: 600; }}
|
|
||||||
input {{ width: 100%; padding: 10px; margin-top: 4px; border: 1px solid #ccc; border-radius: 6px; box-sizing: border-box; }}
|
|
||||||
button {{ margin-top: 24px; width: 100%; padding: 12px; background: #aa5cc3; color: #fff; border: none; border-radius: 6px; font-size: 16px; cursor: pointer; }}
|
|
||||||
button:hover {{ background: #9448b0; }}
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<h2>🔗 Link Jellyfin to Discord</h2>
|
|
||||||
<p class="sub">Enter your Jellyfin server URL and credentials to link your account.</p>
|
|
||||||
|
|
||||||
<form method="POST" action="/api/v1/auth/login">
|
|
||||||
<input type="hidden" name="token" value="{token}">
|
|
||||||
<input type="hidden" name="discord_id" value="{discord_id}">
|
|
||||||
<input type="hidden" name="service" value="jellyfin">
|
|
||||||
|
|
||||||
<label for="jellyfin_url">Jellyfin Server URL</label>
|
|
||||||
<input id="jellyfin_url" name="jellyfin_url" type="url"
|
|
||||||
placeholder="https://jellyfin.example.com" required>
|
|
||||||
|
|
||||||
<label for="username">Username</label>
|
|
||||||
<input id="username" name="username" type="text"
|
|
||||||
placeholder="Your Jellyfin username" required autofocus>
|
|
||||||
|
|
||||||
<label for="password">Password</label>
|
|
||||||
<input id="password" name="password" type="password"
|
|
||||||
placeholder="Your Jellyfin password" required>
|
|
||||||
|
|
||||||
<button type="submit">Link Account</button>
|
|
||||||
</form>
|
|
||||||
</body>
|
|
||||||
</html>"""
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Authentication
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
async def authenticate(self, form_data: dict) -> AuthResult:
|
|
||||||
url = form_data.get("jellyfin_url", "").strip().rstrip("/")
|
|
||||||
username = form_data.get("username", "").strip()
|
|
||||||
password = form_data.get("password", "").strip()
|
|
||||||
|
|
||||||
if not url or not username or not password:
|
|
||||||
return AuthResult(
|
|
||||||
success=False,
|
|
||||||
error_message="All fields are required (URL, username, password).",
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Attempting Jellyfin login for '%s' on %s", username, url)
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=10) as client:
|
|
||||||
try:
|
|
||||||
resp = await client.post(
|
|
||||||
f"{url}/Users/AuthenticateByName",
|
|
||||||
json={"Username": username, "Pw": password},
|
|
||||||
headers={"X-Emby-Authorization": _EMBY_HEADER},
|
|
||||||
)
|
|
||||||
if resp.status_code != 200:
|
|
||||||
logger.warning(
|
|
||||||
"Jellyfin login failed for '%s': HTTP %s", username, resp.status_code
|
|
||||||
)
|
|
||||||
return AuthResult(
|
|
||||||
success=False,
|
|
||||||
error_message=f"Login failed — check your server URL and credentials.",
|
|
||||||
)
|
|
||||||
|
|
||||||
data = resp.json()
|
|
||||||
user = data.get("User", {})
|
|
||||||
token = data.get("AccessToken", "")
|
|
||||||
|
|
||||||
if not token:
|
|
||||||
return AuthResult(
|
|
||||||
success=False,
|
|
||||||
error_message="Jellyfin returned an unexpected response.",
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Jellyfin login OK: user=%s (%s)",
|
|
||||||
user.get("Name", "?"),
|
|
||||||
user.get("Id", "?"),
|
|
||||||
)
|
|
||||||
|
|
||||||
return AuthResult(
|
|
||||||
success=True,
|
|
||||||
external_user_id=user.get("Id", ""),
|
|
||||||
external_name=user.get("Name", username),
|
|
||||||
credentials={
|
|
||||||
"token": token,
|
|
||||||
"url": url,
|
|
||||||
"user_id": user.get("Id", ""),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
except httpx.TimeoutException:
|
|
||||||
return AuthResult(
|
|
||||||
success=False,
|
|
||||||
error_message=f"Could not reach {url} — connection timed out. Check the URL.",
|
|
||||||
)
|
|
||||||
except httpx.ConnectError:
|
|
||||||
return AuthResult(
|
|
||||||
success=False,
|
|
||||||
error_message=f"Could not connect to {url}. Is the server running?",
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.exception("Unexpected error during Jellyfin login")
|
|
||||||
return AuthResult(
|
|
||||||
success=False,
|
|
||||||
error_message=f"An unexpected error occurred. Please try again.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Self-register at import time
|
# Self-register at import time
|
||||||
|
|||||||
@@ -1,36 +0,0 @@
|
|||||||
from fastapi import Request
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
from src.graph import create_agent_graph
|
|
||||||
|
|
||||||
|
|
||||||
def get_llm_client(request: Request) -> OpenAI:
|
|
||||||
"""FastAPI dependency — returns the singleton OpenAI client from app.state."""
|
|
||||||
return request.app.state.llm_client
|
|
||||||
|
|
||||||
|
|
||||||
def get_agent_graph(agent_id: str, request: Request):
|
|
||||||
"""
|
|
||||||
FastAPI dependency — returns the compiled LangGraph graph for *agent_id*.
|
|
||||||
|
|
||||||
Graphs are lazily compiled on first use and cached on app.state so each
|
|
||||||
agent's graph is only built once per process lifetime.
|
|
||||||
"""
|
|
||||||
cache: dict = request.app.state.agent_graphs
|
|
||||||
|
|
||||||
if agent_id not in cache:
|
|
||||||
from agents import get as get_agent
|
|
||||||
|
|
||||||
agent = get_agent(agent_id)
|
|
||||||
if agent is None:
|
|
||||||
# Fall back to the naked agent if the requested one doesn't exist
|
|
||||||
agent_id = "naked"
|
|
||||||
agent = get_agent(agent_id)
|
|
||||||
|
|
||||||
cache[agent_id] = create_agent_graph(
|
|
||||||
client=request.app.state.llm_client,
|
|
||||||
agent_skills=agent.skills,
|
|
||||||
system_prompt=agent.build_system_prompt(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return cache[agent_id]
|
|
||||||
@@ -207,9 +207,6 @@ class AgentBot(discord.Client):
|
|||||||
|
|
||||||
# --- Quick Connect flow ---
|
# --- Quick Connect flow ---
|
||||||
svc = get_auth_service(service)
|
svc = get_auth_service(service)
|
||||||
if svc is None:
|
|
||||||
await message.channel.send(f"Unknown service: {service}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
await message.channel.send(f"🔑 Starting **{svc.display_name}** Quick Connect…")
|
await message.channel.send(f"🔑 Starting **{svc.display_name}** Quick Connect…")
|
||||||
|
|
||||||
|
|||||||
+9
-140
@@ -1,156 +1,27 @@
|
|||||||
"""
|
"""
|
||||||
Auth API — generic endpoints for linking Discord users to external services.
|
Auth API — endpoints for checking linked services and dev operations.
|
||||||
|
|
||||||
GET /api/v1/auth/login?service=X&token=Y&discord_id=Z
|
GET /api/v1/auth/Discord/status?discord_id=Z
|
||||||
Validates the link token and serves a service-specific login form.
|
|
||||||
|
|
||||||
POST /api/v1/auth/login
|
|
||||||
Accepts the form submission, validates credentials against the service,
|
|
||||||
stores the session, and returns a result page.
|
|
||||||
|
|
||||||
GET /api/v1/auth/status?discord_id=Z
|
|
||||||
Returns which services are linked for this Discord user.
|
Returns which services are linked for this Discord user.
|
||||||
|
|
||||||
|
POST /api/v1/auth/reset
|
||||||
|
Wipes the auth store (dev only — requires ALLOW_AUTH_RESET=true).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from fastapi import APIRouter, Form, HTTPException, Request
|
from fastapi import APIRouter, HTTPException
|
||||||
from fastapi.responses import HTMLResponse
|
|
||||||
|
|
||||||
from gateway.auth import get_auth_service, list_auth_services
|
|
||||||
from src import auth_store
|
from src import auth_store
|
||||||
|
from src.config import get_config
|
||||||
|
|
||||||
logger = logging.getLogger("gateway.auth")
|
logger = logging.getLogger("gateway.auth")
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# GET /auth/login — serve the login form
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.get("/login")
|
|
||||||
async def login_form(
|
|
||||||
service: str,
|
|
||||||
token: str,
|
|
||||||
discord_id: int,
|
|
||||||
):
|
|
||||||
"""Validate the one-time link token and return a service-specific login form."""
|
|
||||||
|
|
||||||
# Validate the token WITHOUT consuming it (the POST will consume it)
|
|
||||||
result = auth_store.validate_token(token)
|
|
||||||
if result is None:
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid or expired link token.")
|
|
||||||
|
|
||||||
uid, svc = result
|
|
||||||
if uid != discord_id or svc != service:
|
|
||||||
raise HTTPException(status_code=400, detail="Token does not match the request.")
|
|
||||||
|
|
||||||
# Look up the AuthService
|
|
||||||
svc_obj = get_auth_service(service)
|
|
||||||
if svc_obj is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Unknown service: {service}")
|
|
||||||
|
|
||||||
logger.info("Serving login form: user=%s service=%s", discord_id, service)
|
|
||||||
return HTMLResponse(svc_obj.render_login_form(token, discord_id))
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# POST /auth/login — handle form submission
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.post("/login")
|
|
||||||
async def login_submit(request: Request):
|
|
||||||
"""Handle the login form POST: validate credentials, store auth, show result."""
|
|
||||||
|
|
||||||
# Parse form data
|
|
||||||
form = await request.form()
|
|
||||||
token = form.get("token", "")
|
|
||||||
discord_id_str = form.get("discord_id", "")
|
|
||||||
service = form.get("service", "")
|
|
||||||
|
|
||||||
if not token or not discord_id_str or not service:
|
|
||||||
raise HTTPException(status_code=400, detail="Missing required fields.")
|
|
||||||
|
|
||||||
try:
|
|
||||||
discord_id = int(discord_id_str)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid discord_id.")
|
|
||||||
|
|
||||||
# Consume the token on POST (the GET only validated, didn't consume)
|
|
||||||
result = auth_store.consume_token(token)
|
|
||||||
if result is None:
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid or expired link token.")
|
|
||||||
|
|
||||||
# Look up the AuthService
|
|
||||||
svc_obj = get_auth_service(service)
|
|
||||||
if svc_obj is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Unknown service: {service}")
|
|
||||||
|
|
||||||
# Collect service-specific form fields (everything except token, discord_id, service)
|
|
||||||
form_data: dict[str, str] = {}
|
|
||||||
for key, value in form.items():
|
|
||||||
if key not in ("token", "discord_id", "service"):
|
|
||||||
form_data[key] = str(value)
|
|
||||||
|
|
||||||
# Authenticate against the service
|
|
||||||
auth_result = await svc_obj.authenticate(form_data)
|
|
||||||
|
|
||||||
if not auth_result.success:
|
|
||||||
return HTMLResponse(
|
|
||||||
status_code=401,
|
|
||||||
content=f"""<!DOCTYPE html>
|
|
||||||
<html><head><meta charset="utf-8"><title>Login Failed</title>
|
|
||||||
<style>
|
|
||||||
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; }}
|
|
||||||
h2 {{ color: #d32f2f; }}
|
|
||||||
a {{ color: #aa5cc3; }}
|
|
||||||
</style></head><body>
|
|
||||||
<h2>❌ Login Failed</h2>
|
|
||||||
<p>{auth_result.error_message or "Authentication failed. Please try again."}</p>
|
|
||||||
<p><a href="javascript:history.back()">← Go back and try again</a></p>
|
|
||||||
</body></html>""",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store the successful auth
|
|
||||||
auth_store.store_auth(
|
|
||||||
discord_user_id=discord_id,
|
|
||||||
service=service,
|
|
||||||
external_user_id=auth_result.external_user_id or "",
|
|
||||||
external_name=auth_result.external_name or "",
|
|
||||||
credentials=auth_result.credentials,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Auth linked: discord=%s → %s (%s)",
|
|
||||||
discord_id,
|
|
||||||
service,
|
|
||||||
auth_result.external_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
return HTMLResponse(f"""<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="utf-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
|
||||||
<title>Account Linked</title>
|
|
||||||
<style>
|
|
||||||
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; text-align: center; }}
|
|
||||||
h1 {{ color: #388e3c; }}
|
|
||||||
.name {{ font-weight: bold; color: #aa5cc3; }}
|
|
||||||
p {{ color: #666; }}
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<h1>✅ Account Linked!</h1>
|
|
||||||
<p>Logged in as <span class="name">{auth_result.external_name}</span> on <strong>{svc_obj.display_name}</strong>.</p>
|
|
||||||
<p>You can close this page and return to Discord.</p>
|
|
||||||
</body>
|
|
||||||
</html>""")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# GET /auth/status — get all linked services for a Discord user
|
# GET /auth/status — get all linked services for a Discord user
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -199,12 +70,10 @@ async def auth_status(discord_id: int):
|
|||||||
# POST /auth/reset — wipe auth store (DEV ONLY)
|
# POST /auth/reset — wipe auth store (DEV ONLY)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
from src.config import get_config # noqa: E402
|
|
||||||
|
|
||||||
@router.post("/reset")
|
@router.post("/reset")
|
||||||
async def reset_auth():
|
async def reset_auth():
|
||||||
"""
|
"""
|
||||||
Reset the entire auth store — clears all link tokens and user auth records.
|
Reset the entire auth store — clears all user auth records.
|
||||||
|
|
||||||
Only enabled when ALLOW_AUTH_RESET=true in the environment.
|
Only enabled when ALLOW_AUTH_RESET=true in the environment.
|
||||||
Returns 403 in production.
|
Returns 403 in production.
|
||||||
@@ -217,4 +86,4 @@ async def reset_auth():
|
|||||||
|
|
||||||
auth_store.reset_all()
|
auth_store.reset_all()
|
||||||
logger.warning("Auth store reset via API endpoint.")
|
logger.warning("Auth store reset via API endpoint.")
|
||||||
return {"status": "ok", "message": "Auth store cleared — all tokens and auth records removed."}
|
return {"status": "ok", "message": "Auth store cleared — all auth records removed."}
|
||||||
|
|||||||
@@ -1,241 +0,0 @@
|
|||||||
from fastapi import APIRouter, Depends, Request
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from openai import OpenAI
|
|
||||||
from pydantic import BaseModel
|
|
||||||
import json
|
|
||||||
|
|
||||||
from gateway.dependencies import get_llm_client, get_agent_graph
|
|
||||||
from agents import get as get_agent, list_all as list_all_agents
|
|
||||||
from src.state import AgentState
|
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
|
||||||
message: str
|
|
||||||
session_id: str | None = None
|
|
||||||
agent_id: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
|
||||||
messages: list[dict]
|
|
||||||
stream: bool = False
|
|
||||||
model: str = "deepseek-chat"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Agent resolution
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _resolve_agent(agent_id: str | None = None, model: str | None = None):
|
|
||||||
"""
|
|
||||||
1. explicit agent_id
|
|
||||||
2. model field (OpenWebUI sends this — maps to agent_id if registered)
|
|
||||||
3. fallback to "naked"
|
|
||||||
"""
|
|
||||||
lookup = agent_id or model
|
|
||||||
if lookup is None:
|
|
||||||
return get_agent("naked")
|
|
||||||
agent = get_agent(lookup)
|
|
||||||
return agent if agent else get_agent("naked")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# LangGraph helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
async def _invoke_graph(graph, messages: list[dict]) -> str:
|
|
||||||
"""Run the graph synchronously (non-streaming) and return the final text."""
|
|
||||||
state: AgentState = {"messages": messages}
|
|
||||||
result = await graph.ainvoke(state)
|
|
||||||
last_msg = result["messages"][-1]
|
|
||||||
return last_msg.content or ""
|
|
||||||
|
|
||||||
|
|
||||||
async def _stream_graph(graph, messages: list[dict]):
|
|
||||||
"""
|
|
||||||
Run the graph and stream the final response token-by-token.
|
|
||||||
|
|
||||||
LangGraph's astream_events would require langchain-openai's ChatOpenAI
|
|
||||||
to intercept LLM chunks. Instead we run the graph to completion (tools
|
|
||||||
execute silently) and then stream the final text content character by
|
|
||||||
character — this gives the client a real SSE stream without adding new
|
|
||||||
dependencies.
|
|
||||||
"""
|
|
||||||
state: AgentState = {"messages": messages}
|
|
||||||
result = await graph.ainvoke(state)
|
|
||||||
content = result["messages"][-1].content or ""
|
|
||||||
# Yield token-by-token so the SSE client sees incremental output
|
|
||||||
for token in content:
|
|
||||||
yield token
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Non-streaming run (kept for /chat/sync and sync completions)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
async def run_agent_with_tools(
|
|
||||||
request: Request,
|
|
||||||
messages: list[dict],
|
|
||||||
agent_id: str | None = None,
|
|
||||||
model: str | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""Send messages through the agent's LangGraph. Non-streaming."""
|
|
||||||
agent = _resolve_agent(agent_id, model)
|
|
||||||
graph = get_agent_graph(agent.agent_id, request)
|
|
||||||
return await _invoke_graph(graph, messages)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Streaming generator (kept for /chat and stream completions)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
async def run_agent_stream(
|
|
||||||
request: Request,
|
|
||||||
messages: list[dict],
|
|
||||||
agent_id: str | None = None,
|
|
||||||
model: str | None = None,
|
|
||||||
):
|
|
||||||
"""Async generator — yields tokens via the agent's LangGraph."""
|
|
||||||
agent = _resolve_agent(agent_id, model)
|
|
||||||
graph = get_agent_graph(agent.agent_id, request)
|
|
||||||
async for token in _stream_graph(graph, messages):
|
|
||||||
yield token
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@router.get("/")
|
|
||||||
def root():
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/chat")
|
|
||||||
async def chat(
|
|
||||||
req: ChatRequest,
|
|
||||||
request: Request,
|
|
||||||
client: OpenAI = Depends(get_llm_client),
|
|
||||||
):
|
|
||||||
"""Streaming chat — single message, no history."""
|
|
||||||
messages = [{"role": "user", "content": req.message}]
|
|
||||||
|
|
||||||
async def event_stream():
|
|
||||||
async for token in run_agent_stream(request, messages, req.agent_id):
|
|
||||||
payload = json.dumps({"token": token, "session_id": req.session_id})
|
|
||||||
yield f"data: {payload}\n\n"
|
|
||||||
yield f"data: {json.dumps({'done': True, 'session_id': req.session_id})}\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
event_stream(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/chat/sync")
|
|
||||||
async def chat_sync(
|
|
||||||
req: ChatRequest,
|
|
||||||
request: Request,
|
|
||||||
client: OpenAI = Depends(get_llm_client),
|
|
||||||
):
|
|
||||||
"""Non-streaming chat — single message."""
|
|
||||||
messages = [{"role": "user", "content": req.message}]
|
|
||||||
response = await run_agent_with_tools(request, messages, req.agent_id)
|
|
||||||
return {"response": response, "session_id": req.session_id}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/agents")
|
|
||||||
def list_agents():
|
|
||||||
"""Return all registered agents."""
|
|
||||||
return {
|
|
||||||
"agents": [
|
|
||||||
{
|
|
||||||
"agent_id": a.agent_id,
|
|
||||||
"description": a.description,
|
|
||||||
"skills": a.skills,
|
|
||||||
}
|
|
||||||
for a in list_all_agents().values()
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/models")
|
|
||||||
def list_models():
|
|
||||||
"""Return agents as selectable models for OpenWebUI."""
|
|
||||||
return {
|
|
||||||
"object": "list",
|
|
||||||
"data": [
|
|
||||||
{
|
|
||||||
"id": a.agent_id,
|
|
||||||
"object": "model",
|
|
||||||
"created": 0,
|
|
||||||
"owned_by": "local-agent",
|
|
||||||
}
|
|
||||||
for a in list_all_agents().values()
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/chat/completions")
|
|
||||||
async def chat_completions(
|
|
||||||
req: ChatCompletionRequest,
|
|
||||||
request: Request,
|
|
||||||
client: OpenAI = Depends(get_llm_client),
|
|
||||||
):
|
|
||||||
"""OpenAI-compatible /chat/completions — supports stream=True.
|
|
||||||
Multi-turn: req.messages contains the FULL conversation history.
|
|
||||||
Agent resolved from the model field (OpenWebUI sends this).
|
|
||||||
"""
|
|
||||||
agent = _resolve_agent(model=req.model)
|
|
||||||
|
|
||||||
if req.stream:
|
|
||||||
async def sse_stream():
|
|
||||||
async for token in run_agent_stream(
|
|
||||||
request, req.messages, agent_id=agent.agent_id,
|
|
||||||
):
|
|
||||||
chunk = {
|
|
||||||
"id": "chatcmpl-local",
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"choices": [
|
|
||||||
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(chunk)}\n\n"
|
|
||||||
final_chunk = {
|
|
||||||
"id": "chatcmpl-local",
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(final_chunk)}\n\n"
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
sse_stream(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Non-streaming — full history, LangGraph agent
|
|
||||||
response = await run_agent_with_tools(
|
|
||||||
request, req.messages, agent_id=agent.agent_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"id": "chatcmpl-local",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": 0,
|
|
||||||
"model": req.model,
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {"role": "assistant", "content": response},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
@@ -5,10 +5,7 @@ from fastapi import FastAPI
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from gateway.v1.auth import router as auth_router
|
from gateway.v1.auth import router as auth_router
|
||||||
from gateway.v1.chat import router as v1_router
|
|
||||||
from gateway.jellystat.api import router as jellystat_router
|
from gateway.jellystat.api import router as jellystat_router
|
||||||
from src.config import DEEPSEEK_API_KEY, get_config
|
|
||||||
from src.llm import create_client
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Logging — tool calls will appear in the uvicorn console
|
# Logging — tool calls will appear in the uvicorn console
|
||||||
@@ -57,17 +54,8 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Singletons (stored on app.state so every module can reach them via Depends)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
app.state.llm_client = create_client(DEEPSEEK_API_KEY)
|
|
||||||
|
|
||||||
# Lazy-compiled LangGraph graphs — populated on first use per agent
|
|
||||||
app.state.agent_graphs: dict = {}
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Routers
|
# Routers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
app.include_router(v1_router, prefix="/v1")
|
|
||||||
app.include_router(auth_router)
|
app.include_router(auth_router)
|
||||||
app.include_router(jellystat_router)
|
app.include_router(jellystat_router)
|
||||||
+3
-113
@@ -1,23 +1,17 @@
|
|||||||
"""
|
"""
|
||||||
Auth Store — SQLite-backed persistence for Discord-to-service authentication.
|
Auth Store — SQLite-backed persistence for Discord-to-service authentication.
|
||||||
|
|
||||||
Two tables:
|
Stores per-user, per-service credentials (Jellyfin access tokens, etc.).
|
||||||
- link_tokens : one-time tokens sent via Discord DM to initiate login
|
|
||||||
- user_auth : per-user, per-service credentials (Jellyfin token, etc.)
|
|
||||||
|
|
||||||
Thread-safe via WAL mode and a shared lock. No passwords are ever stored
|
Thread-safe via WAL mode and a shared lock. No passwords are ever stored
|
||||||
— only opaque service tokens (e.g. Jellyfin AccessToken).
|
— only opaque service tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import secrets
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime, timedelta, timezone
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from src.config import get_config
|
from src.config import get_config
|
||||||
|
|
||||||
@@ -27,7 +21,6 @@ logger = logging.getLogger("auth_store")
|
|||||||
# Config
|
# Config
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
AUTH_DB_PATH = get_config("AUTH_DB_PATH", "data/auth.db")
|
AUTH_DB_PATH = get_config("AUTH_DB_PATH", "data/auth.db")
|
||||||
TOKEN_EXPIRY_MINUTES = int(get_config("AUTH_TOKEN_EXPIRY", "10"))
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Singleton handle
|
# Singleton handle
|
||||||
@@ -54,8 +47,6 @@ def _resolve_path() -> Path:
|
|||||||
|
|
||||||
def _get_conn() -> sqlite3.Connection:
|
def _get_conn() -> sqlite3.Connection:
|
||||||
"""Return a thread-local connection to the auth database."""
|
"""Return a thread-local connection to the auth database."""
|
||||||
import sqlite3
|
|
||||||
|
|
||||||
conn = sqlite3.connect(str(_resolve_path()), check_same_thread=False)
|
conn = sqlite3.connect(str(_resolve_path()), check_same_thread=False)
|
||||||
conn.execute("PRAGMA journal_mode=WAL")
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
conn.execute("PRAGMA foreign_keys=ON")
|
conn.execute("PRAGMA foreign_keys=ON")
|
||||||
@@ -68,15 +59,6 @@ def _get_conn() -> sqlite3.Connection:
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
_SCHEMA = """
|
_SCHEMA = """
|
||||||
CREATE TABLE IF NOT EXISTS link_tokens (
|
|
||||||
token TEXT PRIMARY KEY,
|
|
||||||
discord_user_id INTEGER NOT NULL,
|
|
||||||
service TEXT NOT NULL,
|
|
||||||
expires_at TEXT NOT NULL,
|
|
||||||
used INTEGER DEFAULT 0,
|
|
||||||
created_at TEXT DEFAULT (datetime('now'))
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS user_auth (
|
CREATE TABLE IF NOT EXISTS user_auth (
|
||||||
discord_user_id INTEGER NOT NULL,
|
discord_user_id INTEGER NOT NULL,
|
||||||
service TEXT NOT NULL,
|
service TEXT NOT NULL,
|
||||||
@@ -107,97 +89,6 @@ def _ensure_schema() -> None:
|
|||||||
logger.info("Auth store initialized at %s", _resolve_path())
|
logger.info("Auth store initialized at %s", _resolve_path())
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Public API — Link Tokens
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def create_token(discord_user_id: int, service: str) -> str:
|
|
||||||
"""Generate a one-time link token. Expires after TOKEN_EXPIRY_MINUTES."""
|
|
||||||
_ensure_schema()
|
|
||||||
token = secrets.token_urlsafe(32)
|
|
||||||
expires = (datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRY_MINUTES)).isoformat()
|
|
||||||
|
|
||||||
with _db_lock:
|
|
||||||
conn = _get_conn()
|
|
||||||
conn.execute(
|
|
||||||
"INSERT INTO link_tokens (token, discord_user_id, service, expires_at) VALUES (?, ?, ?, ?)",
|
|
||||||
(token, discord_user_id, service, expires),
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
logger.info("Created link token for user %s / service %s", discord_user_id, service)
|
|
||||||
return token
|
|
||||||
|
|
||||||
|
|
||||||
def validate_token(token: str) -> tuple[int, str] | None:
|
|
||||||
"""Read-only validation — does NOT consume the token.
|
|
||||||
|
|
||||||
Returns (discord_user_id, service) if the token exists, is unused,
|
|
||||||
and has not expired. Returns None otherwise.
|
|
||||||
"""
|
|
||||||
_ensure_schema()
|
|
||||||
|
|
||||||
with _db_lock:
|
|
||||||
conn = _get_conn()
|
|
||||||
row = conn.execute(
|
|
||||||
"SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?",
|
|
||||||
(token,),
|
|
||||||
).fetchone()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
if row is None:
|
|
||||||
return None
|
|
||||||
if row["used"]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
expires = datetime.fromisoformat(row["expires_at"])
|
|
||||||
if datetime.now(timezone.utc) > expires:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return (row["discord_user_id"], row["service"])
|
|
||||||
|
|
||||||
|
|
||||||
def consume_token(token: str) -> tuple[int, str] | None:
|
|
||||||
"""Validate and consume a link token. Returns (discord_user_id, service) or None.
|
|
||||||
|
|
||||||
A token is valid if:
|
|
||||||
- It exists
|
|
||||||
- It has not been used
|
|
||||||
- It has not expired
|
|
||||||
"""
|
|
||||||
_ensure_schema()
|
|
||||||
|
|
||||||
with _db_lock:
|
|
||||||
conn = _get_conn()
|
|
||||||
row = conn.execute(
|
|
||||||
"SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?",
|
|
||||||
(token,),
|
|
||||||
).fetchone()
|
|
||||||
|
|
||||||
if row is None:
|
|
||||||
conn.close()
|
|
||||||
return None
|
|
||||||
|
|
||||||
if row["used"]:
|
|
||||||
conn.close()
|
|
||||||
logger.warning("Token already used: %s…", token[:8])
|
|
||||||
return None
|
|
||||||
|
|
||||||
expires = datetime.fromisoformat(row["expires_at"])
|
|
||||||
if datetime.now(timezone.utc) > expires:
|
|
||||||
conn.close()
|
|
||||||
logger.warning("Token expired: %s…", token[:8])
|
|
||||||
return None
|
|
||||||
|
|
||||||
conn.execute("UPDATE link_tokens SET used = 1 WHERE token = ?", (token,))
|
|
||||||
conn.commit()
|
|
||||||
result = (row["discord_user_id"], row["service"])
|
|
||||||
conn.close()
|
|
||||||
logger.info("Token consumed: %s… → user=%s service=%s", token[:8], result[0], result[1])
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Public API — User Auth
|
# Public API — User Auth
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -350,9 +241,8 @@ def reset_all() -> None:
|
|||||||
|
|
||||||
with _db_lock:
|
with _db_lock:
|
||||||
conn = _get_conn()
|
conn = _get_conn()
|
||||||
conn.execute("DELETE FROM link_tokens")
|
|
||||||
conn.execute("DELETE FROM user_auth")
|
conn.execute("DELETE FROM user_auth")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
logger.warning("Auth store RESET — all tokens and auth records cleared.")
|
logger.warning("Auth store RESET — all auth records cleared.")
|
||||||
|
|||||||
@@ -1,51 +0,0 @@
|
|||||||
"""
|
|
||||||
Tools adapter — bridges the existing skill/tool system with LangGraph's ToolNode.
|
|
||||||
|
|
||||||
LangGraph's ToolNode expects callable tools (typically @tool-decorated functions).
|
|
||||||
This module wraps our skill-based tool definitions and async executors so
|
|
||||||
ToolNode can invoke them without any changes to the skills/ layer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
|
|
||||||
from agents.skills import get_all_tools, execute_tool
|
|
||||||
|
|
||||||
|
|
||||||
def build_langgraph_tools(skill_names: list[str]) -> list:
|
|
||||||
"""
|
|
||||||
Convert the registered skill tool definitions into LangChain-compatible
|
|
||||||
@tool-decorated functions that ToolNode can call.
|
|
||||||
|
|
||||||
Each tool wraps the existing `execute_tool()` pipeline, so the skill
|
|
||||||
system's ToolResult + httpx session handling is fully preserved.
|
|
||||||
"""
|
|
||||||
tool_defs = get_all_tools(skill_names)
|
|
||||||
wrapped: list = []
|
|
||||||
|
|
||||||
for td in tool_defs:
|
|
||||||
fn_def = td.get("function", {})
|
|
||||||
fn_name = fn_def.get("name", "")
|
|
||||||
fn_desc = fn_def.get("description", "")
|
|
||||||
|
|
||||||
# Create a unique factory so each closure captures the right fn_name
|
|
||||||
def _make_tool(name: str, desc: str, skills: list[str]):
|
|
||||||
@tool(name, description=desc)
|
|
||||||
async def _wrapped(**kwargs: Any) -> str:
|
|
||||||
"""Execute the tool via the skill system and return its content."""
|
|
||||||
result = await execute_tool(skills, name, kwargs)
|
|
||||||
if result is None:
|
|
||||||
return f"Tool '{name}' is not available."
|
|
||||||
return result.content
|
|
||||||
|
|
||||||
# Stash the original OpenAI schema so LangGraph can use it
|
|
||||||
_wrapped.metadata = fn_def
|
|
||||||
return _wrapped
|
|
||||||
|
|
||||||
wrapped.append(_make_tool(fn_name, fn_desc, skill_names))
|
|
||||||
|
|
||||||
return wrapped
|
|
||||||
Reference in New Issue
Block a user