small refactor of the structure
This commit is contained in:
+235
@@ -0,0 +1,235 @@
|
||||
# API Architecture — Agent + Skill + Graph Pipeline
|
||||
|
||||
This document explains how the API routes user messages through the
|
||||
agent / skill / LangGraph pipeline to produce responses.
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ OpenWebUI / Client │
|
||||
│ POST /v1/chat/completions { model, messages, stream } │
|
||||
└──────────────────────────────┬──────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────────────────────────────────────┐
|
||||
│ api/v1/chat.py — chat_completions() │
|
||||
│ │
|
||||
│ 1. _resolve_agent(req.model) → Agent │
|
||||
│ 2. get_agent_graph(agent_id) → compiled StateGraph │
|
||||
│ 3. graph.ainvoke(state) or _stream_graph(graph, messages) │
|
||||
└──────────────────────────────┬───────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────────────────────────────────────┐
|
||||
│ LangGraph StateGraph (core/graph.py) │
|
||||
│ │
|
||||
│ ┌──────────────┐ tool_calls? ┌──────────────┐ │
|
||||
│ │ agent_node │ ───────────────▶ │ tool_node │ │
|
||||
│ │ (LLM call) │ ◀─────────────── │ (skill exec) │ │
|
||||
│ └──────┬───────┘ └──────────────┘ │
|
||||
│ │ no tool_calls │
|
||||
│ ▼ │
|
||||
│ [END] │
|
||||
└──────────────────────────────────────────────────────────────────┘
|
||||
|
||||
## Key Concepts
|
||||
|
||||
### 1. Agent
|
||||
|
||||
An **Agent** is a persona + skill bundle. Defined in `agents/`.
|
||||
|
||||
```python
|
||||
# agents/media_agent.py
|
||||
Agent(
|
||||
agent_id="media-agent",
|
||||
description="Media assistant with Seerr integration",
|
||||
skills=["media_info", "seerr", "triage"],
|
||||
base_prompt="You are a media assistant...",
|
||||
)
|
||||
```
|
||||
|
||||
- `agent_id` — unique name, exposed as a model in OpenWebUI
|
||||
- `skills` — list of skill names to load
|
||||
- `base_prompt` — starting system prompt, combined with skill fragments
|
||||
- `build_system_prompt()` — merges base_prompt + all skill prompt fragments
|
||||
|
||||
Agents self-register at import time via `agents/__init__.py`'s `register()`.
|
||||
`main.py` calls `load_all_agents()` at startup to import every agent and skill
|
||||
module.
|
||||
|
||||
### 2. Skill
|
||||
|
||||
A **Skill** is a capability bundle. Defined in `skills/`.
|
||||
|
||||
```python
|
||||
# skills/seerr.py
|
||||
Skill(
|
||||
name="seerr",
|
||||
description="Seerr integration — trending, discover, request media, submit issues",
|
||||
prompt_fragment="## Seerr Media Tools\n...",
|
||||
tools=[...], # OpenAI function-calling schema
|
||||
execute=_execute, # async handler: tool_name + args → ToolResult
|
||||
)
|
||||
```
|
||||
|
||||
- `prompt_fragment` — injected into the agent's system prompt.
|
||||
- `tools` — list of OpenAI function definitions (name, description, parameters).
|
||||
- `execute` — async callable that routes tool calls to API handlers.
|
||||
|
||||
### 3. Graph
|
||||
|
||||
Each agent gets a **compiled LangGraph StateGraph** built by
|
||||
`core/graph.py:create_agent_graph()`. The graph is compiled lazily on the
|
||||
first request and cached on `app.state.agent_graphs` for the lifetime of the
|
||||
process.
|
||||
|
||||
| Graph node / edge | What it does |
|
||||
|---|---|
|
||||
| `agent_node` | Converts state messages to OpenAI dicts, calls the LLM with the agent's system prompt + tool definitions, returns an `AIMessage` |
|
||||
| `tool_node` | Reads `tool_calls` from the last AI message, calls `execute_tool()` from the skill system, returns `ToolMessage` results |
|
||||
| `_should_continue` | Conditional edge — returns `"tool_node"` if the AI message has `tool_calls`, else `END` |
|
||||
|
||||
### 4. State
|
||||
|
||||
Defined in `core/state.py`:
|
||||
|
||||
```python
|
||||
class AgentState(TypedDict):
|
||||
messages: Annotated[list, add_messages]
|
||||
```
|
||||
|
||||
LangGraph's `add_messages` reducer appends new messages and replaces messages
|
||||
with matching IDs (so tool-call results overwrite their placeholders).
|
||||
|
||||
### 5. Message Conversion
|
||||
|
||||
Because we use the raw `openai` client (not `langchain-openai`), messages must
|
||||
be converted between LangChain and OpenAI formats at every LLM call:
|
||||
|
||||
- **LangChain → OpenAI** (`_lc_role_to_openai`, `_langchain_tc_to_openai`):
|
||||
Maps `type` → `role` and converts top-level `name`/`args` tool-calls into
|
||||
the nested `function` sub-object that the OpenAI API expects.
|
||||
|
||||
- **OpenAI → LangChain** (inside `agent_node`):
|
||||
Converts the `ChatCompletionMessage` response into an `AIMessage` with
|
||||
LangChain-format `tool_calls` (top-level `name`/`args`/`id`).
|
||||
|
||||
---
|
||||
|
||||
## Full Request Flow
|
||||
|
||||
### Step-by-step: "What are trending movies?"
|
||||
|
||||
```
|
||||
1. OpenWebUI sends:
|
||||
POST /v1/chat/completions
|
||||
{
|
||||
"model": "media-agent",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What are trending movies?"}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
|
||||
2. chat_completions():
|
||||
→ _resolve_agent(model="media-agent")
|
||||
→ get_agent("media-agent") → Agent(skills=["media_info", "seerr", "triage"])
|
||||
→ get_agent_graph("media-agent", request)
|
||||
→ looks up app.state.agent_graphs["media-agent"]
|
||||
→ first call → create_agent_graph() compiles the graph with 7 Seerr tools
|
||||
→ run_agent_with_tools(request, messages, agent_id)
|
||||
→ _invoke_graph(graph, messages)
|
||||
|
||||
3. Graph — Pass 1 (agent_node):
|
||||
→ LLM receives: [system prompt] + [user: "What are trending movies?"]
|
||||
→ LLM responds with tool_calls: seerr_trending(kind="movie")
|
||||
→ agent_node returns AIMessage with tool_calls in LangChain format
|
||||
|
||||
4. Graph — _should_continue:
|
||||
→ AIMessage has tool_calls → route to "tool_node"
|
||||
|
||||
5. Graph — tool_node:
|
||||
→ Reads tool_call: name="seerr_trending", args={"kind": "movie"}
|
||||
→ execute_tool(["media_info", "seerr", "triage"], "seerr_trending", ...)
|
||||
→ Seerr API → GET /api/v1/discover/trending?mediaType=movie
|
||||
→ Returns ToolMessage with formatted results including [tmdb:IDs]
|
||||
|
||||
6. Graph — Pass 2 (agent_node):
|
||||
→ LLM receives previous exchange + tool result
|
||||
→ LLM responds with text only (no tool_calls)
|
||||
→ agent_node returns AIMessage(content="Here are the top trending movies!...")
|
||||
|
||||
7. Graph — _should_continue:
|
||||
→ No tool_calls → route to END
|
||||
|
||||
8. chat_completions() returns:
|
||||
{ "choices": [{"message": {"role": "assistant", "content": "Here are the top..."}}] }
|
||||
```
|
||||
|
||||
### Step-by-step: "Request the 2026 one" (multi-turn context)
|
||||
|
||||
```
|
||||
1. OpenWebUI sends the FULL history:
|
||||
{
|
||||
"model": "media-agent",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What are trending movies?"},
|
||||
{"role": "assistant", "content": "Here are the top 10 trending movies!
|
||||
1. **Mortal Kombat II** (2026) [tmdb:931285] — ..."},
|
||||
{"role": "user", "content": "could request the mortal kombat one?"},
|
||||
{"role": "assistant", "content": "There are several Mortal Kombat entries! ..."},
|
||||
{"role": "user", "content": "the 2026 one"}
|
||||
]
|
||||
}
|
||||
|
||||
2. chat_completions():
|
||||
→ req.messages contains the ENTIRE conversation history
|
||||
→ graph.ainvoke({"messages": all_messages})
|
||||
→ agent_node prepends system prompt and sends everything to the LLM
|
||||
|
||||
3. LLM reasons from full context:
|
||||
- Previously listed Mortal Kombat II (2026) with [tmdb:931285]
|
||||
- The user said "request the mortal kombat one" → I searched and showed 4 options
|
||||
- Now they say "the 2026 one" → that matches Mortal Kombat II (2026) [tmdb:931285]
|
||||
- I should call seerr_request_media(kind="movie", title="Mortal Kombat II", tmdb_id=931285)
|
||||
|
||||
4. tool_node executes the request → ✅ Success
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Streaming
|
||||
|
||||
Streaming works slightly differently from the sync path:
|
||||
|
||||
```
|
||||
chat_completions(stream=True)
|
||||
→ _stream_graph(graph, messages)
|
||||
→ graph.ainvoke(state) # runs graph to completion (tools execute silently)
|
||||
→ yields content character-by-character via SSE
|
||||
```
|
||||
|
||||
For true token-level streaming (tokens appear as the LLM generates them),
|
||||
the agent_node would need to use `langchain-openai`'s `ChatOpenAI` instead of
|
||||
the raw `openai` client. The current approach is a pragmatic middle ground
|
||||
that avoids adding another dependency while still giving the SSE client
|
||||
incremental output.
|
||||
|
||||
---
|
||||
|
||||
## File Map
|
||||
|
||||
| File | Responsibility |
|
||||
|---|---|
|
||||
| `main.py` | FastAPI app, singleton creation, router mounting |
|
||||
| `api/v1/chat.py` | Endpoints — resolves agent, invokes graph, formats responses |
|
||||
| `api/dependencies.py` | `get_llm_client()`, `get_agent_graph()` — FastAPI `Depends` |
|
||||
| `core/graph.py` | `create_agent_graph()` — builds the StateGraph |
|
||||
| `core/state.py` | `AgentState` TypedDict |
|
||||
| `core/llm.py` | `create_client()` — OpenAI client factory |
|
||||
| `core/config.py` | Environment variable loader |
|
||||
| `agents/` | Agent definitions (dataclass + self-registration) |
|
||||
| `skills/` | Skill definitions (prompt fragments + tools + executors) |
|
||||
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
Auth Service registry — generic, pluggable authentication for any service.
|
||||
|
||||
Add a new service (Plex, Seerr, etc.) by:
|
||||
1. Subclassing AuthService
|
||||
2. Dropping the module in this package
|
||||
3. Calling register_auth_service() at import time
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AuthResult — returned by AuthService.authenticate()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class AuthResult:
|
||||
"""Outcome of a credential validation attempt."""
|
||||
success: bool
|
||||
external_user_id: Optional[str] = None
|
||||
external_name: Optional[str] = None
|
||||
credentials: Optional[dict] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AuthService — abstract base class
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AuthService(ABC):
|
||||
"""A service that users can authenticate against (Jellyfin, Seerr, Plex, etc.)
|
||||
|
||||
Subclasses must implement:
|
||||
- name : unique identifier used in URLs and DB keys
|
||||
- display_name : human-readable label shown in Discord
|
||||
- render_login_form(token, discord_id) → HTML string
|
||||
- authenticate(form_data) → AuthResult
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Unique service name: "jellyfin", "seerr", etc."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def display_name(self) -> str:
|
||||
"""Human-readable: "Jellyfin", "Seerr", "Plex" """
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def render_login_form(self, token: str, discord_id: int) -> str:
|
||||
"""Return HTML string with a login form for this service.
|
||||
|
||||
The form MUST include these hidden fields:
|
||||
<input type="hidden" name="token" value="{token}">
|
||||
<input type="hidden" name="discord_id" value="{discord_id}">
|
||||
<input type="hidden" name="service" value="{self.name}">
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def authenticate(self, form_data: dict) -> AuthResult:
|
||||
"""Validate credentials against the service. Return AuthResult."""
|
||||
...
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Global registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_registry: dict[str, AuthService] = {}
|
||||
|
||||
|
||||
def register_auth_service(svc: AuthService) -> None:
|
||||
"""Register an AuthService so it can be looked up by name."""
|
||||
_registry[svc.name] = svc
|
||||
|
||||
|
||||
def get_auth_service(name: str) -> AuthService | None:
|
||||
"""Look up a registered AuthService by name."""
|
||||
return _registry.get(name)
|
||||
|
||||
|
||||
def list_auth_services() -> list[str]:
|
||||
"""Return names of all registered auth services."""
|
||||
return list(_registry.keys())
|
||||
@@ -0,0 +1,401 @@
|
||||
"""
|
||||
Jellyfin AuthService — validates Jellyfin credentials and stores the session token.
|
||||
|
||||
Two authentication flows:
|
||||
1. Quick Connect (primary): user enters a short code on their Jellyfin app.
|
||||
- initiate_quick_connect() → {code, secret}
|
||||
- poll_quick_connect(secret) → "Active" | "Authorized" | "Expired"
|
||||
- authenticate_quick_connect(secret) → AuthResult with token
|
||||
|
||||
2. Username/password (legacy): renders an HTML form, called via the REST API.
|
||||
- render_login_form(token, discord_id) → HTML string
|
||||
- authenticate(form_data) → AuthResult
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from gateway.auth import AuthService, AuthResult, register_auth_service
|
||||
from src.config import get_config
|
||||
|
||||
logger = logging.getLogger("auth.jellyfin")
|
||||
|
||||
# Emby-style authorization header required by Jellyfin's AuthenticateByName
|
||||
_EMBY_HEADER = (
|
||||
'MediaBrowser Client="AgentBot",'
|
||||
'Device="DiscordBot",'
|
||||
'DeviceId="agent-bot",'
|
||||
'Version="1.0"'
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuickConnectResult:
|
||||
"""Result of a Quick Connect initiation."""
|
||||
secret: str
|
||||
code: str
|
||||
device_id: str
|
||||
device_name: str
|
||||
|
||||
|
||||
class JellyfinAuth(AuthService):
|
||||
name = "jellyfin"
|
||||
display_name = "Jellyfin"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Quick Connect helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _qc_headers(self) -> dict[str, str]:
|
||||
"""Return headers used by all Quick Connect API calls."""
|
||||
return {
|
||||
"X-Emby-Authorization": (
|
||||
'MediaBrowser Client="AgentBot",'
|
||||
'Device="DiscordBot",'
|
||||
'DeviceId="agent-bot-qc",'
|
||||
'Version="1.0"'
|
||||
)
|
||||
}
|
||||
|
||||
async def _resolve_url(self) -> str | None:
|
||||
"""
|
||||
Resolve the Jellyfin server URL.
|
||||
1. Check JELLYFIN_URL env var (used in deployment).
|
||||
2. Check if user already has a stored auth with a URL (from legacy login).
|
||||
Returns None if no URL is configured.
|
||||
"""
|
||||
# First: explicit env var
|
||||
env_url = get_config("JELLYFIN_URL")
|
||||
if env_url:
|
||||
return env_url.strip().rstrip("/")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 1a: initiate Quick Connect
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def initiate_quick_connect(self, url: str | None = None) -> QuickConnectResult | None:
|
||||
"""
|
||||
Call Jellyfin's POST /QuickConnect/Initiate.
|
||||
Returns a QuickConnectResult with {secret, code} or None on failure.
|
||||
|
||||
The *code* is what the user enters on their Jellyfin page.
|
||||
The *secret* is used internally to poll/authenticate.
|
||||
"""
|
||||
base_url = url or await self._resolve_url()
|
||||
if not base_url:
|
||||
logger.error("QuickConnect failed — no JELLYFIN_URL configured.")
|
||||
return None
|
||||
|
||||
logger.info("Initiating Quick Connect on %s", base_url)
|
||||
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{base_url}/QuickConnect/Initiate",
|
||||
headers=self._qc_headers(),
|
||||
json={},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
"QuickConnect init failed: HTTP %s — %s",
|
||||
resp.status_code, resp.text[:200],
|
||||
)
|
||||
return None
|
||||
|
||||
data = resp.json()
|
||||
secret = data.get("Secret", "")
|
||||
code = data.get("Code", "")
|
||||
device_id = data.get("DeviceId", "")
|
||||
device_name = data.get("DeviceName", "")
|
||||
|
||||
if not secret or not code:
|
||||
logger.warning("QuickConnect init returned unexpected payload: %s", data)
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
"Quick Connect initiated: code=%s device=%s",
|
||||
code, device_name,
|
||||
)
|
||||
return QuickConnectResult(
|
||||
secret=secret,
|
||||
code=code,
|
||||
device_id=device_id,
|
||||
device_name=device_name,
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("QuickConnect init timed out reaching %s", base_url)
|
||||
return None
|
||||
except httpx.ConnectError:
|
||||
logger.error("QuickConnect init — cannot connect to %s", base_url)
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Unexpected error during QuickConnect init")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 1b: poll Quick Connect status
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def poll_quick_connect(self, secret: str, url: str | None = None) -> str:
|
||||
"""
|
||||
Call Jellyfin's GET /QuickConnect/Connect?secret=<secret>.
|
||||
Returns one of:
|
||||
- "Active" → user hasn't entered the code yet
|
||||
- "Authorized" → user entered code AND approved
|
||||
- "Expired" → code expired / unknown secret
|
||||
- "Error" → network or unexpected failure
|
||||
"""
|
||||
base_url = url or await self._resolve_url()
|
||||
if not base_url:
|
||||
logger.error("QuickConnect poll failed — no JELLYFIN_URL configured.")
|
||||
return "Error"
|
||||
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
try:
|
||||
resp = await client.get(
|
||||
f"{base_url}/QuickConnect/Connect",
|
||||
params={"secret": secret},
|
||||
headers=self._qc_headers(),
|
||||
)
|
||||
if resp.status_code == 404:
|
||||
return "Expired"
|
||||
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
# Jellyfin returns "Authenticated" (not "Authorized")
|
||||
if data.get("Authenticated") is True:
|
||||
return "Authorized"
|
||||
# "Authenticated" is false, missing, or null → still active
|
||||
return "Active"
|
||||
|
||||
logger.warning(
|
||||
"QuickConnect poll unexpected: HTTP %s — %s",
|
||||
resp.status_code, resp.text[:200],
|
||||
)
|
||||
return "Error"
|
||||
|
||||
except (httpx.TimeoutException, httpx.ConnectError):
|
||||
logger.warning("QuickConnect poll network error")
|
||||
return "Error"
|
||||
except Exception:
|
||||
logger.exception("Unexpected error during QuickConnect poll")
|
||||
return "Error"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 1c: exchange secret for token
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def authenticate_quick_connect(
|
||||
self, secret: str, url: str | None = None
|
||||
) -> AuthResult:
|
||||
"""
|
||||
After poll_quick_connect returns "Authorized", call
|
||||
POST /Users/AuthenticateWithQuickConnect to exchange the secret
|
||||
for a real access token.
|
||||
|
||||
Returns AuthResult with token, user_id, username on success.
|
||||
"""
|
||||
base_url = url or await self._resolve_url()
|
||||
if not base_url:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message="No Jellyfin server URL configured.",
|
||||
)
|
||||
|
||||
logger.info("Exchanging QuickConnect secret for token on %s", base_url)
|
||||
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{base_url}/Users/AuthenticateWithQuickConnect",
|
||||
json={"Secret": secret},
|
||||
headers=self._qc_headers(),
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
"QuickConnect auth exchange failed: HTTP %s",
|
||||
resp.status_code,
|
||||
)
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message="Quick Connect authentication failed. The code may have expired.",
|
||||
)
|
||||
|
||||
data = resp.json()
|
||||
user = data.get("User", {})
|
||||
token = data.get("AccessToken", "")
|
||||
|
||||
if not token:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message="Jellyfin returned an unexpected response.",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"QuickConnect linked: user=%s (%s)",
|
||||
user.get("Name", "?"),
|
||||
user.get("Id", "?"),
|
||||
)
|
||||
|
||||
return AuthResult(
|
||||
success=True,
|
||||
external_user_id=user.get("Id", ""),
|
||||
external_name=user.get("Name", "?"),
|
||||
credentials={
|
||||
"token": token,
|
||||
"url": base_url,
|
||||
"user_id": user.get("Id", ""),
|
||||
},
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message=f"Could not reach {base_url} — connection timed out.",
|
||||
)
|
||||
except httpx.ConnectError:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message=f"Could not connect to {base_url}. Is the server running?",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Unexpected error during QuickConnect auth exchange")
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message="An unexpected error occurred during authentication.",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Login form (legacy — used by the REST API)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def render_login_form(self, token: str, discord_id: int) -> str:
|
||||
return f"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>Link Jellyfin</title>
|
||||
<style>
|
||||
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; }}
|
||||
h2 {{ margin-bottom: 4px; }}
|
||||
.sub {{ color: #666; margin-bottom: 24px; }}
|
||||
label {{ display: block; margin-top: 16px; font-weight: 600; }}
|
||||
input {{ width: 100%; padding: 10px; margin-top: 4px; border: 1px solid #ccc; border-radius: 6px; box-sizing: border-box; }}
|
||||
button {{ margin-top: 24px; width: 100%; padding: 12px; background: #aa5cc3; color: #fff; border: none; border-radius: 6px; font-size: 16px; cursor: pointer; }}
|
||||
button:hover {{ background: #9448b0; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h2>🔗 Link Jellyfin to Discord</h2>
|
||||
<p class="sub">Enter your Jellyfin server URL and credentials to link your account.</p>
|
||||
|
||||
<form method="POST" action="/api/v1/auth/login">
|
||||
<input type="hidden" name="token" value="{token}">
|
||||
<input type="hidden" name="discord_id" value="{discord_id}">
|
||||
<input type="hidden" name="service" value="jellyfin">
|
||||
|
||||
<label for="jellyfin_url">Jellyfin Server URL</label>
|
||||
<input id="jellyfin_url" name="jellyfin_url" type="url"
|
||||
placeholder="https://jellyfin.example.com" required>
|
||||
|
||||
<label for="username">Username</label>
|
||||
<input id="username" name="username" type="text"
|
||||
placeholder="Your Jellyfin username" required autofocus>
|
||||
|
||||
<label for="password">Password</label>
|
||||
<input id="password" name="password" type="password"
|
||||
placeholder="Your Jellyfin password" required>
|
||||
|
||||
<button type="submit">Link Account</button>
|
||||
</form>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Authentication
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def authenticate(self, form_data: dict) -> AuthResult:
|
||||
url = form_data.get("jellyfin_url", "").strip().rstrip("/")
|
||||
username = form_data.get("username", "").strip()
|
||||
password = form_data.get("password", "").strip()
|
||||
|
||||
if not url or not username or not password:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message="All fields are required (URL, username, password).",
|
||||
)
|
||||
|
||||
logger.info("Attempting Jellyfin login for '%s' on %s", username, url)
|
||||
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{url}/Users/AuthenticateByName",
|
||||
json={"Username": username, "Pw": password},
|
||||
headers={"X-Emby-Authorization": _EMBY_HEADER},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
"Jellyfin login failed for '%s': HTTP %s", username, resp.status_code
|
||||
)
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message=f"Login failed — check your server URL and credentials.",
|
||||
)
|
||||
|
||||
data = resp.json()
|
||||
user = data.get("User", {})
|
||||
token = data.get("AccessToken", "")
|
||||
|
||||
if not token:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message="Jellyfin returned an unexpected response.",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Jellyfin login OK: user=%s (%s)",
|
||||
user.get("Name", "?"),
|
||||
user.get("Id", "?"),
|
||||
)
|
||||
|
||||
return AuthResult(
|
||||
success=True,
|
||||
external_user_id=user.get("Id", ""),
|
||||
external_name=user.get("Name", username),
|
||||
credentials={
|
||||
"token": token,
|
||||
"url": url,
|
||||
"user_id": user.get("Id", ""),
|
||||
},
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message=f"Could not reach {url} — connection timed out. Check the URL.",
|
||||
)
|
||||
except httpx.ConnectError:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message=f"Could not connect to {url}. Is the server running?",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Unexpected error during Jellyfin login")
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message=f"An unexpected error occurred. Please try again.",
|
||||
)
|
||||
|
||||
|
||||
# Self-register at import time
|
||||
register_auth_service(JellyfinAuth())
|
||||
@@ -0,0 +1,36 @@
|
||||
from fastapi import Request
|
||||
from openai import OpenAI
|
||||
|
||||
from src.graph import create_agent_graph
|
||||
|
||||
|
||||
def get_llm_client(request: Request) -> OpenAI:
|
||||
"""FastAPI dependency — returns the singleton OpenAI client from app.state."""
|
||||
return request.app.state.llm_client
|
||||
|
||||
|
||||
def get_agent_graph(agent_id: str, request: Request):
|
||||
"""
|
||||
FastAPI dependency — returns the compiled LangGraph graph for *agent_id*.
|
||||
|
||||
Graphs are lazily compiled on first use and cached on app.state so each
|
||||
agent's graph is only built once per process lifetime.
|
||||
"""
|
||||
cache: dict = request.app.state.agent_graphs
|
||||
|
||||
if agent_id not in cache:
|
||||
from agents import get as get_agent
|
||||
|
||||
agent = get_agent(agent_id)
|
||||
if agent is None:
|
||||
# Fall back to the naked agent if the requested one doesn't exist
|
||||
agent_id = "naked"
|
||||
agent = get_agent(agent_id)
|
||||
|
||||
cache[agent_id] = create_agent_graph(
|
||||
client=request.app.state.llm_client,
|
||||
agent_skills=agent.skills,
|
||||
system_prompt=agent.build_system_prompt(),
|
||||
)
|
||||
|
||||
return cache[agent_id]
|
||||
@@ -0,0 +1 @@
|
||||
# Discord bot package
|
||||
@@ -0,0 +1,378 @@
|
||||
"""
|
||||
Discord bot that connects users to the LangGraph agent via private messages.
|
||||
|
||||
Architecture
|
||||
------------
|
||||
- The bot runs in-process alongside FastAPI (on a background asyncio task).
|
||||
- Private messages (DMs) are routed through the same LangGraph graphs that
|
||||
power the REST API — no HTTP loopback needed.
|
||||
- Per-user conversation history is maintained so the LLM has context.
|
||||
|
||||
Environment
|
||||
-----------
|
||||
DISCORD_BOT_TOKEN – the bot token from the Discord Developer Portal
|
||||
DISCORD_MAX_HISTORY – how many past messages to keep per user (default 7)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
import discord
|
||||
|
||||
from agents import list_all as list_all_agents
|
||||
from gateway.discord.conversation import ConversationStore
|
||||
from src.config import DEEPSEEK_API_KEY, get_config
|
||||
from src.graph import create_agent_graph
|
||||
from src.llm import create_client
|
||||
from src import auth_store
|
||||
from gateway.auth import list_auth_services, get_auth_service
|
||||
|
||||
logger = logging.getLogger("bot.discord")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
DISCORD_BOT_TOKEN = get_config("DISCORD_BOT_TOKEN") or ""
|
||||
DISCORD_MAX_HISTORY = int(get_config("DISCORD_MAX_HISTORY", "7"))
|
||||
DISCORD_DEFAULT_AGENT = get_config("DISCORD_DEFAULT_AGENT", "media-agent")
|
||||
BASE_URL = get_config("BASE_URL", "http://localhost:8000").rstrip("/")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM client shared by all agents (same as the REST API uses)
|
||||
# ---------------------------------------------------------------------------
|
||||
_llm_client = create_client(DEEPSEEK_API_KEY)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Conversation store — one per process
|
||||
# ---------------------------------------------------------------------------
|
||||
_conversations = ConversationStore(max_history=DISCORD_MAX_HISTORY)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Graph cache — lazy-compiled per agent, same pattern as api/dependencies.py
|
||||
# ---------------------------------------------------------------------------
|
||||
_agent_graphs: dict[str, object] = {}
|
||||
|
||||
|
||||
def _get_graph(agent_id: str):
|
||||
"""Return a compiled LangGraph for *agent_id*, building it on first use."""
|
||||
if agent_id not in _agent_graphs:
|
||||
agents = list_all_agents()
|
||||
agent = agents.get(agent_id, agents.get("naked"))
|
||||
_agent_graphs[agent_id] = create_agent_graph(
|
||||
client=_llm_client,
|
||||
agent_skills=agent.skills if agent else [],
|
||||
system_prompt=agent.build_system_prompt() if agent else (
|
||||
"You are a helpful, general-purpose assistant."
|
||||
),
|
||||
)
|
||||
return _agent_graphs[agent_id]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discord client
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AgentBot(discord.Client):
|
||||
"""A discord.py Client that connects users to the LangGraph agent."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# message_content lets us read DM text.
|
||||
# guilds is required so that mutual_guilds is populated — without it
|
||||
# every DM is silently ignored.
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
intents.guilds = True
|
||||
super().__init__(intents=intents)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def on_ready(self) -> None:
|
||||
logger.info("Bot logged in as %s (ID %s)", self.user, self.user.id)
|
||||
# Print a ready banner so the dev knows it's alive
|
||||
print(f"\n🤖 Discord bot ready — logged in as {self.user}\n")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Shared-guild helper — uses the REST API, no privileged intents
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _shares_guild(self, user: discord.User) -> bool:
|
||||
"""Return True if *user* and the bot share at least one guild."""
|
||||
for guild in self.guilds:
|
||||
try:
|
||||
member = await guild.fetch_member(user.id)
|
||||
if member is not None:
|
||||
return True
|
||||
except (discord.NotFound, discord.Forbidden, discord.HTTPException):
|
||||
continue
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Message handler — DMs only
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def on_message(self, message: discord.Message) -> None:
|
||||
# Never reply to ourselves
|
||||
if message.author == self.user:
|
||||
return
|
||||
|
||||
# |-- DM channel only for now ----------------------------------|
|
||||
if not isinstance(message.channel, discord.DMChannel):
|
||||
logger.debug("Ignoring message from #%s (not a DM)", message.channel)
|
||||
return
|
||||
# |--------------------------------------------------------------|
|
||||
|
||||
# |-- Shared-server gate — only users who share at least one --|
|
||||
# | guild with the bot can interact via DM. --|
|
||||
# | We use fetch_member (REST API) instead of --|
|
||||
# | User.mutual_guilds because the latter requires the --|
|
||||
# | privileged "members" intent. This way no privileged --|
|
||||
# | intents are needed. --|
|
||||
if not await self._shares_guild(message.author):
|
||||
logger.warning(
|
||||
"Blocking DM from %s — no mutual guilds.",
|
||||
message.author.name,
|
||||
)
|
||||
return
|
||||
# |--------------------------------------------------------------|
|
||||
|
||||
user_id = message.author.id
|
||||
content = message.content.strip()
|
||||
|
||||
# |-- Bot commands — handled directly, never sent to the LLM --|
|
||||
if await self._handle_command(message, user_id, content):
|
||||
return
|
||||
# |--------------------------------------------------------------|
|
||||
|
||||
# Show typing indicator while the graph runs
|
||||
async with message.channel.typing():
|
||||
try:
|
||||
reply = await self._run_agent(
|
||||
user_id=user_id,
|
||||
user_msg=message.content,
|
||||
)
|
||||
await message.channel.send(reply)
|
||||
except Exception:
|
||||
logger.exception("Agent run failed for user %s", user_id)
|
||||
await message.channel.send(
|
||||
"Sorry, something went wrong processing your request. "
|
||||
"Please try again in a moment."
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Bot commands
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _handle_command(
|
||||
self, message: discord.Message, user_id: int, content: str
|
||||
) -> bool:
|
||||
"""Handle bot commands (/login, /logout). Returns True if handled."""
|
||||
lower = content.lower()
|
||||
|
||||
# --- /login [service] ---
|
||||
if lower.startswith("/login"):
|
||||
parts = content.split()
|
||||
service = parts[1].lower() if len(parts) > 1 else None
|
||||
|
||||
available = list_auth_services()
|
||||
if not available:
|
||||
await message.channel.send("No auth services are configured yet.")
|
||||
return True
|
||||
|
||||
if service is None:
|
||||
svc_list = ", ".join(available)
|
||||
await message.channel.send(
|
||||
f"Available services to link: **{svc_list}**\n"
|
||||
f"Use `/login <service>` — e.g. `/login jellyfin`"
|
||||
)
|
||||
return True
|
||||
|
||||
if service not in available:
|
||||
await message.channel.send(
|
||||
f"Unknown service '{service}'. Available: {', '.join(available)}"
|
||||
)
|
||||
return True
|
||||
|
||||
if auth_store.is_authenticated(user_id, service):
|
||||
svc_display = (get_auth_service(service) and get_auth_service(service).display_name) or service
|
||||
await message.channel.send(
|
||||
f"You're already linked to **{svc_display}**! "
|
||||
f"Use `/logout {service}` to unlink."
|
||||
)
|
||||
return True
|
||||
|
||||
# --- Quick Connect flow ---
|
||||
svc = get_auth_service(service)
|
||||
if svc is None:
|
||||
await message.channel.send(f"Unknown service: {service}")
|
||||
return True
|
||||
|
||||
await message.channel.send(f"🔑 Starting **{svc.display_name}** Quick Connect…")
|
||||
|
||||
qc_result = await svc.initiate_quick_connect()
|
||||
if qc_result is None:
|
||||
await message.channel.send(
|
||||
f"❌ Could not start Quick Connect for **{svc.display_name}**.\n"
|
||||
"Check that `JELLYFIN_URL` is configured and the server is reachable."
|
||||
)
|
||||
return True
|
||||
|
||||
await message.channel.send(
|
||||
f"Open **{svc.display_name}** → **Quick Connect** and enter this code:\n\n"
|
||||
f"**`{qc_result.code}`**\n\n"
|
||||
f"⏳ Waiting for you to approve…"
|
||||
)
|
||||
|
||||
# Poll for authorization
|
||||
async with message.channel.typing():
|
||||
for attempt in range(24): # 24 × 5s = 2 minutes
|
||||
await asyncio.sleep(5)
|
||||
status = await svc.poll_quick_connect(qc_result.secret)
|
||||
|
||||
if status == "Authorized":
|
||||
auth_result = await svc.authenticate_quick_connect(qc_result.secret)
|
||||
if auth_result.success:
|
||||
auth_store.store_auth(
|
||||
discord_user_id=user_id,
|
||||
service=service,
|
||||
external_user_id=auth_result.external_user_id or "",
|
||||
external_name=auth_result.external_name or "",
|
||||
credentials=auth_result.credentials,
|
||||
)
|
||||
await message.channel.send(
|
||||
f"✅ Linked to **{svc.display_name}** as "
|
||||
f"**{auth_result.external_name}**!"
|
||||
)
|
||||
else:
|
||||
await message.channel.send(
|
||||
f"❌ Authentication failed: "
|
||||
f"{auth_result.error_message or 'Unknown error'}"
|
||||
)
|
||||
return True
|
||||
|
||||
elif status == "Expired":
|
||||
await message.channel.send(
|
||||
"⌛ The Quick Connect code expired. "
|
||||
f"Use `/login {service}` to try again."
|
||||
)
|
||||
return True
|
||||
|
||||
# else: still "Active" — keep polling
|
||||
|
||||
await message.channel.send(
|
||||
"⌛ Timed out waiting for Quick Connect approval. "
|
||||
f"Use `/login {service}` to try again."
|
||||
)
|
||||
return True
|
||||
|
||||
# --- /logout [service] ---
|
||||
if lower.startswith("/logout"):
|
||||
parts = content.split()
|
||||
service = parts[1].lower() if len(parts) > 1 else None
|
||||
|
||||
if service is None:
|
||||
linked = auth_store.list_services(user_id)
|
||||
if not linked:
|
||||
await message.channel.send("You don't have any linked services.")
|
||||
else:
|
||||
svc_list = ", ".join(linked)
|
||||
await message.channel.send(
|
||||
f"Linked services: **{svc_list}**\n"
|
||||
f"Use `/logout <service>` to unlink."
|
||||
)
|
||||
return True
|
||||
|
||||
if not auth_store.is_authenticated(user_id, service):
|
||||
await message.channel.send(f"You're not linked to **{service}**.")
|
||||
return True
|
||||
|
||||
auth_store.revoke(user_id, service)
|
||||
svc_display = (get_auth_service(service) and get_auth_service(service).display_name) or service
|
||||
await message.channel.send(f"Unlinked from **{svc_display}**. Use `/login {service}` to re-link.")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Agent invocation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _run_agent(self, *, user_id: int, user_msg: str) -> str:
|
||||
"""Build the message list from history, invoke the graph, store the
|
||||
reply, and return the assistant's final text."""
|
||||
|
||||
# 1. Pick agent — defaults to DISCORD_DEFAULT_AGENT env var.
|
||||
agent_id = DISCORD_DEFAULT_AGENT
|
||||
|
||||
# 2. Build message list from stored history + new user message
|
||||
history = _conversations.get_history(user_id)
|
||||
messages = [*history, {"role": "user", "content": user_msg}]
|
||||
|
||||
# 3. Run the LangGraph (tools execute inline if needed)
|
||||
graph = _get_graph(agent_id)
|
||||
state = {"messages": messages, "discord_user_id": user_id}
|
||||
result = await graph.ainvoke(state)
|
||||
|
||||
last_msg = result["messages"][-1]
|
||||
reply = last_msg.content or ""
|
||||
|
||||
# 4. Persist the conversation
|
||||
_conversations.append(user_id, user_msg, reply)
|
||||
|
||||
return reply
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bootstrap helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _start_bot_sync(token: str) -> None:
|
||||
"""Synchronous entry-point that runs the bot in a new asyncio event loop.
|
||||
|
||||
Called from a background thread so the main thread can keep running the
|
||||
FastAPI / uvicorn server.
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
async def _run() -> None:
|
||||
bot = AgentBot()
|
||||
try:
|
||||
await bot.start(token)
|
||||
except discord.LoginFailure:
|
||||
logger.error(
|
||||
"Discord login failed — check DISCORD_BOT_TOKEN in your .env file."
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Unhandled exception in bot event loop.")
|
||||
|
||||
loop.run_until_complete(_run())
|
||||
|
||||
|
||||
def start_in_background(token: str | None = None) -> None:
|
||||
"""Launch the Discord bot in a daemon thread.
|
||||
|
||||
Pass *token* explicitly if you already have it; otherwise it is read
|
||||
from the DISCORD_BOT_TOKEN env variable.
|
||||
"""
|
||||
token = token or DISCORD_BOT_TOKEN
|
||||
if not token:
|
||||
logger.warning(
|
||||
"DISCORD_BOT_TOKEN is not set — Discord bot will NOT start."
|
||||
)
|
||||
return
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(
|
||||
target=_start_bot_sync,
|
||||
args=(token,),
|
||||
daemon=True,
|
||||
name="discord-bot",
|
||||
)
|
||||
t.start()
|
||||
logger.info("Discord bot thread started.")
|
||||
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Per-user conversation history store.
|
||||
|
||||
Each Discord user gets their own isolated message list. Only the last
|
||||
`max_history` messages are kept — older ones are silently dropped so the
|
||||
LLM context stays small.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
|
||||
logger = logging.getLogger("bot.conversation")
|
||||
|
||||
# role we assign to user messages inside the OpenAI-style message list
|
||||
_USER_ROLE = "user"
|
||||
# role we assign to bot responses
|
||||
_ASSISTANT_ROLE = "assistant"
|
||||
|
||||
|
||||
class ConversationStore:
|
||||
"""Thread-safe-ish in-memory store keyed by Discord user ID (int)."""
|
||||
|
||||
def __init__(self, max_history: int = 7) -> None:
|
||||
self._max = max_history
|
||||
self._store: Dict[int, List[dict]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_history(self, user_id: int) -> list[dict]:
|
||||
"""Return the last *max_history* messages for *user_id*."""
|
||||
return list(self._store.get(user_id, []))
|
||||
|
||||
def append(self, user_id: int, user_msg: str, assistant_reply: str) -> None:
|
||||
"""Store the user message + assistant reply, then trim to max."""
|
||||
if user_id not in self._store:
|
||||
self._store[user_id] = []
|
||||
|
||||
history = self._store[user_id]
|
||||
history.append({"role": _USER_ROLE, "content": user_msg})
|
||||
history.append({"role": _ASSISTANT_ROLE, "content": assistant_reply})
|
||||
|
||||
# Trim oldest messages if we exceeded the limit
|
||||
while len(history) > self._max:
|
||||
history.pop(0)
|
||||
|
||||
def clear(self, user_id: int) -> None:
|
||||
"""Wipe the conversation for a user."""
|
||||
self._store.pop(user_id, None)
|
||||
logger.info("Cleared conversation for user %s", user_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# debug / introspection
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def user_count(self) -> int:
|
||||
return len(self._store)
|
||||
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
Auth API — generic endpoints for linking Discord users to external services.
|
||||
|
||||
GET /api/v1/auth/login?service=X&token=Y&discord_id=Z
|
||||
Validates the link token and serves a service-specific login form.
|
||||
|
||||
POST /api/v1/auth/login
|
||||
Accepts the form submission, validates credentials against the service,
|
||||
stores the session, and returns a result page.
|
||||
|
||||
GET /api/v1/auth/status?discord_id=Z
|
||||
Returns which services are linked for this Discord user.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Form, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from gateway.auth import get_auth_service, list_auth_services
|
||||
from src import auth_store
|
||||
|
||||
logger = logging.getLogger("gateway.auth")
|
||||
|
||||
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /auth/login — serve the login form
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/login")
|
||||
async def login_form(
|
||||
service: str,
|
||||
token: str,
|
||||
discord_id: int,
|
||||
):
|
||||
"""Validate the one-time link token and return a service-specific login form."""
|
||||
|
||||
# Validate the token WITHOUT consuming it (the POST will consume it)
|
||||
result = auth_store.validate_token(token)
|
||||
if result is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired link token.")
|
||||
|
||||
uid, svc = result
|
||||
if uid != discord_id or svc != service:
|
||||
raise HTTPException(status_code=400, detail="Token does not match the request.")
|
||||
|
||||
# Look up the AuthService
|
||||
svc_obj = get_auth_service(service)
|
||||
if svc_obj is None:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown service: {service}")
|
||||
|
||||
logger.info("Serving login form: user=%s service=%s", discord_id, service)
|
||||
return HTMLResponse(svc_obj.render_login_form(token, discord_id))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /auth/login — handle form submission
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/login")
|
||||
async def login_submit(request: Request):
|
||||
"""Handle the login form POST: validate credentials, store auth, show result."""
|
||||
|
||||
# Parse form data
|
||||
form = await request.form()
|
||||
token = form.get("token", "")
|
||||
discord_id_str = form.get("discord_id", "")
|
||||
service = form.get("service", "")
|
||||
|
||||
if not token or not discord_id_str or not service:
|
||||
raise HTTPException(status_code=400, detail="Missing required fields.")
|
||||
|
||||
try:
|
||||
discord_id = int(discord_id_str)
|
||||
except (ValueError, TypeError):
|
||||
raise HTTPException(status_code=400, detail="Invalid discord_id.")
|
||||
|
||||
# Consume the token on POST (the GET only validated, didn't consume)
|
||||
result = auth_store.consume_token(token)
|
||||
if result is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired link token.")
|
||||
|
||||
# Look up the AuthService
|
||||
svc_obj = get_auth_service(service)
|
||||
if svc_obj is None:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown service: {service}")
|
||||
|
||||
# Collect service-specific form fields (everything except token, discord_id, service)
|
||||
form_data: dict[str, str] = {}
|
||||
for key, value in form.items():
|
||||
if key not in ("token", "discord_id", "service"):
|
||||
form_data[key] = str(value)
|
||||
|
||||
# Authenticate against the service
|
||||
auth_result = await svc_obj.authenticate(form_data)
|
||||
|
||||
if not auth_result.success:
|
||||
return HTMLResponse(
|
||||
status_code=401,
|
||||
content=f"""<!DOCTYPE html>
|
||||
<html><head><meta charset="utf-8"><title>Login Failed</title>
|
||||
<style>
|
||||
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; }}
|
||||
h2 {{ color: #d32f2f; }}
|
||||
a {{ color: #aa5cc3; }}
|
||||
</style></head><body>
|
||||
<h2>❌ Login Failed</h2>
|
||||
<p>{auth_result.error_message or "Authentication failed. Please try again."}</p>
|
||||
<p><a href="javascript:history.back()">← Go back and try again</a></p>
|
||||
</body></html>""",
|
||||
)
|
||||
|
||||
# Store the successful auth
|
||||
auth_store.store_auth(
|
||||
discord_user_id=discord_id,
|
||||
service=service,
|
||||
external_user_id=auth_result.external_user_id or "",
|
||||
external_name=auth_result.external_name or "",
|
||||
credentials=auth_result.credentials,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Auth linked: discord=%s → %s (%s)",
|
||||
discord_id,
|
||||
service,
|
||||
auth_result.external_name,
|
||||
)
|
||||
|
||||
return HTMLResponse(f"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>Account Linked</title>
|
||||
<style>
|
||||
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; text-align: center; }}
|
||||
h1 {{ color: #388e3c; }}
|
||||
.name {{ font-weight: bold; color: #aa5cc3; }}
|
||||
p {{ color: #666; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>✅ Account Linked!</h1>
|
||||
<p>Logged in as <span class="name">{auth_result.external_name}</span> on <strong>{svc_obj.display_name}</strong>.</p>
|
||||
<p>You can close this page and return to Discord.</p>
|
||||
</body>
|
||||
</html>""")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /auth/status — get all linked services for a Discord user
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/Discord/status")
|
||||
async def auth_status(discord_id: int):
|
||||
"""
|
||||
Return all services linked to this Discord user with full details.
|
||||
|
||||
Response:
|
||||
{
|
||||
"discord_id": 123456789,
|
||||
"linked_services": {
|
||||
"jellyfin": {
|
||||
"external_user_id": "abc123",
|
||||
"external_name": "Tim",
|
||||
"linked_at": "2026-05-25T10:00:00",
|
||||
"credentials": {
|
||||
"token": "jwt...",
|
||||
"url": "http://jellyfin:8096",
|
||||
"user_id": "abc123"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
auths = auth_store.get_all_auths(discord_id)
|
||||
|
||||
linked_services: dict[str, dict] = {}
|
||||
for auth in auths:
|
||||
svc_name = auth["service"]
|
||||
linked_services[svc_name] = {
|
||||
"external_user_id": auth["external_user_id"],
|
||||
"external_name": auth["external_name"],
|
||||
"linked_at": auth["linked_at"],
|
||||
"credentials": auth["credentials"],
|
||||
}
|
||||
|
||||
return {
|
||||
"discord_id": discord_id,
|
||||
"linked_services": linked_services,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /auth/reset — wipe auth store (DEV ONLY)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from src.config import get_config # noqa: E402
|
||||
|
||||
@router.post("/reset")
|
||||
async def reset_auth():
|
||||
"""
|
||||
Reset the entire auth store — clears all link tokens and user auth records.
|
||||
|
||||
Only enabled when ALLOW_AUTH_RESET=true in the environment.
|
||||
Returns 403 in production.
|
||||
"""
|
||||
if get_config("ALLOW_AUTH_RESET", "false").lower() != "true":
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Auth reset is disabled. Set ALLOW_AUTH_RESET=true to enable (dev only).",
|
||||
)
|
||||
|
||||
auth_store.reset_all()
|
||||
logger.warning("Auth store reset via API endpoint.")
|
||||
return {"status": "ok", "message": "Auth store cleared — all tokens and auth records removed."}
|
||||
@@ -0,0 +1,241 @@
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
|
||||
from gateway.dependencies import get_llm_client, get_agent_graph
|
||||
from agents import get as get_agent, list_all as list_all_agents
|
||||
from src.state import AgentState
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
session_id: str | None = None
|
||||
agent_id: str | None = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
messages: list[dict]
|
||||
stream: bool = False
|
||||
model: str = "deepseek-chat"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _resolve_agent(agent_id: str | None = None, model: str | None = None):
|
||||
"""
|
||||
1. explicit agent_id
|
||||
2. model field (OpenWebUI sends this — maps to agent_id if registered)
|
||||
3. fallback to "naked"
|
||||
"""
|
||||
lookup = agent_id or model
|
||||
if lookup is None:
|
||||
return get_agent("naked")
|
||||
agent = get_agent(lookup)
|
||||
return agent if agent else get_agent("naked")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LangGraph helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _invoke_graph(graph, messages: list[dict]) -> str:
|
||||
"""Run the graph synchronously (non-streaming) and return the final text."""
|
||||
state: AgentState = {"messages": messages}
|
||||
result = await graph.ainvoke(state)
|
||||
last_msg = result["messages"][-1]
|
||||
return last_msg.content or ""
|
||||
|
||||
|
||||
async def _stream_graph(graph, messages: list[dict]):
|
||||
"""
|
||||
Run the graph and stream the final response token-by-token.
|
||||
|
||||
LangGraph's astream_events would require langchain-openai's ChatOpenAI
|
||||
to intercept LLM chunks. Instead we run the graph to completion (tools
|
||||
execute silently) and then stream the final text content character by
|
||||
character — this gives the client a real SSE stream without adding new
|
||||
dependencies.
|
||||
"""
|
||||
state: AgentState = {"messages": messages}
|
||||
result = await graph.ainvoke(state)
|
||||
content = result["messages"][-1].content or ""
|
||||
# Yield token-by-token so the SSE client sees incremental output
|
||||
for token in content:
|
||||
yield token
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-streaming run (kept for /chat/sync and sync completions)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def run_agent_with_tools(
|
||||
request: Request,
|
||||
messages: list[dict],
|
||||
agent_id: str | None = None,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""Send messages through the agent's LangGraph. Non-streaming."""
|
||||
agent = _resolve_agent(agent_id, model)
|
||||
graph = get_agent_graph(agent.agent_id, request)
|
||||
return await _invoke_graph(graph, messages)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Streaming generator (kept for /chat and stream completions)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def run_agent_stream(
|
||||
request: Request,
|
||||
messages: list[dict],
|
||||
agent_id: str | None = None,
|
||||
model: str | None = None,
|
||||
):
|
||||
"""Async generator — yields tokens via the agent's LangGraph."""
|
||||
agent = _resolve_agent(agent_id, model)
|
||||
graph = get_agent_graph(agent.agent_id, request)
|
||||
async for token in _stream_graph(graph, messages):
|
||||
yield token
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/")
|
||||
def root():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
async def chat(
|
||||
req: ChatRequest,
|
||||
request: Request,
|
||||
client: OpenAI = Depends(get_llm_client),
|
||||
):
|
||||
"""Streaming chat — single message, no history."""
|
||||
messages = [{"role": "user", "content": req.message}]
|
||||
|
||||
async def event_stream():
|
||||
async for token in run_agent_stream(request, messages, req.agent_id):
|
||||
payload = json.dumps({"token": token, "session_id": req.session_id})
|
||||
yield f"data: {payload}\n\n"
|
||||
yield f"data: {json.dumps({'done': True, 'session_id': req.session_id})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat/sync")
|
||||
async def chat_sync(
|
||||
req: ChatRequest,
|
||||
request: Request,
|
||||
client: OpenAI = Depends(get_llm_client),
|
||||
):
|
||||
"""Non-streaming chat — single message."""
|
||||
messages = [{"role": "user", "content": req.message}]
|
||||
response = await run_agent_with_tools(request, messages, req.agent_id)
|
||||
return {"response": response, "session_id": req.session_id}
|
||||
|
||||
|
||||
@router.get("/agents")
|
||||
def list_agents():
|
||||
"""Return all registered agents."""
|
||||
return {
|
||||
"agents": [
|
||||
{
|
||||
"agent_id": a.agent_id,
|
||||
"description": a.description,
|
||||
"skills": a.skills,
|
||||
}
|
||||
for a in list_all_agents().values()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
def list_models():
|
||||
"""Return agents as selectable models for OpenWebUI."""
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": a.agent_id,
|
||||
"object": "model",
|
||||
"created": 0,
|
||||
"owned_by": "local-agent",
|
||||
}
|
||||
for a in list_all_agents().values()
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.post("/chat/completions")
|
||||
async def chat_completions(
|
||||
req: ChatCompletionRequest,
|
||||
request: Request,
|
||||
client: OpenAI = Depends(get_llm_client),
|
||||
):
|
||||
"""OpenAI-compatible /chat/completions — supports stream=True.
|
||||
Multi-turn: req.messages contains the FULL conversation history.
|
||||
Agent resolved from the model field (OpenWebUI sends this).
|
||||
"""
|
||||
agent = _resolve_agent(model=req.model)
|
||||
|
||||
if req.stream:
|
||||
async def sse_stream():
|
||||
async for token in run_agent_stream(
|
||||
request, req.messages, agent_id=agent.agent_id,
|
||||
):
|
||||
chunk = {
|
||||
"id": "chatcmpl-local",
|
||||
"object": "chat.completion.chunk",
|
||||
"choices": [
|
||||
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
||||
],
|
||||
}
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
final_chunk = {
|
||||
"id": "chatcmpl-local",
|
||||
"object": "chat.completion.chunk",
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}
|
||||
yield f"data: {json.dumps(final_chunk)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
sse_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
|
||||
# Non-streaming — full history, LangGraph agent
|
||||
response = await run_agent_with_tools(
|
||||
request, req.messages, agent_id=agent.agent_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": "chatcmpl-local",
|
||||
"object": "chat.completion",
|
||||
"created": 0,
|
||||
"model": req.model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": response},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
Reference in New Issue
Block a user