Compare commits

..

6 Commits

39 changed files with 2528 additions and 303 deletions
+35 -6
View File
@@ -1,16 +1,19 @@
# --------------------------------------------------------------------------- # =============================================================================
# Agent Backend — Environment Variables # Agent Bot — Environment Configuration
# Copy this to .env and fill in your values. # =============================================================================
# --------------------------------------------------------------------------- # Copy this file to .env and fill in your values.
# .env is git-ignored — never commit real secrets.
# ---------------------------------------------------------------------------
# LLM — DeepSeek (OpenAI-compatible) # LLM — DeepSeek (OpenAI-compatible)
# ---------------------------------------------------------------------------
DEEPSEEK_API_KEY=sk-your-deepseek-api-key DEEPSEEK_API_KEY=sk-your-deepseek-api-key
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Discord Bot # Discord Bot
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
DISCORD_BOT_TOKEN=your-discord-bot-token-here DISCORD_BOT_TOKEN=your-discord-bot-token-here
# DISCORD_MAX_HISTORY=7 # optional, defaults to 7 (max past messages per user) # DISCORD_MAX_HISTORY=7 # optional, defaults to 7 (max past messages per user)
# DISCORD_DEFAULT_AGENT=media-agent # optional, which agent the DM bot uses # DISCORD_DEFAULT_AGENT=media-agent # optional, which agent the DM bot uses
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -18,4 +21,30 @@ DISCORD_BOT_TOKEN=your-discord-bot-token-here
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
SEERR_URL=https://seerr.example.com SEERR_URL=https://seerr.example.com
SEERR_API_KEY=your-seerr-api-key SEERR_API_KEY=your-seerr-api-key
# SEERR_TIMEOUT=30 # optional, defaults to 30 seconds # SEERR_USERNAME=your-username # alternative: username+password auth
# SEERR_PASSWORD=your-password
# SEERR_TIMEOUT=30 # optional, defaults to 30 seconds
# ---------------------------------------------------------------------------
# Auth System (Discord ↔ external services)
# ---------------------------------------------------------------------------
# The public-facing URL where users reach this bot's web API.
# Used to build the "Click here to link" URLs sent via Discord DM.
# For local dev: http://localhost:8000
# For production behind a reverse proxy: https://bot.yourdomain.com
BASE_URL=http://localhost:8000
# Where the auth SQLite database lives (relative to project root)
# AUTH_DB_PATH=data/auth.db
# Link token expiry in minutes (default 10)
# AUTH_TOKEN_EXPIRY=10
# ---------------------------------------------------------------------------
# JellyStat — PostgreSQL watch-history database
# ---------------------------------------------------------------------------
JELLYSTAT_DB_HOST=localhost
JELLYSTAT_DB_PORT=5432
JELLYSTAT_DB_USER=postgres
JELLYSTAT_DB_PASSWORD=
JELLYSTAT_DB_NAME=jfstat
+1
View File
@@ -175,3 +175,4 @@ cython_debug/
.pypirc .pypirc
.docs/ .docs/
data/
+6 -5
View File
@@ -12,7 +12,7 @@ An Agent is a lightweight wrapper:
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List from typing import Dict, List
from skills import Skill, get_combined_prompt, list_all as list_all_skills from agents.skills import Skill, get_combined_prompt, list_all as list_all_skills
@dataclass @dataclass
@@ -61,7 +61,8 @@ def load_all_agents() -> None:
import agents.media_agent # noqa: F401 import agents.media_agent # noqa: F401
# Also import skill modules so they self-register # Also import skill modules so they self-register
import skills.media_info # noqa: F401 import agents.skills.media_info # noqa: F401
import skills.seerr # noqa: F401 import agents.skills.seerr # noqa: F401
import skills.triage # noqa: F401 import agents.skills.triage # noqa: F401
import skills.easter_eggs # noqa: F401 import agents.skills.easter_eggs # noqa: F401
import agents.skills.watch_history # noqa: F401
+7 -2
View File
@@ -14,11 +14,16 @@ media_agent = Agent(
agent_id="media-agent", agent_id="media-agent",
description="Media assistant — handles movie/TV/subtitle/ticket requests " description="Media assistant — handles movie/TV/subtitle/ticket requests "
"via Seerr, Jellyfin, Sonarr, etc.", "via Seerr, Jellyfin, Sonarr, etc.",
skills=["media_info", "seerr", "triage", "easter_eggs"], skills=["media_info", "seerr", "triage", "easter_eggs", "watch_history"],
base_prompt=( base_prompt=(
"You are a media assistant connected to Seerr and other media services. " "You are a media assistant connected to Seerr and other media services. "
"Help users discover, request, and troubleshoot their media library. " "Help users discover, request, and troubleshoot their media library. "
"Use the tools provided to perform real actions." "Use the tools provided to perform real actions.\n\n"
"## Authentication\n"
"If a tool returns a message saying the user needs to log in first, "
"tell the user to type `/login <service>` in their DM (e.g. `/login jellyfin`). "
"This opens Quick Connect on their Jellyfin app so they can link their account. "
"Do NOT tell the user you 'can't connect' or 'don't have access' — just relay the login instructions."
), ),
) )
@@ -12,7 +12,7 @@ A Skill is a lightweight object with:
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, Dict, List, Optional from typing import Any, Awaitable, Callable, Dict, List, Optional
from core.config import get_config # re-export so every skill can use it from src.config import get_config # re-export so every skill can use it
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -47,6 +47,7 @@ class Skill:
prompt_fragment: str = "" prompt_fragment: str = ""
tools: List[Dict[str, Any]] = field(default_factory=list) tools: List[Dict[str, Any]] = field(default_factory=list)
execute: Optional[ToolExecutor] = None execute: Optional[ToolExecutor] = None
requires_auth: List[str] = field(default_factory=list)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -96,9 +97,15 @@ def get_all_tools(skill_names: list[str]) -> List[Dict[str, Any]]:
async def execute_tool( async def execute_tool(
skill_names: list[str], tool_name: str, args: dict skill_names: list[str], tool_name: str, args: dict,
discord_user_id: int | None = None,
) -> ToolResult | None: ) -> ToolResult | None:
"""Find the skill that owns *tool_name* and run its executor. """Find the skill that owns *tool_name* and run its executor.
If *discord_user_id* is provided, also checks whether the owning skill
requires authentication for any services. If auth is missing, returns
a friendly ToolResult.fail(...) telling the user how to log in.
Only logs failures to the console successful calls are silent. Only logs failures to the console successful calls are silent.
""" """
import logging import logging
@@ -109,6 +116,27 @@ async def execute_tool(
if s and s.execute: if s and s.execute:
for t in s.tools: for t in s.tools:
if t.get("function", {}).get("name") == tool_name: if t.get("function", {}).get("name") == tool_name:
# --- Auth gate ---
if s.requires_auth and discord_user_id is not None:
from src import auth_store
from gateway.auth import get_auth_service
missing: list[str] = []
for svc in s.requires_auth:
if not auth_store.is_authenticated(discord_user_id, svc):
missing.append(svc)
if missing:
svc_displays = ", ".join(
(get_auth_service(m) and get_auth_service(m).display_name) or m
for m in missing
)
return ToolResult.fail(
f"You need to log in to {svc_displays} first. "
+ " ".join(f"Send `/login {m}` in a DM to get started." for m in missing)
)
# --- End auth gate ---
# Inject discord_user_id so skills can resolve external user IDs
if discord_user_id is not None:
args = {**args, "_discord_user_id": discord_user_id}
try: try:
result = await s.execute(tool_name, args) result = await s.execute(tool_name, args)
if not result.success: if not result.success:
@@ -8,7 +8,7 @@ requested actions normally. Functionality is never sacrificed for a reference.
Add a new theme by adding one entry to THEMES no code changes needed. Add a new theme by adding one entry to THEMES no code changes needed.
""" """
from skills import Skill, register from agents.skills import Skill, register
THEMES = { THEMES = {
"naruto": { "naruto": {
@@ -5,7 +5,7 @@ A lightweight base skill that teaches the agent it is a media assistant.
Real API capabilities come from other skills (seerr, triage, etc.). Real API capabilities come from other skills (seerr, triage, etc.).
""" """
from skills import Skill, register from agents.skills import Skill, register
media_info_skill = Skill( media_info_skill = Skill(
name="media_info", name="media_info",
@@ -23,6 +23,16 @@ When responding:
suggest submitting a ticket if there's a problem. suggest submitting a ticket if there's a problem.
- Always confirm successful actions and warn about failures. - Always confirm successful actions and warn about failures.
## Jellyfin & Authentication
You are connected to the user's Jellyfin server. If a user asks you to
"connect to Jellyfin", "link my Jellyfin", or asks about their watch history,
simply call the `watch_history` tool. The system will automatically handle
authentication if the user isn't linked yet, they'll be guided through
Quick Connect seamlessly. NEVER tell a user you "don't have access to
Jellyfin" or "can't connect" — always try the tool first and let the system
sort it out.
This is the base media assistant persona. Real API capabilities come from the This is the base media assistant persona. Real API capabilities come from the
attached skills (seerr, triage, etc.).""", attached skills (seerr, triage, etc.).""",
) )
+1 -1
View File
@@ -24,7 +24,7 @@ from urllib.parse import quote
import httpx import httpx
from skills import Skill, register, ToolResult, get_config from agents.skills import Skill, register, ToolResult, get_config
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Config # Config
+1 -1
View File
@@ -10,7 +10,7 @@ cancelling requests, banning users), this skill teaches the LLM to:
3. Use the seerr_submit_issue tool (if available) to create the ticket. 3. Use the seerr_submit_issue tool (if available) to create the ticket.
""" """
from skills import Skill, register from agents.skills import Skill, register
# This skill has no tools of its own — it guides the LLM's behavior. # This skill has no tools of its own — it guides the LLM's behavior.
# The actual ticket submission is handled by seerr_submit_issue. # The actual ticket submission is handled by seerr_submit_issue.
+273
View File
@@ -0,0 +1,273 @@
"""
Watch History skill — fetch the user's Jellyfin watch history via JellyStat API.
Requires the user to have linked Jellyfin via `/login jellyfin` in Discord.
The auth gate (`requires_auth=["jellyfin"]`) is already active — users who
haven't linked Jellyfin will be prompted to /login first.
Architecture
------------
This skill calls the JellyStat REST API (same FastAPI process, via HTTP)
rather than accessing the PostgreSQL database directly. This keeps the
bot isolated from database credentials.
"""
from __future__ import annotations
import httpx
from agents.skills import Skill, register, ToolResult
from src import auth_store
from src.config import get_config
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
BASE_URL = (get_config("BASE_URL") or "http://localhost:8000").rstrip("/")
# ---------------------------------------------------------------------------
# Tool definitions
# ---------------------------------------------------------------------------
TOOLS = [
{
"type": "function",
"function": {
"name": "watch_history",
"description": (
"Get the user's Jellyfin watch history — titles grouped by total "
"watch time in a configurable time window. Use this when a user "
"asks what they've watched, what they've been watching recently, "
"or wants to see their viewing activity."
),
"parameters": {
"type": "object",
"properties": {
"limit": {
"type": "integer",
"description": "How many titles to return (default 10, max 20).",
},
"minutes": {
"type": "integer",
"description": (
"Time window in minutes. Default 10080 (7 days). "
"Use a large number like 525600 for 'all time' (1 year)."
),
},
},
},
},
},
{
"type": "function",
"function": {
"name": "watch_genres",
"description": (
"Get the user's most-watched genres from Jellyfin, ranked by "
"total watch time. Use this when a user asks what kinds of "
"content they watch most, their favourite genres, or what "
"categories dominate their viewing."
),
"parameters": {
"type": "object",
"properties": {
"minutes": {
"type": "integer",
"description": (
"Time window in minutes. Default 10080 (7 days). "
"Use a large number like 525600 for 'all time'."
),
},
},
},
},
},
{
"type": "function",
"function": {
"name": "watch_summary",
"description": (
"Get an all-time Jellyfin watch summary — total watch time, "
"most-watched series, most-watched movie, 30-day and 7-day "
"activity, and top 3 genres. Use this when a user asks for "
"their overall stats, a dashboard, or 'how much have I watched?'."
),
"parameters": {"type": "object", "properties": {}},
},
},
]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _resolve_jellyfin_id(args: dict) -> str | None:
"""Extract the Jellyfin user ID from auth_store using the injected Discord ID."""
discord_user_id = args.pop("_discord_user_id", None)
if discord_user_id is None:
return None # not called from Discord — shouldn't happen with auth gate
auth = auth_store.get_auth(discord_user_id, "jellyfin")
if auth is None or not auth.get("external_user_id"):
return None
return auth["external_user_id"]
async def _fetch_json(url: str) -> dict:
"""GET *url* and return the parsed JSON body, or {} on failure."""
async with httpx.AsyncClient(timeout=10) as client:
resp = await client.get(url)
resp.raise_for_status()
return resp.json()
def _format_seconds(total: float) -> str:
"""Convert seconds to a human-friendly string."""
total = max(total, 0)
hours = int(total // 3600)
minutes = int((total % 3600) // 60)
if hours and minutes:
return f"{hours}h {minutes}m"
if hours:
return f"{hours}h"
if minutes:
return f"{minutes}m"
return f"{int(total)}s"
def _format_history(data: dict, limit: int) -> ToolResult:
"""Format a watch-history API response for the LLM."""
items = data.get("items", [])[:limit]
if not items:
return ToolResult.ok("You haven't watched anything in this time window.")
lines = [f"**Watch History** (last {data.get('window_minutes', '?')} minutes):"]
for i, item in enumerate(items, 1):
duration = _format_seconds(item["watch_time_sec"])
icon = "📺" if item["media_type"] == "series" else "🎬"
lines.append(f"{i}. {icon} **{item['title']}** — {duration}")
return ToolResult.ok("\n".join(lines))
def _format_genres(data: dict) -> ToolResult:
"""Format a genre-summary API response for the LLM."""
genres = data.get("genres", [])
if not genres:
return ToolResult.ok("No genre data available for this time window.")
lines = [f"**Top Genres** (last {data.get('window_minutes', '?')} minutes):"]
for i, g in enumerate(genres, 1):
duration = _format_seconds(g["watch_time_sec"])
lines.append(f"{i}. **{g['genre']}** — {duration}")
return ToolResult.ok("\n".join(lines))
def _format_summary(data: dict) -> ToolResult:
"""Format a user-summary API response for the LLM."""
total = _format_seconds(data.get("total_watch_time_sec", 0))
last_30 = _format_seconds(data.get("total_last_30d_sec", 0))
last_7 = _format_seconds(data.get("total_last_7d_sec", 0))
top_series = data.get("most_watched_series") or ""
top_movie = data.get("most_watched_movie") or ""
top_genres = data.get("top_genres", [])
genres_str = ", ".join(top_genres) if top_genres else ""
lines = [
"**Your Jellyfin Summary** (all time):",
f"⏱️ Total watch time: **{total}**",
f"📺 Most-watched series: **{top_series}**",
f"🎬 Most-watched movie: **{top_movie}**",
f"📅 Last 30 days: **{last_30}**",
f"📅 Last 7 days: **{last_7}**",
f"🏷️ Top genres: {genres_str}",
]
return ToolResult.ok("\n".join(lines))
# ---------------------------------------------------------------------------
# Executor
# ---------------------------------------------------------------------------
async def _execute(tool_name: str, args: dict) -> ToolResult:
# 1. Resolve Jellyfin user ID
jellyfin_id = _resolve_jellyfin_id(args)
if jellyfin_id is None:
return ToolResult.fail(
"Your Jellyfin account is not linked. Use `/login jellyfin` in a DM to connect."
)
# 2. Route to the right JellyStat endpoint
try:
match tool_name:
case "watch_history":
limit = args.get("limit", 10)
minutes = args.get("minutes", 10080)
url = f"{BASE_URL}/jellystat/history/{jellyfin_id}?minutes={minutes}"
data = await _fetch_json(url)
return _format_history(data, limit)
case "watch_genres":
minutes = args.get("minutes", 10080)
url = f"{BASE_URL}/jellystat/genres/{jellyfin_id}?minutes={minutes}"
data = await _fetch_json(url)
return _format_genres(data)
case "watch_summary":
url = f"{BASE_URL}/jellystat/summary/{jellyfin_id}"
data = await _fetch_json(url)
return _format_summary(data)
case _:
return ToolResult.fail(f"Unknown tool: {tool_name}")
except httpx.HTTPError:
return ToolResult.fail(
"Could not reach the watch-history service right now. "
"Please try again in a moment."
)
# ---------------------------------------------------------------------------
# Skill registration
# ---------------------------------------------------------------------------
_PROMPT = (
"## Watch History\n"
"\n"
"You have THREE tools to answer questions about the user's Jellyfin watch activity:\n"
"\n"
"1. **`watch_history`** — per-title watch time in a time window (default: 7 days).\n"
" Use when a user asks what they've watched, to show their history,\n"
" or what they watched this week or yesterday.\n"
"\n"
"2. **`watch_genres`** — watch time broken down by genre.\n"
" Use when a user asks what genres they watch, whether they watch more\n"
" comedy than drama, or what their most-watched genre is.\n"
"\n"
"3. **`watch_summary`** — all-time dashboard: total watch time, most-watched\n"
" series and movie, 30-day and 7-day activity, and top 3 genres.\n"
" Use when a user asks for their stats, how much they've watched in\n"
" total, or what their favourites are.\n"
"\n"
"Always call the appropriate tool before answering — NEVER guess at watch data.\n"
"Format watch times in a human-readable way (hours and minutes), but keep the\n"
"raw data visible too."
)
watch_history_skill = Skill(
name="watch_history",
description="User's Jellyfin watch history, genres, and summary stats",
requires_auth=["jellyfin"],
prompt_fragment=_PROMPT,
tools=TOOLS,
execute=_execute,
)
register(watch_history_skill)
-235
View File
@@ -1,235 +0,0 @@
# API Architecture — Agent + Skill + Graph Pipeline
This document explains how the API routes user messages through the
agent / skill / LangGraph pipeline to produce responses.
---
## Overview
```
┌─────────────────────────────────────────────────────────────────┐
│ OpenWebUI / Client │
│ POST /v1/chat/completions { model, messages, stream } │
└──────────────────────────────┬──────────────────────────────────┘
┌──────────────────────────────────────────────────────────────────┐
│ api/v1/chat.py — chat_completions() │
│ │
│ 1. _resolve_agent(req.model) → Agent │
│ 2. get_agent_graph(agent_id) → compiled StateGraph │
│ 3. graph.ainvoke(state) or _stream_graph(graph, messages) │
└──────────────────────────────┬───────────────────────────────────┘
┌──────────────────────────────────────────────────────────────────┐
│ LangGraph StateGraph (core/graph.py) │
│ │
│ ┌──────────────┐ tool_calls? ┌──────────────┐ │
│ │ agent_node │ ───────────────▶ │ tool_node │ │
│ │ (LLM call) │ ◀─────────────── │ (skill exec) │ │
│ └──────┬───────┘ └──────────────┘ │
│ │ no tool_calls │
│ ▼ │
│ [END] │
└──────────────────────────────────────────────────────────────────┘
## Key Concepts
### 1. Agent
An **Agent** is a persona + skill bundle. Defined in `agents/`.
```python
# agents/media_agent.py
Agent(
agent_id="media-agent",
description="Media assistant with Seerr integration",
skills=["media_info", "seerr", "triage"],
base_prompt="You are a media assistant...",
)
```
- `agent_id` — unique name, exposed as a model in OpenWebUI
- `skills` — list of skill names to load
- `base_prompt` — starting system prompt, combined with skill fragments
- `build_system_prompt()` — merges base_prompt + all skill prompt fragments
Agents self-register at import time via `agents/__init__.py`'s `register()`.
`main.py` calls `load_all_agents()` at startup to import every agent and skill
module.
### 2. Skill
A **Skill** is a capability bundle. Defined in `skills/`.
```python
# skills/seerr.py
Skill(
name="seerr",
description="Seerr integration — trending, discover, request media, submit issues",
prompt_fragment="## Seerr Media Tools\n...",
tools=[...], # OpenAI function-calling schema
execute=_execute, # async handler: tool_name + args → ToolResult
)
```
- `prompt_fragment` — injected into the agent's system prompt.
- `tools` — list of OpenAI function definitions (name, description, parameters).
- `execute` — async callable that routes tool calls to API handlers.
### 3. Graph
Each agent gets a **compiled LangGraph StateGraph** built by
`core/graph.py:create_agent_graph()`. The graph is compiled lazily on the
first request and cached on `app.state.agent_graphs` for the lifetime of the
process.
| Graph node / edge | What it does |
|---|---|
| `agent_node` | Converts state messages to OpenAI dicts, calls the LLM with the agent's system prompt + tool definitions, returns an `AIMessage` |
| `tool_node` | Reads `tool_calls` from the last AI message, calls `execute_tool()` from the skill system, returns `ToolMessage` results |
| `_should_continue` | Conditional edge — returns `"tool_node"` if the AI message has `tool_calls`, else `END` |
### 4. State
Defined in `core/state.py`:
```python
class AgentState(TypedDict):
messages: Annotated[list, add_messages]
```
LangGraph's `add_messages` reducer appends new messages and replaces messages
with matching IDs (so tool-call results overwrite their placeholders).
### 5. Message Conversion
Because we use the raw `openai` client (not `langchain-openai`), messages must
be converted between LangChain and OpenAI formats at every LLM call:
- **LangChain → OpenAI** (`_lc_role_to_openai`, `_langchain_tc_to_openai`):
Maps `type``role` and converts top-level `name`/`args` tool-calls into
the nested `function` sub-object that the OpenAI API expects.
- **OpenAI → LangChain** (inside `agent_node`):
Converts the `ChatCompletionMessage` response into an `AIMessage` with
LangChain-format `tool_calls` (top-level `name`/`args`/`id`).
---
## Full Request Flow
### Step-by-step: "What are trending movies?"
```
1. OpenWebUI sends:
POST /v1/chat/completions
{
"model": "media-agent",
"messages": [
{"role": "user", "content": "What are trending movies?"}
],
"stream": false
}
2. chat_completions():
→ _resolve_agent(model="media-agent")
→ get_agent("media-agent") → Agent(skills=["media_info", "seerr", "triage"])
→ get_agent_graph("media-agent", request)
→ looks up app.state.agent_graphs["media-agent"]
→ first call → create_agent_graph() compiles the graph with 7 Seerr tools
→ run_agent_with_tools(request, messages, agent_id)
→ _invoke_graph(graph, messages)
3. Graph — Pass 1 (agent_node):
→ LLM receives: [system prompt] + [user: "What are trending movies?"]
→ LLM responds with tool_calls: seerr_trending(kind="movie")
→ agent_node returns AIMessage with tool_calls in LangChain format
4. Graph — _should_continue:
→ AIMessage has tool_calls → route to "tool_node"
5. Graph — tool_node:
→ Reads tool_call: name="seerr_trending", args={"kind": "movie"}
→ execute_tool(["media_info", "seerr", "triage"], "seerr_trending", ...)
→ Seerr API → GET /api/v1/discover/trending?mediaType=movie
→ Returns ToolMessage with formatted results including [tmdb:IDs]
6. Graph — Pass 2 (agent_node):
→ LLM receives previous exchange + tool result
→ LLM responds with text only (no tool_calls)
→ agent_node returns AIMessage(content="Here are the top trending movies!...")
7. Graph — _should_continue:
→ No tool_calls → route to END
8. chat_completions() returns:
{ "choices": [{"message": {"role": "assistant", "content": "Here are the top..."}}] }
```
### Step-by-step: "Request the 2026 one" (multi-turn context)
```
1. OpenWebUI sends the FULL history:
{
"model": "media-agent",
"messages": [
{"role": "user", "content": "What are trending movies?"},
{"role": "assistant", "content": "Here are the top 10 trending movies!
1. **Mortal Kombat II** (2026) [tmdb:931285] — ..."},
{"role": "user", "content": "could request the mortal kombat one?"},
{"role": "assistant", "content": "There are several Mortal Kombat entries! ..."},
{"role": "user", "content": "the 2026 one"}
]
}
2. chat_completions():
→ req.messages contains the ENTIRE conversation history
→ graph.ainvoke({"messages": all_messages})
→ agent_node prepends system prompt and sends everything to the LLM
3. LLM reasons from full context:
- Previously listed Mortal Kombat II (2026) with [tmdb:931285]
- The user said "request the mortal kombat one" → I searched and showed 4 options
- Now they say "the 2026 one" → that matches Mortal Kombat II (2026) [tmdb:931285]
- I should call seerr_request_media(kind="movie", title="Mortal Kombat II", tmdb_id=931285)
4. tool_node executes the request → ✅ Success
```
---
## Streaming
Streaming works slightly differently from the sync path:
```
chat_completions(stream=True)
→ _stream_graph(graph, messages)
→ graph.ainvoke(state) # runs graph to completion (tools execute silently)
→ yields content character-by-character via SSE
```
For true token-level streaming (tokens appear as the LLM generates them),
the agent_node would need to use `langchain-openai`'s `ChatOpenAI` instead of
the raw `openai` client. The current approach is a pragmatic middle ground
that avoids adding another dependency while still giving the SSE client
incremental output.
---
## File Map
| File | Responsibility |
|---|---|
| `main.py` | FastAPI app, singleton creation, router mounting |
| `api/v1/chat.py` | Endpoints — resolves agent, invokes graph, formats responses |
| `api/dependencies.py` | `get_llm_client()`, `get_agent_graph()` — FastAPI `Depends` |
| `core/graph.py` | `create_agent_graph()` — builds the StateGraph |
| `core/state.py` | `AgentState` TypedDict |
| `core/llm.py` | `create_client()` — OpenAI client factory |
| `core/config.py` | Environment variable loader |
| `agents/` | Agent definitions (dataclass + self-registration) |
| `skills/` | Skill definitions (prompt fragments + tools + executors) |
+75
View File
@@ -0,0 +1,75 @@
# Gateway Architecture — Agent + Skill + Graph Pipeline
This is the **interface layer** of the Agents project. Everything that connects
the outside world to the agent system lives here — REST APIs, Discord bot,
and authentication.
---
## Directory Map
| Path | Description | Docs |
|---|---|---|
| `gateway/v1/` | REST API endpoints — chat, agent listing, OpenAI-compatible completions | [v1.md](v1/v1.md) |
| `gateway/discord/` | Discord bot connector — in-process DM handler with LangGraph integration | [discord.md](discord/discord.md) |
| `gateway/auth/` | Auth service registry + Jellyfin Quick Connect implementation | [auth.md](auth/auth.md) |
---
## Supporting Modules
| Path | Purpose |
|---|---|
| `gateway/dependencies.py` | FastAPI `Depends` providers — `get_llm_client()`, `get_agent_graph()` |
| `src/config.py` | `.env` loader and config accessor |
| `src/llm.py` | OpenAI-compatible client factory (DeepSeek) |
| `src/state.py` | LangGraph `AgentState` TypedDict |
| `src/graph.py` | LangGraph StateGraph factory — agent_node, tool_node, routing |
| `src/tools_adapter.py` | Wraps skill tools as LangChain `@tool` functions |
| `src/auth_store.py` | SQLite persistence for Discord → service auth linking |
| `agents/` | Agent definitions (dataclass + registry) |
| `agents/skills/` | Skill definitions — prompt fragments, tool schemas, executors |
---
## High-Level Request Flow
```
┌──────────────────────────────┐
│ Client (OpenWebUI / HTTP) │
└──────────────┬───────────────┘
│ POST /v1/chat/completions
┌──────────────────────────────┐
│ gateway/v1/chat.py │ ← resolves agent, invokes graph
└──────────────┬───────────────┘
┌──────────────────────────────┐
│ LangGraph StateGraph │ ← src/graph.py
│ ┌──────────┐ ┌──────────┐│
│ │agent_node│──▶│tool_node ││
│ │(LLM call)│◀──│(skills) ││
│ └──────────┘ └──────────┘│
└──────────────┬───────────────┘
┌──────────────────────────────┐
│ agents/skills/ │ ← Seerr API, Jellyfin API, etc.
└──────────────────────────────┘
```
For a detailed step-by-step walkthrough of the graph execution (including
multi-turn context and tool-calling loops), see [v1.md](v1/v1.md).
---
## Startup
`main.py` is the entry point. At startup it:
1. Loads `.env` → creates the LLM client (DeepSeek) → stores on `app.state.llm_client`
2. Calls `load_all_agents()` → imports every agent and skill module (they self-register)
3. Imports `gateway.auth.jellyfin` → self-registers the Jellyfin auth service
4. Mounts routers: `/v1/*` (chat endpoints) and `/api/v1/auth/*` (auth endpoints)
5. Starts the Discord bot as a background asyncio task (lifespan)
+93
View File
@@ -0,0 +1,93 @@
"""
Auth Service registry — generic, pluggable authentication for any service.
Add a new service (Plex, Seerr, etc.) by:
1. Subclassing AuthService
2. Dropping the module in this package
3. Calling register_auth_service() at import time
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Optional
# ---------------------------------------------------------------------------
# AuthResult — returned by AuthService.authenticate()
# ---------------------------------------------------------------------------
@dataclass
class AuthResult:
"""Outcome of a credential validation attempt."""
success: bool
external_user_id: Optional[str] = None
external_name: Optional[str] = None
credentials: Optional[dict] = None
error_message: Optional[str] = None
# ---------------------------------------------------------------------------
# AuthService — abstract base class
# ---------------------------------------------------------------------------
class AuthService(ABC):
"""A service that users can authenticate against (Jellyfin, Seerr, Plex, etc.)
Subclasses must implement:
- name : unique identifier used in URLs and DB keys
- display_name : human-readable label shown in Discord
- render_login_form(token, discord_id) → HTML string
- authenticate(form_data) → AuthResult
"""
@property
@abstractmethod
def name(self) -> str:
"""Unique service name: "jellyfin", "seerr", etc."""
...
@property
@abstractmethod
def display_name(self) -> str:
"""Human-readable: "Jellyfin", "Seerr", "Plex" """
...
@abstractmethod
def render_login_form(self, token: str, discord_id: int) -> str:
"""Return HTML string with a login form for this service.
The form MUST include these hidden fields:
<input type="hidden" name="token" value="{token}">
<input type="hidden" name="discord_id" value="{discord_id}">
<input type="hidden" name="service" value="{self.name}">
"""
...
@abstractmethod
async def authenticate(self, form_data: dict) -> AuthResult:
"""Validate credentials against the service. Return AuthResult."""
...
# ---------------------------------------------------------------------------
# Global registry
# ---------------------------------------------------------------------------
_registry: dict[str, AuthService] = {}
def register_auth_service(svc: AuthService) -> None:
"""Register an AuthService so it can be looked up by name."""
_registry[svc.name] = svc
def get_auth_service(name: str) -> AuthService | None:
"""Look up a registered AuthService by name."""
return _registry.get(name)
def list_auth_services() -> list[str]:
"""Return names of all registered auth services."""
return list(_registry.keys())
+152
View File
@@ -0,0 +1,152 @@
# Auth — Service Registry & Persistence
The authentication system lets Discord users link their accounts to external
services (currently **Jellyfin**) so the agent can perform actions on their
behalf (e.g. checking watch history).
---
## Architecture
```
gateway/auth/ gateway/v1/auth.py
┌──────────────────────┐ ┌──────────────────────────────┐
│ AuthService (ABC) │ │ GET /api/v1/auth/login │
│ ├─ JellyfinAuth │◀─────────│ POST /api/v1/auth/login │
│ └─ (Plex, Seerr…) │ │ GET /api/v1/auth/status │
│ │ │ GET /api/v1/auth/reset │
└─────────┬────────────┘ └──────────────────────────────┘
src/auth_store.py
┌──────────────────────┐
│ SQLite │
│ ├─ link_tokens │ one-time tokens sent via Discord DM
│ └─ user_auth │ per-user, per-service credentials
└──────────────────────┘
```
---
## Files
| File | Purpose |
|---|---|
| `gateway/auth/__init__.py` | Abstract `AuthService` base class + global registry |
| `gateway/auth/jellyfin.py` | Jellyfin implementation — Quick Connect + username/password |
| `gateway/v1/auth.py` | REST endpoints for the web-based login flow |
| `src/auth_store.py` | SQLite persistence for link tokens and stored credentials |
---
## Flow: Discord User Links Jellyfin
```
Discord DM Web Browser Jellyfin Server
│ │ │
│ 1. /login jellyfin │ │
│ ──────────────────────────────▶│ │
│ Bot creates link token in │ │
│ SQLite, DMs the user a URL │ │
│ │ │
│ 2. User clicks link │ │
│ ◀─────────────────────────────▶│ │
│ │ GET /api/v1/auth/login │
│ │ ?service=jellyfin │
│ │ &token=xxx&discord_id=123 │
│ │ │
│ │ 3. Serve Quick Connect form │
│ │ ◀──────────────────────────── │
│ │ │
│ │ 4. Initiate Quick Connect │
│ │ ─────────────────────────────▶│
│ │ POST /QuickConnect/Initiate │
│ │ ◀── { Code: "ABC123" } │
│ │ │
│ 5. User enters code in │ │
│ Jellyfin app │ │
│ │ │
│ │ 6. Poll: is it authorized? │
│ │ ─────────────────────────────▶│
│ │ GET /QuickConnect/Connect │
│ │ ◀── Authenticated + Token │
│ │ │
│ 7. auth_store saves: │ │
│ (discord_id, jellyfin, │ │
│ AccessToken, username) │ │
│ │ │
│ 8. "✅ Linked to Jellyfin!" │ │
│ ◀───────────────────────────── │ │
```
---
## AuthService Base Class
```python
class AuthService(ABC):
name: str # "jellyfin"
display_name: str # "Jellyfin"
def render_login_form(token, discord_id) -> str: ...
async def authenticate(form_data) -> AuthResult: ...
```
Add a new service (e.g. Plex, Seerr) by subclassing `AuthService`, dropping
the module in `gateway/auth/`, and calling `register_auth_service()` at import
time. The REST endpoints and auth store work generically — no changes needed.
---
## Current Implementation: Jellyfin
`gateway/auth/jellyfin.py` supports two flows:
| Method | How it works |
|---|---|
| **Quick Connect** (primary) | Calls Jellyfin's `/QuickConnect/Initiate` → polls `/QuickConnect/Connect` → stores the `AccessToken` |
| **Username/Password** (fallback) | Renders an HTML form → user submits credentials → calls `/Users/AuthenticateByName` → stores the `AccessToken` |
The stored credentials include:
- `external_user_id` — Jellyfin user ID
- `external_name` — Jellyfin username
- `credentials` dict — `{"AccessToken": "...", "ServerURL": "..."}`
---
## Auth Store (SQLite)
Two tables in `data/auth.db`:
```sql
-- One-time tokens for the web login flow (expire after 10 min)
CREATE TABLE link_tokens (
token TEXT PRIMARY KEY,
discord_id INTEGER NOT NULL,
service TEXT NOT NULL,
created_at TEXT NOT NULL,
used INTEGER DEFAULT 0
);
-- Per-user, per-service stored credentials
CREATE TABLE user_auth (
discord_id INTEGER NOT NULL,
service TEXT NOT NULL,
external_user_id TEXT,
external_name TEXT,
credentials TEXT, -- JSON
created_at TEXT NOT NULL,
PRIMARY KEY (discord_id, service)
);
```
---
## Skill-Level Auth Gating
Skills can declare `requires_auth=["jellyfin"]`. When a tool is executed,
the skill system checks the auth store. If the user isn't linked:
1. The tool returns `ToolResult.fail("Please login first using /login jellyfin")`
2. The LLM relays this message to the user in Discord
3. The user types `/login jellyfin` → Quick Connect flow → re-linked → try again
+401
View File
@@ -0,0 +1,401 @@
"""
Jellyfin AuthService — validates Jellyfin credentials and stores the session token.
Two authentication flows:
1. Quick Connect (primary): user enters a short code on their Jellyfin app.
- initiate_quick_connect() → {code, secret}
- poll_quick_connect(secret) → "Active" | "Authorized" | "Expired"
- authenticate_quick_connect(secret) → AuthResult with token
2. Username/password (legacy): renders an HTML form, called via the REST API.
- render_login_form(token, discord_id) → HTML string
- authenticate(form_data) → AuthResult
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Optional
import httpx
from gateway.auth import AuthService, AuthResult, register_auth_service
from src.config import get_config
logger = logging.getLogger("auth.jellyfin")
# Emby-style authorization header required by Jellyfin's AuthenticateByName
_EMBY_HEADER = (
'MediaBrowser Client="AgentBot",'
'Device="DiscordBot",'
'DeviceId="agent-bot",'
'Version="1.0"'
)
@dataclass
class QuickConnectResult:
"""Result of a Quick Connect initiation."""
secret: str
code: str
device_id: str
device_name: str
class JellyfinAuth(AuthService):
name = "jellyfin"
display_name = "Jellyfin"
# ------------------------------------------------------------------
# Quick Connect helpers
# ------------------------------------------------------------------
def _qc_headers(self) -> dict[str, str]:
"""Return headers used by all Quick Connect API calls."""
return {
"X-Emby-Authorization": (
'MediaBrowser Client="AgentBot",'
'Device="DiscordBot",'
'DeviceId="agent-bot-qc",'
'Version="1.0"'
)
}
async def _resolve_url(self) -> str | None:
"""
Resolve the Jellyfin server URL.
1. Check JELLYFIN_URL env var (used in deployment).
2. Check if user already has a stored auth with a URL (from legacy login).
Returns None if no URL is configured.
"""
# First: explicit env var
env_url = get_config("JELLYFIN_URL")
if env_url:
return env_url.strip().rstrip("/")
return None
# ------------------------------------------------------------------
# Phase 1a: initiate Quick Connect
# ------------------------------------------------------------------
async def initiate_quick_connect(self, url: str | None = None) -> QuickConnectResult | None:
"""
Call Jellyfin's POST /QuickConnect/Initiate.
Returns a QuickConnectResult with {secret, code} or None on failure.
The *code* is what the user enters on their Jellyfin page.
The *secret* is used internally to poll/authenticate.
"""
base_url = url or await self._resolve_url()
if not base_url:
logger.error("QuickConnect failed — no JELLYFIN_URL configured.")
return None
logger.info("Initiating Quick Connect on %s", base_url)
async with httpx.AsyncClient(timeout=10) as client:
try:
resp = await client.post(
f"{base_url}/QuickConnect/Initiate",
headers=self._qc_headers(),
json={},
)
if resp.status_code != 200:
logger.warning(
"QuickConnect init failed: HTTP %s%s",
resp.status_code, resp.text[:200],
)
return None
data = resp.json()
secret = data.get("Secret", "")
code = data.get("Code", "")
device_id = data.get("DeviceId", "")
device_name = data.get("DeviceName", "")
if not secret or not code:
logger.warning("QuickConnect init returned unexpected payload: %s", data)
return None
logger.info(
"Quick Connect initiated: code=%s device=%s",
code, device_name,
)
return QuickConnectResult(
secret=secret,
code=code,
device_id=device_id,
device_name=device_name,
)
except httpx.TimeoutException:
logger.error("QuickConnect init timed out reaching %s", base_url)
return None
except httpx.ConnectError:
logger.error("QuickConnect init — cannot connect to %s", base_url)
return None
except Exception:
logger.exception("Unexpected error during QuickConnect init")
return None
# ------------------------------------------------------------------
# Phase 1b: poll Quick Connect status
# ------------------------------------------------------------------
async def poll_quick_connect(self, secret: str, url: str | None = None) -> str:
"""
Call Jellyfin's GET /QuickConnect/Connect?secret=<secret>.
Returns one of:
- "Active" → user hasn't entered the code yet
- "Authorized" → user entered code AND approved
- "Expired" → code expired / unknown secret
- "Error" → network or unexpected failure
"""
base_url = url or await self._resolve_url()
if not base_url:
logger.error("QuickConnect poll failed — no JELLYFIN_URL configured.")
return "Error"
async with httpx.AsyncClient(timeout=10) as client:
try:
resp = await client.get(
f"{base_url}/QuickConnect/Connect",
params={"secret": secret},
headers=self._qc_headers(),
)
if resp.status_code == 404:
return "Expired"
if resp.status_code == 200:
data = resp.json()
# Jellyfin returns "Authenticated" (not "Authorized")
if data.get("Authenticated") is True:
return "Authorized"
# "Authenticated" is false, missing, or null → still active
return "Active"
logger.warning(
"QuickConnect poll unexpected: HTTP %s%s",
resp.status_code, resp.text[:200],
)
return "Error"
except (httpx.TimeoutException, httpx.ConnectError):
logger.warning("QuickConnect poll network error")
return "Error"
except Exception:
logger.exception("Unexpected error during QuickConnect poll")
return "Error"
# ------------------------------------------------------------------
# Phase 1c: exchange secret for token
# ------------------------------------------------------------------
async def authenticate_quick_connect(
self, secret: str, url: str | None = None
) -> AuthResult:
"""
After poll_quick_connect returns "Authorized", call
POST /Users/AuthenticateWithQuickConnect to exchange the secret
for a real access token.
Returns AuthResult with token, user_id, username on success.
"""
base_url = url or await self._resolve_url()
if not base_url:
return AuthResult(
success=False,
error_message="No Jellyfin server URL configured.",
)
logger.info("Exchanging QuickConnect secret for token on %s", base_url)
async with httpx.AsyncClient(timeout=10) as client:
try:
resp = await client.post(
f"{base_url}/Users/AuthenticateWithQuickConnect",
json={"Secret": secret},
headers=self._qc_headers(),
)
if resp.status_code != 200:
logger.warning(
"QuickConnect auth exchange failed: HTTP %s",
resp.status_code,
)
return AuthResult(
success=False,
error_message="Quick Connect authentication failed. The code may have expired.",
)
data = resp.json()
user = data.get("User", {})
token = data.get("AccessToken", "")
if not token:
return AuthResult(
success=False,
error_message="Jellyfin returned an unexpected response.",
)
logger.info(
"QuickConnect linked: user=%s (%s)",
user.get("Name", "?"),
user.get("Id", "?"),
)
return AuthResult(
success=True,
external_user_id=user.get("Id", ""),
external_name=user.get("Name", "?"),
credentials={
"token": token,
"url": base_url,
"user_id": user.get("Id", ""),
},
)
except httpx.TimeoutException:
return AuthResult(
success=False,
error_message=f"Could not reach {base_url} — connection timed out.",
)
except httpx.ConnectError:
return AuthResult(
success=False,
error_message=f"Could not connect to {base_url}. Is the server running?",
)
except Exception:
logger.exception("Unexpected error during QuickConnect auth exchange")
return AuthResult(
success=False,
error_message="An unexpected error occurred during authentication.",
)
# ------------------------------------------------------------------
# Login form (legacy — used by the REST API)
# ------------------------------------------------------------------
def render_login_form(self, token: str, discord_id: int) -> str:
return f"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Link Jellyfin</title>
<style>
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; }}
h2 {{ margin-bottom: 4px; }}
.sub {{ color: #666; margin-bottom: 24px; }}
label {{ display: block; margin-top: 16px; font-weight: 600; }}
input {{ width: 100%; padding: 10px; margin-top: 4px; border: 1px solid #ccc; border-radius: 6px; box-sizing: border-box; }}
button {{ margin-top: 24px; width: 100%; padding: 12px; background: #aa5cc3; color: #fff; border: none; border-radius: 6px; font-size: 16px; cursor: pointer; }}
button:hover {{ background: #9448b0; }}
</style>
</head>
<body>
<h2>🔗 Link Jellyfin to Discord</h2>
<p class="sub">Enter your Jellyfin server URL and credentials to link your account.</p>
<form method="POST" action="/api/v1/auth/login">
<input type="hidden" name="token" value="{token}">
<input type="hidden" name="discord_id" value="{discord_id}">
<input type="hidden" name="service" value="jellyfin">
<label for="jellyfin_url">Jellyfin Server URL</label>
<input id="jellyfin_url" name="jellyfin_url" type="url"
placeholder="https://jellyfin.example.com" required>
<label for="username">Username</label>
<input id="username" name="username" type="text"
placeholder="Your Jellyfin username" required autofocus>
<label for="password">Password</label>
<input id="password" name="password" type="password"
placeholder="Your Jellyfin password" required>
<button type="submit">Link Account</button>
</form>
</body>
</html>"""
# ------------------------------------------------------------------
# Authentication
# ------------------------------------------------------------------
async def authenticate(self, form_data: dict) -> AuthResult:
url = form_data.get("jellyfin_url", "").strip().rstrip("/")
username = form_data.get("username", "").strip()
password = form_data.get("password", "").strip()
if not url or not username or not password:
return AuthResult(
success=False,
error_message="All fields are required (URL, username, password).",
)
logger.info("Attempting Jellyfin login for '%s' on %s", username, url)
async with httpx.AsyncClient(timeout=10) as client:
try:
resp = await client.post(
f"{url}/Users/AuthenticateByName",
json={"Username": username, "Pw": password},
headers={"X-Emby-Authorization": _EMBY_HEADER},
)
if resp.status_code != 200:
logger.warning(
"Jellyfin login failed for '%s': HTTP %s", username, resp.status_code
)
return AuthResult(
success=False,
error_message=f"Login failed — check your server URL and credentials.",
)
data = resp.json()
user = data.get("User", {})
token = data.get("AccessToken", "")
if not token:
return AuthResult(
success=False,
error_message="Jellyfin returned an unexpected response.",
)
logger.info(
"Jellyfin login OK: user=%s (%s)",
user.get("Name", "?"),
user.get("Id", "?"),
)
return AuthResult(
success=True,
external_user_id=user.get("Id", ""),
external_name=user.get("Name", username),
credentials={
"token": token,
"url": url,
"user_id": user.get("Id", ""),
},
)
except httpx.TimeoutException:
return AuthResult(
success=False,
error_message=f"Could not reach {url} — connection timed out. Check the URL.",
)
except httpx.ConnectError:
return AuthResult(
success=False,
error_message=f"Could not connect to {url}. Is the server running?",
)
except Exception as exc:
logger.exception("Unexpected error during Jellyfin login")
return AuthResult(
success=False,
error_message=f"An unexpected error occurred. Please try again.",
)
# Self-register at import time
register_auth_service(JellyfinAuth())
@@ -1,7 +1,7 @@
from fastapi import Request from fastapi import Request
from openai import OpenAI from openai import OpenAI
from core.graph import create_agent_graph from src.graph import create_agent_graph
def get_llm_client(request: Request) -> OpenAI: def get_llm_client(request: Request) -> OpenAI:
+148 -6
View File
@@ -23,10 +23,12 @@ import os
import discord import discord
from agents import list_all as list_all_agents from agents import list_all as list_all_agents
from bot.conversation import ConversationStore from gateway.discord.conversation import ConversationStore
from core.config import DEEPSEEK_API_KEY, get_config from src.config import DEEPSEEK_API_KEY, get_config
from core.graph import create_agent_graph from src.graph import create_agent_graph
from core.llm import create_client from src.llm import create_client
from src import auth_store
from gateway.auth import list_auth_services, get_auth_service
logger = logging.getLogger("bot.discord") logger = logging.getLogger("bot.discord")
@@ -36,6 +38,7 @@ logger = logging.getLogger("bot.discord")
DISCORD_BOT_TOKEN = get_config("DISCORD_BOT_TOKEN") or "" DISCORD_BOT_TOKEN = get_config("DISCORD_BOT_TOKEN") or ""
DISCORD_MAX_HISTORY = int(get_config("DISCORD_MAX_HISTORY", "7")) DISCORD_MAX_HISTORY = int(get_config("DISCORD_MAX_HISTORY", "7"))
DISCORD_DEFAULT_AGENT = get_config("DISCORD_DEFAULT_AGENT", "media-agent") DISCORD_DEFAULT_AGENT = get_config("DISCORD_DEFAULT_AGENT", "media-agent")
BASE_URL = get_config("BASE_URL", "http://localhost:8000").rstrip("/")
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# LLM client shared by all agents (same as the REST API uses) # LLM client shared by all agents (same as the REST API uses)
@@ -138,6 +141,12 @@ class AgentBot(discord.Client):
# |--------------------------------------------------------------| # |--------------------------------------------------------------|
user_id = message.author.id user_id = message.author.id
content = message.content.strip()
# |-- Bot commands — handled directly, never sent to the LLM --|
if await self._handle_command(message, user_id, content):
return
# |--------------------------------------------------------------|
# Show typing indicator while the graph runs # Show typing indicator while the graph runs
async with message.channel.typing(): async with message.channel.typing():
@@ -154,6 +163,140 @@ class AgentBot(discord.Client):
"Please try again in a moment." "Please try again in a moment."
) )
# ------------------------------------------------------------------
# Bot commands
# ------------------------------------------------------------------
async def _handle_command(
self, message: discord.Message, user_id: int, content: str
) -> bool:
"""Handle bot commands (/login, /logout). Returns True if handled."""
lower = content.lower()
# --- /login [service] ---
if lower.startswith("/login"):
parts = content.split()
service = parts[1].lower() if len(parts) > 1 else None
available = list_auth_services()
if not available:
await message.channel.send("No auth services are configured yet.")
return True
if service is None:
svc_list = ", ".join(available)
await message.channel.send(
f"Available services to link: **{svc_list}**\n"
f"Use `/login <service>` — e.g. `/login jellyfin`"
)
return True
if service not in available:
await message.channel.send(
f"Unknown service '{service}'. Available: {', '.join(available)}"
)
return True
if auth_store.is_authenticated(user_id, service):
svc_display = (get_auth_service(service) and get_auth_service(service).display_name) or service
await message.channel.send(
f"You're already linked to **{svc_display}**! "
f"Use `/logout {service}` to unlink."
)
return True
# --- Quick Connect flow ---
svc = get_auth_service(service)
if svc is None:
await message.channel.send(f"Unknown service: {service}")
return True
await message.channel.send(f"🔑 Starting **{svc.display_name}** Quick Connect…")
qc_result = await svc.initiate_quick_connect()
if qc_result is None:
await message.channel.send(
f"❌ Could not start Quick Connect for **{svc.display_name}**.\n"
"Check that `JELLYFIN_URL` is configured and the server is reachable."
)
return True
await message.channel.send(
f"Open **{svc.display_name}** → **Quick Connect** and enter this code:\n\n"
f"**`{qc_result.code}`**\n\n"
f"⏳ Waiting for you to approve…"
)
# Poll for authorization
async with message.channel.typing():
for attempt in range(24): # 24 × 5s = 2 minutes
await asyncio.sleep(5)
status = await svc.poll_quick_connect(qc_result.secret)
if status == "Authorized":
auth_result = await svc.authenticate_quick_connect(qc_result.secret)
if auth_result.success:
auth_store.store_auth(
discord_user_id=user_id,
service=service,
external_user_id=auth_result.external_user_id or "",
external_name=auth_result.external_name or "",
credentials=auth_result.credentials,
)
await message.channel.send(
f"✅ Linked to **{svc.display_name}** as "
f"**{auth_result.external_name}**!"
)
else:
await message.channel.send(
f"❌ Authentication failed: "
f"{auth_result.error_message or 'Unknown error'}"
)
return True
elif status == "Expired":
await message.channel.send(
"⌛ The Quick Connect code expired. "
f"Use `/login {service}` to try again."
)
return True
# else: still "Active" — keep polling
await message.channel.send(
"⌛ Timed out waiting for Quick Connect approval. "
f"Use `/login {service}` to try again."
)
return True
# --- /logout [service] ---
if lower.startswith("/logout"):
parts = content.split()
service = parts[1].lower() if len(parts) > 1 else None
if service is None:
linked = auth_store.list_services(user_id)
if not linked:
await message.channel.send("You don't have any linked services.")
else:
svc_list = ", ".join(linked)
await message.channel.send(
f"Linked services: **{svc_list}**\n"
f"Use `/logout <service>` to unlink."
)
return True
if not auth_store.is_authenticated(user_id, service):
await message.channel.send(f"You're not linked to **{service}**.")
return True
auth_store.revoke(user_id, service)
svc_display = (get_auth_service(service) and get_auth_service(service).display_name) or service
await message.channel.send(f"Unlinked from **{svc_display}**. Use `/login {service}` to re-link.")
return True
return False
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Agent invocation # Agent invocation
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@@ -163,7 +306,6 @@ class AgentBot(discord.Client):
reply, and return the assistant's final text.""" reply, and return the assistant's final text."""
# 1. Pick agent — defaults to DISCORD_DEFAULT_AGENT env var. # 1. Pick agent — defaults to DISCORD_DEFAULT_AGENT env var.
# Change DISCORD_DEFAULT_AGENT in .env to switch agents.
agent_id = DISCORD_DEFAULT_AGENT agent_id = DISCORD_DEFAULT_AGENT
# 2. Build message list from stored history + new user message # 2. Build message list from stored history + new user message
@@ -172,7 +314,7 @@ class AgentBot(discord.Client):
# 3. Run the LangGraph (tools execute inline if needed) # 3. Run the LangGraph (tools execute inline if needed)
graph = _get_graph(agent_id) graph = _get_graph(agent_id)
state = {"messages": messages} state = {"messages": messages, "discord_user_id": user_id}
result = await graph.ainvoke(state) result = await graph.ainvoke(state)
last_msg = result["messages"][-1] last_msg = result["messages"][-1]
+73
View File
@@ -0,0 +1,73 @@
# Discord — Connector
The Discord module embeds a Discord bot **in-process** alongside FastAPI.
It uses the same LangGraph graphs and LLM client as the REST API — there is
no HTTP loopback, no separate process, and no code duplication.
---
## Files
| File | Purpose |
|---|---|
| `bot.py` | Discord `Client` subclass (`AgentBot`) — DM handler, command parser, graph invoker, Quick Connect orchestrator |
| `conversation.py` | In-memory conversation history store, keyed by Discord user ID |
---
## Architecture
```
Discord Gateway (websocket)
│ DM: "What's trending?"
discord.Client.on_message()
│ 1. Check: is this a DM? shares a guild? not a command?
│ 2. Build message history from ConversationStore
│ 3. Append user message
_create_agent_graph(agent_id="media-agent")
│ Uses the exact same create_agent_graph() from src/graph.py
│ as the REST API — same LLM client, same tools, same cache.
graph.ainvoke({"messages": [...]})
│ LangGraph runs agent_node → tool_node → agent_node → END
Response text → split into ≤2000-char Discord messages → sent to user
```
---
## Commands
Commands are DMs that start with `/`. The bot parses them before hitting the
LLM:
| Command | Action |
|---|---|
| `/login <service>` | Generate a one-time auth link, DM it to the user |
| `/jellyfin login` | Alias for `/login jellyfin` |
| `/help` | Show available agents and commands |
| `/<agent_id>` | Switch to a different agent for future messages |
---
## Auth Flow (Quick Connect)
When a user types `/login jellyfin`:
1. Bot generates a one-time token via `auth_store`
2. Bot calls `auth_store.create_link_token(discord_id, "jellyfin")`
3. Bot DMs the user: `https://<BASE_URL>/api/v1/auth/login?service=jellyfin&token=...&discord_id=...`
4. User clicks the link → FastAPI serves the Jellyfin login form (or Quick Connect prompt)
5. User authenticates → credentials stored in `auth_store`
6. Future tool calls (e.g. `watch_history`) automatically use the stored Jellyfin session
---
## Conversation Persistence
- Per-user history stored in `ConversationStore` (in-memory dict)
- Max history length configurable via `DISCORD_MAX_HISTORY` env var (default: 7)
- Oldest messages are silently dropped when the limit is exceeded
- History is NOT persisted across restarts (future: could use SQLite)
+106
View File
@@ -0,0 +1,106 @@
"""JellyStat REST API — watch history, genre summary, and user summary."""
from __future__ import annotations
import asyncpg
from fastapi import APIRouter, Depends, Query
from gateway.jellystat.db import get_pool
from gateway.jellystat.models import (
GenreSummaryResponse,
UserSummaryResponse,
WatchHistoryResponse,
)
router = APIRouter(prefix="/jellystat", tags=["jellystat"])
DEFAULT_WINDOW_MINUTES = 10080 # 7 days
# ---------------------------------------------------------------------------
# GET /jellystat/history/{user_id}
# ---------------------------------------------------------------------------
@router.get("/history/{user_id}", response_model=WatchHistoryResponse)
async def get_watch_history(
user_id: str,
minutes: int = Query(
default=DEFAULT_WINDOW_MINUTES, ge=1, description="Time window in minutes"
),
pool: asyncpg.Pool = Depends(get_pool),
):
"""Return watch history grouped by title, ordered by most-watched first."""
rows = await pool.fetch(
"SELECT * FROM fn_user_watch_history($1, $2)", user_id, minutes
)
return WatchHistoryResponse(
user_id=user_id,
window_minutes=minutes,
items=[
{
"title": r["title"],
"watch_time_sec": float(r["watch_time_sec"]),
"media_type": r["media_type"],
}
for r in rows
],
)
# ---------------------------------------------------------------------------
# GET /jellystat/genres/{user_id}
# ---------------------------------------------------------------------------
@router.get("/genres/{user_id}", response_model=GenreSummaryResponse)
async def get_genre_summary(
user_id: str,
minutes: int = Query(
default=DEFAULT_WINDOW_MINUTES, ge=1, description="Time window in minutes"
),
pool: asyncpg.Pool = Depends(get_pool),
):
"""Return total watch time per genre, ordered by most-watched first."""
rows = await pool.fetch(
"SELECT * FROM fn_user_genre_summary($1, $2)", user_id, minutes
)
return GenreSummaryResponse(
user_id=user_id,
window_minutes=minutes,
genres=[
{"genre": r["genre"], "watch_time_sec": float(r["watch_time_sec"])}
for r in rows
],
)
# ---------------------------------------------------------------------------
# GET /jellystat/summary/{user_id}
# ---------------------------------------------------------------------------
@router.get("/summary/{user_id}", response_model=UserSummaryResponse)
async def get_user_summary(
user_id: str,
pool: asyncpg.Pool = Depends(get_pool),
):
"""Return all-time summary: total watch time, most-watched titles, top genres."""
rows = await pool.fetch("SELECT * FROM fn_user_summary($1)", user_id)
# fn_user_summary returns key-value rows — build a dict
# asyncpg already deserialises JSONB → Python objects
metrics: dict[str, object] = {r["metric"]: r["value"] for r in rows}
top_genres_raw = metrics.get("top_genres", [])
top_genres: list[str] = top_genres_raw if isinstance(top_genres_raw, list) else []
return UserSummaryResponse(
user_id=user_id,
total_watch_time_sec=float(metrics.get("total_watch_time", 0)),
most_watched_series=metrics.get("most_watched_series"),
most_watched_movie=metrics.get("most_watched_movie"),
total_last_30d_sec=float(metrics.get("total_last_30d", 0)),
total_last_7d_sec=float(metrics.get("total_last_7d", 0)),
top_genres=top_genres,
)
+130
View File
@@ -0,0 +1,130 @@
"""PostgreSQL connection pool for the JellyStat database."""
from __future__ import annotations
import logging
from pathlib import Path
import asyncpg
from fastapi import FastAPI, Request
from src.config import get_config
logger = logging.getLogger("gateway.jellystat")
# ---------------------------------------------------------------------------
# DSN builder
# ---------------------------------------------------------------------------
def _build_dsn() -> str:
"""Build a PostgreSQL DSN from individual environment variables."""
host = get_config("JELLYSTAT_DB_HOST", "localhost")
port = get_config("JELLYSTAT_DB_PORT", "5432")
user = get_config("JELLYSTAT_DB_USER", "postgres")
password = get_config("JELLYSTAT_DB_PASSWORD", "")
dbname = get_config("JELLYSTAT_DB_NAME", "jfstat")
return f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
# ---------------------------------------------------------------------------
# Pool lifecycle (called from main.py lifespan)
# ---------------------------------------------------------------------------
async def init_pool(app: FastAPI) -> None:
"""Create the connection pool and store it on app.state."""
dsn = _build_dsn()
safe = dsn.split("@")[1] if "@" in dsn else dsn
logger.info("Connecting to JellyStat database at %s", safe)
pool = await asyncpg.create_pool(dsn, min_size=1, max_size=5)
app.state.jellystat_pool = pool
# Deploy functions on every startup (CREATE OR REPLACE is idempotent)
await _ensure_functions(pool)
async def close_pool(app: FastAPI) -> None:
"""Close the pool on shutdown."""
pool: asyncpg.Pool | None = getattr(app.state, "jellystat_pool", None)
if pool:
await pool.close()
logger.info("JellyStat pool closed")
# ---------------------------------------------------------------------------
# FastAPI dependency
# ---------------------------------------------------------------------------
async def get_pool(request: Request) -> asyncpg.Pool:
"""Return the JellyStat connection pool from app state."""
return request.app.state.jellystat_pool
# ---------------------------------------------------------------------------
# Function deployment
# ---------------------------------------------------------------------------
async def _ensure_functions(pool: asyncpg.Pool) -> None:
"""Run startup-functions.sql to create or replace all JellyStat functions."""
sql_path = Path(__file__).parent / "startup-functions.sql"
if not sql_path.exists():
logger.warning("startup-functions.sql not found — skipping function deployment")
return
sql = sql_path.read_text()
statements = _split_sql(sql)
async with pool.acquire() as conn:
for stmt in statements:
try:
await conn.execute(stmt)
except Exception:
# Log but don't crash — functions might already exist
logger.exception("Failed to deploy SQL statement — continuing")
logger.info("JellyStat functions deployed (%d statements)", len(statements))
def _split_sql(sql: str) -> list[str]:
"""
Split a multi-statement SQL string into individual statements.
Respects $$ dollar-quoting so that semicolons inside function bodies
don't cause premature splits. Pure comment lines (starting with ``--``)
outside dollar-quoted blocks are stripped.
"""
statements: list[str] = []
current: list[str] = []
in_dollar_quote = False
for line in sql.split("\n"):
stripped = line.strip()
# Skip pure comment lines outside of dollar-quoted blocks
if not in_dollar_quote and stripped.startswith("--"):
continue
# Toggle dollar-quote state whenever we see $$
if "$$" in line:
in_dollar_quote = not in_dollar_quote
current.append(line)
# Statement terminator: semicolon at end of line, outside $$ block
if not in_dollar_quote and line.rstrip().endswith(";"):
stmt = "\n".join(current).strip()
if stmt:
statements.append(stmt)
current = []
# Catch any trailing statement that wasn't terminated by a semicolon
if current:
stmt = "\n".join(current).strip()
if stmt:
statements.append(stmt)
return statements
+36
View File
@@ -0,0 +1,36 @@
"""Pydantic response models for the JellyStat API."""
from pydantic import BaseModel
class WatchHistoryItem(BaseModel):
title: str
watch_time_sec: float
media_type: str
class WatchHistoryResponse(BaseModel):
user_id: str
window_minutes: int
items: list[WatchHistoryItem]
class GenreSummaryItem(BaseModel):
genre: str
watch_time_sec: float
class GenreSummaryResponse(BaseModel):
user_id: str
window_minutes: int
genres: list[GenreSummaryItem]
class UserSummaryResponse(BaseModel):
user_id: str
total_watch_time_sec: float
most_watched_series: str | None
most_watched_movie: str | None
total_last_30d_sec: float
total_last_7d_sec: float
top_genres: list[str]
+224
View File
@@ -0,0 +1,224 @@
-- ============================================================================
-- JellyStat API Functions
-- Parameterized database functions callable by the API layer as:
-- SELECT * FROM fn_user_watch_history('user_id_here', 10080);
-- SELECT * FROM fn_user_genre_summary('user_id_here', 10080);
-- SELECT * FROM fn_user_summary('user_id_here');
-- ============================================================================
-- ----------------------------------------------------------------------------
-- 1. User Watch History
-- Returns every distinct title watched in the last N minutes,
-- grouped and summed by title, ordered by most-watched first.
-- ----------------------------------------------------------------------------
CREATE OR REPLACE FUNCTION public.fn_user_watch_history(
p_user_id TEXT,
p_minutes INTEGER DEFAULT 10080 -- 7 days in minutes
)
RETURNS TABLE(
title TEXT,
watch_time_sec NUMERIC,
media_type TEXT
)
LANGUAGE sql
STABLE
AS $$
SELECT
COALESCE(a."SeriesName", a."NowPlayingItemName") AS title,
SUM(a."PlaybackDuration")::NUMERIC AS watch_time_sec,
CASE
WHEN a."SeriesName" IS NOT NULL THEN 'series'
ELSE 'movie'
END AS media_type
FROM jf_playback_activity a
WHERE a."UserId" = p_user_id
AND a."ActivityDateInserted"
>= NOW() - (p_minutes * INTERVAL '1 minute')
GROUP BY
COALESCE(a."SeriesName", a."NowPlayingItemName"),
CASE WHEN a."SeriesName" IS NOT NULL THEN 'series' ELSE 'movie' END
ORDER BY watch_time_sec DESC;
$$;
-- ----------------------------------------------------------------------------
-- 2. Genre Summary
-- Returns total watch time per genre for a user over the last N minutes.
-- Resolves genres for both movies (directly on the item) and series
-- episodes (via jf_library_episodes → jf_library_items chain).
-- ----------------------------------------------------------------------------
CREATE OR REPLACE FUNCTION public.fn_user_genre_summary(
p_user_id TEXT,
p_minutes INTEGER DEFAULT 10080
)
RETURNS TABLE(
genre TEXT,
watch_time_sec NUMERIC
)
LANGUAGE sql
STABLE
AS $$
WITH movie_genres AS (
-- Movies: join playback directly to library_items on NowPlayingItemId
SELECT
genre_item.value AS genre,
SUM(a."PlaybackDuration") AS watch_time_sec
FROM jf_playback_activity a
JOIN jf_library_items i
ON i."Id" = a."NowPlayingItemId"
CROSS JOIN LATERAL jsonb_array_elements_text(i."Genres") AS genre_item(value)
WHERE a."UserId" = p_user_id
AND a."SeriesName" IS NULL -- movies only
AND a."ActivityDateInserted"
>= NOW() - (p_minutes * INTERVAL '1 minute')
AND i."Genres" IS NOT NULL
AND jsonb_array_length(i."Genres") > 0
GROUP BY genre_item.value
),
series_genres AS (
-- Series: playback → episodes → series item → genres
SELECT
genre_item.value AS genre,
SUM(a."PlaybackDuration") AS watch_time_sec
FROM jf_playback_activity a
JOIN jf_library_episodes e
ON e."EpisodeId" = a."EpisodeId"
JOIN jf_library_items i
ON i."Id" = e."SeriesId"
CROSS JOIN LATERAL jsonb_array_elements_text(i."Genres") AS genre_item(value)
WHERE a."UserId" = p_user_id
AND a."SeriesName" IS NOT NULL -- TV episodes only
AND a."ActivityDateInserted"
>= NOW() - (p_minutes * INTERVAL '1 minute')
AND i."Genres" IS NOT NULL
AND jsonb_array_length(i."Genres") > 0
GROUP BY genre_item.value
),
combined AS (
SELECT genre, watch_time_sec FROM movie_genres
UNION ALL
SELECT genre, watch_time_sec FROM series_genres
)
SELECT
genre,
SUM(watch_time_sec)::NUMERIC AS watch_time_sec
FROM combined
GROUP BY genre
ORDER BY watch_time_sec DESC;
$$;
-- ----------------------------------------------------------------------------
-- 3. User Summary
-- One-shot dashboard: all-time stats + recent windows + top genres.
-- Returns key-value rows that the API trivially converts to a JSON object
-- with Object.fromEntries() or similar.
-- ----------------------------------------------------------------------------
CREATE OR REPLACE FUNCTION public.fn_user_summary(
p_user_id TEXT
)
RETURNS TABLE(
metric TEXT,
value JSONB
)
LANGUAGE sql
STABLE
AS $$
-- total_watch_time (all time)
SELECT 'total_watch_time'::TEXT AS metric,
to_jsonb(COALESCE(SUM("PlaybackDuration"), 0)::NUMERIC) AS value
FROM jf_playback_activity
WHERE "UserId" = p_user_id
UNION ALL
-- most_watched_series (by total watch time)
SELECT 'most_watched_series'::TEXT AS metric,
COALESCE(
(SELECT to_jsonb("SeriesName")
FROM jf_playback_activity
WHERE "UserId" = p_user_id
AND "SeriesName" IS NOT NULL
GROUP BY "SeriesName"
ORDER BY SUM("PlaybackDuration") DESC
LIMIT 1),
'null'::JSONB
) AS value
UNION ALL
-- most_watched_movie (by total watch time)
SELECT 'most_watched_movie'::TEXT AS metric,
COALESCE(
(SELECT to_jsonb("NowPlayingItemName")
FROM jf_playback_activity
WHERE "UserId" = p_user_id
AND "SeriesName" IS NULL
GROUP BY "NowPlayingItemName"
ORDER BY SUM("PlaybackDuration") DESC
LIMIT 1),
'null'::JSONB
) AS value
UNION ALL
-- total_watch_time_last_month (last 30 days)
SELECT 'total_last_30d'::TEXT AS metric,
to_jsonb(COALESCE(SUM("PlaybackDuration"), 0)::NUMERIC) AS value
FROM jf_playback_activity
WHERE "UserId" = p_user_id
AND "ActivityDateInserted" >= NOW() - INTERVAL '30 days'
UNION ALL
-- total_watch_time_last_week (last 7 days)
SELECT 'total_last_7d'::TEXT AS metric,
to_jsonb(COALESCE(SUM("PlaybackDuration"), 0)::NUMERIC) AS value
FROM jf_playback_activity
WHERE "UserId" = p_user_id
AND "ActivityDateInserted" >= NOW() - INTERVAL '7 days'
UNION ALL
-- top_genres (top 3 all-time, as a JSON array)
SELECT 'top_genres'::TEXT AS metric,
COALESCE(
(SELECT jsonb_agg(genre ORDER BY watch_time_sec DESC)
FROM (
SELECT genre, SUM(watch_time_sec) AS watch_time_sec
FROM (
-- movies
SELECT
genre_item.value AS genre,
SUM(a."PlaybackDuration") AS watch_time_sec
FROM jf_playback_activity a
JOIN jf_library_items i ON i."Id" = a."NowPlayingItemId"
CROSS JOIN LATERAL jsonb_array_elements_text(i."Genres") AS genre_item(value)
WHERE a."UserId" = p_user_id
AND a."SeriesName" IS NULL
AND i."Genres" IS NOT NULL
AND jsonb_array_length(i."Genres") > 0
GROUP BY genre_item.value
UNION ALL
-- series
SELECT
genre_item.value AS genre,
SUM(a."PlaybackDuration") AS watch_time_sec
FROM jf_playback_activity a
JOIN jf_library_episodes e ON e."EpisodeId" = a."EpisodeId"
JOIN jf_library_items i ON i."Id" = e."SeriesId"
CROSS JOIN LATERAL jsonb_array_elements_text(i."Genres") AS genre_item(value)
WHERE a."UserId" = p_user_id
AND a."SeriesName" IS NOT NULL
AND i."Genres" IS NOT NULL
AND jsonb_array_length(i."Genres") > 0
GROUP BY genre_item.value
) combined
GROUP BY genre
ORDER BY SUM(watch_time_sec) DESC
LIMIT 3
) top3
),
'[]'::JSONB
) AS value;
$$;
+220
View File
@@ -0,0 +1,220 @@
"""
Auth API — generic endpoints for linking Discord users to external services.
GET /api/v1/auth/login?service=X&token=Y&discord_id=Z
Validates the link token and serves a service-specific login form.
POST /api/v1/auth/login
Accepts the form submission, validates credentials against the service,
stores the session, and returns a result page.
GET /api/v1/auth/status?discord_id=Z
Returns which services are linked for this Discord user.
"""
from __future__ import annotations
import logging
from fastapi import APIRouter, Form, HTTPException, Request
from fastapi.responses import HTMLResponse
from gateway.auth import get_auth_service, list_auth_services
from src import auth_store
logger = logging.getLogger("gateway.auth")
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
# ---------------------------------------------------------------------------
# GET /auth/login — serve the login form
# ---------------------------------------------------------------------------
@router.get("/login")
async def login_form(
service: str,
token: str,
discord_id: int,
):
"""Validate the one-time link token and return a service-specific login form."""
# Validate the token WITHOUT consuming it (the POST will consume it)
result = auth_store.validate_token(token)
if result is None:
raise HTTPException(status_code=400, detail="Invalid or expired link token.")
uid, svc = result
if uid != discord_id or svc != service:
raise HTTPException(status_code=400, detail="Token does not match the request.")
# Look up the AuthService
svc_obj = get_auth_service(service)
if svc_obj is None:
raise HTTPException(status_code=404, detail=f"Unknown service: {service}")
logger.info("Serving login form: user=%s service=%s", discord_id, service)
return HTMLResponse(svc_obj.render_login_form(token, discord_id))
# ---------------------------------------------------------------------------
# POST /auth/login — handle form submission
# ---------------------------------------------------------------------------
@router.post("/login")
async def login_submit(request: Request):
"""Handle the login form POST: validate credentials, store auth, show result."""
# Parse form data
form = await request.form()
token = form.get("token", "")
discord_id_str = form.get("discord_id", "")
service = form.get("service", "")
if not token or not discord_id_str or not service:
raise HTTPException(status_code=400, detail="Missing required fields.")
try:
discord_id = int(discord_id_str)
except (ValueError, TypeError):
raise HTTPException(status_code=400, detail="Invalid discord_id.")
# Consume the token on POST (the GET only validated, didn't consume)
result = auth_store.consume_token(token)
if result is None:
raise HTTPException(status_code=400, detail="Invalid or expired link token.")
# Look up the AuthService
svc_obj = get_auth_service(service)
if svc_obj is None:
raise HTTPException(status_code=404, detail=f"Unknown service: {service}")
# Collect service-specific form fields (everything except token, discord_id, service)
form_data: dict[str, str] = {}
for key, value in form.items():
if key not in ("token", "discord_id", "service"):
form_data[key] = str(value)
# Authenticate against the service
auth_result = await svc_obj.authenticate(form_data)
if not auth_result.success:
return HTMLResponse(
status_code=401,
content=f"""<!DOCTYPE html>
<html><head><meta charset="utf-8"><title>Login Failed</title>
<style>
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; }}
h2 {{ color: #d32f2f; }}
a {{ color: #aa5cc3; }}
</style></head><body>
<h2>❌ Login Failed</h2>
<p>{auth_result.error_message or "Authentication failed. Please try again."}</p>
<p><a href="javascript:history.back()">← Go back and try again</a></p>
</body></html>""",
)
# Store the successful auth
auth_store.store_auth(
discord_user_id=discord_id,
service=service,
external_user_id=auth_result.external_user_id or "",
external_name=auth_result.external_name or "",
credentials=auth_result.credentials,
)
logger.info(
"Auth linked: discord=%s%s (%s)",
discord_id,
service,
auth_result.external_name,
)
return HTMLResponse(f"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Account Linked</title>
<style>
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; text-align: center; }}
h1 {{ color: #388e3c; }}
.name {{ font-weight: bold; color: #aa5cc3; }}
p {{ color: #666; }}
</style>
</head>
<body>
<h1>✅ Account Linked!</h1>
<p>Logged in as <span class="name">{auth_result.external_name}</span> on <strong>{svc_obj.display_name}</strong>.</p>
<p>You can close this page and return to Discord.</p>
</body>
</html>""")
# ---------------------------------------------------------------------------
# GET /auth/status — get all linked services for a Discord user
# ---------------------------------------------------------------------------
@router.get("/Discord/status")
async def auth_status(discord_id: int):
"""
Return all services linked to this Discord user with full details.
Response:
{
"discord_id": 123456789,
"linked_services": {
"jellyfin": {
"external_user_id": "abc123",
"external_name": "Tim",
"linked_at": "2026-05-25T10:00:00",
"credentials": {
"token": "jwt...",
"url": "http://jellyfin:8096",
"user_id": "abc123"
}
}
}
}
"""
auths = auth_store.get_all_auths(discord_id)
linked_services: dict[str, dict] = {}
for auth in auths:
svc_name = auth["service"]
linked_services[svc_name] = {
"external_user_id": auth["external_user_id"],
"external_name": auth["external_name"],
"linked_at": auth["linked_at"],
"credentials": auth["credentials"],
}
return {
"discord_id": discord_id,
"linked_services": linked_services,
}
# ---------------------------------------------------------------------------
# POST /auth/reset — wipe auth store (DEV ONLY)
# ---------------------------------------------------------------------------
from src.config import get_config # noqa: E402
@router.post("/reset")
async def reset_auth():
"""
Reset the entire auth store — clears all link tokens and user auth records.
Only enabled when ALLOW_AUTH_RESET=true in the environment.
Returns 403 in production.
"""
if get_config("ALLOW_AUTH_RESET", "false").lower() != "true":
raise HTTPException(
status_code=403,
detail="Auth reset is disabled. Set ALLOW_AUTH_RESET=true to enable (dev only).",
)
auth_store.reset_all()
logger.warning("Auth store reset via API endpoint.")
return {"status": "ok", "message": "Auth store cleared — all tokens and auth records removed."}
+2 -2
View File
@@ -4,9 +4,9 @@ from openai import OpenAI
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from api.dependencies import get_llm_client, get_agent_graph from gateway.dependencies import get_llm_client, get_agent_graph
from agents import get as get_agent, list_all as list_all_agents from agents import get as get_agent, list_all as list_all_agents
from core.state import AgentState from src.state import AgentState
router = APIRouter() router = APIRouter()
+106
View File
@@ -0,0 +1,106 @@
# V1 — Chat & Agent API Endpoints
This is the primary HTTP API surface for the chatbot agent system. It exposes
both a custom streaming chat endpoint and an OpenAI-compatible
`/chat/completions` endpoint so it works as a drop-in backend for OpenWebUI,
LibreChat, or any OpenAI-compatible client.
---
## Endpoints
| Method | Path | Description |
|---|---|---|
| `GET ` | `/v1/` | Health check — returns `{"status": "ok"}` |
| `GET ` | `/v1/agents` | List all registered agents (id + description) |
| `GET ` | `/v1/models` | OpenAI-compatible model list (one entry per agent) |
| `POST` | `/v1/chat` | Chat with an agent — streaming (SSE) |
| `POST` | `/v1/chat/sync` | Chat with an agent — non-streaming |
| `POST` | `/v1/chat/completions` | OpenAI-compatible chat completions (supports `stream: true`) |
All `/v1/*` endpoints are mounted by `main.py` via:
```python
app.include_router(v1_router, prefix="/v1")
```
---
## Agent Resolution
Each request can target a specific agent. The resolution order is:
1. **Explicit `agent_id`** field in the request body
2. **OpenAI `model` field** (OpenWebUI sends this — mapped to `agent_id` if a matching agent is registered)
3. **Fallback** to the `"naked"` agent (a plain LLM with no tools)
This means an OpenWebUI client can simply set `model: "media-agent"` and get
the full Media Agent with Seerr tools.
---
## Request Flow
```
Client (OpenWebUI / HTTP)
│ POST /v1/chat/completions
│ { model: "media-agent", messages: [...], stream: true/false }
chat_completions()
│ 1. _resolve_agent(req.model) → Agent(id="media-agent", skills=[...])
│ 2. get_agent_graph("media-agent", request)
│ → lazy-compiled LangGraph StateGraph, cached on app.state
│ 3. stream=True → _stream_graph(graph, messages) → SSE token stream
│ stream=False → _invoke_graph(graph, messages) → plain response
LangGraph StateGraph (src/graph.py)
├── agent_node: calls LLM with system prompt + tool definitions
│ └── LLM returns text OR tool_calls
├── _should_continue: if tool_calls → tool_node, else → END
└── tool_node: executes tool via agents/skills system → ToolMessage
└── loops back to agent_node with the result
```
For a detailed walkthrough, see [api.md](../api.md).
---
## Streaming
Two streaming modes exist:
### SSE (Server-Sent Events) — `/v1/chat`
```
data: {"token": "Here"}
data: {"token": " are"}
data: {"token": " the"}
...
data: [DONE]
```
The graph runs to completion (tools execute silently), then the final text is
yielded token-by-token as SSE events.
### OpenAI-compatible — `/v1/chat/completions` with `stream: true`
```
data: {"id":"...","object":"chat.completion.chunk","choices":[{"delta":{"content":"Hello"}}]}
data: {"id":"...","object":"chat.completion.chunk","choices":[{"delta":{"content":"!"}}]}
data: [DONE]
```
> **Future improvement:** true token-level streaming (tokens appear as the LLM
> generates them) would require using `langchain-openai`'s `ChatOpenAI` in
> place of the raw `openai` client. The current approach avoids adding that
> dependency.
---
## Dependencies
Endpoints receive shared singletons via FastAPI `Depends`:
- **`get_llm_client(request)`** → returns `request.app.state.llm_client` (OpenAI client singleton, created once in `main.py`)
- **`get_agent_graph(agent_id, request)`** → returns a lazy-compiled LangGraph from `request.app.state.agent_graphs`
+15 -5
View File
@@ -4,9 +4,11 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from api.v1.chat import router as v1_router from gateway.v1.auth import router as auth_router
from core.config import DEEPSEEK_API_KEY from gateway.v1.chat import router as v1_router
from core.llm import create_client from gateway.jellystat.api import router as jellystat_router
from src.config import DEEPSEEK_API_KEY, get_config
from src.llm import create_client
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Logging — tool calls will appear in the uvicorn console # Logging — tool calls will appear in the uvicorn console
@@ -18,23 +20,29 @@ logging.basicConfig(
) )
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Load all agents & skills so they self-register at startup # Load all agents, skills, AND auth services so they self-register at startup
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
from agents import load_all_agents # noqa: E402 from agents import load_all_agents # noqa: E402
load_all_agents() load_all_agents()
import gateway.auth.jellyfin # noqa: E402 — self-registers JellyfinAuth
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Lifespan # Lifespan
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
from bot.discord_bot import start_in_background # noqa: E402 from gateway.discord.bot import start_in_background # noqa: E402
from gateway.jellystat.db import init_pool, close_pool # noqa: E402
await init_pool(app)
start_in_background() start_in_background()
yield yield
await close_pool(app)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# App # App
@@ -61,3 +69,5 @@ app.state.agent_graphs: dict = {}
# Routers # Routers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
app.include_router(v1_router, prefix="/v1") app.include_router(v1_router, prefix="/v1")
app.include_router(auth_router)
app.include_router(jellystat_router)
+2
View File
@@ -6,3 +6,5 @@ httpx
langgraph langgraph
langgraph-checkpoint langgraph-checkpoint
discord.py discord.py
python-multipart
asyncpg
View File
+358
View File
@@ -0,0 +1,358 @@
"""
Auth Store — SQLite-backed persistence for Discord-to-service authentication.
Two tables:
- link_tokens : one-time tokens sent via Discord DM to initiate login
- user_auth : per-user, per-service credentials (Jellyfin token, etc.)
Thread-safe via WAL mode and a shared lock. No passwords are ever stored
— only opaque service tokens (e.g. Jellyfin AccessToken).
"""
from __future__ import annotations
import logging
import secrets
import sqlite3
import threading
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Optional
from src.config import get_config
logger = logging.getLogger("auth_store")
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
AUTH_DB_PATH = get_config("AUTH_DB_PATH", "data/auth.db")
TOKEN_EXPIRY_MINUTES = int(get_config("AUTH_TOKEN_EXPIRY", "10"))
# ---------------------------------------------------------------------------
# Singleton handle
# ---------------------------------------------------------------------------
_db_path: Path | None = None
_db_lock = threading.Lock()
def _resolve_path() -> Path:
"""Turn AUTH_DB_PATH into an absolute path, creating parent dirs."""
global _db_path
if _db_path is not None:
return _db_path
p = Path(AUTH_DB_PATH)
if not p.is_absolute():
# Relative to the project root (two levels above this file)
project_root = Path(__file__).resolve().parent.parent
p = project_root / p
p.parent.mkdir(parents=True, exist_ok=True)
_db_path = p
return p
def _get_conn() -> sqlite3.Connection:
"""Return a thread-local connection to the auth database."""
import sqlite3
conn = sqlite3.connect(str(_resolve_path()), check_same_thread=False)
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
conn.row_factory = sqlite3.Row
return conn
# ---------------------------------------------------------------------------
# Schema
# ---------------------------------------------------------------------------
_SCHEMA = """
CREATE TABLE IF NOT EXISTS link_tokens (
token TEXT PRIMARY KEY,
discord_user_id INTEGER NOT NULL,
service TEXT NOT NULL,
expires_at TEXT NOT NULL,
used INTEGER DEFAULT 0,
created_at TEXT DEFAULT (datetime('now'))
);
CREATE TABLE IF NOT EXISTS user_auth (
discord_user_id INTEGER NOT NULL,
service TEXT NOT NULL,
external_user_id TEXT,
external_name TEXT,
credentials TEXT,
linked_at TEXT DEFAULT (datetime('now')),
is_active INTEGER DEFAULT 1,
PRIMARY KEY (discord_user_id, service)
);
"""
_initialized = False
def _ensure_schema() -> None:
global _initialized
if _initialized:
return
with _db_lock:
if _initialized:
return
conn = _get_conn()
conn.executescript(_SCHEMA)
conn.commit()
conn.close()
_initialized = True
logger.info("Auth store initialized at %s", _resolve_path())
# ---------------------------------------------------------------------------
# Public API — Link Tokens
# ---------------------------------------------------------------------------
def create_token(discord_user_id: int, service: str) -> str:
"""Generate a one-time link token. Expires after TOKEN_EXPIRY_MINUTES."""
_ensure_schema()
token = secrets.token_urlsafe(32)
expires = (datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRY_MINUTES)).isoformat()
with _db_lock:
conn = _get_conn()
conn.execute(
"INSERT INTO link_tokens (token, discord_user_id, service, expires_at) VALUES (?, ?, ?, ?)",
(token, discord_user_id, service, expires),
)
conn.commit()
conn.close()
logger.info("Created link token for user %s / service %s", discord_user_id, service)
return token
def validate_token(token: str) -> tuple[int, str] | None:
"""Read-only validation — does NOT consume the token.
Returns (discord_user_id, service) if the token exists, is unused,
and has not expired. Returns None otherwise.
"""
_ensure_schema()
with _db_lock:
conn = _get_conn()
row = conn.execute(
"SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?",
(token,),
).fetchone()
conn.close()
if row is None:
return None
if row["used"]:
return None
expires = datetime.fromisoformat(row["expires_at"])
if datetime.now(timezone.utc) > expires:
return None
return (row["discord_user_id"], row["service"])
def consume_token(token: str) -> tuple[int, str] | None:
"""Validate and consume a link token. Returns (discord_user_id, service) or None.
A token is valid if:
- It exists
- It has not been used
- It has not expired
"""
_ensure_schema()
with _db_lock:
conn = _get_conn()
row = conn.execute(
"SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?",
(token,),
).fetchone()
if row is None:
conn.close()
return None
if row["used"]:
conn.close()
logger.warning("Token already used: %s", token[:8])
return None
expires = datetime.fromisoformat(row["expires_at"])
if datetime.now(timezone.utc) > expires:
conn.close()
logger.warning("Token expired: %s", token[:8])
return None
conn.execute("UPDATE link_tokens SET used = 1 WHERE token = ?", (token,))
conn.commit()
result = (row["discord_user_id"], row["service"])
conn.close()
logger.info("Token consumed: %s… → user=%s service=%s", token[:8], result[0], result[1])
return result
# ---------------------------------------------------------------------------
# Public API — User Auth
# ---------------------------------------------------------------------------
def store_auth(
discord_user_id: int,
service: str,
*,
external_user_id: str = "",
external_name: str = "",
credentials: dict | None = None,
) -> None:
"""Store or update authentication for a user on a service."""
_ensure_schema()
import json
creds_json = json.dumps(credentials) if credentials else "{}"
with _db_lock:
conn = _get_conn()
conn.execute(
"""INSERT INTO user_auth (discord_user_id, service, external_user_id, external_name, credentials, linked_at)
VALUES (?, ?, ?, ?, ?, datetime('now'))
ON CONFLICT(discord_user_id, service) DO UPDATE SET
external_user_id = excluded.external_user_id,
external_name = excluded.external_name,
credentials = excluded.credentials,
linked_at = datetime('now'),
is_active = 1""",
(discord_user_id, service, external_user_id, external_name, creds_json),
)
conn.commit()
conn.close()
logger.info("Stored auth for user %s on %s as %s", discord_user_id, service, external_name)
def get_auth(discord_user_id: int, service: str) -> dict | None:
"""Retrieve stored auth for a user on a service. Returns None if not linked."""
_ensure_schema()
import json
with _db_lock:
conn = _get_conn()
row = conn.execute(
"SELECT * FROM user_auth WHERE discord_user_id = ? AND service = ? AND is_active = 1",
(discord_user_id, service),
).fetchone()
conn.close()
if row is None:
return None
credentials = json.loads(row["credentials"]) if row["credentials"] else {}
return {
"discord_user_id": row["discord_user_id"],
"service": row["service"],
"external_user_id": row["external_user_id"],
"external_name": row["external_name"],
"credentials": credentials,
"linked_at": row["linked_at"],
}
def is_authenticated(discord_user_id: int, service: str) -> bool:
"""Quick check: is this user linked to this service?"""
return get_auth(discord_user_id, service) is not None
def list_services(discord_user_id: int) -> list[str]:
"""Return list of service names this user has linked."""
_ensure_schema()
with _db_lock:
conn = _get_conn()
rows = conn.execute(
"SELECT service FROM user_auth WHERE discord_user_id = ? AND is_active = 1",
(discord_user_id,),
).fetchall()
conn.close()
return [r["service"] for r in rows]
def revoke(discord_user_id: int, service: str) -> None:
"""Unlink a user from a service."""
_ensure_schema()
with _db_lock:
conn = _get_conn()
conn.execute(
"UPDATE user_auth SET is_active = 0 WHERE discord_user_id = ? AND service = ?",
(discord_user_id, service),
)
conn.commit()
conn.close()
logger.info("Revoked auth for user %s on %s", discord_user_id, service)
def get_all_auths(discord_user_id: int) -> list[dict]:
"""
Return all active auth records for a Discord user.
Each record includes service name, external user id, external name,
linked_at timestamp, and the raw credentials (e.g. Jellyfin token + URL).
Used by the /api/v1/auth/status endpoint so other services can discover
linked accounts for a given Discord ID.
"""
_ensure_schema()
import json
with _db_lock:
conn = _get_conn()
rows = conn.execute(
"""SELECT service, external_user_id, external_name, credentials, linked_at
FROM user_auth
WHERE discord_user_id = ? AND is_active = 1
ORDER BY linked_at DESC""",
(discord_user_id,),
).fetchall()
conn.close()
results: list[dict] = []
for row in rows:
creds = {}
if row["credentials"]:
try:
creds = json.loads(row["credentials"])
except (json.JSONDecodeError, TypeError):
creds = {}
results.append({
"service": row["service"],
"external_user_id": row["external_user_id"] or "",
"external_name": row["external_name"] or "",
"linked_at": row["linked_at"] or "",
"credentials": creds,
})
return results
# ---------------------------------------------------------------------------
# Dev / testing — reset the entire store
# ---------------------------------------------------------------------------
def reset_all() -> None:
"""Truncate all auth tables — for development and testing only."""
_ensure_schema()
with _db_lock:
conn = _get_conn()
conn.execute("DELETE FROM link_tokens")
conn.execute("DELETE FROM user_auth")
conn.commit()
conn.close()
logger.warning("Auth store RESET — all tokens and auth records cleared.")
View File
+15 -31
View File
@@ -1,12 +1,13 @@
""" """
LangGraph agent graph factory. LangGraph agent graph factory.
Builds a StateGraph that replaces the manual tool-calling loop in api/v1/chat.py. Builds a StateGraph with two nodes:
The graph has two nodes:
- agent_node : calls the LLM (with system prompt + tool definitions) - agent_node : calls the LLM (with system prompt + tool definitions)
- tool_node : executes tool calls via the existing skill system - tool_node : executes tool calls via the existing skill system
A conditional edge routes tool_calls back to the agent, or ends the run. A conditional edge routes tool_calls back to the agent, or ends the run.
When a tool fails due to missing authentication, the failure message is
relayed to the LLM, which tells the user to use /login.
""" """
from __future__ import annotations from __future__ import annotations
@@ -19,8 +20,8 @@ from langchain_core.messages import AIMessage, ToolMessage
from langgraph.graph import END, StateGraph from langgraph.graph import END, StateGraph
from openai import OpenAI from openai import OpenAI
from core.state import AgentState from src.state import AgentState
from skills import get_all_tools, execute_tool from agents.skills import get_all_tools, execute_tool
logger = logging.getLogger("graph") logger = logging.getLogger("graph")
@@ -97,18 +98,14 @@ def _make_agent_node(
full: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}] full: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
for m in messages: for m in messages:
if isinstance(m, dict): if isinstance(m, dict):
# Already a plain dict — pass through.
# But fix tool_calls if they're in LangChain format.
d = dict(m) d = dict(m)
tc = d.get("tool_calls") tc = d.get("tool_calls")
if tc and isinstance(tc, list) and tc and isinstance(tc[0], dict) and "function" not in tc[0]: if tc and isinstance(tc, list) and tc and isinstance(tc[0], dict) and "function" not in tc[0]:
d["tool_calls"] = _langchain_tc_to_openai(tc) d["tool_calls"] = _langchain_tc_to_openai(tc)
full.append(d) full.append(d)
else: else:
# LangChain message object → OpenAI-compatible dict
role = _lc_role_to_openai(getattr(m, "type", "user")) role = _lc_role_to_openai(getattr(m, "type", "user"))
d: dict[str, Any] = {"role": role, "content": getattr(m, "content", "")} d: dict[str, Any] = {"role": role, "content": getattr(m, "content", "")}
# Serialize tool_calls back to OpenAI format (if this is an AI msg)
tc = getattr(m, "tool_calls", None) tc = getattr(m, "tool_calls", None)
if tc: if tc:
d["tool_calls"] = _langchain_tc_to_openai(tc) d["tool_calls"] = _langchain_tc_to_openai(tc)
@@ -125,7 +122,6 @@ def _make_agent_node(
) )
choice = resp.choices[0] choice = resp.choices[0]
# Convert OpenAI tool_calls to the dict format LangChain expects.
raw_tool_calls = list(choice.message.tool_calls) if choice.message.tool_calls else [] raw_tool_calls = list(choice.message.tool_calls) if choice.message.tool_calls else []
tool_calls: list[dict[str, Any]] = [] tool_calls: list[dict[str, Any]] = []
for tc in raw_tool_calls: for tc in raw_tool_calls:
@@ -153,9 +149,9 @@ def _make_tool_node(skill_names: list[str]):
""" """
Return a callable that executes tool_calls from the last AI message. Return a callable that executes tool_calls from the last AI message.
This replaces LangGraph's built-in ToolNode — we call our own If a tool fails because the user isn't authenticated, the failure
`execute_tool()` pipeline so that skill-level auth, httpx sessions, message (which tells the user to /login) is returned to the LLM.
and ToolResult handling are fully preserved. The LLM naturally relays the instructions to the user.
""" """
async def tool_node(state: AgentState) -> dict[str, list]: async def tool_node(state: AgentState) -> dict[str, list]:
@@ -164,18 +160,16 @@ def _make_tool_node(skill_names: list[str]):
if not tool_calls: if not tool_calls:
return {"messages": []} return {"messages": []}
discord_user_id = state.get("discord_user_id")
results: list[ToolMessage] = [] results: list[ToolMessage] = []
for tc in tool_calls: for tc in tool_calls:
# Handle both LangChain format (top-level name/args) and
# OpenAI format (nested "function" key).
if isinstance(tc, dict): if isinstance(tc, dict):
if "function" in tc: if "function" in tc:
# OpenAI format: {"id":..., "function": {"name":..., "arguments":"..."}}
fn = tc["function"] fn = tc["function"]
fn_name = fn.get("name", "") fn_name = fn.get("name", "")
fn_args_raw = fn.get("arguments", "{}") fn_args_raw = fn.get("arguments", "{}")
else: else:
# LangChain format: {"name":..., "args":{...}, "id":...}
fn_name = tc.get("name", "") fn_name = tc.get("name", "")
fn_args_raw = tc.get("args", {}) fn_args_raw = tc.get("args", {})
tc_id = tc.get("id", "") tc_id = tc.get("id", "")
@@ -184,13 +178,15 @@ def _make_tool_node(skill_names: list[str]):
fn_args_raw = getattr(tc, "args", {}) fn_args_raw = getattr(tc, "args", {})
tc_id = getattr(tc, "id", "") tc_id = getattr(tc, "id", "")
# Parse args if they arrive as a JSON string
if isinstance(fn_args_raw, str): if isinstance(fn_args_raw, str):
fn_args = json.loads(fn_args_raw) fn_args = json.loads(fn_args_raw)
else: else:
fn_args = fn_args_raw fn_args = fn_args_raw
tr = await execute_tool(skill_names, fn_name, fn_args) tr = await execute_tool(
skill_names, fn_name, fn_args,
discord_user_id=discord_user_id,
)
content = tr.content if tr else f"Tool '{fn_name}' is not available." content = tr.content if tr else f"Tool '{fn_name}' is not available."
results.append(ToolMessage(content=content, tool_call_id=tc_id)) results.append(ToolMessage(content=content, tool_call_id=tc_id))
@@ -224,27 +220,16 @@ def create_agent_graph(
) -> StateGraph: ) -> StateGraph:
""" """
Build and compile a LangGraph StateGraph for a single agent. Build and compile a LangGraph StateGraph for a single agent.
Parameters
----------
client : The OpenAI-compatible client (already authenticated).
agent_skills : Skill names assigned to the agent (e.g. ["seerr", "triage"]).
system_prompt : The fully-built system prompt (base + skill fragments).
model_name : Model identifier sent to the LLM provider.
Returns
-------
A compiled LangGraph graph ready for `.ainvoke()` or `.astream()`.
""" """
tool_defs = get_all_tools(agent_skills) tool_defs = get_all_tools(agent_skills)
graph = StateGraph(AgentState) graph = StateGraph(AgentState)
# Nodes
graph.add_node( graph.add_node(
"agent_node", "agent_node",
_make_agent_node(client, system_prompt, tool_defs, model_name), _make_agent_node(client, system_prompt, tool_defs, model_name),
) )
if tool_defs: if tool_defs:
graph.add_node("tool_node", _make_tool_node(agent_skills)) graph.add_node("tool_node", _make_tool_node(agent_skills))
graph.add_conditional_edges("agent_node", _should_continue, { graph.add_conditional_edges("agent_node", _should_continue, {
@@ -253,7 +238,6 @@ def create_agent_graph(
}) })
graph.add_edge("tool_node", "agent_node") graph.add_edge("tool_node", "agent_node")
else: else:
# No tools — agent responds once and finishes
graph.add_edge("agent_node", END) graph.add_edge("agent_node", END)
graph.set_entry_point("agent_node") graph.set_entry_point("agent_node")
View File
+1
View File
@@ -18,3 +18,4 @@ class AgentState(TypedDict):
""" """
messages: Annotated[list, add_messages] messages: Annotated[list, add_messages]
discord_user_id: int | None # set by the Discord bot, None for REST API calls
@@ -13,7 +13,7 @@ from typing import Any
from langchain_core.tools import tool from langchain_core.tools import tool
from skills import get_all_tools, execute_tool from agents.skills import get_all_tools, execute_tool
def build_langgraph_tools(skill_names: list[str]) -> list: def build_langgraph_tools(skill_names: list[str]) -> list: