Compare commits
5 Commits
3cd2e4dfbb
..
auth
| Author | SHA1 | Date | |
|---|---|---|---|
| 0151c8210e | |||
| 4b87b817a8 | |||
| b0f10b6bb1 | |||
| 51e099acdd | |||
| bf358f7248 |
+35
-6
@@ -1,16 +1,19 @@
|
|||||||
# ---------------------------------------------------------------------------
|
# =============================================================================
|
||||||
# Agent Backend — Environment Variables
|
# Agent Bot — Environment Configuration
|
||||||
# Copy this to .env and fill in your values.
|
# =============================================================================
|
||||||
# ---------------------------------------------------------------------------
|
# Copy this file to .env and fill in your values.
|
||||||
|
# .env is git-ignored — never commit real secrets.
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
# LLM — DeepSeek (OpenAI-compatible)
|
# LLM — DeepSeek (OpenAI-compatible)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
DEEPSEEK_API_KEY=sk-your-deepseek-api-key
|
DEEPSEEK_API_KEY=sk-your-deepseek-api-key
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Discord Bot
|
# Discord Bot
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
DISCORD_BOT_TOKEN=your-discord-bot-token-here
|
DISCORD_BOT_TOKEN=your-discord-bot-token-here
|
||||||
# DISCORD_MAX_HISTORY=7 # optional, defaults to 7 (max past messages per user)
|
# DISCORD_MAX_HISTORY=7 # optional, defaults to 7 (max past messages per user)
|
||||||
# DISCORD_DEFAULT_AGENT=media-agent # optional, which agent the DM bot uses
|
# DISCORD_DEFAULT_AGENT=media-agent # optional, which agent the DM bot uses
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -18,4 +21,30 @@ DISCORD_BOT_TOKEN=your-discord-bot-token-here
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
SEERR_URL=https://seerr.example.com
|
SEERR_URL=https://seerr.example.com
|
||||||
SEERR_API_KEY=your-seerr-api-key
|
SEERR_API_KEY=your-seerr-api-key
|
||||||
# SEERR_TIMEOUT=30 # optional, defaults to 30 seconds
|
# SEERR_USERNAME=your-username # alternative: username+password auth
|
||||||
|
# SEERR_PASSWORD=your-password
|
||||||
|
# SEERR_TIMEOUT=30 # optional, defaults to 30 seconds
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Auth System (Discord ↔ external services)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# The public-facing URL where users reach this bot's web API.
|
||||||
|
# Used to build the "Click here to link" URLs sent via Discord DM.
|
||||||
|
# For local dev: http://localhost:8000
|
||||||
|
# For production behind a reverse proxy: https://bot.yourdomain.com
|
||||||
|
BASE_URL=http://localhost:8000
|
||||||
|
|
||||||
|
# Where the auth SQLite database lives (relative to project root)
|
||||||
|
# AUTH_DB_PATH=data/auth.db
|
||||||
|
|
||||||
|
# Link token expiry in minutes (default 10)
|
||||||
|
# AUTH_TOKEN_EXPIRY=10
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# JellyStat — PostgreSQL watch-history database
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
JELLYSTAT_DB_HOST=localhost
|
||||||
|
JELLYSTAT_DB_PORT=5432
|
||||||
|
JELLYSTAT_DB_USER=postgres
|
||||||
|
JELLYSTAT_DB_PASSWORD=
|
||||||
|
JELLYSTAT_DB_NAME=jfstat
|
||||||
|
|||||||
@@ -175,3 +175,4 @@ cython_debug/
|
|||||||
.pypirc
|
.pypirc
|
||||||
|
|
||||||
.docs/
|
.docs/
|
||||||
|
data/
|
||||||
+6
-5
@@ -12,7 +12,7 @@ An Agent is a lightweight wrapper:
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from skills import Skill, get_combined_prompt, list_all as list_all_skills
|
from agents.skills import Skill, get_combined_prompt, list_all as list_all_skills
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -61,7 +61,8 @@ def load_all_agents() -> None:
|
|||||||
import agents.media_agent # noqa: F401
|
import agents.media_agent # noqa: F401
|
||||||
|
|
||||||
# Also import skill modules so they self-register
|
# Also import skill modules so they self-register
|
||||||
import skills.media_info # noqa: F401
|
import agents.skills.media_info # noqa: F401
|
||||||
import skills.seerr # noqa: F401
|
import agents.skills.seerr # noqa: F401
|
||||||
import skills.triage # noqa: F401
|
import agents.skills.triage # noqa: F401
|
||||||
import skills.easter_eggs # noqa: F401
|
import agents.skills.easter_eggs # noqa: F401
|
||||||
|
import agents.skills.watch_history # noqa: F401
|
||||||
|
|||||||
@@ -14,11 +14,16 @@ media_agent = Agent(
|
|||||||
agent_id="media-agent",
|
agent_id="media-agent",
|
||||||
description="Media assistant — handles movie/TV/subtitle/ticket requests "
|
description="Media assistant — handles movie/TV/subtitle/ticket requests "
|
||||||
"via Seerr, Jellyfin, Sonarr, etc.",
|
"via Seerr, Jellyfin, Sonarr, etc.",
|
||||||
skills=["media_info", "seerr", "triage", "easter_eggs"],
|
skills=["media_info", "seerr", "triage", "easter_eggs", "watch_history"],
|
||||||
base_prompt=(
|
base_prompt=(
|
||||||
"You are a media assistant connected to Seerr and other media services. "
|
"You are a media assistant connected to Seerr and other media services. "
|
||||||
"Help users discover, request, and troubleshoot their media library. "
|
"Help users discover, request, and troubleshoot their media library. "
|
||||||
"Use the tools provided to perform real actions."
|
"Use the tools provided to perform real actions.\n\n"
|
||||||
|
"## Authentication\n"
|
||||||
|
"If a tool returns a message saying the user needs to log in first, "
|
||||||
|
"tell the user to type `/login <service>` in their DM (e.g. `/login jellyfin`). "
|
||||||
|
"This opens Quick Connect on their Jellyfin app so they can link their account. "
|
||||||
|
"Do NOT tell the user you 'can't connect' or 'don't have access' — just relay the login instructions."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ A Skill is a lightweight object with:
|
|||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||||
from core.config import get_config # re-export so every skill can use it
|
from src.config import get_config # re-export so every skill can use it
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -47,6 +47,7 @@ class Skill:
|
|||||||
prompt_fragment: str = ""
|
prompt_fragment: str = ""
|
||||||
tools: List[Dict[str, Any]] = field(default_factory=list)
|
tools: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
execute: Optional[ToolExecutor] = None
|
execute: Optional[ToolExecutor] = None
|
||||||
|
requires_auth: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -96,9 +97,15 @@ def get_all_tools(skill_names: list[str]) -> List[Dict[str, Any]]:
|
|||||||
|
|
||||||
|
|
||||||
async def execute_tool(
|
async def execute_tool(
|
||||||
skill_names: list[str], tool_name: str, args: dict
|
skill_names: list[str], tool_name: str, args: dict,
|
||||||
|
discord_user_id: int | None = None,
|
||||||
) -> ToolResult | None:
|
) -> ToolResult | None:
|
||||||
"""Find the skill that owns *tool_name* and run its executor.
|
"""Find the skill that owns *tool_name* and run its executor.
|
||||||
|
|
||||||
|
If *discord_user_id* is provided, also checks whether the owning skill
|
||||||
|
requires authentication for any services. If auth is missing, returns
|
||||||
|
a friendly ToolResult.fail(...) telling the user how to log in.
|
||||||
|
|
||||||
Only logs failures to the console — successful calls are silent.
|
Only logs failures to the console — successful calls are silent.
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
@@ -109,6 +116,27 @@ async def execute_tool(
|
|||||||
if s and s.execute:
|
if s and s.execute:
|
||||||
for t in s.tools:
|
for t in s.tools:
|
||||||
if t.get("function", {}).get("name") == tool_name:
|
if t.get("function", {}).get("name") == tool_name:
|
||||||
|
# --- Auth gate ---
|
||||||
|
if s.requires_auth and discord_user_id is not None:
|
||||||
|
from src import auth_store
|
||||||
|
from gateway.auth import get_auth_service
|
||||||
|
missing: list[str] = []
|
||||||
|
for svc in s.requires_auth:
|
||||||
|
if not auth_store.is_authenticated(discord_user_id, svc):
|
||||||
|
missing.append(svc)
|
||||||
|
if missing:
|
||||||
|
svc_displays = ", ".join(
|
||||||
|
(get_auth_service(m) and get_auth_service(m).display_name) or m
|
||||||
|
for m in missing
|
||||||
|
)
|
||||||
|
return ToolResult.fail(
|
||||||
|
f"You need to log in to {svc_displays} first. "
|
||||||
|
+ " ".join(f"Send `/login {m}` in a DM to get started." for m in missing)
|
||||||
|
)
|
||||||
|
# --- End auth gate ---
|
||||||
|
# Inject discord_user_id so skills can resolve external user IDs
|
||||||
|
if discord_user_id is not None:
|
||||||
|
args = {**args, "_discord_user_id": discord_user_id}
|
||||||
try:
|
try:
|
||||||
result = await s.execute(tool_name, args)
|
result = await s.execute(tool_name, args)
|
||||||
if not result.success:
|
if not result.success:
|
||||||
@@ -8,7 +8,7 @@ requested actions normally. Functionality is never sacrificed for a reference.
|
|||||||
Add a new theme by adding one entry to THEMES — no code changes needed.
|
Add a new theme by adding one entry to THEMES — no code changes needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from skills import Skill, register
|
from agents.skills import Skill, register
|
||||||
|
|
||||||
THEMES = {
|
THEMES = {
|
||||||
"naruto": {
|
"naruto": {
|
||||||
@@ -5,7 +5,7 @@ A lightweight base skill that teaches the agent it is a media assistant.
|
|||||||
Real API capabilities come from other skills (seerr, triage, etc.).
|
Real API capabilities come from other skills (seerr, triage, etc.).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from skills import Skill, register
|
from agents.skills import Skill, register
|
||||||
|
|
||||||
media_info_skill = Skill(
|
media_info_skill = Skill(
|
||||||
name="media_info",
|
name="media_info",
|
||||||
@@ -23,6 +23,16 @@ When responding:
|
|||||||
suggest submitting a ticket if there's a problem.
|
suggest submitting a ticket if there's a problem.
|
||||||
- Always confirm successful actions and warn about failures.
|
- Always confirm successful actions and warn about failures.
|
||||||
|
|
||||||
|
## Jellyfin & Authentication
|
||||||
|
|
||||||
|
You are connected to the user's Jellyfin server. If a user asks you to
|
||||||
|
"connect to Jellyfin", "link my Jellyfin", or asks about their watch history,
|
||||||
|
simply call the `watch_history` tool. The system will automatically handle
|
||||||
|
authentication — if the user isn't linked yet, they'll be guided through
|
||||||
|
Quick Connect seamlessly. NEVER tell a user you "don't have access to
|
||||||
|
Jellyfin" or "can't connect" — always try the tool first and let the system
|
||||||
|
sort it out.
|
||||||
|
|
||||||
This is the base media assistant persona. Real API capabilities come from the
|
This is the base media assistant persona. Real API capabilities come from the
|
||||||
attached skills (seerr, triage, etc.).""",
|
attached skills (seerr, triage, etc.).""",
|
||||||
)
|
)
|
||||||
@@ -24,7 +24,7 @@ from urllib.parse import quote
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from skills import Skill, register, ToolResult, get_config
|
from agents.skills import Skill, register, ToolResult, get_config
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Config
|
# Config
|
||||||
@@ -10,7 +10,7 @@ cancelling requests, banning users), this skill teaches the LLM to:
|
|||||||
3. Use the seerr_submit_issue tool (if available) to create the ticket.
|
3. Use the seerr_submit_issue tool (if available) to create the ticket.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from skills import Skill, register
|
from agents.skills import Skill, register
|
||||||
|
|
||||||
# This skill has no tools of its own — it guides the LLM's behavior.
|
# This skill has no tools of its own — it guides the LLM's behavior.
|
||||||
# The actual ticket submission is handled by seerr_submit_issue.
|
# The actual ticket submission is handled by seerr_submit_issue.
|
||||||
@@ -0,0 +1,273 @@
|
|||||||
|
"""
|
||||||
|
Watch History skill — fetch the user's Jellyfin watch history via JellyStat API.
|
||||||
|
|
||||||
|
Requires the user to have linked Jellyfin via `/login jellyfin` in Discord.
|
||||||
|
The auth gate (`requires_auth=["jellyfin"]`) is already active — users who
|
||||||
|
haven't linked Jellyfin will be prompted to /login first.
|
||||||
|
|
||||||
|
Architecture
|
||||||
|
------------
|
||||||
|
This skill calls the JellyStat REST API (same FastAPI process, via HTTP)
|
||||||
|
rather than accessing the PostgreSQL database directly. This keeps the
|
||||||
|
bot isolated from database credentials.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from agents.skills import Skill, register, ToolResult
|
||||||
|
from src import auth_store
|
||||||
|
from src.config import get_config
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Config
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
BASE_URL = (get_config("BASE_URL") or "http://localhost:8000").rstrip("/")
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tool definitions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
TOOLS = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "watch_history",
|
||||||
|
"description": (
|
||||||
|
"Get the user's Jellyfin watch history — titles grouped by total "
|
||||||
|
"watch time in a configurable time window. Use this when a user "
|
||||||
|
"asks what they've watched, what they've been watching recently, "
|
||||||
|
"or wants to see their viewing activity."
|
||||||
|
),
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "How many titles to return (default 10, max 20).",
|
||||||
|
},
|
||||||
|
"minutes": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": (
|
||||||
|
"Time window in minutes. Default 10080 (7 days). "
|
||||||
|
"Use a large number like 525600 for 'all time' (1 year)."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "watch_genres",
|
||||||
|
"description": (
|
||||||
|
"Get the user's most-watched genres from Jellyfin, ranked by "
|
||||||
|
"total watch time. Use this when a user asks what kinds of "
|
||||||
|
"content they watch most, their favourite genres, or what "
|
||||||
|
"categories dominate their viewing."
|
||||||
|
),
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"minutes": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": (
|
||||||
|
"Time window in minutes. Default 10080 (7 days). "
|
||||||
|
"Use a large number like 525600 for 'all time'."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "watch_summary",
|
||||||
|
"description": (
|
||||||
|
"Get an all-time Jellyfin watch summary — total watch time, "
|
||||||
|
"most-watched series, most-watched movie, 30-day and 7-day "
|
||||||
|
"activity, and top 3 genres. Use this when a user asks for "
|
||||||
|
"their overall stats, a dashboard, or 'how much have I watched?'."
|
||||||
|
),
|
||||||
|
"parameters": {"type": "object", "properties": {}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_jellyfin_id(args: dict) -> str | None:
|
||||||
|
"""Extract the Jellyfin user ID from auth_store using the injected Discord ID."""
|
||||||
|
discord_user_id = args.pop("_discord_user_id", None)
|
||||||
|
if discord_user_id is None:
|
||||||
|
return None # not called from Discord — shouldn't happen with auth gate
|
||||||
|
|
||||||
|
auth = auth_store.get_auth(discord_user_id, "jellyfin")
|
||||||
|
if auth is None or not auth.get("external_user_id"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return auth["external_user_id"]
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_json(url: str) -> dict:
|
||||||
|
"""GET *url* and return the parsed JSON body, or {} on failure."""
|
||||||
|
async with httpx.AsyncClient(timeout=10) as client:
|
||||||
|
resp = await client.get(url)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
def _format_seconds(total: float) -> str:
|
||||||
|
"""Convert seconds to a human-friendly string."""
|
||||||
|
total = max(total, 0)
|
||||||
|
hours = int(total // 3600)
|
||||||
|
minutes = int((total % 3600) // 60)
|
||||||
|
if hours and minutes:
|
||||||
|
return f"{hours}h {minutes}m"
|
||||||
|
if hours:
|
||||||
|
return f"{hours}h"
|
||||||
|
if minutes:
|
||||||
|
return f"{minutes}m"
|
||||||
|
return f"{int(total)}s"
|
||||||
|
|
||||||
|
|
||||||
|
def _format_history(data: dict, limit: int) -> ToolResult:
|
||||||
|
"""Format a watch-history API response for the LLM."""
|
||||||
|
items = data.get("items", [])[:limit]
|
||||||
|
if not items:
|
||||||
|
return ToolResult.ok("You haven't watched anything in this time window.")
|
||||||
|
|
||||||
|
lines = [f"**Watch History** (last {data.get('window_minutes', '?')} minutes):"]
|
||||||
|
for i, item in enumerate(items, 1):
|
||||||
|
duration = _format_seconds(item["watch_time_sec"])
|
||||||
|
icon = "📺" if item["media_type"] == "series" else "🎬"
|
||||||
|
lines.append(f"{i}. {icon} **{item['title']}** — {duration}")
|
||||||
|
|
||||||
|
return ToolResult.ok("\n".join(lines))
|
||||||
|
|
||||||
|
|
||||||
|
def _format_genres(data: dict) -> ToolResult:
|
||||||
|
"""Format a genre-summary API response for the LLM."""
|
||||||
|
genres = data.get("genres", [])
|
||||||
|
if not genres:
|
||||||
|
return ToolResult.ok("No genre data available for this time window.")
|
||||||
|
|
||||||
|
lines = [f"**Top Genres** (last {data.get('window_minutes', '?')} minutes):"]
|
||||||
|
for i, g in enumerate(genres, 1):
|
||||||
|
duration = _format_seconds(g["watch_time_sec"])
|
||||||
|
lines.append(f"{i}. **{g['genre']}** — {duration}")
|
||||||
|
|
||||||
|
return ToolResult.ok("\n".join(lines))
|
||||||
|
|
||||||
|
|
||||||
|
def _format_summary(data: dict) -> ToolResult:
|
||||||
|
"""Format a user-summary API response for the LLM."""
|
||||||
|
total = _format_seconds(data.get("total_watch_time_sec", 0))
|
||||||
|
last_30 = _format_seconds(data.get("total_last_30d_sec", 0))
|
||||||
|
last_7 = _format_seconds(data.get("total_last_7d_sec", 0))
|
||||||
|
|
||||||
|
top_series = data.get("most_watched_series") or "—"
|
||||||
|
top_movie = data.get("most_watched_movie") or "—"
|
||||||
|
top_genres = data.get("top_genres", [])
|
||||||
|
genres_str = ", ".join(top_genres) if top_genres else "—"
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
"**Your Jellyfin Summary** (all time):",
|
||||||
|
f"⏱️ Total watch time: **{total}**",
|
||||||
|
f"📺 Most-watched series: **{top_series}**",
|
||||||
|
f"🎬 Most-watched movie: **{top_movie}**",
|
||||||
|
f"📅 Last 30 days: **{last_30}**",
|
||||||
|
f"📅 Last 7 days: **{last_7}**",
|
||||||
|
f"🏷️ Top genres: {genres_str}",
|
||||||
|
]
|
||||||
|
return ToolResult.ok("\n".join(lines))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Executor
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def _execute(tool_name: str, args: dict) -> ToolResult:
|
||||||
|
# 1. Resolve Jellyfin user ID
|
||||||
|
jellyfin_id = _resolve_jellyfin_id(args)
|
||||||
|
if jellyfin_id is None:
|
||||||
|
return ToolResult.fail(
|
||||||
|
"Your Jellyfin account is not linked. Use `/login jellyfin` in a DM to connect."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Route to the right JellyStat endpoint
|
||||||
|
try:
|
||||||
|
match tool_name:
|
||||||
|
case "watch_history":
|
||||||
|
limit = args.get("limit", 10)
|
||||||
|
minutes = args.get("minutes", 10080)
|
||||||
|
url = f"{BASE_URL}/jellystat/history/{jellyfin_id}?minutes={minutes}"
|
||||||
|
data = await _fetch_json(url)
|
||||||
|
return _format_history(data, limit)
|
||||||
|
|
||||||
|
case "watch_genres":
|
||||||
|
minutes = args.get("minutes", 10080)
|
||||||
|
url = f"{BASE_URL}/jellystat/genres/{jellyfin_id}?minutes={minutes}"
|
||||||
|
data = await _fetch_json(url)
|
||||||
|
return _format_genres(data)
|
||||||
|
|
||||||
|
case "watch_summary":
|
||||||
|
url = f"{BASE_URL}/jellystat/summary/{jellyfin_id}"
|
||||||
|
data = await _fetch_json(url)
|
||||||
|
return _format_summary(data)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
return ToolResult.fail(f"Unknown tool: {tool_name}")
|
||||||
|
|
||||||
|
except httpx.HTTPError:
|
||||||
|
return ToolResult.fail(
|
||||||
|
"Could not reach the watch-history service right now. "
|
||||||
|
"Please try again in a moment."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Skill registration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_PROMPT = (
|
||||||
|
"## Watch History\n"
|
||||||
|
"\n"
|
||||||
|
"You have THREE tools to answer questions about the user's Jellyfin watch activity:\n"
|
||||||
|
"\n"
|
||||||
|
"1. **`watch_history`** — per-title watch time in a time window (default: 7 days).\n"
|
||||||
|
" Use when a user asks what they've watched, to show their history,\n"
|
||||||
|
" or what they watched this week or yesterday.\n"
|
||||||
|
"\n"
|
||||||
|
"2. **`watch_genres`** — watch time broken down by genre.\n"
|
||||||
|
" Use when a user asks what genres they watch, whether they watch more\n"
|
||||||
|
" comedy than drama, or what their most-watched genre is.\n"
|
||||||
|
"\n"
|
||||||
|
"3. **`watch_summary`** — all-time dashboard: total watch time, most-watched\n"
|
||||||
|
" series and movie, 30-day and 7-day activity, and top 3 genres.\n"
|
||||||
|
" Use when a user asks for their stats, how much they've watched in\n"
|
||||||
|
" total, or what their favourites are.\n"
|
||||||
|
"\n"
|
||||||
|
"Always call the appropriate tool before answering — NEVER guess at watch data.\n"
|
||||||
|
"Format watch times in a human-readable way (hours and minutes), but keep the\n"
|
||||||
|
"raw data visible too."
|
||||||
|
)
|
||||||
|
|
||||||
|
watch_history_skill = Skill(
|
||||||
|
name="watch_history",
|
||||||
|
description="User's Jellyfin watch history, genres, and summary stats",
|
||||||
|
requires_auth=["jellyfin"],
|
||||||
|
prompt_fragment=_PROMPT,
|
||||||
|
tools=TOOLS,
|
||||||
|
execute=_execute,
|
||||||
|
)
|
||||||
|
|
||||||
|
register(watch_history_skill)
|
||||||
-235
@@ -1,235 +0,0 @@
|
|||||||
# API Architecture — Agent + Skill + Graph Pipeline
|
|
||||||
|
|
||||||
This document explains how the API routes user messages through the
|
|
||||||
agent / skill / LangGraph pipeline to produce responses.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
```
|
|
||||||
┌─────────────────────────────────────────────────────────────────┐
|
|
||||||
│ OpenWebUI / Client │
|
|
||||||
│ POST /v1/chat/completions { model, messages, stream } │
|
|
||||||
└──────────────────────────────┬──────────────────────────────────┘
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
┌──────────────────────────────────────────────────────────────────┐
|
|
||||||
│ api/v1/chat.py — chat_completions() │
|
|
||||||
│ │
|
|
||||||
│ 1. _resolve_agent(req.model) → Agent │
|
|
||||||
│ 2. get_agent_graph(agent_id) → compiled StateGraph │
|
|
||||||
│ 3. graph.ainvoke(state) or _stream_graph(graph, messages) │
|
|
||||||
└──────────────────────────────┬───────────────────────────────────┘
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
┌──────────────────────────────────────────────────────────────────┐
|
|
||||||
│ LangGraph StateGraph (core/graph.py) │
|
|
||||||
│ │
|
|
||||||
│ ┌──────────────┐ tool_calls? ┌──────────────┐ │
|
|
||||||
│ │ agent_node │ ───────────────▶ │ tool_node │ │
|
|
||||||
│ │ (LLM call) │ ◀─────────────── │ (skill exec) │ │
|
|
||||||
│ └──────┬───────┘ └──────────────┘ │
|
|
||||||
│ │ no tool_calls │
|
|
||||||
│ ▼ │
|
|
||||||
│ [END] │
|
|
||||||
└──────────────────────────────────────────────────────────────────┘
|
|
||||||
|
|
||||||
## Key Concepts
|
|
||||||
|
|
||||||
### 1. Agent
|
|
||||||
|
|
||||||
An **Agent** is a persona + skill bundle. Defined in `agents/`.
|
|
||||||
|
|
||||||
```python
|
|
||||||
# agents/media_agent.py
|
|
||||||
Agent(
|
|
||||||
agent_id="media-agent",
|
|
||||||
description="Media assistant with Seerr integration",
|
|
||||||
skills=["media_info", "seerr", "triage"],
|
|
||||||
base_prompt="You are a media assistant...",
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
- `agent_id` — unique name, exposed as a model in OpenWebUI
|
|
||||||
- `skills` — list of skill names to load
|
|
||||||
- `base_prompt` — starting system prompt, combined with skill fragments
|
|
||||||
- `build_system_prompt()` — merges base_prompt + all skill prompt fragments
|
|
||||||
|
|
||||||
Agents self-register at import time via `agents/__init__.py`'s `register()`.
|
|
||||||
`main.py` calls `load_all_agents()` at startup to import every agent and skill
|
|
||||||
module.
|
|
||||||
|
|
||||||
### 2. Skill
|
|
||||||
|
|
||||||
A **Skill** is a capability bundle. Defined in `skills/`.
|
|
||||||
|
|
||||||
```python
|
|
||||||
# skills/seerr.py
|
|
||||||
Skill(
|
|
||||||
name="seerr",
|
|
||||||
description="Seerr integration — trending, discover, request media, submit issues",
|
|
||||||
prompt_fragment="## Seerr Media Tools\n...",
|
|
||||||
tools=[...], # OpenAI function-calling schema
|
|
||||||
execute=_execute, # async handler: tool_name + args → ToolResult
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
- `prompt_fragment` — injected into the agent's system prompt.
|
|
||||||
- `tools` — list of OpenAI function definitions (name, description, parameters).
|
|
||||||
- `execute` — async callable that routes tool calls to API handlers.
|
|
||||||
|
|
||||||
### 3. Graph
|
|
||||||
|
|
||||||
Each agent gets a **compiled LangGraph StateGraph** built by
|
|
||||||
`core/graph.py:create_agent_graph()`. The graph is compiled lazily on the
|
|
||||||
first request and cached on `app.state.agent_graphs` for the lifetime of the
|
|
||||||
process.
|
|
||||||
|
|
||||||
| Graph node / edge | What it does |
|
|
||||||
|---|---|
|
|
||||||
| `agent_node` | Converts state messages to OpenAI dicts, calls the LLM with the agent's system prompt + tool definitions, returns an `AIMessage` |
|
|
||||||
| `tool_node` | Reads `tool_calls` from the last AI message, calls `execute_tool()` from the skill system, returns `ToolMessage` results |
|
|
||||||
| `_should_continue` | Conditional edge — returns `"tool_node"` if the AI message has `tool_calls`, else `END` |
|
|
||||||
|
|
||||||
### 4. State
|
|
||||||
|
|
||||||
Defined in `core/state.py`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class AgentState(TypedDict):
|
|
||||||
messages: Annotated[list, add_messages]
|
|
||||||
```
|
|
||||||
|
|
||||||
LangGraph's `add_messages` reducer appends new messages and replaces messages
|
|
||||||
with matching IDs (so tool-call results overwrite their placeholders).
|
|
||||||
|
|
||||||
### 5. Message Conversion
|
|
||||||
|
|
||||||
Because we use the raw `openai` client (not `langchain-openai`), messages must
|
|
||||||
be converted between LangChain and OpenAI formats at every LLM call:
|
|
||||||
|
|
||||||
- **LangChain → OpenAI** (`_lc_role_to_openai`, `_langchain_tc_to_openai`):
|
|
||||||
Maps `type` → `role` and converts top-level `name`/`args` tool-calls into
|
|
||||||
the nested `function` sub-object that the OpenAI API expects.
|
|
||||||
|
|
||||||
- **OpenAI → LangChain** (inside `agent_node`):
|
|
||||||
Converts the `ChatCompletionMessage` response into an `AIMessage` with
|
|
||||||
LangChain-format `tool_calls` (top-level `name`/`args`/`id`).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Full Request Flow
|
|
||||||
|
|
||||||
### Step-by-step: "What are trending movies?"
|
|
||||||
|
|
||||||
```
|
|
||||||
1. OpenWebUI sends:
|
|
||||||
POST /v1/chat/completions
|
|
||||||
{
|
|
||||||
"model": "media-agent",
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "What are trending movies?"}
|
|
||||||
],
|
|
||||||
"stream": false
|
|
||||||
}
|
|
||||||
|
|
||||||
2. chat_completions():
|
|
||||||
→ _resolve_agent(model="media-agent")
|
|
||||||
→ get_agent("media-agent") → Agent(skills=["media_info", "seerr", "triage"])
|
|
||||||
→ get_agent_graph("media-agent", request)
|
|
||||||
→ looks up app.state.agent_graphs["media-agent"]
|
|
||||||
→ first call → create_agent_graph() compiles the graph with 7 Seerr tools
|
|
||||||
→ run_agent_with_tools(request, messages, agent_id)
|
|
||||||
→ _invoke_graph(graph, messages)
|
|
||||||
|
|
||||||
3. Graph — Pass 1 (agent_node):
|
|
||||||
→ LLM receives: [system prompt] + [user: "What are trending movies?"]
|
|
||||||
→ LLM responds with tool_calls: seerr_trending(kind="movie")
|
|
||||||
→ agent_node returns AIMessage with tool_calls in LangChain format
|
|
||||||
|
|
||||||
4. Graph — _should_continue:
|
|
||||||
→ AIMessage has tool_calls → route to "tool_node"
|
|
||||||
|
|
||||||
5. Graph — tool_node:
|
|
||||||
→ Reads tool_call: name="seerr_trending", args={"kind": "movie"}
|
|
||||||
→ execute_tool(["media_info", "seerr", "triage"], "seerr_trending", ...)
|
|
||||||
→ Seerr API → GET /api/v1/discover/trending?mediaType=movie
|
|
||||||
→ Returns ToolMessage with formatted results including [tmdb:IDs]
|
|
||||||
|
|
||||||
6. Graph — Pass 2 (agent_node):
|
|
||||||
→ LLM receives previous exchange + tool result
|
|
||||||
→ LLM responds with text only (no tool_calls)
|
|
||||||
→ agent_node returns AIMessage(content="Here are the top trending movies!...")
|
|
||||||
|
|
||||||
7. Graph — _should_continue:
|
|
||||||
→ No tool_calls → route to END
|
|
||||||
|
|
||||||
8. chat_completions() returns:
|
|
||||||
{ "choices": [{"message": {"role": "assistant", "content": "Here are the top..."}}] }
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step-by-step: "Request the 2026 one" (multi-turn context)
|
|
||||||
|
|
||||||
```
|
|
||||||
1. OpenWebUI sends the FULL history:
|
|
||||||
{
|
|
||||||
"model": "media-agent",
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "What are trending movies?"},
|
|
||||||
{"role": "assistant", "content": "Here are the top 10 trending movies!
|
|
||||||
1. **Mortal Kombat II** (2026) [tmdb:931285] — ..."},
|
|
||||||
{"role": "user", "content": "could request the mortal kombat one?"},
|
|
||||||
{"role": "assistant", "content": "There are several Mortal Kombat entries! ..."},
|
|
||||||
{"role": "user", "content": "the 2026 one"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
2. chat_completions():
|
|
||||||
→ req.messages contains the ENTIRE conversation history
|
|
||||||
→ graph.ainvoke({"messages": all_messages})
|
|
||||||
→ agent_node prepends system prompt and sends everything to the LLM
|
|
||||||
|
|
||||||
3. LLM reasons from full context:
|
|
||||||
- Previously listed Mortal Kombat II (2026) with [tmdb:931285]
|
|
||||||
- The user said "request the mortal kombat one" → I searched and showed 4 options
|
|
||||||
- Now they say "the 2026 one" → that matches Mortal Kombat II (2026) [tmdb:931285]
|
|
||||||
- I should call seerr_request_media(kind="movie", title="Mortal Kombat II", tmdb_id=931285)
|
|
||||||
|
|
||||||
4. tool_node executes the request → ✅ Success
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Streaming
|
|
||||||
|
|
||||||
Streaming works slightly differently from the sync path:
|
|
||||||
|
|
||||||
```
|
|
||||||
chat_completions(stream=True)
|
|
||||||
→ _stream_graph(graph, messages)
|
|
||||||
→ graph.ainvoke(state) # runs graph to completion (tools execute silently)
|
|
||||||
→ yields content character-by-character via SSE
|
|
||||||
```
|
|
||||||
|
|
||||||
For true token-level streaming (tokens appear as the LLM generates them),
|
|
||||||
the agent_node would need to use `langchain-openai`'s `ChatOpenAI` instead of
|
|
||||||
the raw `openai` client. The current approach is a pragmatic middle ground
|
|
||||||
that avoids adding another dependency while still giving the SSE client
|
|
||||||
incremental output.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## File Map
|
|
||||||
|
|
||||||
| File | Responsibility |
|
|
||||||
|---|---|
|
|
||||||
| `main.py` | FastAPI app, singleton creation, router mounting |
|
|
||||||
| `api/v1/chat.py` | Endpoints — resolves agent, invokes graph, formats responses |
|
|
||||||
| `api/dependencies.py` | `get_llm_client()`, `get_agent_graph()` — FastAPI `Depends` |
|
|
||||||
| `core/graph.py` | `create_agent_graph()` — builds the StateGraph |
|
|
||||||
| `core/state.py` | `AgentState` TypedDict |
|
|
||||||
| `core/llm.py` | `create_client()` — OpenAI client factory |
|
|
||||||
| `core/config.py` | Environment variable loader |
|
|
||||||
| `agents/` | Agent definitions (dataclass + self-registration) |
|
|
||||||
| `skills/` | Skill definitions (prompt fragments + tools + executors) |
|
|
||||||
@@ -0,0 +1,75 @@
|
|||||||
|
# Gateway Architecture — Agent + Skill + Graph Pipeline
|
||||||
|
|
||||||
|
This is the **interface layer** of the Agents project. Everything that connects
|
||||||
|
the outside world to the agent system lives here — REST APIs, Discord bot,
|
||||||
|
and authentication.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Directory Map
|
||||||
|
|
||||||
|
| Path | Description | Docs |
|
||||||
|
|---|---|---|
|
||||||
|
| `gateway/v1/` | REST API endpoints — chat, agent listing, OpenAI-compatible completions | [v1.md](v1/v1.md) |
|
||||||
|
| `gateway/discord/` | Discord bot connector — in-process DM handler with LangGraph integration | [discord.md](discord/discord.md) |
|
||||||
|
| `gateway/auth/` | Auth service registry + Jellyfin Quick Connect implementation | [auth.md](auth/auth.md) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Supporting Modules
|
||||||
|
|
||||||
|
| Path | Purpose |
|
||||||
|
|---|---|
|
||||||
|
| `gateway/dependencies.py` | FastAPI `Depends` providers — `get_llm_client()`, `get_agent_graph()` |
|
||||||
|
| `src/config.py` | `.env` loader and config accessor |
|
||||||
|
| `src/llm.py` | OpenAI-compatible client factory (DeepSeek) |
|
||||||
|
| `src/state.py` | LangGraph `AgentState` TypedDict |
|
||||||
|
| `src/graph.py` | LangGraph StateGraph factory — agent_node, tool_node, routing |
|
||||||
|
| `src/tools_adapter.py` | Wraps skill tools as LangChain `@tool` functions |
|
||||||
|
| `src/auth_store.py` | SQLite persistence for Discord → service auth linking |
|
||||||
|
| `agents/` | Agent definitions (dataclass + registry) |
|
||||||
|
| `agents/skills/` | Skill definitions — prompt fragments, tool schemas, executors |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## High-Level Request Flow
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────────────────────┐
|
||||||
|
│ Client (OpenWebUI / HTTP) │
|
||||||
|
└──────────────┬───────────────┘
|
||||||
|
│ POST /v1/chat/completions
|
||||||
|
▼
|
||||||
|
┌──────────────────────────────┐
|
||||||
|
│ gateway/v1/chat.py │ ← resolves agent, invokes graph
|
||||||
|
└──────────────┬───────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌──────────────────────────────┐
|
||||||
|
│ LangGraph StateGraph │ ← src/graph.py
|
||||||
|
│ ┌──────────┐ ┌──────────┐│
|
||||||
|
│ │agent_node│──▶│tool_node ││
|
||||||
|
│ │(LLM call)│◀──│(skills) ││
|
||||||
|
│ └──────────┘ └──────────┘│
|
||||||
|
└──────────────┬───────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌──────────────────────────────┐
|
||||||
|
│ agents/skills/ │ ← Seerr API, Jellyfin API, etc.
|
||||||
|
└──────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
For a detailed step-by-step walkthrough of the graph execution (including
|
||||||
|
multi-turn context and tool-calling loops), see [v1.md](v1/v1.md).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Startup
|
||||||
|
|
||||||
|
`main.py` is the entry point. At startup it:
|
||||||
|
|
||||||
|
1. Loads `.env` → creates the LLM client (DeepSeek) → stores on `app.state.llm_client`
|
||||||
|
2. Calls `load_all_agents()` → imports every agent and skill module (they self-register)
|
||||||
|
3. Imports `gateway.auth.jellyfin` → self-registers the Jellyfin auth service
|
||||||
|
4. Mounts routers: `/v1/*` (chat endpoints) and `/api/v1/auth/*` (auth endpoints)
|
||||||
|
5. Starts the Discord bot as a background asyncio task (lifespan)
|
||||||
@@ -0,0 +1,93 @@
|
|||||||
|
"""
|
||||||
|
Auth Service registry — generic, pluggable authentication for any service.
|
||||||
|
|
||||||
|
Add a new service (Plex, Seerr, etc.) by:
|
||||||
|
1. Subclassing AuthService
|
||||||
|
2. Dropping the module in this package
|
||||||
|
3. Calling register_auth_service() at import time
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# AuthResult — returned by AuthService.authenticate()
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AuthResult:
|
||||||
|
"""Outcome of a credential validation attempt."""
|
||||||
|
success: bool
|
||||||
|
external_user_id: Optional[str] = None
|
||||||
|
external_name: Optional[str] = None
|
||||||
|
credentials: Optional[dict] = None
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# AuthService — abstract base class
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class AuthService(ABC):
|
||||||
|
"""A service that users can authenticate against (Jellyfin, Seerr, Plex, etc.)
|
||||||
|
|
||||||
|
Subclasses must implement:
|
||||||
|
- name : unique identifier used in URLs and DB keys
|
||||||
|
- display_name : human-readable label shown in Discord
|
||||||
|
- render_login_form(token, discord_id) → HTML string
|
||||||
|
- authenticate(form_data) → AuthResult
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Unique service name: "jellyfin", "seerr", etc."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def display_name(self) -> str:
|
||||||
|
"""Human-readable: "Jellyfin", "Seerr", "Plex" """
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def render_login_form(self, token: str, discord_id: int) -> str:
|
||||||
|
"""Return HTML string with a login form for this service.
|
||||||
|
|
||||||
|
The form MUST include these hidden fields:
|
||||||
|
<input type="hidden" name="token" value="{token}">
|
||||||
|
<input type="hidden" name="discord_id" value="{discord_id}">
|
||||||
|
<input type="hidden" name="service" value="{self.name}">
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def authenticate(self, form_data: dict) -> AuthResult:
|
||||||
|
"""Validate credentials against the service. Return AuthResult."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Global registry
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_registry: dict[str, AuthService] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_auth_service(svc: AuthService) -> None:
|
||||||
|
"""Register an AuthService so it can be looked up by name."""
|
||||||
|
_registry[svc.name] = svc
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_service(name: str) -> AuthService | None:
|
||||||
|
"""Look up a registered AuthService by name."""
|
||||||
|
return _registry.get(name)
|
||||||
|
|
||||||
|
|
||||||
|
def list_auth_services() -> list[str]:
|
||||||
|
"""Return names of all registered auth services."""
|
||||||
|
return list(_registry.keys())
|
||||||
@@ -0,0 +1,152 @@
|
|||||||
|
# Auth — Service Registry & Persistence
|
||||||
|
|
||||||
|
The authentication system lets Discord users link their accounts to external
|
||||||
|
services (currently **Jellyfin**) so the agent can perform actions on their
|
||||||
|
behalf (e.g. checking watch history).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
gateway/auth/ gateway/v1/auth.py
|
||||||
|
┌──────────────────────┐ ┌──────────────────────────────┐
|
||||||
|
│ AuthService (ABC) │ │ GET /api/v1/auth/login │
|
||||||
|
│ ├─ JellyfinAuth │◀─────────│ POST /api/v1/auth/login │
|
||||||
|
│ └─ (Plex, Seerr…) │ │ GET /api/v1/auth/status │
|
||||||
|
│ │ │ GET /api/v1/auth/reset │
|
||||||
|
└─────────┬────────────┘ └──────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
src/auth_store.py
|
||||||
|
┌──────────────────────┐
|
||||||
|
│ SQLite │
|
||||||
|
│ ├─ link_tokens │ one-time tokens sent via Discord DM
|
||||||
|
│ └─ user_auth │ per-user, per-service credentials
|
||||||
|
└──────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Files
|
||||||
|
|
||||||
|
| File | Purpose |
|
||||||
|
|---|---|
|
||||||
|
| `gateway/auth/__init__.py` | Abstract `AuthService` base class + global registry |
|
||||||
|
| `gateway/auth/jellyfin.py` | Jellyfin implementation — Quick Connect + username/password |
|
||||||
|
| `gateway/v1/auth.py` | REST endpoints for the web-based login flow |
|
||||||
|
| `src/auth_store.py` | SQLite persistence for link tokens and stored credentials |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Flow: Discord User Links Jellyfin
|
||||||
|
|
||||||
|
```
|
||||||
|
Discord DM Web Browser Jellyfin Server
|
||||||
|
│ │ │
|
||||||
|
│ 1. /login jellyfin │ │
|
||||||
|
│ ──────────────────────────────▶│ │
|
||||||
|
│ Bot creates link token in │ │
|
||||||
|
│ SQLite, DMs the user a URL │ │
|
||||||
|
│ │ │
|
||||||
|
│ 2. User clicks link │ │
|
||||||
|
│ ◀─────────────────────────────▶│ │
|
||||||
|
│ │ GET /api/v1/auth/login │
|
||||||
|
│ │ ?service=jellyfin │
|
||||||
|
│ │ &token=xxx&discord_id=123 │
|
||||||
|
│ │ │
|
||||||
|
│ │ 3. Serve Quick Connect form │
|
||||||
|
│ │ ◀──────────────────────────── │
|
||||||
|
│ │ │
|
||||||
|
│ │ 4. Initiate Quick Connect │
|
||||||
|
│ │ ─────────────────────────────▶│
|
||||||
|
│ │ POST /QuickConnect/Initiate │
|
||||||
|
│ │ ◀── { Code: "ABC123" } │
|
||||||
|
│ │ │
|
||||||
|
│ 5. User enters code in │ │
|
||||||
|
│ Jellyfin app │ │
|
||||||
|
│ │ │
|
||||||
|
│ │ 6. Poll: is it authorized? │
|
||||||
|
│ │ ─────────────────────────────▶│
|
||||||
|
│ │ GET /QuickConnect/Connect │
|
||||||
|
│ │ ◀── Authenticated + Token │
|
||||||
|
│ │ │
|
||||||
|
│ 7. auth_store saves: │ │
|
||||||
|
│ (discord_id, jellyfin, │ │
|
||||||
|
│ AccessToken, username) │ │
|
||||||
|
│ │ │
|
||||||
|
│ 8. "✅ Linked to Jellyfin!" │ │
|
||||||
|
│ ◀───────────────────────────── │ │
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## AuthService Base Class
|
||||||
|
|
||||||
|
```python
|
||||||
|
class AuthService(ABC):
|
||||||
|
name: str # "jellyfin"
|
||||||
|
display_name: str # "Jellyfin"
|
||||||
|
|
||||||
|
def render_login_form(token, discord_id) -> str: ...
|
||||||
|
async def authenticate(form_data) -> AuthResult: ...
|
||||||
|
```
|
||||||
|
|
||||||
|
Add a new service (e.g. Plex, Seerr) by subclassing `AuthService`, dropping
|
||||||
|
the module in `gateway/auth/`, and calling `register_auth_service()` at import
|
||||||
|
time. The REST endpoints and auth store work generically — no changes needed.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Current Implementation: Jellyfin
|
||||||
|
|
||||||
|
`gateway/auth/jellyfin.py` supports two flows:
|
||||||
|
|
||||||
|
| Method | How it works |
|
||||||
|
|---|---|
|
||||||
|
| **Quick Connect** (primary) | Calls Jellyfin's `/QuickConnect/Initiate` → polls `/QuickConnect/Connect` → stores the `AccessToken` |
|
||||||
|
| **Username/Password** (fallback) | Renders an HTML form → user submits credentials → calls `/Users/AuthenticateByName` → stores the `AccessToken` |
|
||||||
|
|
||||||
|
The stored credentials include:
|
||||||
|
- `external_user_id` — Jellyfin user ID
|
||||||
|
- `external_name` — Jellyfin username
|
||||||
|
- `credentials` dict — `{"AccessToken": "...", "ServerURL": "..."}`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Auth Store (SQLite)
|
||||||
|
|
||||||
|
Two tables in `data/auth.db`:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- One-time tokens for the web login flow (expire after 10 min)
|
||||||
|
CREATE TABLE link_tokens (
|
||||||
|
token TEXT PRIMARY KEY,
|
||||||
|
discord_id INTEGER NOT NULL,
|
||||||
|
service TEXT NOT NULL,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
used INTEGER DEFAULT 0
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Per-user, per-service stored credentials
|
||||||
|
CREATE TABLE user_auth (
|
||||||
|
discord_id INTEGER NOT NULL,
|
||||||
|
service TEXT NOT NULL,
|
||||||
|
external_user_id TEXT,
|
||||||
|
external_name TEXT,
|
||||||
|
credentials TEXT, -- JSON
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
PRIMARY KEY (discord_id, service)
|
||||||
|
);
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Skill-Level Auth Gating
|
||||||
|
|
||||||
|
Skills can declare `requires_auth=["jellyfin"]`. When a tool is executed,
|
||||||
|
the skill system checks the auth store. If the user isn't linked:
|
||||||
|
|
||||||
|
1. The tool returns `ToolResult.fail("Please login first using /login jellyfin")`
|
||||||
|
2. The LLM relays this message to the user in Discord
|
||||||
|
3. The user types `/login jellyfin` → Quick Connect flow → re-linked → try again
|
||||||
@@ -0,0 +1,401 @@
|
|||||||
|
"""
|
||||||
|
Jellyfin AuthService — validates Jellyfin credentials and stores the session token.
|
||||||
|
|
||||||
|
Two authentication flows:
|
||||||
|
1. Quick Connect (primary): user enters a short code on their Jellyfin app.
|
||||||
|
- initiate_quick_connect() → {code, secret}
|
||||||
|
- poll_quick_connect(secret) → "Active" | "Authorized" | "Expired"
|
||||||
|
- authenticate_quick_connect(secret) → AuthResult with token
|
||||||
|
|
||||||
|
2. Username/password (legacy): renders an HTML form, called via the REST API.
|
||||||
|
- render_login_form(token, discord_id) → HTML string
|
||||||
|
- authenticate(form_data) → AuthResult
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from gateway.auth import AuthService, AuthResult, register_auth_service
|
||||||
|
from src.config import get_config
|
||||||
|
|
||||||
|
logger = logging.getLogger("auth.jellyfin")
|
||||||
|
|
||||||
|
# Emby-style authorization header required by Jellyfin's AuthenticateByName
|
||||||
|
_EMBY_HEADER = (
|
||||||
|
'MediaBrowser Client="AgentBot",'
|
||||||
|
'Device="DiscordBot",'
|
||||||
|
'DeviceId="agent-bot",'
|
||||||
|
'Version="1.0"'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QuickConnectResult:
|
||||||
|
"""Result of a Quick Connect initiation."""
|
||||||
|
secret: str
|
||||||
|
code: str
|
||||||
|
device_id: str
|
||||||
|
device_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class JellyfinAuth(AuthService):
|
||||||
|
name = "jellyfin"
|
||||||
|
display_name = "Jellyfin"
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Quick Connect helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _qc_headers(self) -> dict[str, str]:
|
||||||
|
"""Return headers used by all Quick Connect API calls."""
|
||||||
|
return {
|
||||||
|
"X-Emby-Authorization": (
|
||||||
|
'MediaBrowser Client="AgentBot",'
|
||||||
|
'Device="DiscordBot",'
|
||||||
|
'DeviceId="agent-bot-qc",'
|
||||||
|
'Version="1.0"'
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _resolve_url(self) -> str | None:
|
||||||
|
"""
|
||||||
|
Resolve the Jellyfin server URL.
|
||||||
|
1. Check JELLYFIN_URL env var (used in deployment).
|
||||||
|
2. Check if user already has a stored auth with a URL (from legacy login).
|
||||||
|
Returns None if no URL is configured.
|
||||||
|
"""
|
||||||
|
# First: explicit env var
|
||||||
|
env_url = get_config("JELLYFIN_URL")
|
||||||
|
if env_url:
|
||||||
|
return env_url.strip().rstrip("/")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Phase 1a: initiate Quick Connect
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def initiate_quick_connect(self, url: str | None = None) -> QuickConnectResult | None:
|
||||||
|
"""
|
||||||
|
Call Jellyfin's POST /QuickConnect/Initiate.
|
||||||
|
Returns a QuickConnectResult with {secret, code} or None on failure.
|
||||||
|
|
||||||
|
The *code* is what the user enters on their Jellyfin page.
|
||||||
|
The *secret* is used internally to poll/authenticate.
|
||||||
|
"""
|
||||||
|
base_url = url or await self._resolve_url()
|
||||||
|
if not base_url:
|
||||||
|
logger.error("QuickConnect failed — no JELLYFIN_URL configured.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info("Initiating Quick Connect on %s", base_url)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=10) as client:
|
||||||
|
try:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{base_url}/QuickConnect/Initiate",
|
||||||
|
headers=self._qc_headers(),
|
||||||
|
json={},
|
||||||
|
)
|
||||||
|
if resp.status_code != 200:
|
||||||
|
logger.warning(
|
||||||
|
"QuickConnect init failed: HTTP %s — %s",
|
||||||
|
resp.status_code, resp.text[:200],
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
secret = data.get("Secret", "")
|
||||||
|
code = data.get("Code", "")
|
||||||
|
device_id = data.get("DeviceId", "")
|
||||||
|
device_name = data.get("DeviceName", "")
|
||||||
|
|
||||||
|
if not secret or not code:
|
||||||
|
logger.warning("QuickConnect init returned unexpected payload: %s", data)
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Quick Connect initiated: code=%s device=%s",
|
||||||
|
code, device_name,
|
||||||
|
)
|
||||||
|
return QuickConnectResult(
|
||||||
|
secret=secret,
|
||||||
|
code=code,
|
||||||
|
device_id=device_id,
|
||||||
|
device_name=device_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.error("QuickConnect init timed out reaching %s", base_url)
|
||||||
|
return None
|
||||||
|
except httpx.ConnectError:
|
||||||
|
logger.error("QuickConnect init — cannot connect to %s", base_url)
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Unexpected error during QuickConnect init")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Phase 1b: poll Quick Connect status
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def poll_quick_connect(self, secret: str, url: str | None = None) -> str:
|
||||||
|
"""
|
||||||
|
Call Jellyfin's GET /QuickConnect/Connect?secret=<secret>.
|
||||||
|
Returns one of:
|
||||||
|
- "Active" → user hasn't entered the code yet
|
||||||
|
- "Authorized" → user entered code AND approved
|
||||||
|
- "Expired" → code expired / unknown secret
|
||||||
|
- "Error" → network or unexpected failure
|
||||||
|
"""
|
||||||
|
base_url = url or await self._resolve_url()
|
||||||
|
if not base_url:
|
||||||
|
logger.error("QuickConnect poll failed — no JELLYFIN_URL configured.")
|
||||||
|
return "Error"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=10) as client:
|
||||||
|
try:
|
||||||
|
resp = await client.get(
|
||||||
|
f"{base_url}/QuickConnect/Connect",
|
||||||
|
params={"secret": secret},
|
||||||
|
headers=self._qc_headers(),
|
||||||
|
)
|
||||||
|
if resp.status_code == 404:
|
||||||
|
return "Expired"
|
||||||
|
|
||||||
|
if resp.status_code == 200:
|
||||||
|
data = resp.json()
|
||||||
|
# Jellyfin returns "Authenticated" (not "Authorized")
|
||||||
|
if data.get("Authenticated") is True:
|
||||||
|
return "Authorized"
|
||||||
|
# "Authenticated" is false, missing, or null → still active
|
||||||
|
return "Active"
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"QuickConnect poll unexpected: HTTP %s — %s",
|
||||||
|
resp.status_code, resp.text[:200],
|
||||||
|
)
|
||||||
|
return "Error"
|
||||||
|
|
||||||
|
except (httpx.TimeoutException, httpx.ConnectError):
|
||||||
|
logger.warning("QuickConnect poll network error")
|
||||||
|
return "Error"
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Unexpected error during QuickConnect poll")
|
||||||
|
return "Error"
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Phase 1c: exchange secret for token
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def authenticate_quick_connect(
|
||||||
|
self, secret: str, url: str | None = None
|
||||||
|
) -> AuthResult:
|
||||||
|
"""
|
||||||
|
After poll_quick_connect returns "Authorized", call
|
||||||
|
POST /Users/AuthenticateWithQuickConnect to exchange the secret
|
||||||
|
for a real access token.
|
||||||
|
|
||||||
|
Returns AuthResult with token, user_id, username on success.
|
||||||
|
"""
|
||||||
|
base_url = url or await self._resolve_url()
|
||||||
|
if not base_url:
|
||||||
|
return AuthResult(
|
||||||
|
success=False,
|
||||||
|
error_message="No Jellyfin server URL configured.",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Exchanging QuickConnect secret for token on %s", base_url)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=10) as client:
|
||||||
|
try:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{base_url}/Users/AuthenticateWithQuickConnect",
|
||||||
|
json={"Secret": secret},
|
||||||
|
headers=self._qc_headers(),
|
||||||
|
)
|
||||||
|
if resp.status_code != 200:
|
||||||
|
logger.warning(
|
||||||
|
"QuickConnect auth exchange failed: HTTP %s",
|
||||||
|
resp.status_code,
|
||||||
|
)
|
||||||
|
return AuthResult(
|
||||||
|
success=False,
|
||||||
|
error_message="Quick Connect authentication failed. The code may have expired.",
|
||||||
|
)
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
user = data.get("User", {})
|
||||||
|
token = data.get("AccessToken", "")
|
||||||
|
|
||||||
|
if not token:
|
||||||
|
return AuthResult(
|
||||||
|
success=False,
|
||||||
|
error_message="Jellyfin returned an unexpected response.",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"QuickConnect linked: user=%s (%s)",
|
||||||
|
user.get("Name", "?"),
|
||||||
|
user.get("Id", "?"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return AuthResult(
|
||||||
|
success=True,
|
||||||
|
external_user_id=user.get("Id", ""),
|
||||||
|
external_name=user.get("Name", "?"),
|
||||||
|
credentials={
|
||||||
|
"token": token,
|
||||||
|
"url": base_url,
|
||||||
|
"user_id": user.get("Id", ""),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
return AuthResult(
|
||||||
|
success=False,
|
||||||
|
error_message=f"Could not reach {base_url} — connection timed out.",
|
||||||
|
)
|
||||||
|
except httpx.ConnectError:
|
||||||
|
return AuthResult(
|
||||||
|
success=False,
|
||||||
|
error_message=f"Could not connect to {base_url}. Is the server running?",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Unexpected error during QuickConnect auth exchange")
|
||||||
|
return AuthResult(
|
||||||
|
success=False,
|
||||||
|
error_message="An unexpected error occurred during authentication.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Login form (legacy — used by the REST API)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def render_login_form(self, token: str, discord_id: int) -> str:
|
||||||
|
return f"""<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||||
|
<title>Link Jellyfin</title>
|
||||||
|
<style>
|
||||||
|
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; }}
|
||||||
|
h2 {{ margin-bottom: 4px; }}
|
||||||
|
.sub {{ color: #666; margin-bottom: 24px; }}
|
||||||
|
label {{ display: block; margin-top: 16px; font-weight: 600; }}
|
||||||
|
input {{ width: 100%; padding: 10px; margin-top: 4px; border: 1px solid #ccc; border-radius: 6px; box-sizing: border-box; }}
|
||||||
|
button {{ margin-top: 24px; width: 100%; padding: 12px; background: #aa5cc3; color: #fff; border: none; border-radius: 6px; font-size: 16px; cursor: pointer; }}
|
||||||
|
button:hover {{ background: #9448b0; }}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h2>🔗 Link Jellyfin to Discord</h2>
|
||||||
|
<p class="sub">Enter your Jellyfin server URL and credentials to link your account.</p>
|
||||||
|
|
||||||
|
<form method="POST" action="/api/v1/auth/login">
|
||||||
|
<input type="hidden" name="token" value="{token}">
|
||||||
|
<input type="hidden" name="discord_id" value="{discord_id}">
|
||||||
|
<input type="hidden" name="service" value="jellyfin">
|
||||||
|
|
||||||
|
<label for="jellyfin_url">Jellyfin Server URL</label>
|
||||||
|
<input id="jellyfin_url" name="jellyfin_url" type="url"
|
||||||
|
placeholder="https://jellyfin.example.com" required>
|
||||||
|
|
||||||
|
<label for="username">Username</label>
|
||||||
|
<input id="username" name="username" type="text"
|
||||||
|
placeholder="Your Jellyfin username" required autofocus>
|
||||||
|
|
||||||
|
<label for="password">Password</label>
|
||||||
|
<input id="password" name="password" type="password"
|
||||||
|
placeholder="Your Jellyfin password" required>
|
||||||
|
|
||||||
|
<button type="submit">Link Account</button>
|
||||||
|
</form>
|
||||||
|
</body>
|
||||||
|
</html>"""
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Authentication
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def authenticate(self, form_data: dict) -> AuthResult:
|
||||||
|
url = form_data.get("jellyfin_url", "").strip().rstrip("/")
|
||||||
|
username = form_data.get("username", "").strip()
|
||||||
|
password = form_data.get("password", "").strip()
|
||||||
|
|
||||||
|
if not url or not username or not password:
|
||||||
|
return AuthResult(
|
||||||
|
success=False,
|
||||||
|
error_message="All fields are required (URL, username, password).",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Attempting Jellyfin login for '%s' on %s", username, url)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=10) as client:
|
||||||
|
try:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{url}/Users/AuthenticateByName",
|
||||||
|
json={"Username": username, "Pw": password},
|
||||||
|
headers={"X-Emby-Authorization": _EMBY_HEADER},
|
||||||
|
)
|
||||||
|
if resp.status_code != 200:
|
||||||
|
logger.warning(
|
||||||
|
"Jellyfin login failed for '%s': HTTP %s", username, resp.status_code
|
||||||
|
)
|
||||||
|
return AuthResult(
|
||||||
|
success=False,
|
||||||
|
error_message=f"Login failed — check your server URL and credentials.",
|
||||||
|
)
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
user = data.get("User", {})
|
||||||
|
token = data.get("AccessToken", "")
|
||||||
|
|
||||||
|
if not token:
|
||||||
|
return AuthResult(
|
||||||
|
success=False,
|
||||||
|
error_message="Jellyfin returned an unexpected response.",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Jellyfin login OK: user=%s (%s)",
|
||||||
|
user.get("Name", "?"),
|
||||||
|
user.get("Id", "?"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return AuthResult(
|
||||||
|
success=True,
|
||||||
|
external_user_id=user.get("Id", ""),
|
||||||
|
external_name=user.get("Name", username),
|
||||||
|
credentials={
|
||||||
|
"token": token,
|
||||||
|
"url": url,
|
||||||
|
"user_id": user.get("Id", ""),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
return AuthResult(
|
||||||
|
success=False,
|
||||||
|
error_message=f"Could not reach {url} — connection timed out. Check the URL.",
|
||||||
|
)
|
||||||
|
except httpx.ConnectError:
|
||||||
|
return AuthResult(
|
||||||
|
success=False,
|
||||||
|
error_message=f"Could not connect to {url}. Is the server running?",
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Unexpected error during Jellyfin login")
|
||||||
|
return AuthResult(
|
||||||
|
success=False,
|
||||||
|
error_message=f"An unexpected error occurred. Please try again.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Self-register at import time
|
||||||
|
register_auth_service(JellyfinAuth())
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from core.graph import create_agent_graph
|
from src.graph import create_agent_graph
|
||||||
|
|
||||||
|
|
||||||
def get_llm_client(request: Request) -> OpenAI:
|
def get_llm_client(request: Request) -> OpenAI:
|
||||||
@@ -23,10 +23,12 @@ import os
|
|||||||
import discord
|
import discord
|
||||||
|
|
||||||
from agents import list_all as list_all_agents
|
from agents import list_all as list_all_agents
|
||||||
from bot.conversation import ConversationStore
|
from gateway.discord.conversation import ConversationStore
|
||||||
from core.config import DEEPSEEK_API_KEY, get_config
|
from src.config import DEEPSEEK_API_KEY, get_config
|
||||||
from core.graph import create_agent_graph
|
from src.graph import create_agent_graph
|
||||||
from core.llm import create_client
|
from src.llm import create_client
|
||||||
|
from src import auth_store
|
||||||
|
from gateway.auth import list_auth_services, get_auth_service
|
||||||
|
|
||||||
logger = logging.getLogger("bot.discord")
|
logger = logging.getLogger("bot.discord")
|
||||||
|
|
||||||
@@ -36,6 +38,7 @@ logger = logging.getLogger("bot.discord")
|
|||||||
DISCORD_BOT_TOKEN = get_config("DISCORD_BOT_TOKEN") or ""
|
DISCORD_BOT_TOKEN = get_config("DISCORD_BOT_TOKEN") or ""
|
||||||
DISCORD_MAX_HISTORY = int(get_config("DISCORD_MAX_HISTORY", "7"))
|
DISCORD_MAX_HISTORY = int(get_config("DISCORD_MAX_HISTORY", "7"))
|
||||||
DISCORD_DEFAULT_AGENT = get_config("DISCORD_DEFAULT_AGENT", "media-agent")
|
DISCORD_DEFAULT_AGENT = get_config("DISCORD_DEFAULT_AGENT", "media-agent")
|
||||||
|
BASE_URL = get_config("BASE_URL", "http://localhost:8000").rstrip("/")
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# LLM client shared by all agents (same as the REST API uses)
|
# LLM client shared by all agents (same as the REST API uses)
|
||||||
@@ -138,6 +141,12 @@ class AgentBot(discord.Client):
|
|||||||
# |--------------------------------------------------------------|
|
# |--------------------------------------------------------------|
|
||||||
|
|
||||||
user_id = message.author.id
|
user_id = message.author.id
|
||||||
|
content = message.content.strip()
|
||||||
|
|
||||||
|
# |-- Bot commands — handled directly, never sent to the LLM --|
|
||||||
|
if await self._handle_command(message, user_id, content):
|
||||||
|
return
|
||||||
|
# |--------------------------------------------------------------|
|
||||||
|
|
||||||
# Show typing indicator while the graph runs
|
# Show typing indicator while the graph runs
|
||||||
async with message.channel.typing():
|
async with message.channel.typing():
|
||||||
@@ -154,6 +163,140 @@ class AgentBot(discord.Client):
|
|||||||
"Please try again in a moment."
|
"Please try again in a moment."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Bot commands
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _handle_command(
|
||||||
|
self, message: discord.Message, user_id: int, content: str
|
||||||
|
) -> bool:
|
||||||
|
"""Handle bot commands (/login, /logout). Returns True if handled."""
|
||||||
|
lower = content.lower()
|
||||||
|
|
||||||
|
# --- /login [service] ---
|
||||||
|
if lower.startswith("/login"):
|
||||||
|
parts = content.split()
|
||||||
|
service = parts[1].lower() if len(parts) > 1 else None
|
||||||
|
|
||||||
|
available = list_auth_services()
|
||||||
|
if not available:
|
||||||
|
await message.channel.send("No auth services are configured yet.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
if service is None:
|
||||||
|
svc_list = ", ".join(available)
|
||||||
|
await message.channel.send(
|
||||||
|
f"Available services to link: **{svc_list}**\n"
|
||||||
|
f"Use `/login <service>` — e.g. `/login jellyfin`"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
if service not in available:
|
||||||
|
await message.channel.send(
|
||||||
|
f"Unknown service '{service}'. Available: {', '.join(available)}"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
if auth_store.is_authenticated(user_id, service):
|
||||||
|
svc_display = (get_auth_service(service) and get_auth_service(service).display_name) or service
|
||||||
|
await message.channel.send(
|
||||||
|
f"You're already linked to **{svc_display}**! "
|
||||||
|
f"Use `/logout {service}` to unlink."
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# --- Quick Connect flow ---
|
||||||
|
svc = get_auth_service(service)
|
||||||
|
if svc is None:
|
||||||
|
await message.channel.send(f"Unknown service: {service}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
await message.channel.send(f"🔑 Starting **{svc.display_name}** Quick Connect…")
|
||||||
|
|
||||||
|
qc_result = await svc.initiate_quick_connect()
|
||||||
|
if qc_result is None:
|
||||||
|
await message.channel.send(
|
||||||
|
f"❌ Could not start Quick Connect for **{svc.display_name}**.\n"
|
||||||
|
"Check that `JELLYFIN_URL` is configured and the server is reachable."
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
await message.channel.send(
|
||||||
|
f"Open **{svc.display_name}** → **Quick Connect** and enter this code:\n\n"
|
||||||
|
f"**`{qc_result.code}`**\n\n"
|
||||||
|
f"⏳ Waiting for you to approve…"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Poll for authorization
|
||||||
|
async with message.channel.typing():
|
||||||
|
for attempt in range(24): # 24 × 5s = 2 minutes
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
status = await svc.poll_quick_connect(qc_result.secret)
|
||||||
|
|
||||||
|
if status == "Authorized":
|
||||||
|
auth_result = await svc.authenticate_quick_connect(qc_result.secret)
|
||||||
|
if auth_result.success:
|
||||||
|
auth_store.store_auth(
|
||||||
|
discord_user_id=user_id,
|
||||||
|
service=service,
|
||||||
|
external_user_id=auth_result.external_user_id or "",
|
||||||
|
external_name=auth_result.external_name or "",
|
||||||
|
credentials=auth_result.credentials,
|
||||||
|
)
|
||||||
|
await message.channel.send(
|
||||||
|
f"✅ Linked to **{svc.display_name}** as "
|
||||||
|
f"**{auth_result.external_name}**!"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await message.channel.send(
|
||||||
|
f"❌ Authentication failed: "
|
||||||
|
f"{auth_result.error_message or 'Unknown error'}"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
elif status == "Expired":
|
||||||
|
await message.channel.send(
|
||||||
|
"⌛ The Quick Connect code expired. "
|
||||||
|
f"Use `/login {service}` to try again."
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# else: still "Active" — keep polling
|
||||||
|
|
||||||
|
await message.channel.send(
|
||||||
|
"⌛ Timed out waiting for Quick Connect approval. "
|
||||||
|
f"Use `/login {service}` to try again."
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# --- /logout [service] ---
|
||||||
|
if lower.startswith("/logout"):
|
||||||
|
parts = content.split()
|
||||||
|
service = parts[1].lower() if len(parts) > 1 else None
|
||||||
|
|
||||||
|
if service is None:
|
||||||
|
linked = auth_store.list_services(user_id)
|
||||||
|
if not linked:
|
||||||
|
await message.channel.send("You don't have any linked services.")
|
||||||
|
else:
|
||||||
|
svc_list = ", ".join(linked)
|
||||||
|
await message.channel.send(
|
||||||
|
f"Linked services: **{svc_list}**\n"
|
||||||
|
f"Use `/logout <service>` to unlink."
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not auth_store.is_authenticated(user_id, service):
|
||||||
|
await message.channel.send(f"You're not linked to **{service}**.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
auth_store.revoke(user_id, service)
|
||||||
|
svc_display = (get_auth_service(service) and get_auth_service(service).display_name) or service
|
||||||
|
await message.channel.send(f"Unlinked from **{svc_display}**. Use `/login {service}` to re-link.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Agent invocation
|
# Agent invocation
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -163,7 +306,6 @@ class AgentBot(discord.Client):
|
|||||||
reply, and return the assistant's final text."""
|
reply, and return the assistant's final text."""
|
||||||
|
|
||||||
# 1. Pick agent — defaults to DISCORD_DEFAULT_AGENT env var.
|
# 1. Pick agent — defaults to DISCORD_DEFAULT_AGENT env var.
|
||||||
# Change DISCORD_DEFAULT_AGENT in .env to switch agents.
|
|
||||||
agent_id = DISCORD_DEFAULT_AGENT
|
agent_id = DISCORD_DEFAULT_AGENT
|
||||||
|
|
||||||
# 2. Build message list from stored history + new user message
|
# 2. Build message list from stored history + new user message
|
||||||
@@ -172,7 +314,7 @@ class AgentBot(discord.Client):
|
|||||||
|
|
||||||
# 3. Run the LangGraph (tools execute inline if needed)
|
# 3. Run the LangGraph (tools execute inline if needed)
|
||||||
graph = _get_graph(agent_id)
|
graph = _get_graph(agent_id)
|
||||||
state = {"messages": messages}
|
state = {"messages": messages, "discord_user_id": user_id}
|
||||||
result = await graph.ainvoke(state)
|
result = await graph.ainvoke(state)
|
||||||
|
|
||||||
last_msg = result["messages"][-1]
|
last_msg = result["messages"][-1]
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
# Discord — Connector
|
||||||
|
|
||||||
|
The Discord module embeds a Discord bot **in-process** alongside FastAPI.
|
||||||
|
It uses the same LangGraph graphs and LLM client as the REST API — there is
|
||||||
|
no HTTP loopback, no separate process, and no code duplication.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Files
|
||||||
|
|
||||||
|
| File | Purpose |
|
||||||
|
|---|---|
|
||||||
|
| `bot.py` | Discord `Client` subclass (`AgentBot`) — DM handler, command parser, graph invoker, Quick Connect orchestrator |
|
||||||
|
| `conversation.py` | In-memory conversation history store, keyed by Discord user ID |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
Discord Gateway (websocket)
|
||||||
|
│ DM: "What's trending?"
|
||||||
|
▼
|
||||||
|
discord.Client.on_message()
|
||||||
|
│ 1. Check: is this a DM? shares a guild? not a command?
|
||||||
|
│ 2. Build message history from ConversationStore
|
||||||
|
│ 3. Append user message
|
||||||
|
▼
|
||||||
|
_create_agent_graph(agent_id="media-agent")
|
||||||
|
│ Uses the exact same create_agent_graph() from src/graph.py
|
||||||
|
│ as the REST API — same LLM client, same tools, same cache.
|
||||||
|
▼
|
||||||
|
graph.ainvoke({"messages": [...]})
|
||||||
|
│ LangGraph runs agent_node → tool_node → agent_node → END
|
||||||
|
▼
|
||||||
|
Response text → split into ≤2000-char Discord messages → sent to user
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Commands
|
||||||
|
|
||||||
|
Commands are DMs that start with `/`. The bot parses them before hitting the
|
||||||
|
LLM:
|
||||||
|
|
||||||
|
| Command | Action |
|
||||||
|
|---|---|
|
||||||
|
| `/login <service>` | Generate a one-time auth link, DM it to the user |
|
||||||
|
| `/jellyfin login` | Alias for `/login jellyfin` |
|
||||||
|
| `/help` | Show available agents and commands |
|
||||||
|
| `/<agent_id>` | Switch to a different agent for future messages |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Auth Flow (Quick Connect)
|
||||||
|
|
||||||
|
When a user types `/login jellyfin`:
|
||||||
|
|
||||||
|
1. Bot generates a one-time token via `auth_store`
|
||||||
|
2. Bot calls `auth_store.create_link_token(discord_id, "jellyfin")`
|
||||||
|
3. Bot DMs the user: `https://<BASE_URL>/api/v1/auth/login?service=jellyfin&token=...&discord_id=...`
|
||||||
|
4. User clicks the link → FastAPI serves the Jellyfin login form (or Quick Connect prompt)
|
||||||
|
5. User authenticates → credentials stored in `auth_store`
|
||||||
|
6. Future tool calls (e.g. `watch_history`) automatically use the stored Jellyfin session
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Conversation Persistence
|
||||||
|
|
||||||
|
- Per-user history stored in `ConversationStore` (in-memory dict)
|
||||||
|
- Max history length configurable via `DISCORD_MAX_HISTORY` env var (default: 7)
|
||||||
|
- Oldest messages are silently dropped when the limit is exceeded
|
||||||
|
- History is NOT persisted across restarts (future: could use SQLite)
|
||||||
@@ -0,0 +1,106 @@
|
|||||||
|
"""JellyStat REST API — watch history, genre summary, and user summary."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
|
from gateway.jellystat.db import get_pool
|
||||||
|
from gateway.jellystat.models import (
|
||||||
|
GenreSummaryResponse,
|
||||||
|
UserSummaryResponse,
|
||||||
|
WatchHistoryResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/jellystat", tags=["jellystat"])
|
||||||
|
|
||||||
|
DEFAULT_WINDOW_MINUTES = 10080 # 7 days
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET /jellystat/history/{user_id}
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/history/{user_id}", response_model=WatchHistoryResponse)
|
||||||
|
async def get_watch_history(
|
||||||
|
user_id: str,
|
||||||
|
minutes: int = Query(
|
||||||
|
default=DEFAULT_WINDOW_MINUTES, ge=1, description="Time window in minutes"
|
||||||
|
),
|
||||||
|
pool: asyncpg.Pool = Depends(get_pool),
|
||||||
|
):
|
||||||
|
"""Return watch history grouped by title, ordered by most-watched first."""
|
||||||
|
rows = await pool.fetch(
|
||||||
|
"SELECT * FROM fn_user_watch_history($1, $2)", user_id, minutes
|
||||||
|
)
|
||||||
|
return WatchHistoryResponse(
|
||||||
|
user_id=user_id,
|
||||||
|
window_minutes=minutes,
|
||||||
|
items=[
|
||||||
|
{
|
||||||
|
"title": r["title"],
|
||||||
|
"watch_time_sec": float(r["watch_time_sec"]),
|
||||||
|
"media_type": r["media_type"],
|
||||||
|
}
|
||||||
|
for r in rows
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET /jellystat/genres/{user_id}
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/genres/{user_id}", response_model=GenreSummaryResponse)
|
||||||
|
async def get_genre_summary(
|
||||||
|
user_id: str,
|
||||||
|
minutes: int = Query(
|
||||||
|
default=DEFAULT_WINDOW_MINUTES, ge=1, description="Time window in minutes"
|
||||||
|
),
|
||||||
|
pool: asyncpg.Pool = Depends(get_pool),
|
||||||
|
):
|
||||||
|
"""Return total watch time per genre, ordered by most-watched first."""
|
||||||
|
rows = await pool.fetch(
|
||||||
|
"SELECT * FROM fn_user_genre_summary($1, $2)", user_id, minutes
|
||||||
|
)
|
||||||
|
return GenreSummaryResponse(
|
||||||
|
user_id=user_id,
|
||||||
|
window_minutes=minutes,
|
||||||
|
genres=[
|
||||||
|
{"genre": r["genre"], "watch_time_sec": float(r["watch_time_sec"])}
|
||||||
|
for r in rows
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET /jellystat/summary/{user_id}
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/summary/{user_id}", response_model=UserSummaryResponse)
|
||||||
|
async def get_user_summary(
|
||||||
|
user_id: str,
|
||||||
|
pool: asyncpg.Pool = Depends(get_pool),
|
||||||
|
):
|
||||||
|
"""Return all-time summary: total watch time, most-watched titles, top genres."""
|
||||||
|
rows = await pool.fetch("SELECT * FROM fn_user_summary($1)", user_id)
|
||||||
|
|
||||||
|
# fn_user_summary returns key-value rows — build a dict
|
||||||
|
# asyncpg already deserialises JSONB → Python objects
|
||||||
|
metrics: dict[str, object] = {r["metric"]: r["value"] for r in rows}
|
||||||
|
|
||||||
|
top_genres_raw = metrics.get("top_genres", [])
|
||||||
|
top_genres: list[str] = top_genres_raw if isinstance(top_genres_raw, list) else []
|
||||||
|
|
||||||
|
return UserSummaryResponse(
|
||||||
|
user_id=user_id,
|
||||||
|
total_watch_time_sec=float(metrics.get("total_watch_time", 0)),
|
||||||
|
most_watched_series=metrics.get("most_watched_series"),
|
||||||
|
most_watched_movie=metrics.get("most_watched_movie"),
|
||||||
|
total_last_30d_sec=float(metrics.get("total_last_30d", 0)),
|
||||||
|
total_last_7d_sec=float(metrics.get("total_last_7d", 0)),
|
||||||
|
top_genres=top_genres,
|
||||||
|
)
|
||||||
@@ -0,0 +1,130 @@
|
|||||||
|
"""PostgreSQL connection pool for the JellyStat database."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
|
||||||
|
from src.config import get_config
|
||||||
|
|
||||||
|
logger = logging.getLogger("gateway.jellystat")
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# DSN builder
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _build_dsn() -> str:
|
||||||
|
"""Build a PostgreSQL DSN from individual environment variables."""
|
||||||
|
host = get_config("JELLYSTAT_DB_HOST", "localhost")
|
||||||
|
port = get_config("JELLYSTAT_DB_PORT", "5432")
|
||||||
|
user = get_config("JELLYSTAT_DB_USER", "postgres")
|
||||||
|
password = get_config("JELLYSTAT_DB_PASSWORD", "")
|
||||||
|
dbname = get_config("JELLYSTAT_DB_NAME", "jfstat")
|
||||||
|
return f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pool lifecycle (called from main.py lifespan)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def init_pool(app: FastAPI) -> None:
|
||||||
|
"""Create the connection pool and store it on app.state."""
|
||||||
|
dsn = _build_dsn()
|
||||||
|
safe = dsn.split("@")[1] if "@" in dsn else dsn
|
||||||
|
logger.info("Connecting to JellyStat database at %s", safe)
|
||||||
|
|
||||||
|
pool = await asyncpg.create_pool(dsn, min_size=1, max_size=5)
|
||||||
|
app.state.jellystat_pool = pool
|
||||||
|
|
||||||
|
# Deploy functions on every startup (CREATE OR REPLACE is idempotent)
|
||||||
|
await _ensure_functions(pool)
|
||||||
|
|
||||||
|
|
||||||
|
async def close_pool(app: FastAPI) -> None:
|
||||||
|
"""Close the pool on shutdown."""
|
||||||
|
pool: asyncpg.Pool | None = getattr(app.state, "jellystat_pool", None)
|
||||||
|
if pool:
|
||||||
|
await pool.close()
|
||||||
|
logger.info("JellyStat pool closed")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# FastAPI dependency
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def get_pool(request: Request) -> asyncpg.Pool:
|
||||||
|
"""Return the JellyStat connection pool from app state."""
|
||||||
|
return request.app.state.jellystat_pool
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Function deployment
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def _ensure_functions(pool: asyncpg.Pool) -> None:
|
||||||
|
"""Run startup-functions.sql to create or replace all JellyStat functions."""
|
||||||
|
sql_path = Path(__file__).parent / "startup-functions.sql"
|
||||||
|
if not sql_path.exists():
|
||||||
|
logger.warning("startup-functions.sql not found — skipping function deployment")
|
||||||
|
return
|
||||||
|
|
||||||
|
sql = sql_path.read_text()
|
||||||
|
statements = _split_sql(sql)
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
for stmt in statements:
|
||||||
|
try:
|
||||||
|
await conn.execute(stmt)
|
||||||
|
except Exception:
|
||||||
|
# Log but don't crash — functions might already exist
|
||||||
|
logger.exception("Failed to deploy SQL statement — continuing")
|
||||||
|
|
||||||
|
logger.info("JellyStat functions deployed (%d statements)", len(statements))
|
||||||
|
|
||||||
|
|
||||||
|
def _split_sql(sql: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Split a multi-statement SQL string into individual statements.
|
||||||
|
|
||||||
|
Respects $$ dollar-quoting so that semicolons inside function bodies
|
||||||
|
don't cause premature splits. Pure comment lines (starting with ``--``)
|
||||||
|
outside dollar-quoted blocks are stripped.
|
||||||
|
"""
|
||||||
|
statements: list[str] = []
|
||||||
|
current: list[str] = []
|
||||||
|
in_dollar_quote = False
|
||||||
|
|
||||||
|
for line in sql.split("\n"):
|
||||||
|
stripped = line.strip()
|
||||||
|
|
||||||
|
# Skip pure comment lines outside of dollar-quoted blocks
|
||||||
|
if not in_dollar_quote and stripped.startswith("--"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Toggle dollar-quote state whenever we see $$
|
||||||
|
if "$$" in line:
|
||||||
|
in_dollar_quote = not in_dollar_quote
|
||||||
|
|
||||||
|
current.append(line)
|
||||||
|
|
||||||
|
# Statement terminator: semicolon at end of line, outside $$ block
|
||||||
|
if not in_dollar_quote and line.rstrip().endswith(";"):
|
||||||
|
stmt = "\n".join(current).strip()
|
||||||
|
if stmt:
|
||||||
|
statements.append(stmt)
|
||||||
|
current = []
|
||||||
|
|
||||||
|
# Catch any trailing statement that wasn't terminated by a semicolon
|
||||||
|
if current:
|
||||||
|
stmt = "\n".join(current).strip()
|
||||||
|
if stmt:
|
||||||
|
statements.append(stmt)
|
||||||
|
|
||||||
|
return statements
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
"""Pydantic response models for the JellyStat API."""
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class WatchHistoryItem(BaseModel):
|
||||||
|
title: str
|
||||||
|
watch_time_sec: float
|
||||||
|
media_type: str
|
||||||
|
|
||||||
|
|
||||||
|
class WatchHistoryResponse(BaseModel):
|
||||||
|
user_id: str
|
||||||
|
window_minutes: int
|
||||||
|
items: list[WatchHistoryItem]
|
||||||
|
|
||||||
|
|
||||||
|
class GenreSummaryItem(BaseModel):
|
||||||
|
genre: str
|
||||||
|
watch_time_sec: float
|
||||||
|
|
||||||
|
|
||||||
|
class GenreSummaryResponse(BaseModel):
|
||||||
|
user_id: str
|
||||||
|
window_minutes: int
|
||||||
|
genres: list[GenreSummaryItem]
|
||||||
|
|
||||||
|
|
||||||
|
class UserSummaryResponse(BaseModel):
|
||||||
|
user_id: str
|
||||||
|
total_watch_time_sec: float
|
||||||
|
most_watched_series: str | None
|
||||||
|
most_watched_movie: str | None
|
||||||
|
total_last_30d_sec: float
|
||||||
|
total_last_7d_sec: float
|
||||||
|
top_genres: list[str]
|
||||||
@@ -0,0 +1,224 @@
|
|||||||
|
-- ============================================================================
|
||||||
|
-- JellyStat API Functions
|
||||||
|
-- Parameterized database functions callable by the API layer as:
|
||||||
|
-- SELECT * FROM fn_user_watch_history('user_id_here', 10080);
|
||||||
|
-- SELECT * FROM fn_user_genre_summary('user_id_here', 10080);
|
||||||
|
-- SELECT * FROM fn_user_summary('user_id_here');
|
||||||
|
-- ============================================================================
|
||||||
|
|
||||||
|
-- ----------------------------------------------------------------------------
|
||||||
|
-- 1. User Watch History
|
||||||
|
-- Returns every distinct title watched in the last N minutes,
|
||||||
|
-- grouped and summed by title, ordered by most-watched first.
|
||||||
|
-- ----------------------------------------------------------------------------
|
||||||
|
CREATE OR REPLACE FUNCTION public.fn_user_watch_history(
|
||||||
|
p_user_id TEXT,
|
||||||
|
p_minutes INTEGER DEFAULT 10080 -- 7 days in minutes
|
||||||
|
)
|
||||||
|
RETURNS TABLE(
|
||||||
|
title TEXT,
|
||||||
|
watch_time_sec NUMERIC,
|
||||||
|
media_type TEXT
|
||||||
|
)
|
||||||
|
LANGUAGE sql
|
||||||
|
STABLE
|
||||||
|
AS $$
|
||||||
|
SELECT
|
||||||
|
COALESCE(a."SeriesName", a."NowPlayingItemName") AS title,
|
||||||
|
SUM(a."PlaybackDuration")::NUMERIC AS watch_time_sec,
|
||||||
|
CASE
|
||||||
|
WHEN a."SeriesName" IS NOT NULL THEN 'series'
|
||||||
|
ELSE 'movie'
|
||||||
|
END AS media_type
|
||||||
|
FROM jf_playback_activity a
|
||||||
|
WHERE a."UserId" = p_user_id
|
||||||
|
AND a."ActivityDateInserted"
|
||||||
|
>= NOW() - (p_minutes * INTERVAL '1 minute')
|
||||||
|
GROUP BY
|
||||||
|
COALESCE(a."SeriesName", a."NowPlayingItemName"),
|
||||||
|
CASE WHEN a."SeriesName" IS NOT NULL THEN 'series' ELSE 'movie' END
|
||||||
|
ORDER BY watch_time_sec DESC;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
-- ----------------------------------------------------------------------------
|
||||||
|
-- 2. Genre Summary
|
||||||
|
-- Returns total watch time per genre for a user over the last N minutes.
|
||||||
|
-- Resolves genres for both movies (directly on the item) and series
|
||||||
|
-- episodes (via jf_library_episodes → jf_library_items chain).
|
||||||
|
-- ----------------------------------------------------------------------------
|
||||||
|
CREATE OR REPLACE FUNCTION public.fn_user_genre_summary(
|
||||||
|
p_user_id TEXT,
|
||||||
|
p_minutes INTEGER DEFAULT 10080
|
||||||
|
)
|
||||||
|
RETURNS TABLE(
|
||||||
|
genre TEXT,
|
||||||
|
watch_time_sec NUMERIC
|
||||||
|
)
|
||||||
|
LANGUAGE sql
|
||||||
|
STABLE
|
||||||
|
AS $$
|
||||||
|
WITH movie_genres AS (
|
||||||
|
-- Movies: join playback directly to library_items on NowPlayingItemId
|
||||||
|
SELECT
|
||||||
|
genre_item.value AS genre,
|
||||||
|
SUM(a."PlaybackDuration") AS watch_time_sec
|
||||||
|
FROM jf_playback_activity a
|
||||||
|
JOIN jf_library_items i
|
||||||
|
ON i."Id" = a."NowPlayingItemId"
|
||||||
|
CROSS JOIN LATERAL jsonb_array_elements_text(i."Genres") AS genre_item(value)
|
||||||
|
WHERE a."UserId" = p_user_id
|
||||||
|
AND a."SeriesName" IS NULL -- movies only
|
||||||
|
AND a."ActivityDateInserted"
|
||||||
|
>= NOW() - (p_minutes * INTERVAL '1 minute')
|
||||||
|
AND i."Genres" IS NOT NULL
|
||||||
|
AND jsonb_array_length(i."Genres") > 0
|
||||||
|
GROUP BY genre_item.value
|
||||||
|
),
|
||||||
|
series_genres AS (
|
||||||
|
-- Series: playback → episodes → series item → genres
|
||||||
|
SELECT
|
||||||
|
genre_item.value AS genre,
|
||||||
|
SUM(a."PlaybackDuration") AS watch_time_sec
|
||||||
|
FROM jf_playback_activity a
|
||||||
|
JOIN jf_library_episodes e
|
||||||
|
ON e."EpisodeId" = a."EpisodeId"
|
||||||
|
JOIN jf_library_items i
|
||||||
|
ON i."Id" = e."SeriesId"
|
||||||
|
CROSS JOIN LATERAL jsonb_array_elements_text(i."Genres") AS genre_item(value)
|
||||||
|
WHERE a."UserId" = p_user_id
|
||||||
|
AND a."SeriesName" IS NOT NULL -- TV episodes only
|
||||||
|
AND a."ActivityDateInserted"
|
||||||
|
>= NOW() - (p_minutes * INTERVAL '1 minute')
|
||||||
|
AND i."Genres" IS NOT NULL
|
||||||
|
AND jsonb_array_length(i."Genres") > 0
|
||||||
|
GROUP BY genre_item.value
|
||||||
|
),
|
||||||
|
combined AS (
|
||||||
|
SELECT genre, watch_time_sec FROM movie_genres
|
||||||
|
UNION ALL
|
||||||
|
SELECT genre, watch_time_sec FROM series_genres
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
genre,
|
||||||
|
SUM(watch_time_sec)::NUMERIC AS watch_time_sec
|
||||||
|
FROM combined
|
||||||
|
GROUP BY genre
|
||||||
|
ORDER BY watch_time_sec DESC;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
-- ----------------------------------------------------------------------------
|
||||||
|
-- 3. User Summary
|
||||||
|
-- One-shot dashboard: all-time stats + recent windows + top genres.
|
||||||
|
-- Returns key-value rows that the API trivially converts to a JSON object
|
||||||
|
-- with Object.fromEntries() or similar.
|
||||||
|
-- ----------------------------------------------------------------------------
|
||||||
|
CREATE OR REPLACE FUNCTION public.fn_user_summary(
|
||||||
|
p_user_id TEXT
|
||||||
|
)
|
||||||
|
RETURNS TABLE(
|
||||||
|
metric TEXT,
|
||||||
|
value JSONB
|
||||||
|
)
|
||||||
|
LANGUAGE sql
|
||||||
|
STABLE
|
||||||
|
AS $$
|
||||||
|
-- total_watch_time (all time)
|
||||||
|
SELECT 'total_watch_time'::TEXT AS metric,
|
||||||
|
to_jsonb(COALESCE(SUM("PlaybackDuration"), 0)::NUMERIC) AS value
|
||||||
|
FROM jf_playback_activity
|
||||||
|
WHERE "UserId" = p_user_id
|
||||||
|
|
||||||
|
UNION ALL
|
||||||
|
|
||||||
|
-- most_watched_series (by total watch time)
|
||||||
|
SELECT 'most_watched_series'::TEXT AS metric,
|
||||||
|
COALESCE(
|
||||||
|
(SELECT to_jsonb("SeriesName")
|
||||||
|
FROM jf_playback_activity
|
||||||
|
WHERE "UserId" = p_user_id
|
||||||
|
AND "SeriesName" IS NOT NULL
|
||||||
|
GROUP BY "SeriesName"
|
||||||
|
ORDER BY SUM("PlaybackDuration") DESC
|
||||||
|
LIMIT 1),
|
||||||
|
'null'::JSONB
|
||||||
|
) AS value
|
||||||
|
|
||||||
|
UNION ALL
|
||||||
|
|
||||||
|
-- most_watched_movie (by total watch time)
|
||||||
|
SELECT 'most_watched_movie'::TEXT AS metric,
|
||||||
|
COALESCE(
|
||||||
|
(SELECT to_jsonb("NowPlayingItemName")
|
||||||
|
FROM jf_playback_activity
|
||||||
|
WHERE "UserId" = p_user_id
|
||||||
|
AND "SeriesName" IS NULL
|
||||||
|
GROUP BY "NowPlayingItemName"
|
||||||
|
ORDER BY SUM("PlaybackDuration") DESC
|
||||||
|
LIMIT 1),
|
||||||
|
'null'::JSONB
|
||||||
|
) AS value
|
||||||
|
|
||||||
|
UNION ALL
|
||||||
|
|
||||||
|
-- total_watch_time_last_month (last 30 days)
|
||||||
|
SELECT 'total_last_30d'::TEXT AS metric,
|
||||||
|
to_jsonb(COALESCE(SUM("PlaybackDuration"), 0)::NUMERIC) AS value
|
||||||
|
FROM jf_playback_activity
|
||||||
|
WHERE "UserId" = p_user_id
|
||||||
|
AND "ActivityDateInserted" >= NOW() - INTERVAL '30 days'
|
||||||
|
|
||||||
|
UNION ALL
|
||||||
|
|
||||||
|
-- total_watch_time_last_week (last 7 days)
|
||||||
|
SELECT 'total_last_7d'::TEXT AS metric,
|
||||||
|
to_jsonb(COALESCE(SUM("PlaybackDuration"), 0)::NUMERIC) AS value
|
||||||
|
FROM jf_playback_activity
|
||||||
|
WHERE "UserId" = p_user_id
|
||||||
|
AND "ActivityDateInserted" >= NOW() - INTERVAL '7 days'
|
||||||
|
|
||||||
|
UNION ALL
|
||||||
|
|
||||||
|
-- top_genres (top 3 all-time, as a JSON array)
|
||||||
|
SELECT 'top_genres'::TEXT AS metric,
|
||||||
|
COALESCE(
|
||||||
|
(SELECT jsonb_agg(genre ORDER BY watch_time_sec DESC)
|
||||||
|
FROM (
|
||||||
|
SELECT genre, SUM(watch_time_sec) AS watch_time_sec
|
||||||
|
FROM (
|
||||||
|
-- movies
|
||||||
|
SELECT
|
||||||
|
genre_item.value AS genre,
|
||||||
|
SUM(a."PlaybackDuration") AS watch_time_sec
|
||||||
|
FROM jf_playback_activity a
|
||||||
|
JOIN jf_library_items i ON i."Id" = a."NowPlayingItemId"
|
||||||
|
CROSS JOIN LATERAL jsonb_array_elements_text(i."Genres") AS genre_item(value)
|
||||||
|
WHERE a."UserId" = p_user_id
|
||||||
|
AND a."SeriesName" IS NULL
|
||||||
|
AND i."Genres" IS NOT NULL
|
||||||
|
AND jsonb_array_length(i."Genres") > 0
|
||||||
|
GROUP BY genre_item.value
|
||||||
|
|
||||||
|
UNION ALL
|
||||||
|
|
||||||
|
-- series
|
||||||
|
SELECT
|
||||||
|
genre_item.value AS genre,
|
||||||
|
SUM(a."PlaybackDuration") AS watch_time_sec
|
||||||
|
FROM jf_playback_activity a
|
||||||
|
JOIN jf_library_episodes e ON e."EpisodeId" = a."EpisodeId"
|
||||||
|
JOIN jf_library_items i ON i."Id" = e."SeriesId"
|
||||||
|
CROSS JOIN LATERAL jsonb_array_elements_text(i."Genres") AS genre_item(value)
|
||||||
|
WHERE a."UserId" = p_user_id
|
||||||
|
AND a."SeriesName" IS NOT NULL
|
||||||
|
AND i."Genres" IS NOT NULL
|
||||||
|
AND jsonb_array_length(i."Genres") > 0
|
||||||
|
GROUP BY genre_item.value
|
||||||
|
) combined
|
||||||
|
GROUP BY genre
|
||||||
|
ORDER BY SUM(watch_time_sec) DESC
|
||||||
|
LIMIT 3
|
||||||
|
) top3
|
||||||
|
),
|
||||||
|
'[]'::JSONB
|
||||||
|
) AS value;
|
||||||
|
$$;
|
||||||
@@ -0,0 +1,220 @@
|
|||||||
|
"""
|
||||||
|
Auth API — generic endpoints for linking Discord users to external services.
|
||||||
|
|
||||||
|
GET /api/v1/auth/login?service=X&token=Y&discord_id=Z
|
||||||
|
Validates the link token and serves a service-specific login form.
|
||||||
|
|
||||||
|
POST /api/v1/auth/login
|
||||||
|
Accepts the form submission, validates credentials against the service,
|
||||||
|
stores the session, and returns a result page.
|
||||||
|
|
||||||
|
GET /api/v1/auth/status?discord_id=Z
|
||||||
|
Returns which services are linked for this Discord user.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Form, HTTPException, Request
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
|
||||||
|
from gateway.auth import get_auth_service, list_auth_services
|
||||||
|
from src import auth_store
|
||||||
|
|
||||||
|
logger = logging.getLogger("gateway.auth")
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET /auth/login — serve the login form
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/login")
|
||||||
|
async def login_form(
|
||||||
|
service: str,
|
||||||
|
token: str,
|
||||||
|
discord_id: int,
|
||||||
|
):
|
||||||
|
"""Validate the one-time link token and return a service-specific login form."""
|
||||||
|
|
||||||
|
# Validate the token WITHOUT consuming it (the POST will consume it)
|
||||||
|
result = auth_store.validate_token(token)
|
||||||
|
if result is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid or expired link token.")
|
||||||
|
|
||||||
|
uid, svc = result
|
||||||
|
if uid != discord_id or svc != service:
|
||||||
|
raise HTTPException(status_code=400, detail="Token does not match the request.")
|
||||||
|
|
||||||
|
# Look up the AuthService
|
||||||
|
svc_obj = get_auth_service(service)
|
||||||
|
if svc_obj is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Unknown service: {service}")
|
||||||
|
|
||||||
|
logger.info("Serving login form: user=%s service=%s", discord_id, service)
|
||||||
|
return HTMLResponse(svc_obj.render_login_form(token, discord_id))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /auth/login — handle form submission
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/login")
|
||||||
|
async def login_submit(request: Request):
|
||||||
|
"""Handle the login form POST: validate credentials, store auth, show result."""
|
||||||
|
|
||||||
|
# Parse form data
|
||||||
|
form = await request.form()
|
||||||
|
token = form.get("token", "")
|
||||||
|
discord_id_str = form.get("discord_id", "")
|
||||||
|
service = form.get("service", "")
|
||||||
|
|
||||||
|
if not token or not discord_id_str or not service:
|
||||||
|
raise HTTPException(status_code=400, detail="Missing required fields.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
discord_id = int(discord_id_str)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid discord_id.")
|
||||||
|
|
||||||
|
# Consume the token on POST (the GET only validated, didn't consume)
|
||||||
|
result = auth_store.consume_token(token)
|
||||||
|
if result is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid or expired link token.")
|
||||||
|
|
||||||
|
# Look up the AuthService
|
||||||
|
svc_obj = get_auth_service(service)
|
||||||
|
if svc_obj is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Unknown service: {service}")
|
||||||
|
|
||||||
|
# Collect service-specific form fields (everything except token, discord_id, service)
|
||||||
|
form_data: dict[str, str] = {}
|
||||||
|
for key, value in form.items():
|
||||||
|
if key not in ("token", "discord_id", "service"):
|
||||||
|
form_data[key] = str(value)
|
||||||
|
|
||||||
|
# Authenticate against the service
|
||||||
|
auth_result = await svc_obj.authenticate(form_data)
|
||||||
|
|
||||||
|
if not auth_result.success:
|
||||||
|
return HTMLResponse(
|
||||||
|
status_code=401,
|
||||||
|
content=f"""<!DOCTYPE html>
|
||||||
|
<html><head><meta charset="utf-8"><title>Login Failed</title>
|
||||||
|
<style>
|
||||||
|
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; }}
|
||||||
|
h2 {{ color: #d32f2f; }}
|
||||||
|
a {{ color: #aa5cc3; }}
|
||||||
|
</style></head><body>
|
||||||
|
<h2>❌ Login Failed</h2>
|
||||||
|
<p>{auth_result.error_message or "Authentication failed. Please try again."}</p>
|
||||||
|
<p><a href="javascript:history.back()">← Go back and try again</a></p>
|
||||||
|
</body></html>""",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store the successful auth
|
||||||
|
auth_store.store_auth(
|
||||||
|
discord_user_id=discord_id,
|
||||||
|
service=service,
|
||||||
|
external_user_id=auth_result.external_user_id or "",
|
||||||
|
external_name=auth_result.external_name or "",
|
||||||
|
credentials=auth_result.credentials,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Auth linked: discord=%s → %s (%s)",
|
||||||
|
discord_id,
|
||||||
|
service,
|
||||||
|
auth_result.external_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return HTMLResponse(f"""<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||||
|
<title>Account Linked</title>
|
||||||
|
<style>
|
||||||
|
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; text-align: center; }}
|
||||||
|
h1 {{ color: #388e3c; }}
|
||||||
|
.name {{ font-weight: bold; color: #aa5cc3; }}
|
||||||
|
p {{ color: #666; }}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>✅ Account Linked!</h1>
|
||||||
|
<p>Logged in as <span class="name">{auth_result.external_name}</span> on <strong>{svc_obj.display_name}</strong>.</p>
|
||||||
|
<p>You can close this page and return to Discord.</p>
|
||||||
|
</body>
|
||||||
|
</html>""")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET /auth/status — get all linked services for a Discord user
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/Discord/status")
|
||||||
|
async def auth_status(discord_id: int):
|
||||||
|
"""
|
||||||
|
Return all services linked to this Discord user with full details.
|
||||||
|
|
||||||
|
Response:
|
||||||
|
{
|
||||||
|
"discord_id": 123456789,
|
||||||
|
"linked_services": {
|
||||||
|
"jellyfin": {
|
||||||
|
"external_user_id": "abc123",
|
||||||
|
"external_name": "Tim",
|
||||||
|
"linked_at": "2026-05-25T10:00:00",
|
||||||
|
"credentials": {
|
||||||
|
"token": "jwt...",
|
||||||
|
"url": "http://jellyfin:8096",
|
||||||
|
"user_id": "abc123"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
auths = auth_store.get_all_auths(discord_id)
|
||||||
|
|
||||||
|
linked_services: dict[str, dict] = {}
|
||||||
|
for auth in auths:
|
||||||
|
svc_name = auth["service"]
|
||||||
|
linked_services[svc_name] = {
|
||||||
|
"external_user_id": auth["external_user_id"],
|
||||||
|
"external_name": auth["external_name"],
|
||||||
|
"linked_at": auth["linked_at"],
|
||||||
|
"credentials": auth["credentials"],
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"discord_id": discord_id,
|
||||||
|
"linked_services": linked_services,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /auth/reset — wipe auth store (DEV ONLY)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
from src.config import get_config # noqa: E402
|
||||||
|
|
||||||
|
@router.post("/reset")
|
||||||
|
async def reset_auth():
|
||||||
|
"""
|
||||||
|
Reset the entire auth store — clears all link tokens and user auth records.
|
||||||
|
|
||||||
|
Only enabled when ALLOW_AUTH_RESET=true in the environment.
|
||||||
|
Returns 403 in production.
|
||||||
|
"""
|
||||||
|
if get_config("ALLOW_AUTH_RESET", "false").lower() != "true":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="Auth reset is disabled. Set ALLOW_AUTH_RESET=true to enable (dev only).",
|
||||||
|
)
|
||||||
|
|
||||||
|
auth_store.reset_all()
|
||||||
|
logger.warning("Auth store reset via API endpoint.")
|
||||||
|
return {"status": "ok", "message": "Auth store cleared — all tokens and auth records removed."}
|
||||||
@@ -4,9 +4,9 @@ from openai import OpenAI
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from api.dependencies import get_llm_client, get_agent_graph
|
from gateway.dependencies import get_llm_client, get_agent_graph
|
||||||
from agents import get as get_agent, list_all as list_all_agents
|
from agents import get as get_agent, list_all as list_all_agents
|
||||||
from core.state import AgentState
|
from src.state import AgentState
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@@ -0,0 +1,106 @@
|
|||||||
|
# V1 — Chat & Agent API Endpoints
|
||||||
|
|
||||||
|
This is the primary HTTP API surface for the chatbot agent system. It exposes
|
||||||
|
both a custom streaming chat endpoint and an OpenAI-compatible
|
||||||
|
`/chat/completions` endpoint so it works as a drop-in backend for OpenWebUI,
|
||||||
|
LibreChat, or any OpenAI-compatible client.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Endpoints
|
||||||
|
|
||||||
|
| Method | Path | Description |
|
||||||
|
|---|---|---|
|
||||||
|
| `GET ` | `/v1/` | Health check — returns `{"status": "ok"}` |
|
||||||
|
| `GET ` | `/v1/agents` | List all registered agents (id + description) |
|
||||||
|
| `GET ` | `/v1/models` | OpenAI-compatible model list (one entry per agent) |
|
||||||
|
| `POST` | `/v1/chat` | Chat with an agent — streaming (SSE) |
|
||||||
|
| `POST` | `/v1/chat/sync` | Chat with an agent — non-streaming |
|
||||||
|
| `POST` | `/v1/chat/completions` | OpenAI-compatible chat completions (supports `stream: true`) |
|
||||||
|
|
||||||
|
All `/v1/*` endpoints are mounted by `main.py` via:
|
||||||
|
|
||||||
|
```python
|
||||||
|
app.include_router(v1_router, prefix="/v1")
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Agent Resolution
|
||||||
|
|
||||||
|
Each request can target a specific agent. The resolution order is:
|
||||||
|
|
||||||
|
1. **Explicit `agent_id`** field in the request body
|
||||||
|
2. **OpenAI `model` field** (OpenWebUI sends this — mapped to `agent_id` if a matching agent is registered)
|
||||||
|
3. **Fallback** to the `"naked"` agent (a plain LLM with no tools)
|
||||||
|
|
||||||
|
This means an OpenWebUI client can simply set `model: "media-agent"` and get
|
||||||
|
the full Media Agent with Seerr tools.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Request Flow
|
||||||
|
|
||||||
|
```
|
||||||
|
Client (OpenWebUI / HTTP)
|
||||||
|
│ POST /v1/chat/completions
|
||||||
|
│ { model: "media-agent", messages: [...], stream: true/false }
|
||||||
|
▼
|
||||||
|
chat_completions()
|
||||||
|
│ 1. _resolve_agent(req.model) → Agent(id="media-agent", skills=[...])
|
||||||
|
│ 2. get_agent_graph("media-agent", request)
|
||||||
|
│ → lazy-compiled LangGraph StateGraph, cached on app.state
|
||||||
|
│ 3. stream=True → _stream_graph(graph, messages) → SSE token stream
|
||||||
|
│ stream=False → _invoke_graph(graph, messages) → plain response
|
||||||
|
▼
|
||||||
|
LangGraph StateGraph (src/graph.py)
|
||||||
|
│
|
||||||
|
├── agent_node: calls LLM with system prompt + tool definitions
|
||||||
|
│ └── LLM returns text OR tool_calls
|
||||||
|
│
|
||||||
|
├── _should_continue: if tool_calls → tool_node, else → END
|
||||||
|
│
|
||||||
|
└── tool_node: executes tool via agents/skills system → ToolMessage
|
||||||
|
└── loops back to agent_node with the result
|
||||||
|
```
|
||||||
|
|
||||||
|
For a detailed walkthrough, see [api.md](../api.md).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Streaming
|
||||||
|
|
||||||
|
Two streaming modes exist:
|
||||||
|
|
||||||
|
### SSE (Server-Sent Events) — `/v1/chat`
|
||||||
|
```
|
||||||
|
data: {"token": "Here"}
|
||||||
|
data: {"token": " are"}
|
||||||
|
data: {"token": " the"}
|
||||||
|
...
|
||||||
|
data: [DONE]
|
||||||
|
```
|
||||||
|
|
||||||
|
The graph runs to completion (tools execute silently), then the final text is
|
||||||
|
yielded token-by-token as SSE events.
|
||||||
|
|
||||||
|
### OpenAI-compatible — `/v1/chat/completions` with `stream: true`
|
||||||
|
```
|
||||||
|
data: {"id":"...","object":"chat.completion.chunk","choices":[{"delta":{"content":"Hello"}}]}
|
||||||
|
data: {"id":"...","object":"chat.completion.chunk","choices":[{"delta":{"content":"!"}}]}
|
||||||
|
data: [DONE]
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Future improvement:** true token-level streaming (tokens appear as the LLM
|
||||||
|
> generates them) would require using `langchain-openai`'s `ChatOpenAI` in
|
||||||
|
> place of the raw `openai` client. The current approach avoids adding that
|
||||||
|
> dependency.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
Endpoints receive shared singletons via FastAPI `Depends`:
|
||||||
|
|
||||||
|
- **`get_llm_client(request)`** → returns `request.app.state.llm_client` (OpenAI client singleton, created once in `main.py`)
|
||||||
|
- **`get_agent_graph(agent_id, request)`** → returns a lazy-compiled LangGraph from `request.app.state.agent_graphs`
|
||||||
@@ -4,9 +4,11 @@ from contextlib import asynccontextmanager
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from api.v1.chat import router as v1_router
|
from gateway.v1.auth import router as auth_router
|
||||||
from core.config import DEEPSEEK_API_KEY
|
from gateway.v1.chat import router as v1_router
|
||||||
from core.llm import create_client
|
from gateway.jellystat.api import router as jellystat_router
|
||||||
|
from src.config import DEEPSEEK_API_KEY, get_config
|
||||||
|
from src.llm import create_client
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Logging — tool calls will appear in the uvicorn console
|
# Logging — tool calls will appear in the uvicorn console
|
||||||
@@ -18,23 +20,29 @@ logging.basicConfig(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Load all agents & skills so they self-register at startup
|
# Load all agents, skills, AND auth services so they self-register at startup
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
from agents import load_all_agents # noqa: E402
|
from agents import load_all_agents # noqa: E402
|
||||||
|
|
||||||
load_all_agents()
|
load_all_agents()
|
||||||
|
|
||||||
|
import gateway.auth.jellyfin # noqa: E402 — self-registers JellyfinAuth
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Lifespan
|
# Lifespan
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
from bot.discord_bot import start_in_background # noqa: E402
|
from gateway.discord.bot import start_in_background # noqa: E402
|
||||||
|
from gateway.jellystat.db import init_pool, close_pool # noqa: E402
|
||||||
|
|
||||||
|
await init_pool(app)
|
||||||
start_in_background()
|
start_in_background()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
await close_pool(app)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# App
|
# App
|
||||||
@@ -61,3 +69,5 @@ app.state.agent_graphs: dict = {}
|
|||||||
# Routers
|
# Routers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
app.include_router(v1_router, prefix="/v1")
|
app.include_router(v1_router, prefix="/v1")
|
||||||
|
app.include_router(auth_router)
|
||||||
|
app.include_router(jellystat_router)
|
||||||
@@ -6,3 +6,5 @@ httpx
|
|||||||
langgraph
|
langgraph
|
||||||
langgraph-checkpoint
|
langgraph-checkpoint
|
||||||
discord.py
|
discord.py
|
||||||
|
python-multipart
|
||||||
|
asyncpg
|
||||||
@@ -0,0 +1,358 @@
|
|||||||
|
"""
|
||||||
|
Auth Store — SQLite-backed persistence for Discord-to-service authentication.
|
||||||
|
|
||||||
|
Two tables:
|
||||||
|
- link_tokens : one-time tokens sent via Discord DM to initiate login
|
||||||
|
- user_auth : per-user, per-service credentials (Jellyfin token, etc.)
|
||||||
|
|
||||||
|
Thread-safe via WAL mode and a shared lock. No passwords are ever stored
|
||||||
|
— only opaque service tokens (e.g. Jellyfin AccessToken).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from src.config import get_config
|
||||||
|
|
||||||
|
logger = logging.getLogger("auth_store")
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Config
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
AUTH_DB_PATH = get_config("AUTH_DB_PATH", "data/auth.db")
|
||||||
|
TOKEN_EXPIRY_MINUTES = int(get_config("AUTH_TOKEN_EXPIRY", "10"))
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Singleton handle
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_db_path: Path | None = None
|
||||||
|
_db_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_path() -> Path:
|
||||||
|
"""Turn AUTH_DB_PATH into an absolute path, creating parent dirs."""
|
||||||
|
global _db_path
|
||||||
|
if _db_path is not None:
|
||||||
|
return _db_path
|
||||||
|
p = Path(AUTH_DB_PATH)
|
||||||
|
if not p.is_absolute():
|
||||||
|
# Relative to the project root (two levels above this file)
|
||||||
|
project_root = Path(__file__).resolve().parent.parent
|
||||||
|
p = project_root / p
|
||||||
|
p.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
_db_path = p
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
def _get_conn() -> sqlite3.Connection:
|
||||||
|
"""Return a thread-local connection to the auth database."""
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
conn = sqlite3.connect(str(_resolve_path()), check_same_thread=False)
|
||||||
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
conn.execute("PRAGMA foreign_keys=ON")
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Schema
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_SCHEMA = """
|
||||||
|
CREATE TABLE IF NOT EXISTS link_tokens (
|
||||||
|
token TEXT PRIMARY KEY,
|
||||||
|
discord_user_id INTEGER NOT NULL,
|
||||||
|
service TEXT NOT NULL,
|
||||||
|
expires_at TEXT NOT NULL,
|
||||||
|
used INTEGER DEFAULT 0,
|
||||||
|
created_at TEXT DEFAULT (datetime('now'))
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS user_auth (
|
||||||
|
discord_user_id INTEGER NOT NULL,
|
||||||
|
service TEXT NOT NULL,
|
||||||
|
external_user_id TEXT,
|
||||||
|
external_name TEXT,
|
||||||
|
credentials TEXT,
|
||||||
|
linked_at TEXT DEFAULT (datetime('now')),
|
||||||
|
is_active INTEGER DEFAULT 1,
|
||||||
|
PRIMARY KEY (discord_user_id, service)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
_initialized = False
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_schema() -> None:
|
||||||
|
global _initialized
|
||||||
|
if _initialized:
|
||||||
|
return
|
||||||
|
with _db_lock:
|
||||||
|
if _initialized:
|
||||||
|
return
|
||||||
|
conn = _get_conn()
|
||||||
|
conn.executescript(_SCHEMA)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
_initialized = True
|
||||||
|
logger.info("Auth store initialized at %s", _resolve_path())
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Public API — Link Tokens
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def create_token(discord_user_id: int, service: str) -> str:
|
||||||
|
"""Generate a one-time link token. Expires after TOKEN_EXPIRY_MINUTES."""
|
||||||
|
_ensure_schema()
|
||||||
|
token = secrets.token_urlsafe(32)
|
||||||
|
expires = (datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRY_MINUTES)).isoformat()
|
||||||
|
|
||||||
|
with _db_lock:
|
||||||
|
conn = _get_conn()
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO link_tokens (token, discord_user_id, service, expires_at) VALUES (?, ?, ?, ?)",
|
||||||
|
(token, discord_user_id, service, expires),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
logger.info("Created link token for user %s / service %s", discord_user_id, service)
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
def validate_token(token: str) -> tuple[int, str] | None:
|
||||||
|
"""Read-only validation — does NOT consume the token.
|
||||||
|
|
||||||
|
Returns (discord_user_id, service) if the token exists, is unused,
|
||||||
|
and has not expired. Returns None otherwise.
|
||||||
|
"""
|
||||||
|
_ensure_schema()
|
||||||
|
|
||||||
|
with _db_lock:
|
||||||
|
conn = _get_conn()
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?",
|
||||||
|
(token,),
|
||||||
|
).fetchone()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
if row["used"]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
expires = datetime.fromisoformat(row["expires_at"])
|
||||||
|
if datetime.now(timezone.utc) > expires:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return (row["discord_user_id"], row["service"])
|
||||||
|
|
||||||
|
|
||||||
|
def consume_token(token: str) -> tuple[int, str] | None:
|
||||||
|
"""Validate and consume a link token. Returns (discord_user_id, service) or None.
|
||||||
|
|
||||||
|
A token is valid if:
|
||||||
|
- It exists
|
||||||
|
- It has not been used
|
||||||
|
- It has not expired
|
||||||
|
"""
|
||||||
|
_ensure_schema()
|
||||||
|
|
||||||
|
with _db_lock:
|
||||||
|
conn = _get_conn()
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?",
|
||||||
|
(token,),
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
conn.close()
|
||||||
|
return None
|
||||||
|
|
||||||
|
if row["used"]:
|
||||||
|
conn.close()
|
||||||
|
logger.warning("Token already used: %s…", token[:8])
|
||||||
|
return None
|
||||||
|
|
||||||
|
expires = datetime.fromisoformat(row["expires_at"])
|
||||||
|
if datetime.now(timezone.utc) > expires:
|
||||||
|
conn.close()
|
||||||
|
logger.warning("Token expired: %s…", token[:8])
|
||||||
|
return None
|
||||||
|
|
||||||
|
conn.execute("UPDATE link_tokens SET used = 1 WHERE token = ?", (token,))
|
||||||
|
conn.commit()
|
||||||
|
result = (row["discord_user_id"], row["service"])
|
||||||
|
conn.close()
|
||||||
|
logger.info("Token consumed: %s… → user=%s service=%s", token[:8], result[0], result[1])
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Public API — User Auth
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def store_auth(
|
||||||
|
discord_user_id: int,
|
||||||
|
service: str,
|
||||||
|
*,
|
||||||
|
external_user_id: str = "",
|
||||||
|
external_name: str = "",
|
||||||
|
credentials: dict | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Store or update authentication for a user on a service."""
|
||||||
|
_ensure_schema()
|
||||||
|
import json
|
||||||
|
|
||||||
|
creds_json = json.dumps(credentials) if credentials else "{}"
|
||||||
|
|
||||||
|
with _db_lock:
|
||||||
|
conn = _get_conn()
|
||||||
|
conn.execute(
|
||||||
|
"""INSERT INTO user_auth (discord_user_id, service, external_user_id, external_name, credentials, linked_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, datetime('now'))
|
||||||
|
ON CONFLICT(discord_user_id, service) DO UPDATE SET
|
||||||
|
external_user_id = excluded.external_user_id,
|
||||||
|
external_name = excluded.external_name,
|
||||||
|
credentials = excluded.credentials,
|
||||||
|
linked_at = datetime('now'),
|
||||||
|
is_active = 1""",
|
||||||
|
(discord_user_id, service, external_user_id, external_name, creds_json),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
logger.info("Stored auth for user %s on %s as %s", discord_user_id, service, external_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth(discord_user_id: int, service: str) -> dict | None:
|
||||||
|
"""Retrieve stored auth for a user on a service. Returns None if not linked."""
|
||||||
|
_ensure_schema()
|
||||||
|
import json
|
||||||
|
|
||||||
|
with _db_lock:
|
||||||
|
conn = _get_conn()
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT * FROM user_auth WHERE discord_user_id = ? AND service = ? AND is_active = 1",
|
||||||
|
(discord_user_id, service),
|
||||||
|
).fetchone()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
credentials = json.loads(row["credentials"]) if row["credentials"] else {}
|
||||||
|
return {
|
||||||
|
"discord_user_id": row["discord_user_id"],
|
||||||
|
"service": row["service"],
|
||||||
|
"external_user_id": row["external_user_id"],
|
||||||
|
"external_name": row["external_name"],
|
||||||
|
"credentials": credentials,
|
||||||
|
"linked_at": row["linked_at"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def is_authenticated(discord_user_id: int, service: str) -> bool:
|
||||||
|
"""Quick check: is this user linked to this service?"""
|
||||||
|
return get_auth(discord_user_id, service) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def list_services(discord_user_id: int) -> list[str]:
|
||||||
|
"""Return list of service names this user has linked."""
|
||||||
|
_ensure_schema()
|
||||||
|
|
||||||
|
with _db_lock:
|
||||||
|
conn = _get_conn()
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT service FROM user_auth WHERE discord_user_id = ? AND is_active = 1",
|
||||||
|
(discord_user_id,),
|
||||||
|
).fetchall()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return [r["service"] for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
def revoke(discord_user_id: int, service: str) -> None:
|
||||||
|
"""Unlink a user from a service."""
|
||||||
|
_ensure_schema()
|
||||||
|
|
||||||
|
with _db_lock:
|
||||||
|
conn = _get_conn()
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE user_auth SET is_active = 0 WHERE discord_user_id = ? AND service = ?",
|
||||||
|
(discord_user_id, service),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
logger.info("Revoked auth for user %s on %s", discord_user_id, service)
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_auths(discord_user_id: int) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Return all active auth records for a Discord user.
|
||||||
|
Each record includes service name, external user id, external name,
|
||||||
|
linked_at timestamp, and the raw credentials (e.g. Jellyfin token + URL).
|
||||||
|
|
||||||
|
Used by the /api/v1/auth/status endpoint so other services can discover
|
||||||
|
linked accounts for a given Discord ID.
|
||||||
|
"""
|
||||||
|
_ensure_schema()
|
||||||
|
import json
|
||||||
|
|
||||||
|
with _db_lock:
|
||||||
|
conn = _get_conn()
|
||||||
|
rows = conn.execute(
|
||||||
|
"""SELECT service, external_user_id, external_name, credentials, linked_at
|
||||||
|
FROM user_auth
|
||||||
|
WHERE discord_user_id = ? AND is_active = 1
|
||||||
|
ORDER BY linked_at DESC""",
|
||||||
|
(discord_user_id,),
|
||||||
|
).fetchall()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
results: list[dict] = []
|
||||||
|
for row in rows:
|
||||||
|
creds = {}
|
||||||
|
if row["credentials"]:
|
||||||
|
try:
|
||||||
|
creds = json.loads(row["credentials"])
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
creds = {}
|
||||||
|
results.append({
|
||||||
|
"service": row["service"],
|
||||||
|
"external_user_id": row["external_user_id"] or "",
|
||||||
|
"external_name": row["external_name"] or "",
|
||||||
|
"linked_at": row["linked_at"] or "",
|
||||||
|
"credentials": creds,
|
||||||
|
})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Dev / testing — reset the entire store
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def reset_all() -> None:
|
||||||
|
"""Truncate all auth tables — for development and testing only."""
|
||||||
|
_ensure_schema()
|
||||||
|
|
||||||
|
with _db_lock:
|
||||||
|
conn = _get_conn()
|
||||||
|
conn.execute("DELETE FROM link_tokens")
|
||||||
|
conn.execute("DELETE FROM user_auth")
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
logger.warning("Auth store RESET — all tokens and auth records cleared.")
|
||||||
@@ -1,12 +1,13 @@
|
|||||||
"""
|
"""
|
||||||
LangGraph agent graph factory.
|
LangGraph agent graph factory.
|
||||||
|
|
||||||
Builds a StateGraph that replaces the manual tool-calling loop in api/v1/chat.py.
|
Builds a StateGraph with two nodes:
|
||||||
The graph has two nodes:
|
|
||||||
- agent_node : calls the LLM (with system prompt + tool definitions)
|
- agent_node : calls the LLM (with system prompt + tool definitions)
|
||||||
- tool_node : executes tool calls via the existing skill system
|
- tool_node : executes tool calls via the existing skill system
|
||||||
|
|
||||||
A conditional edge routes tool_calls back to the agent, or ends the run.
|
A conditional edge routes tool_calls back to the agent, or ends the run.
|
||||||
|
When a tool fails due to missing authentication, the failure message is
|
||||||
|
relayed to the LLM, which tells the user to use /login.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -19,8 +20,8 @@ from langchain_core.messages import AIMessage, ToolMessage
|
|||||||
from langgraph.graph import END, StateGraph
|
from langgraph.graph import END, StateGraph
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from core.state import AgentState
|
from src.state import AgentState
|
||||||
from skills import get_all_tools, execute_tool
|
from agents.skills import get_all_tools, execute_tool
|
||||||
|
|
||||||
logger = logging.getLogger("graph")
|
logger = logging.getLogger("graph")
|
||||||
|
|
||||||
@@ -97,18 +98,14 @@ def _make_agent_node(
|
|||||||
full: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
|
full: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
|
||||||
for m in messages:
|
for m in messages:
|
||||||
if isinstance(m, dict):
|
if isinstance(m, dict):
|
||||||
# Already a plain dict — pass through.
|
|
||||||
# But fix tool_calls if they're in LangChain format.
|
|
||||||
d = dict(m)
|
d = dict(m)
|
||||||
tc = d.get("tool_calls")
|
tc = d.get("tool_calls")
|
||||||
if tc and isinstance(tc, list) and tc and isinstance(tc[0], dict) and "function" not in tc[0]:
|
if tc and isinstance(tc, list) and tc and isinstance(tc[0], dict) and "function" not in tc[0]:
|
||||||
d["tool_calls"] = _langchain_tc_to_openai(tc)
|
d["tool_calls"] = _langchain_tc_to_openai(tc)
|
||||||
full.append(d)
|
full.append(d)
|
||||||
else:
|
else:
|
||||||
# LangChain message object → OpenAI-compatible dict
|
|
||||||
role = _lc_role_to_openai(getattr(m, "type", "user"))
|
role = _lc_role_to_openai(getattr(m, "type", "user"))
|
||||||
d: dict[str, Any] = {"role": role, "content": getattr(m, "content", "")}
|
d: dict[str, Any] = {"role": role, "content": getattr(m, "content", "")}
|
||||||
# Serialize tool_calls back to OpenAI format (if this is an AI msg)
|
|
||||||
tc = getattr(m, "tool_calls", None)
|
tc = getattr(m, "tool_calls", None)
|
||||||
if tc:
|
if tc:
|
||||||
d["tool_calls"] = _langchain_tc_to_openai(tc)
|
d["tool_calls"] = _langchain_tc_to_openai(tc)
|
||||||
@@ -125,7 +122,6 @@ def _make_agent_node(
|
|||||||
)
|
)
|
||||||
choice = resp.choices[0]
|
choice = resp.choices[0]
|
||||||
|
|
||||||
# Convert OpenAI tool_calls to the dict format LangChain expects.
|
|
||||||
raw_tool_calls = list(choice.message.tool_calls) if choice.message.tool_calls else []
|
raw_tool_calls = list(choice.message.tool_calls) if choice.message.tool_calls else []
|
||||||
tool_calls: list[dict[str, Any]] = []
|
tool_calls: list[dict[str, Any]] = []
|
||||||
for tc in raw_tool_calls:
|
for tc in raw_tool_calls:
|
||||||
@@ -153,9 +149,9 @@ def _make_tool_node(skill_names: list[str]):
|
|||||||
"""
|
"""
|
||||||
Return a callable that executes tool_calls from the last AI message.
|
Return a callable that executes tool_calls from the last AI message.
|
||||||
|
|
||||||
This replaces LangGraph's built-in ToolNode — we call our own
|
If a tool fails because the user isn't authenticated, the failure
|
||||||
`execute_tool()` pipeline so that skill-level auth, httpx sessions,
|
message (which tells the user to /login) is returned to the LLM.
|
||||||
and ToolResult handling are fully preserved.
|
The LLM naturally relays the instructions to the user.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def tool_node(state: AgentState) -> dict[str, list]:
|
async def tool_node(state: AgentState) -> dict[str, list]:
|
||||||
@@ -164,18 +160,16 @@ def _make_tool_node(skill_names: list[str]):
|
|||||||
if not tool_calls:
|
if not tool_calls:
|
||||||
return {"messages": []}
|
return {"messages": []}
|
||||||
|
|
||||||
|
discord_user_id = state.get("discord_user_id")
|
||||||
|
|
||||||
results: list[ToolMessage] = []
|
results: list[ToolMessage] = []
|
||||||
for tc in tool_calls:
|
for tc in tool_calls:
|
||||||
# Handle both LangChain format (top-level name/args) and
|
|
||||||
# OpenAI format (nested "function" key).
|
|
||||||
if isinstance(tc, dict):
|
if isinstance(tc, dict):
|
||||||
if "function" in tc:
|
if "function" in tc:
|
||||||
# OpenAI format: {"id":..., "function": {"name":..., "arguments":"..."}}
|
|
||||||
fn = tc["function"]
|
fn = tc["function"]
|
||||||
fn_name = fn.get("name", "")
|
fn_name = fn.get("name", "")
|
||||||
fn_args_raw = fn.get("arguments", "{}")
|
fn_args_raw = fn.get("arguments", "{}")
|
||||||
else:
|
else:
|
||||||
# LangChain format: {"name":..., "args":{...}, "id":...}
|
|
||||||
fn_name = tc.get("name", "")
|
fn_name = tc.get("name", "")
|
||||||
fn_args_raw = tc.get("args", {})
|
fn_args_raw = tc.get("args", {})
|
||||||
tc_id = tc.get("id", "")
|
tc_id = tc.get("id", "")
|
||||||
@@ -184,13 +178,15 @@ def _make_tool_node(skill_names: list[str]):
|
|||||||
fn_args_raw = getattr(tc, "args", {})
|
fn_args_raw = getattr(tc, "args", {})
|
||||||
tc_id = getattr(tc, "id", "")
|
tc_id = getattr(tc, "id", "")
|
||||||
|
|
||||||
# Parse args if they arrive as a JSON string
|
|
||||||
if isinstance(fn_args_raw, str):
|
if isinstance(fn_args_raw, str):
|
||||||
fn_args = json.loads(fn_args_raw)
|
fn_args = json.loads(fn_args_raw)
|
||||||
else:
|
else:
|
||||||
fn_args = fn_args_raw
|
fn_args = fn_args_raw
|
||||||
|
|
||||||
tr = await execute_tool(skill_names, fn_name, fn_args)
|
tr = await execute_tool(
|
||||||
|
skill_names, fn_name, fn_args,
|
||||||
|
discord_user_id=discord_user_id,
|
||||||
|
)
|
||||||
content = tr.content if tr else f"Tool '{fn_name}' is not available."
|
content = tr.content if tr else f"Tool '{fn_name}' is not available."
|
||||||
results.append(ToolMessage(content=content, tool_call_id=tc_id))
|
results.append(ToolMessage(content=content, tool_call_id=tc_id))
|
||||||
|
|
||||||
@@ -224,27 +220,16 @@ def create_agent_graph(
|
|||||||
) -> StateGraph:
|
) -> StateGraph:
|
||||||
"""
|
"""
|
||||||
Build and compile a LangGraph StateGraph for a single agent.
|
Build and compile a LangGraph StateGraph for a single agent.
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
client : The OpenAI-compatible client (already authenticated).
|
|
||||||
agent_skills : Skill names assigned to the agent (e.g. ["seerr", "triage"]).
|
|
||||||
system_prompt : The fully-built system prompt (base + skill fragments).
|
|
||||||
model_name : Model identifier sent to the LLM provider.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
A compiled LangGraph graph ready for `.ainvoke()` or `.astream()`.
|
|
||||||
"""
|
"""
|
||||||
tool_defs = get_all_tools(agent_skills)
|
tool_defs = get_all_tools(agent_skills)
|
||||||
|
|
||||||
graph = StateGraph(AgentState)
|
graph = StateGraph(AgentState)
|
||||||
|
|
||||||
# Nodes
|
|
||||||
graph.add_node(
|
graph.add_node(
|
||||||
"agent_node",
|
"agent_node",
|
||||||
_make_agent_node(client, system_prompt, tool_defs, model_name),
|
_make_agent_node(client, system_prompt, tool_defs, model_name),
|
||||||
)
|
)
|
||||||
|
|
||||||
if tool_defs:
|
if tool_defs:
|
||||||
graph.add_node("tool_node", _make_tool_node(agent_skills))
|
graph.add_node("tool_node", _make_tool_node(agent_skills))
|
||||||
graph.add_conditional_edges("agent_node", _should_continue, {
|
graph.add_conditional_edges("agent_node", _should_continue, {
|
||||||
@@ -253,7 +238,6 @@ def create_agent_graph(
|
|||||||
})
|
})
|
||||||
graph.add_edge("tool_node", "agent_node")
|
graph.add_edge("tool_node", "agent_node")
|
||||||
else:
|
else:
|
||||||
# No tools — agent responds once and finishes
|
|
||||||
graph.add_edge("agent_node", END)
|
graph.add_edge("agent_node", END)
|
||||||
|
|
||||||
graph.set_entry_point("agent_node")
|
graph.set_entry_point("agent_node")
|
||||||
@@ -18,3 +18,4 @@ class AgentState(TypedDict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
messages: Annotated[list, add_messages]
|
messages: Annotated[list, add_messages]
|
||||||
|
discord_user_id: int | None # set by the Discord bot, None for REST API calls
|
||||||
@@ -13,7 +13,7 @@ from typing import Any
|
|||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from skills import get_all_tools, execute_tool
|
from agents.skills import get_all_tools, execute_tool
|
||||||
|
|
||||||
|
|
||||||
def build_langgraph_tools(skill_names: list[str]) -> list:
|
def build_langgraph_tools(skill_names: list[str]) -> list:
|
||||||
Reference in New Issue
Block a user