Compare commits
39 Commits
2763c93cef
...
auth
| Author | SHA1 | Date | |
|---|---|---|---|
| 0151c8210e | |||
| 4b87b817a8 | |||
| b0f10b6bb1 | |||
| 51e099acdd | |||
| bf358f7248 | |||
| 3cd2e4dfbb | |||
| f82c416b49 | |||
| 2cd72c7770 | |||
| 1d65c7a9e9 | |||
| 2f7f94f1ce | |||
| 1d821d18fe | |||
| 0634e7400a | |||
| 2adf17493a | |||
| d943d4bd31 | |||
| 2ee33b50eb | |||
| cb4ebfa43e | |||
| 54ac77ab51 | |||
| 2677d381ce | |||
| 1d477c379b | |||
| 81ef01d3ba | |||
| 15c1917389 | |||
| 1476b33a9b | |||
| 37994a76b8 | |||
| cf8012c697 | |||
| 9fc412efbb | |||
| 8456288b6d | |||
| 63c07a602d | |||
| 3a0d09bf4b | |||
| 6dcba9230f | |||
| 725c21e5d7 | |||
| f8f2fa04f4 | |||
| 4c758f7733 | |||
| 58250d22a3 | |||
| 8b2e30eeb5 | |||
| edd0599077 | |||
| 629296f150 | |||
| 02bc43e6ae | |||
| c67cb4d14d | |||
| 1ac7df90f0 |
@@ -0,0 +1,50 @@
|
||||
# =============================================================================
|
||||
# Agent Bot — Environment Configuration
|
||||
# =============================================================================
|
||||
# Copy this file to .env and fill in your values.
|
||||
# .env is git-ignored — never commit real secrets.
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM — DeepSeek (OpenAI-compatible)
|
||||
# ---------------------------------------------------------------------------
|
||||
DEEPSEEK_API_KEY=sk-your-deepseek-api-key
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discord Bot
|
||||
# ---------------------------------------------------------------------------
|
||||
DISCORD_BOT_TOKEN=your-discord-bot-token-here
|
||||
# DISCORD_MAX_HISTORY=7 # optional, defaults to 7 (max past messages per user)
|
||||
# DISCORD_DEFAULT_AGENT=media-agent # optional, which agent the DM bot uses
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Seerr (Overseerr / Jellyseerr)
|
||||
# ---------------------------------------------------------------------------
|
||||
SEERR_URL=https://seerr.example.com
|
||||
SEERR_API_KEY=your-seerr-api-key
|
||||
# SEERR_USERNAME=your-username # alternative: username+password auth
|
||||
# SEERR_PASSWORD=your-password
|
||||
# SEERR_TIMEOUT=30 # optional, defaults to 30 seconds
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth System (Discord ↔ external services)
|
||||
# ---------------------------------------------------------------------------
|
||||
# The public-facing URL where users reach this bot's web API.
|
||||
# Used to build the "Click here to link" URLs sent via Discord DM.
|
||||
# For local dev: http://localhost:8000
|
||||
# For production behind a reverse proxy: https://bot.yourdomain.com
|
||||
BASE_URL=http://localhost:8000
|
||||
|
||||
# Where the auth SQLite database lives (relative to project root)
|
||||
# AUTH_DB_PATH=data/auth.db
|
||||
|
||||
# Link token expiry in minutes (default 10)
|
||||
# AUTH_TOKEN_EXPIRY=10
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JellyStat — PostgreSQL watch-history database
|
||||
# ---------------------------------------------------------------------------
|
||||
JELLYSTAT_DB_HOST=localhost
|
||||
JELLYSTAT_DB_PORT=5432
|
||||
JELLYSTAT_DB_USER=postgres
|
||||
JELLYSTAT_DB_PASSWORD=
|
||||
JELLYSTAT_DB_NAME=jfstat
|
||||
@@ -0,0 +1,35 @@
|
||||
name: Build and Push Agent API
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: home
|
||||
|
||||
container:
|
||||
image: catthehacker/ubuntu:act-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Show files
|
||||
run: ls -la
|
||||
|
||||
- name: Verify docker
|
||||
run: docker version
|
||||
|
||||
- name: Build Docker image
|
||||
run: |
|
||||
docker build \
|
||||
-t 192.168.1.185:5010/agents-api:latest \
|
||||
-f docker/Dockerfile .
|
||||
|
||||
- name: Push image
|
||||
run: |
|
||||
docker push 192.168.1.185:5010/agents-api:latest
|
||||
@@ -174,3 +174,5 @@ cython_debug/
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
.docs/
|
||||
data/
|
||||
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Agent system — each agent combines a base LLM with optional skills
|
||||
to produce tailored system prompts and behavior.
|
||||
|
||||
An Agent is a lightweight wrapper:
|
||||
- agent_id : unique name (e.g. "naked", "media-agent")
|
||||
- description : human-readable summary
|
||||
- skills : list of skill names to load
|
||||
- base_prompt : default system prompt (optional — falls back to generic)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List
|
||||
|
||||
from agents.skills import Skill, get_combined_prompt, list_all as list_all_skills
|
||||
|
||||
|
||||
@dataclass
|
||||
class Agent:
|
||||
agent_id: str
|
||||
description: str = ""
|
||||
skills: List[str] = field(default_factory=list)
|
||||
base_prompt: str = "You are a helpful agent."
|
||||
|
||||
def build_system_prompt(self) -> str:
|
||||
"""Combine base_prompt with all registered skills' prompt fragments."""
|
||||
return get_combined_prompt(self.skills, base_prompt=self.base_prompt)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
sk = ", ".join(self.skills) if self.skills else "none"
|
||||
return f"Agent(id={self.agent_id!r}, skills=[{sk}])"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Global agent registry
|
||||
# ---------------------------------------------------------------------------
|
||||
_agent_registry: Dict[str, Agent] = {}
|
||||
|
||||
|
||||
def register(agent: Agent) -> None:
|
||||
"""Register an agent so it can be looked up by agent_id."""
|
||||
_agent_registry[agent.agent_id] = agent
|
||||
|
||||
|
||||
def get(agent_id: str) -> Agent | None:
|
||||
"""Return a registered agent by id, or None."""
|
||||
return _agent_registry.get(agent_id)
|
||||
|
||||
|
||||
def list_all() -> Dict[str, Agent]:
|
||||
"""Return a shallow copy of the registry."""
|
||||
return dict(_agent_registry)
|
||||
|
||||
|
||||
def load_all_agents() -> None:
|
||||
"""
|
||||
Import all agent modules so they self-register.
|
||||
Call this once at startup.
|
||||
"""
|
||||
import agents.naked # noqa: F401
|
||||
import agents.media_agent # noqa: F401
|
||||
|
||||
# Also import skill modules so they self-register
|
||||
import agents.skills.media_info # noqa: F401
|
||||
import agents.skills.seerr # noqa: F401
|
||||
import agents.skills.triage # noqa: F401
|
||||
import agents.skills.easter_eggs # noqa: F401
|
||||
import agents.skills.watch_history # noqa: F401
|
||||
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
media-agent — an agent that knows how to handle media queries
|
||||
(Jellyfin / Sonarr / Seerr / subtitle requests).
|
||||
|
||||
Skills:
|
||||
- media_info : base persona (prompt-only)
|
||||
- seerr : trending, discover, request media, submit issues (tools + API)
|
||||
- triage : fallback for unsupported actions (prompt-only, uses seerr tools)
|
||||
"""
|
||||
|
||||
from agents import Agent, register
|
||||
|
||||
media_agent = Agent(
|
||||
agent_id="media-agent",
|
||||
description="Media assistant — handles movie/TV/subtitle/ticket requests "
|
||||
"via Seerr, Jellyfin, Sonarr, etc.",
|
||||
skills=["media_info", "seerr", "triage", "easter_eggs", "watch_history"],
|
||||
base_prompt=(
|
||||
"You are a media assistant connected to Seerr and other media services. "
|
||||
"Help users discover, request, and troubleshoot their media library. "
|
||||
"Use the tools provided to perform real actions.\n\n"
|
||||
"## Authentication\n"
|
||||
"If a tool returns a message saying the user needs to log in first, "
|
||||
"tell the user to type `/login <service>` in their DM (e.g. `/login jellyfin`). "
|
||||
"This opens Quick Connect on their Jellyfin app so they can link their account. "
|
||||
"Do NOT tell the user you 'can't connect' or 'don't have access' — just relay the login instructions."
|
||||
),
|
||||
)
|
||||
|
||||
register(media_agent)
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
naked agent — a barebone LLM with no extra skills attached.
|
||||
Just a thin wrapper that instructs the LLM to be a general helpful assistant.
|
||||
"""
|
||||
|
||||
from agents import Agent, register
|
||||
|
||||
naked_agent = Agent(
|
||||
agent_id="naked",
|
||||
description="A plain LLM — no extra skills, just a helpful assistant.",
|
||||
skills=[], # no skills
|
||||
base_prompt="You are a helpful, general-purpose assistant.",
|
||||
)
|
||||
|
||||
register(naked_agent)
|
||||
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Skill system — each skill is a piece of domain knowledge or a capability
|
||||
that can be attached to an agent to shape its behavior and system prompt.
|
||||
|
||||
A Skill is a lightweight object with:
|
||||
- name : short identifier (e.g. "media_info")
|
||||
- description : human-readable summary
|
||||
- prompt_fragment : extra text injected into the agent's system prompt
|
||||
- tools : OpenAI function-calling tool definitions (list of dicts)
|
||||
- execute : async callable to run a tool → ToolResult
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
from src.config import get_config # re-export so every skill can use it
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ToolResult — every skill executor must return this
|
||||
# ---------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""Result of executing a tool.
|
||||
- success: True if the API returned 2xx and the action completed.
|
||||
- content: The message to feed back to the LLM (will be shown to the user).
|
||||
"""
|
||||
content: str
|
||||
success: bool = True
|
||||
|
||||
@classmethod
|
||||
def ok(cls, content: str) -> "ToolResult":
|
||||
return cls(content=content, success=True)
|
||||
|
||||
@classmethod
|
||||
def fail(cls, content: str) -> "ToolResult":
|
||||
return cls(content=content, success=False)
|
||||
|
||||
|
||||
# Type alias for a tool executor
|
||||
ToolExecutor = Callable[[str, dict], Awaitable[ToolResult]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Skill:
|
||||
name: str
|
||||
description: str
|
||||
prompt_fragment: str = ""
|
||||
tools: List[Dict[str, Any]] = field(default_factory=list)
|
||||
execute: Optional[ToolExecutor] = None
|
||||
requires_auth: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Global skill registry — populated at startup / import time
|
||||
# ---------------------------------------------------------------------------
|
||||
_skill_registry: Dict[str, Skill] = {}
|
||||
|
||||
|
||||
def register(skill: Skill) -> None:
|
||||
"""Register a skill so agents can look it up by name."""
|
||||
_skill_registry[skill.name] = skill
|
||||
|
||||
|
||||
def get(name: str) -> Skill | None:
|
||||
"""Return a registered skill by name, or None."""
|
||||
return _skill_registry.get(name)
|
||||
|
||||
|
||||
def list_all() -> Dict[str, Skill]:
|
||||
"""Return a shallow copy of the registry."""
|
||||
return dict(_skill_registry)
|
||||
|
||||
|
||||
def get_combined_prompt(skill_names: list[str], base_prompt: str = "") -> str:
|
||||
"""Build a system prompt from a base prompt + requested skill fragments."""
|
||||
parts = [base_prompt] if base_prompt else []
|
||||
for name in skill_names:
|
||||
s = get(name)
|
||||
if s and s.prompt_fragment:
|
||||
parts.append(s.prompt_fragment)
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
def get_all_tools(skill_names: list[str]) -> List[Dict[str, Any]]:
|
||||
"""Collect all OpenAI tool definitions across the requested skills."""
|
||||
tools: List[Dict[str, Any]] = []
|
||||
seen: set[str] = set()
|
||||
for name in skill_names:
|
||||
s = get(name)
|
||||
if s:
|
||||
for t in s.tools:
|
||||
fn_name = t.get("function", {}).get("name", "")
|
||||
if fn_name and fn_name not in seen:
|
||||
seen.add(fn_name)
|
||||
tools.append(t)
|
||||
return tools
|
||||
|
||||
|
||||
async def execute_tool(
|
||||
skill_names: list[str], tool_name: str, args: dict,
|
||||
discord_user_id: int | None = None,
|
||||
) -> ToolResult | None:
|
||||
"""Find the skill that owns *tool_name* and run its executor.
|
||||
|
||||
If *discord_user_id* is provided, also checks whether the owning skill
|
||||
requires authentication for any services. If auth is missing, returns
|
||||
a friendly ToolResult.fail(...) telling the user how to log in.
|
||||
|
||||
Only logs failures to the console — successful calls are silent.
|
||||
"""
|
||||
import logging
|
||||
logger = logging.getLogger("skills")
|
||||
|
||||
for name in skill_names:
|
||||
s = get(name)
|
||||
if s and s.execute:
|
||||
for t in s.tools:
|
||||
if t.get("function", {}).get("name") == tool_name:
|
||||
# --- Auth gate ---
|
||||
if s.requires_auth and discord_user_id is not None:
|
||||
from src import auth_store
|
||||
from gateway.auth import get_auth_service
|
||||
missing: list[str] = []
|
||||
for svc in s.requires_auth:
|
||||
if not auth_store.is_authenticated(discord_user_id, svc):
|
||||
missing.append(svc)
|
||||
if missing:
|
||||
svc_displays = ", ".join(
|
||||
(get_auth_service(m) and get_auth_service(m).display_name) or m
|
||||
for m in missing
|
||||
)
|
||||
return ToolResult.fail(
|
||||
f"You need to log in to {svc_displays} first. "
|
||||
+ " ".join(f"Send `/login {m}` in a DM to get started." for m in missing)
|
||||
)
|
||||
# --- End auth gate ---
|
||||
# Inject discord_user_id so skills can resolve external user IDs
|
||||
if discord_user_id is not None:
|
||||
args = {**args, "_discord_user_id": discord_user_id}
|
||||
try:
|
||||
result = await s.execute(tool_name, args)
|
||||
if not result.success:
|
||||
logger.warning(
|
||||
"⚠️ TOOL FAILED: %s | args=%s → %s",
|
||||
tool_name, args, result.content[:300],
|
||||
)
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"💥 TOOL CRASH: %s | args=%s", tool_name, args
|
||||
)
|
||||
return ToolResult.fail(
|
||||
f"Tool '{tool_name}' crashed unexpectedly: {exc}"
|
||||
)
|
||||
|
||||
logger.warning("⚠️ TOOL NOT FOUND: %s (skills=%s)", tool_name, skill_names)
|
||||
return None
|
||||
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
Easter eggs skill — theme-aware persona adapter.
|
||||
|
||||
When a user's message contains trigger words from a known fandom/universe,
|
||||
the LLM adopts that theme's persona flavor while still performing all
|
||||
requested actions normally. Functionality is never sacrificed for a reference.
|
||||
|
||||
Add a new theme by adding one entry to THEMES — no code changes needed.
|
||||
"""
|
||||
|
||||
from agents.skills import Skill, register
|
||||
|
||||
THEMES = {
|
||||
"naruto": {
|
||||
"triggers": [
|
||||
"rasengan", "sasuke", "naruto", "kakashi", "sakura", "hokage",
|
||||
"chidori", "sharingan", "kurama", "dattebayo", "believe it",
|
||||
"hidden leaf", "konoha", "akatsuki", "itachi", "jiraiya",
|
||||
"shippuden", "boruto", "sensei", "ninja", "tsunade", "orochimaru",
|
||||
"tailed beast", "jinchuriki", "rinnegan", "byakugan", "genin",
|
||||
"chunin", "jonin", "anbu", "uchiha", "hyuga", "uzumaki",
|
||||
],
|
||||
"persona": (
|
||||
"Adopt the speaking style of a ninja from the Hidden Leaf Village. "
|
||||
"If someone screems 'Rasengan!' in their request, respond with 'SOOSSKEEE!' "
|
||||
"If someone screams 'SOSSSKEE!' in their request, respond with 'RESEENNGGAANN!' "
|
||||
"Stay fully functional — carry out all requested actions normally, but only if something is requested"
|
||||
),
|
||||
},
|
||||
|
||||
"jeroen": {
|
||||
"triggers": [
|
||||
"jeroen"
|
||||
],
|
||||
"persona": (
|
||||
"mention that jeroen is a 'Land verader' in the response "
|
||||
"Stay fully functional — carry out all requested actions normally, but only if something is requested"
|
||||
)
|
||||
},
|
||||
|
||||
"chicken": {
|
||||
"triggers": [
|
||||
"chicken", "cluck", "bawk", "bok bok", "chook", "poultry", "BOAKKK", "BOAKY!"
|
||||
],
|
||||
"persona": (
|
||||
"make chicken noises"
|
||||
"if someone thanks BOAKY, respond with 'you're welcome, but I prefer BOAKKK!' "
|
||||
"Stay fully functional — carry out all requested actions normally, but only if something is requested"
|
||||
)
|
||||
},
|
||||
|
||||
"Mexico": {
|
||||
"triggers": [
|
||||
"mexico", "taco", "burrito", "sombrero", "guacamole", "quesadilla",
|
||||
],
|
||||
"persona": (
|
||||
"make sentences only with at least 80 percent of the following words: taco, burrito, sombrero, guacamole, quesadilla. "
|
||||
"Stay fully functional — carry out all requested actions normally, but only if something is requested"
|
||||
)
|
||||
},
|
||||
|
||||
"nacho": {
|
||||
"triggers": [
|
||||
"nacho", "sharing", "plate", "snack", "chips", "cheese"
|
||||
],
|
||||
"persona": (
|
||||
"if anywhere nacho is mentioned, mention that one of the best things is, to share a plate of nachos with Matty"
|
||||
"Stay fully functional — carry out all requested actions normally, but only if something is requested"
|
||||
)
|
||||
},
|
||||
|
||||
"RockAndStone!": {
|
||||
"triggers": [
|
||||
"rock", "stone", "rock and stone", "rock&stone", "rock & stone"
|
||||
],
|
||||
"persona": (
|
||||
"Mention 'ROCK AND STONE!' anywhere in the response "
|
||||
"Stay fully functional — carry out all requested actions normally, but only if something is requested"
|
||||
)
|
||||
},
|
||||
|
||||
"airplane": {
|
||||
"triggers": [
|
||||
"airplane", "plane", "flight", "pilot", "cockpit", "turbulence"
|
||||
],
|
||||
"persona": (
|
||||
"mention airplane noises, and say that Erwin will be terrified "
|
||||
"mention something airplane related, and hope Erwin is not around to see this reference."
|
||||
"Stay fully functional — carry out all requested actions normally, but only if something is requested"
|
||||
)
|
||||
},
|
||||
|
||||
"dyslexia": {
|
||||
"triggers": [
|
||||
"unable to read", "dyslexia", "dislexia", "dislexic", "dyslexic", "typo", "trouble reading", "misspelled", "misspelling"
|
||||
],
|
||||
"persona": (
|
||||
"mention that tim is dyslexic and has trouble reading in the response "
|
||||
"mention it is sometimes difficult to read, since the creator (which is Tim) of this agent is dyslexic. "
|
||||
"Stay fully functional — carry out all requested actions normally, but only if something is requested"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build the prompt fragment from THEMES
|
||||
# ---------------------------------------------------------------------------
|
||||
theme_blocks = []
|
||||
for name, theme in THEMES.items():
|
||||
all_triggers = theme["triggers"]
|
||||
# Show first 8 triggers + count of remaining
|
||||
shown = all_triggers[:8]
|
||||
suffix = f" (+{len(all_triggers) - 8} more)" if len(all_triggers) > 8 else ""
|
||||
triggers_str = ", ".join(f'"{t}"' for t in shown) + suffix
|
||||
theme_blocks.append(
|
||||
f"### {name.upper()}\n"
|
||||
f"Triggers (case-insensitive, substring match): {triggers_str}\n"
|
||||
f"Persona: {theme['persona']}"
|
||||
)
|
||||
|
||||
easter_eggs_skill = Skill(
|
||||
name="easter_eggs",
|
||||
description="Theme-aware persona adapter — flavors responses "
|
||||
"when users mention known fandoms/universes.",
|
||||
prompt_fragment=(
|
||||
"## Themed Personas\n\n"
|
||||
"Before responding, scan the user's message for these themes "
|
||||
"(case-insensitive, substring match). If a theme matches, adopt its "
|
||||
"persona flavor while still performing all requested actions normally. "
|
||||
"Never skip functionality for the sake of a reference.\n\n"
|
||||
"If multiple themes match, pick the one with the most trigger hits.\n"
|
||||
"If no theme matches, respond with your normal base persona.\n\n"
|
||||
+ "\n\n".join(theme_blocks)
|
||||
),
|
||||
tools=[],
|
||||
execute=None,
|
||||
)
|
||||
|
||||
register(easter_eggs_skill)
|
||||
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Demo skill: media_info
|
||||
|
||||
A lightweight base skill that teaches the agent it is a media assistant.
|
||||
Real API capabilities come from other skills (seerr, triage, etc.).
|
||||
"""
|
||||
|
||||
from agents.skills import Skill, register
|
||||
|
||||
media_info_skill = Skill(
|
||||
name="media_info",
|
||||
description="Base media assistant persona — movie, TV, subtitle, and media requests.",
|
||||
prompt_fragment="""## Media Assistant Persona
|
||||
|
||||
You are a friendly media assistant connected to a media back-end (Seerr,
|
||||
Jellyfin, Sonarr, etc.). Your job is to help users discover, request, and
|
||||
troubleshoot their media library.
|
||||
|
||||
When responding:
|
||||
- Be concise and helpful.
|
||||
- Use the tools available to you for real actions.
|
||||
- If a user asks about **subtitles**, explain that Bazarr handles those and
|
||||
suggest submitting a ticket if there's a problem.
|
||||
- Always confirm successful actions and warn about failures.
|
||||
|
||||
## Jellyfin & Authentication
|
||||
|
||||
You are connected to the user's Jellyfin server. If a user asks you to
|
||||
"connect to Jellyfin", "link my Jellyfin", or asks about their watch history,
|
||||
simply call the `watch_history` tool. The system will automatically handle
|
||||
authentication — if the user isn't linked yet, they'll be guided through
|
||||
Quick Connect seamlessly. NEVER tell a user you "don't have access to
|
||||
Jellyfin" or "can't connect" — always try the tool first and let the system
|
||||
sort it out.
|
||||
|
||||
This is the base media assistant persona. Real API capabilities come from the
|
||||
attached skills (seerr, triage, etc.).""",
|
||||
)
|
||||
|
||||
register(media_info_skill)
|
||||
|
||||
@@ -0,0 +1,971 @@
|
||||
"""
|
||||
Seerr skill — connects to Overseerr / Jellyseerr API for media discovery,
|
||||
requests, and issue submission.
|
||||
|
||||
.env variables:
|
||||
SEERR_URL - base URL (e.g. https://seerr.example.com)
|
||||
SEERR_USERNAME - login username (email)
|
||||
SEERR_PASSWORD - login password
|
||||
SEERR_API_KEY - fallback API key (used if username/password not set)
|
||||
SEERR_TIMEOUT - optional request timeout in seconds (default 30)
|
||||
|
||||
Auth flow:
|
||||
1. If SEERR_USERNAME + SEERR_PASSWORD are set:
|
||||
POST /api/v1/auth/jellyfin {username, password}
|
||||
→ stores the connect.sid cookie in a persistent httpx session
|
||||
→ all subsequent requests use cookie auth
|
||||
2. Falls back to X-Api-Key header if SEERR_API_KEY is set.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from urllib.parse import quote
|
||||
|
||||
import httpx
|
||||
|
||||
from agents.skills import Skill, register, ToolResult, get_config
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
SEERR_URL = (get_config("SEERR_URL") or "").rstrip("/")
|
||||
SEERR_USERNAME = get_config("SEERR_USERNAME") or ""
|
||||
SEERR_PASSWORD = get_config("SEERR_PASSWORD") or ""
|
||||
SEERR_API_KEY = get_config("SEERR_API_KEY") or ""
|
||||
SEERR_TIMEOUT = int(get_config("SEERR_TIMEOUT", "30"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth — cookie-based session (preferred) or API key fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# IMPORTANT: httpx.AsyncClient binds internal asyncio primitives to the
|
||||
# event loop that is current when the client is created. The Discord bot
|
||||
# runs in a separate thread with its own event loop, so we must create a
|
||||
# fresh AsyncClient *per event loop*. We cache one client per loop ID so
|
||||
# each loop still reuses its own singleton (connection pooling works), but
|
||||
# the bot and the REST API never fight over the same connection pool.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
_seerr_sessions: dict[int, httpx.AsyncClient] = {}
|
||||
_seerr_sessions_lock = threading.Lock()
|
||||
|
||||
# Cached login cookies — obtained once at module load (sync) and reused
|
||||
# for every event-loop-specific client. A threading.Event ensures that
|
||||
# the first caller to trigger the login blocks all other callers until
|
||||
# the login is complete — preventing a race where a second thread builds
|
||||
# a client with empty cookies before the login finishes.
|
||||
_seerr_cookies: dict = {}
|
||||
_seerr_cookies_ready = threading.Event()
|
||||
_seerr_cookies_lock = threading.Lock()
|
||||
|
||||
|
||||
def _ensure_cookies() -> None:
|
||||
"""One-time sync login to get the connect.sid cookie.
|
||||
|
||||
Thread-safe: only one thread performs the login; all others block
|
||||
until it finishes, then reuse the result.
|
||||
"""
|
||||
if _seerr_cookies_ready.is_set():
|
||||
return
|
||||
|
||||
with _seerr_cookies_lock:
|
||||
# Double-check — another thread may have finished while we waited
|
||||
if _seerr_cookies_ready.is_set():
|
||||
return
|
||||
|
||||
if SEERR_USERNAME.strip() and SEERR_PASSWORD.strip():
|
||||
sync_client = httpx.Client(base_url=SEERR_URL, timeout=SEERR_TIMEOUT)
|
||||
try:
|
||||
resp = sync_client.post("/api/v1/auth/jellyfin", json={
|
||||
"username": SEERR_USERNAME.strip(),
|
||||
"password": SEERR_PASSWORD.strip(),
|
||||
})
|
||||
resp.raise_for_status()
|
||||
_seerr_cookies.update(dict(sync_client.cookies))
|
||||
except httpx.HTTPError:
|
||||
pass
|
||||
finally:
|
||||
sync_client.close()
|
||||
|
||||
# Signal completion — even if login failed (empty cookies) so we
|
||||
# don't retry forever.
|
||||
_seerr_cookies_ready.set()
|
||||
|
||||
|
||||
def _build_client() -> httpx.AsyncClient:
|
||||
"""Create a new httpx.AsyncClient for the *current* event loop."""
|
||||
if _seerr_cookies:
|
||||
return httpx.AsyncClient(
|
||||
base_url=SEERR_URL,
|
||||
cookies=_seerr_cookies,
|
||||
timeout=SEERR_TIMEOUT,
|
||||
)
|
||||
if SEERR_API_KEY.strip():
|
||||
return httpx.AsyncClient(
|
||||
base_url=SEERR_URL,
|
||||
headers={"X-Api-Key": SEERR_API_KEY.strip()},
|
||||
timeout=SEERR_TIMEOUT,
|
||||
)
|
||||
return httpx.AsyncClient(
|
||||
base_url=SEERR_URL,
|
||||
timeout=SEERR_TIMEOUT,
|
||||
)
|
||||
|
||||
|
||||
def _get_session() -> httpx.AsyncClient:
|
||||
"""Return an AsyncClient valid for the currently-running event loop.
|
||||
|
||||
On the very first call the sync login is performed (if credentials are
|
||||
configured). After that every event loop gets its own cached client.
|
||||
"""
|
||||
_ensure_cookies()
|
||||
|
||||
try:
|
||||
loop_id = id(asyncio.get_running_loop())
|
||||
except RuntimeError:
|
||||
# No event loop running (e.g. called during module import).
|
||||
# Build a throw-away client — the first real call will recreate it.
|
||||
loop_id = 0
|
||||
|
||||
with _seerr_sessions_lock:
|
||||
if loop_id not in _seerr_sessions:
|
||||
_seerr_sessions[loop_id] = _build_client()
|
||||
return _seerr_sessions[loop_id]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _SharedClient:
|
||||
"""Wraps a per-loop httpx.AsyncClient so that `async with` doesn't
|
||||
close it when the context exits. All 11 call sites use:
|
||||
async with _client() as c:
|
||||
"""
|
||||
|
||||
def __init__(self, client: httpx.AsyncClient) -> None:
|
||||
self._client = client
|
||||
|
||||
async def __aenter__(self) -> httpx.AsyncClient:
|
||||
return self._client
|
||||
|
||||
async def __aexit__(self, *args: object) -> None:
|
||||
pass # do NOT close the shared session
|
||||
|
||||
|
||||
def _client() -> _SharedClient:
|
||||
"""Return a context-manager wrapper around the current loop's session."""
|
||||
return _SharedClient(_get_session())
|
||||
|
||||
|
||||
# Per-loop sessions are created lazily on first use — no eager init needed.
|
||||
|
||||
|
||||
def _fmt_items(items: list[dict], kind: str) -> str:
|
||||
"""Format a list of media items for the LLM to present.
|
||||
Includes the TMDb ID so the LLM can reference it for follow-up actions."""
|
||||
lines = []
|
||||
for i, item in enumerate(items[:10], 1):
|
||||
title = item.get("title") or item.get("name") or "Unknown"
|
||||
year = (
|
||||
item.get("releaseDate", "")[:4]
|
||||
or item.get("firstAirDate", "")[:4]
|
||||
or "?"
|
||||
)
|
||||
tmdb_id = item.get("id", "")
|
||||
overview = (item.get("overview") or "")[:120]
|
||||
id_tag = f" [tmdb:{tmdb_id}]" if tmdb_id else ""
|
||||
lines.append(f"{i}. **{title}** ({year}){id_tag} — {overview}…")
|
||||
return f"Found {len(items)} {kind}. Top results:\n\n" + "\n".join(lines)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool definitions (OpenAI function-calling schema)
|
||||
# ---------------------------------------------------------------------------
|
||||
TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "seerr_trending",
|
||||
"description": "Get trending movies and TV shows from Seerr using "
|
||||
"the /discover/trending endpoint. Call this when a user asks what "
|
||||
"is popular, trending, or new.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"kind": {
|
||||
"type": "string",
|
||||
"enum": ["movie", "tv", "all"],
|
||||
"description": "What kind of media to fetch. "
|
||||
"Use 'all' when the user doesn't specify.",
|
||||
},
|
||||
"language": {
|
||||
"type": "string",
|
||||
"description": "Language filter (e.g. 'en', 'nl'). "
|
||||
"Omit for all languages.",
|
||||
},
|
||||
},
|
||||
"required": ["kind"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "seerr_discover",
|
||||
"description": "Discover movies or TV shows by genre, studio, "
|
||||
"keyword, or language in Seerr. Uses /discover/{movies|tv}/genre/{id} "
|
||||
"for genre queries, /discover/{movies|tv}/studio/{id} for studios, "
|
||||
"and /discover/{movies|tv}?query= for keyword search. "
|
||||
"Call when a user asks 'what movies in category X do you recommend?' "
|
||||
"or 'show me horror movies' or 'find Studio Ghibli movies'.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"kind": {
|
||||
"type": "string",
|
||||
"enum": ["movie", "tv"],
|
||||
"description": "Media type to search.",
|
||||
},
|
||||
"genre": {
|
||||
"type": "string",
|
||||
"description": "Genre name, e.g. 'horror', 'comedy', "
|
||||
"'animation', 'action', 'science fiction'. "
|
||||
"Use this for genre-based discovery.",
|
||||
},
|
||||
"studio": {
|
||||
"type": "string",
|
||||
"description": "Studio name to filter by, e.g. "
|
||||
"'Studio Ghibli', 'Pixar', 'Marvel'. "
|
||||
"Use this for studio-based discovery.",
|
||||
},
|
||||
"keyword": {
|
||||
"type": "string",
|
||||
"description": "Free-text keyword search, e.g. "
|
||||
"'space', 'superhero', 'dinosaur'. "
|
||||
"Use this for topic-based discovery.",
|
||||
},
|
||||
"language": {
|
||||
"type": "string",
|
||||
"description": "Language filter (e.g. 'en', 'ja'). "
|
||||
"Omit for all languages.",
|
||||
},
|
||||
"page": {
|
||||
"type": "integer",
|
||||
"description": "Page number (default 1).",
|
||||
},
|
||||
},
|
||||
"required": ["kind"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "seerr_request_media",
|
||||
"description": "Request a movie or TV show to be added to the media "
|
||||
"library via Seerr. Call when a user asks 'can you request movie X?' "
|
||||
"or 'please add show Y'.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"kind": {
|
||||
"type": "string",
|
||||
"enum": ["movie", "tv"],
|
||||
"description": "Whether this is a movie or TV show.",
|
||||
},
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "The title of the movie or TV show to request.",
|
||||
},
|
||||
"tmdb_id": {
|
||||
"type": "integer",
|
||||
"description": "The TMDb ID if known (optional — Seerr will "
|
||||
"search by title if not provided).",
|
||||
},
|
||||
},
|
||||
"required": ["kind", "title"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "seerr_search",
|
||||
"description": "Search for movies, TV shows, or people on Seerr "
|
||||
"by title or name. Uses /search. Call when a user asks 'find me "
|
||||
"the movie X', 'search for show Y', or 'who is actor Z?'.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search term — a movie title, "
|
||||
"TV show name, or person name.",
|
||||
},
|
||||
"kind": {
|
||||
"type": "string",
|
||||
"enum": ["movie", "tv", "person", "all"],
|
||||
"description": "Filter by media type. Use 'all' "
|
||||
"when the user doesn't specify.",
|
||||
},
|
||||
"language": {
|
||||
"type": "string",
|
||||
"description": "Language filter (e.g. 'en'). "
|
||||
"Omit for all languages.",
|
||||
},
|
||||
"page": {
|
||||
"type": "integer",
|
||||
"description": "Page number (default 1).",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "seerr_media_details",
|
||||
"description": "Get full details for a specific movie or TV show "
|
||||
"(cast, crew, runtime, genres, ratings, streaming providers, etc.). "
|
||||
"Call when a user asks 'tell me about movie X' or 'show me details "
|
||||
"for show Y'.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"kind": {
|
||||
"type": "string",
|
||||
"enum": ["movie", "tv"],
|
||||
"description": "Whether to look up a movie or TV show.",
|
||||
},
|
||||
"tmdb_id": {
|
||||
"type": "integer",
|
||||
"description": "The TMDb ID of the movie or TV show.",
|
||||
},
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "Title to search for if tmdb_id is "
|
||||
"not known. The system will search and use the first match.",
|
||||
},
|
||||
"language": {
|
||||
"type": "string",
|
||||
"description": "Language filter (e.g. 'en'). "
|
||||
"Omit for all languages.",
|
||||
},
|
||||
},
|
||||
"required": ["kind"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "seerr_my_requests",
|
||||
"description": "Get the user's pending, approved, or completed "
|
||||
"media requests from Seerr. Call when a user asks 'what have I "
|
||||
"requested?', 'status of my requests?', or 'did my request go through?'.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filter": {
|
||||
"type": "string",
|
||||
"enum": ["all", "approved", "available", "pending",
|
||||
"processing", "unavailable", "failed",
|
||||
"deleted", "completed"],
|
||||
"description": "Filter by request status. "
|
||||
"Default is 'pending'.",
|
||||
},
|
||||
"media_type": {
|
||||
"type": "string",
|
||||
"enum": ["movie", "tv", "all"],
|
||||
"description": "Filter by media type. "
|
||||
"Default is 'all'.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "seerr_submit_issue",
|
||||
"description": "Submit a ticket/issue for a specific media item. "
|
||||
"Call when a user wants to report a problem (bad quality, wrong "
|
||||
"language, missing episodes, corrupt file, etc.) or when they want "
|
||||
"an action that only a human operator can perform. "
|
||||
"IMPORTANT: always include the media_title so the system can "
|
||||
"look up the correct mediaId.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"subject": {
|
||||
"type": "string",
|
||||
"description": "Short summary of the issue.",
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Detailed description of the problem.",
|
||||
},
|
||||
"media_title": {
|
||||
"type": "string",
|
||||
"description": "The movie or TV show title this issue "
|
||||
"relates to. Always provide this — the system will "
|
||||
"search for the matching mediaId.",
|
||||
},
|
||||
"issue_type": {
|
||||
"type": "integer",
|
||||
"enum": [1, 2, 3, 4],
|
||||
"description": "Issue category code: "
|
||||
"1 = Video (playback, codec, quality), "
|
||||
"2 = Audio (sync, missing), "
|
||||
"3 = Subtitle (missing, wrong, timing), "
|
||||
"4 = Other (operator-only actions like delete/cancel).",
|
||||
},
|
||||
},
|
||||
"required": ["subject", "description", "media_title", "issue_type"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool executor
|
||||
# ---------------------------------------------------------------------------
|
||||
async def _execute(tool_name: str, args: dict) -> ToolResult:
|
||||
"""Route tool calls to the right handler. Returns ToolResult with success
|
||||
based on HTTP status code (2xx = ok, everything else = fail)."""
|
||||
import logging
|
||||
logger = logging.getLogger("skills.seerr")
|
||||
|
||||
handlers = {
|
||||
"seerr_trending": _trending,
|
||||
"seerr_discover": _discover,
|
||||
"seerr_request_media": _request_media,
|
||||
"seerr_submit_issue": _submit_issue,
|
||||
"seerr_search": _search,
|
||||
"seerr_media_details": _media_details,
|
||||
"seerr_my_requests": _my_requests,
|
||||
}
|
||||
handler = handlers.get(tool_name)
|
||||
if not handler:
|
||||
return ToolResult.fail(f"Unknown tool: {tool_name}")
|
||||
|
||||
logger.info(
|
||||
"🔧 TOOL CALL: %s | args=%s",
|
||||
tool_name,
|
||||
{k: v for k, v in args.items() if k not in ("description",)},
|
||||
)
|
||||
|
||||
try:
|
||||
result = await handler(args)
|
||||
status = "✅" if result.success else "❌"
|
||||
logger.info(
|
||||
"%s TOOL RESULT: %s → %.300s",
|
||||
status, tool_name, result.content,
|
||||
)
|
||||
return result
|
||||
except httpx.HTTPStatusError as exc:
|
||||
status = exc.response.status_code
|
||||
body = exc.response.text[:500]
|
||||
logger.error("Seerr API HTTP %s on %s: %s", status, tool_name, body)
|
||||
return ToolResult.fail(
|
||||
f"Seerr API returned HTTP {status} for '{tool_name}'. "
|
||||
f"Response: {body}"
|
||||
)
|
||||
except httpx.HTTPError as exc:
|
||||
logger.error("Seerr API network error on %s: %s", tool_name, exc)
|
||||
return ToolResult.fail(
|
||||
f"Seerr API is unreachable for '{tool_name}': {exc}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Unexpected error in %s", tool_name)
|
||||
return ToolResult.fail(f"Unexpected error in '{tool_name}': {exc}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# API handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _trending(args: dict) -> ToolResult:
|
||||
"""Use Jellyseerr's /api/v1/discover/trending endpoint.
|
||||
Query params: language (optional), mediaType (movie | tv, default all).
|
||||
"""
|
||||
media_type = args.get("kind", "all")
|
||||
language = args.get("language", "").strip() or None
|
||||
|
||||
params: dict = {}
|
||||
if media_type in ("movie", "tv"):
|
||||
params["mediaType"] = media_type
|
||||
if language:
|
||||
params["language"] = language
|
||||
|
||||
async with _client() as c:
|
||||
r = await c.get("/api/v1/discover/trending", params=params)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
results = data.get("results", [])
|
||||
|
||||
label = f"trending {media_type}" if media_type != "all" else "trending items"
|
||||
if language:
|
||||
label += f" ({language})"
|
||||
if not results:
|
||||
return ToolResult.ok(f"No {label} found right now.")
|
||||
return ToolResult.ok(_fmt_items(results, label))
|
||||
|
||||
|
||||
async def _discover(args: dict) -> ToolResult:
|
||||
"""Jellyseerr discover endpoints:
|
||||
Genre: /api/v1/discover/{movies|tv}/genre/{genreId}
|
||||
Keyword: /api/v1/search?query=keyword (free-text search, filtered by mediaType)
|
||||
Studio: NOT SUPPORTED — /discover/studio requires a numeric TMDB studio ID,
|
||||
not a name. Use keyword search as fallback.
|
||||
|
||||
Language is passed to discover or search as appropriate.
|
||||
"""
|
||||
kind = args["kind"]
|
||||
genre = args.get("genre", "").strip()
|
||||
studio = args.get("studio", "").strip()
|
||||
keyword = args.get("keyword", "").strip()
|
||||
language = args.get("language", "").strip() or None
|
||||
page = args.get("page", 1)
|
||||
|
||||
# Map common genre names to TMDb genre IDs
|
||||
genre_map = {
|
||||
"action": 28, "adventure": 12, "animation": 16, "comedy": 35,
|
||||
"crime": 80, "documentary": 99, "drama": 18, "family": 10751,
|
||||
"fantasy": 14, "history": 36, "horror": 27, "music": 10402,
|
||||
"mystery": 9648, "romance": 10749, "science fiction": 878,
|
||||
"sci-fi": 878, "scifi": 878, "tv movie": 10770, "thriller": 53,
|
||||
"war": 10752, "western": 37,
|
||||
}
|
||||
|
||||
params: dict = {"page": page}
|
||||
endpoint: str
|
||||
|
||||
if genre:
|
||||
genre_id = genre_map.get(genre.lower())
|
||||
if not genre_id:
|
||||
return ToolResult.fail(
|
||||
f"I don't recognise the genre '{genre}'. "
|
||||
f"Try one of: {', '.join(sorted(genre_map.keys()))}."
|
||||
)
|
||||
endpoint = f"/api/v1/discover/{'movies' if kind == 'movie' else 'tv'}/genre/{genre_id}"
|
||||
if language:
|
||||
params["language"] = language
|
||||
async with _client() as c:
|
||||
r = await c.get(endpoint, params=params)
|
||||
r.raise_for_status()
|
||||
results = r.json().get("results", [])
|
||||
desc = genre
|
||||
elif studio:
|
||||
# /discover/studio/{studioId} requires a numeric TMDB studio ID.
|
||||
# Fall back to searching by name via /search.
|
||||
desc = studio
|
||||
search_query = studio
|
||||
endpoint = "/api/v1/search"
|
||||
params["query"] = search_query
|
||||
if language:
|
||||
params["language"] = language
|
||||
async with _client() as c:
|
||||
r = await c.get(endpoint, params=params)
|
||||
r.raise_for_status()
|
||||
results = r.json().get("results", [])
|
||||
# Filter to requested media type
|
||||
results = [item for item in results if item.get("mediaType") == kind]
|
||||
elif keyword:
|
||||
# Free-text keyword → use /search, filtered by mediaType
|
||||
desc = keyword
|
||||
endpoint = "/api/v1/search"
|
||||
params["query"] = keyword
|
||||
if language:
|
||||
params["language"] = language
|
||||
async with _client() as c:
|
||||
r = await c.get(endpoint, params=params)
|
||||
r.raise_for_status()
|
||||
results = r.json().get("results", [])
|
||||
results = [item for item in results if item.get("mediaType") == kind]
|
||||
else:
|
||||
# Bare discover with no filter
|
||||
endpoint = f"/api/v1/discover/{'movies' if kind == 'movie' else 'tv'}"
|
||||
if language:
|
||||
params["language"] = language
|
||||
async with _client() as c:
|
||||
r = await c.get(endpoint, params=params)
|
||||
r.raise_for_status()
|
||||
results = r.json().get("results", [])
|
||||
desc = kind
|
||||
|
||||
if not results:
|
||||
return ToolResult.ok(f"No {desc} {kind}s found.")
|
||||
return ToolResult.ok(_fmt_items(results, f"{desc} {kind}s"))
|
||||
|
||||
|
||||
async def _request_media(args: dict) -> ToolResult:
|
||||
kind = args["kind"]
|
||||
title = args["title"]
|
||||
tmdb_id = args.get("tmdb_id")
|
||||
|
||||
async with _client() as c:
|
||||
# --- Fast-path: TMDb ID known — confirm the title and request directly ---
|
||||
if tmdb_id:
|
||||
# Quick lookup to get the correct title for the confirmation message
|
||||
detail_r = await c.get(f"/api/v1/{kind}/{tmdb_id}")
|
||||
if detail_r.status_code == 200:
|
||||
detail = detail_r.json()
|
||||
media_title = detail.get("title") or detail.get("name") or title
|
||||
media_year = (
|
||||
detail.get("releaseDate", "")[:4]
|
||||
or detail.get("firstAirDate", "")[:4]
|
||||
or "?"
|
||||
)
|
||||
else:
|
||||
# Detail lookup failed — fall back to title search
|
||||
pass
|
||||
|
||||
if detail_r.status_code == 200:
|
||||
# Submit directly with the known TMDb ID
|
||||
request_body: dict = {"mediaType": kind, "mediaId": tmdb_id}
|
||||
if kind == "tv":
|
||||
request_body["seasons"] = "all"
|
||||
req_r = await c.post("/api/v1/request", json=request_body)
|
||||
if req_r.status_code == 201:
|
||||
return ToolResult.ok(
|
||||
f"✅ Successfully requested **{media_title}** ({media_year}). "
|
||||
f"It has been submitted to Seerr and will be processed soon."
|
||||
)
|
||||
elif req_r.status_code == 409:
|
||||
return ToolResult.fail(
|
||||
f"⚠️ **{media_title}** ({media_year}) has already been "
|
||||
f"requested or is already available."
|
||||
)
|
||||
else:
|
||||
return ToolResult.fail(
|
||||
f"❌ Failed to request **{media_title}** ({media_year}). "
|
||||
f"Seerr responded with status {req_r.status_code}: {req_r.text[:500]}"
|
||||
)
|
||||
|
||||
# --- Slow-path: search by title ---
|
||||
r = await c.get("/api/v1/search", params={"query": quote(title), "page": 1})
|
||||
r.raise_for_status()
|
||||
results = r.json().get("results", [])
|
||||
|
||||
# Filter by mediaType (search returns mixed movie/tv/person results)
|
||||
filtered = [item for item in results if item.get("mediaType") == kind] if results else []
|
||||
|
||||
if not filtered:
|
||||
return ToolResult.fail(
|
||||
f"I couldn't find '{title}' on Seerr. "
|
||||
f"Please double-check the title or provide a TMDb ID."
|
||||
)
|
||||
|
||||
# --- Ambiguity check: more than one match? ---
|
||||
if len(filtered) > 1:
|
||||
lines = [
|
||||
f"⚠️ Multiple matches for \"{title}\". "
|
||||
f"Please call `seerr_request_media` again with the "
|
||||
f"correct `tmdb_id` and exact title:\n"
|
||||
]
|
||||
for i, item in enumerate(filtered[:10], 1):
|
||||
t = item.get("title") or item.get("name", "Unknown")
|
||||
y = (
|
||||
item.get("releaseDate", "")[:4]
|
||||
or item.get("firstAirDate", "")[:4]
|
||||
or "?"
|
||||
)
|
||||
mid = item.get("id", "?")
|
||||
lines.append(
|
||||
f"{i}. **{t}** ({y}) — `kind=\"{kind}\", "
|
||||
f"title=\"{t}\", tmdb_id={mid}`"
|
||||
)
|
||||
return ToolResult.ok("\n".join(lines))
|
||||
|
||||
# --- Single match — request it ---
|
||||
match = filtered[0]
|
||||
media_id = match.get("id")
|
||||
media_title = match.get("title") or match.get("name") or title
|
||||
media_year = (
|
||||
(match.get("releaseDate") or match.get("firstAirDate") or "?")[:4]
|
||||
)
|
||||
|
||||
request_body = {
|
||||
"mediaType": kind,
|
||||
"mediaId": media_id,
|
||||
}
|
||||
if kind == "tv":
|
||||
request_body["seasons"] = "all"
|
||||
|
||||
req_r = await c.post("/api/v1/request", json=request_body)
|
||||
|
||||
if req_r.status_code == 201:
|
||||
return ToolResult.ok(
|
||||
f"✅ Successfully requested **{media_title}** ({media_year}). "
|
||||
f"It has been submitted to Seerr and will be processed soon."
|
||||
)
|
||||
elif req_r.status_code == 409:
|
||||
return ToolResult.fail(
|
||||
f"⚠️ **{media_title}** ({media_year}) has already been requested "
|
||||
f"or is already available."
|
||||
)
|
||||
else:
|
||||
detail = req_r.text
|
||||
return ToolResult.fail(
|
||||
f"❌ Failed to request **{media_title}** ({media_year}). "
|
||||
f"Seerr responded with status {req_r.status_code}: {detail}"
|
||||
)
|
||||
|
||||
|
||||
async def _search(args: dict) -> ToolResult:
|
||||
"""Use Jellyseerr's /api/v1/search endpoint.
|
||||
Supports filtering by mediaType (movie | tv | person).
|
||||
"""
|
||||
query = args["query"]
|
||||
kind = args.get("kind", "all")
|
||||
language = args.get("language", "").strip() or None
|
||||
page = args.get("page", 1)
|
||||
|
||||
params: dict = {"query": quote(query), "page": page}
|
||||
if language:
|
||||
params["language"] = language
|
||||
|
||||
async with _client() as c:
|
||||
r = await c.get("/api/v1/search", params=params)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
results = data.get("results", [])
|
||||
|
||||
# Filter by mediaType if requested
|
||||
if kind != "all":
|
||||
results = [item for item in results if item.get("mediaType") == kind]
|
||||
|
||||
label = f"search results for '{query}'"
|
||||
if kind != "all":
|
||||
label += f" ({kind})"
|
||||
if not results:
|
||||
return ToolResult.ok(f"No {label} found.")
|
||||
return ToolResult.ok(_fmt_items(results, label))
|
||||
|
||||
|
||||
async def _media_details(args: dict) -> ToolResult:
|
||||
"""Fetch full details for a movie or TV show.
|
||||
Resolves the TMDb ID via search if not provided.
|
||||
"""
|
||||
kind = args["kind"]
|
||||
tmdb_id = args.get("tmdb_id")
|
||||
title = args.get("title", "").strip()
|
||||
language = args.get("language", "").strip() or None
|
||||
|
||||
params: dict = {}
|
||||
if language:
|
||||
params["language"] = language
|
||||
|
||||
async with _client() as c:
|
||||
# Resolve TMDb ID if needed
|
||||
if not tmdb_id and title:
|
||||
sr = await c.get("/api/v1/search", params={
|
||||
"query": quote(title), "page": 1,
|
||||
})
|
||||
sr.raise_for_status()
|
||||
sresults = sr.json().get("results", [])
|
||||
sresults = [item for item in sresults if item.get("mediaType") == kind]
|
||||
if sresults:
|
||||
tmdb_id = sresults[0].get("id")
|
||||
else:
|
||||
return ToolResult.fail(
|
||||
f"I couldn't find {kind} '{title}' on Seerr."
|
||||
)
|
||||
|
||||
if not tmdb_id:
|
||||
return ToolResult.fail(
|
||||
"I need either a TMDb ID or a title to look up media details."
|
||||
)
|
||||
|
||||
endpoint = f"/api/v1/{kind}/{tmdb_id}"
|
||||
r = await c.get(endpoint, params=params)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
|
||||
# Build a concise summary for the LLM
|
||||
title_str = data.get("title") or data.get("name") or "Unknown"
|
||||
year = (
|
||||
data.get("releaseDate", "")[:4]
|
||||
or data.get("firstAirDate", "")[:4]
|
||||
or "?"
|
||||
)
|
||||
overview = data.get("overview", "No overview available.")
|
||||
runtime = data.get("runtime", "?")
|
||||
vote = data.get("voteAverage", "?")
|
||||
genres = ", ".join(g.get("name", "") for g in data.get("genres", [])[:5])
|
||||
|
||||
lines = [
|
||||
f"**{title_str}** ({year}) [tmdb:{tmdb_id}]",
|
||||
f"⭐ {vote}/10 | ⏱ {runtime} min | Genres: {genres or 'N/A'}",
|
||||
f"",
|
||||
f"{overview[:500]}",
|
||||
]
|
||||
|
||||
# Cast (top 5)
|
||||
cast = (data.get("credits", {}) or {}).get("cast", [])[:5]
|
||||
if cast:
|
||||
lines.append("")
|
||||
lines.append("**Top Cast:** " + ", ".join(
|
||||
c["name"] for c in cast
|
||||
))
|
||||
|
||||
# Streaming providers
|
||||
providers = data.get("watchProviders", [])
|
||||
if providers:
|
||||
flatrate = []
|
||||
for region in providers:
|
||||
for p in region.get("flatrate", []) or []:
|
||||
if p.get("name") and p["name"] not in flatrate:
|
||||
flatrate.append(p["name"])
|
||||
if flatrate:
|
||||
lines.append("")
|
||||
lines.append("**Streaming:** " + ", ".join(flatrate[:5]))
|
||||
|
||||
return ToolResult.ok("\n".join(lines))
|
||||
|
||||
|
||||
async def _my_requests(args: dict) -> ToolResult:
|
||||
"""Fetch the current user's media requests from /request.
|
||||
Filters and sorting are optional.
|
||||
"""
|
||||
filter_status = args.get("filter", "pending")
|
||||
media_type = args.get("media_type", "all")
|
||||
|
||||
params: dict = {"filter": filter_status}
|
||||
if media_type != "all":
|
||||
params["mediaType"] = media_type
|
||||
|
||||
async with _client() as c:
|
||||
r = await c.get("/api/v1/request", params=params)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
results = data.get("results", [])
|
||||
|
||||
if not results:
|
||||
return ToolResult.ok(f"You have no {filter_status} requests right now.")
|
||||
|
||||
lines = []
|
||||
for i, req in enumerate(results[:10], 1):
|
||||
media = req.get("media", {}) or {}
|
||||
title = media.get("title") or media.get("name") or "Unknown"
|
||||
status = req.get("status", "?")
|
||||
status_labels = {1: "Pending", 2: "Approved", 3: "Declined"}
|
||||
status_str = status_labels.get(status, f"Status {status}")
|
||||
is_4k = " (4K)" if req.get("is4k") else ""
|
||||
lines.append(f"{i}. **{title}**{is_4k} — {status_str}")
|
||||
|
||||
total = data.get("pageInfo", {}).get("results", len(results))
|
||||
return ToolResult.ok(
|
||||
f"You have {total} {filter_status} requests:\n\n" + "\n".join(lines)
|
||||
)
|
||||
|
||||
|
||||
async def _submit_issue(args: dict) -> ToolResult:
|
||||
subject = args["subject"]
|
||||
description = args["description"]
|
||||
media_title = args.get("media_title", "")
|
||||
issue_type = args.get("issue_type", 4) # numeric code: 1=video, 2=audio, 3=sub, 4=other
|
||||
media_id = args.get("media_id")
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger("skills.seerr")
|
||||
|
||||
body: dict = {
|
||||
"issueType": int(issue_type),
|
||||
"message": description,
|
||||
}
|
||||
if media_title:
|
||||
body["message"] = f"[Media: {media_title}]\n\n{description}"
|
||||
|
||||
logger.info("📝 SUBMIT_ISSUE body=%s media_title=%s", body, media_title)
|
||||
|
||||
async with _client() as c:
|
||||
# --- Resolve mediaId (Seerr internal ID for /issue endpoint) ---
|
||||
if not media_id and media_title:
|
||||
search_r = await c.get("/api/v1/search", params={
|
||||
"query": quote(media_title), "page": 1,
|
||||
})
|
||||
search_r.raise_for_status()
|
||||
results = search_r.json().get("results", [])
|
||||
|
||||
logger.info("🔍 Search for '%s' → %d results", media_title, len(results))
|
||||
|
||||
# Filter to actual media (not persons) and prefer exact title match
|
||||
media_results = [
|
||||
item for item in results
|
||||
if item.get("mediaType") in ("movie", "tv")
|
||||
]
|
||||
if media_results:
|
||||
media_info = media_results[0].get("mediaInfo", {})
|
||||
media_id = media_info.get("id") or media_results[0].get("id")
|
||||
logger.info("🔍 Resolved mediaId=%s for '%s'", media_id, media_title)
|
||||
|
||||
if media_id:
|
||||
body["mediaId"] = int(media_id)
|
||||
|
||||
logger.info("📤 POST /api/v1/issue body=%s", body)
|
||||
r = await c.post("/api/v1/issue", json=body)
|
||||
logger.info("📥 Response status=%s body=%s", r.status_code, r.text[:500])
|
||||
|
||||
resp_json = r.json() if r.text else {}
|
||||
if r.status_code in (200, 201):
|
||||
ticket_id = resp_json.get("id", "N/A")
|
||||
return ToolResult.ok(
|
||||
f"✅ Issue submitted successfully (ticket #{ticket_id}). "
|
||||
f"A human operator will review: **{subject}**"
|
||||
)
|
||||
else:
|
||||
return ToolResult.fail(
|
||||
f"❌ Failed to submit issue. Seerr responded with "
|
||||
f"status {r.status_code}: {r.text[:500]}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Register the skill
|
||||
# ---------------------------------------------------------------------------
|
||||
seerr_skill = Skill(
|
||||
name="seerr",
|
||||
description="Seerr integration — search, trending, discover, request media, "
|
||||
"look up details, check requests, submit issues",
|
||||
prompt_fragment="""## Seerr Media Tools
|
||||
|
||||
You have access to the Seerr media management system. Use the provided tools
|
||||
to help users with media-related tasks:
|
||||
|
||||
- **seerr_search** — when a user wants to find a specific movie, show, or person
|
||||
- **seerr_trending** — when a user asks what is trending/popular/new
|
||||
- **seerr_discover** — when a user asks for recommendations by genre/category
|
||||
- **seerr_media_details** — when a user wants full info about a movie or show
|
||||
- **seerr_my_requests** — when a user asks about their pending/approved requests
|
||||
- **seerr_request_media** — when a user wants to request a movie or TV show
|
||||
- **seerr_submit_issue** — when a user needs to report a problem or needs an
|
||||
operator-only action (like deleting media or cancelling a request)
|
||||
|
||||
**TMDb ID Rule**: Every movie and TV show has a unique TMDb ID. When you see
|
||||
`[tmdb:123456]` in search/trending/discover results, always **show it to the user**
|
||||
in your response. Never strip or omit the TMDb ID when presenting results — the
|
||||
user needs it to reference items for follow-up actions. Similarly, capture the ID
|
||||
for any follow-up action you take (request details, submit a request, file an
|
||||
issue, etc.). If you don't have a TMDb ID and need to take action on a title,
|
||||
search first to get one. Never rely on title alone when an ID is available —
|
||||
titles are ambiguous, IDs are not. This rule applies to all media tools, present
|
||||
and future.
|
||||
|
||||
Always confirm successful actions to the user. If a tool fails, tell the user
|
||||
what went wrong and suggest alternatives.""",
|
||||
tools=TOOLS,
|
||||
execute=_execute,
|
||||
)
|
||||
|
||||
register(seerr_skill)
|
||||
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Triage skill — fallback for actions that aren't covered by any registered skill.
|
||||
|
||||
When a user asks for something that the agent cannot do (either because the
|
||||
skill doesn't exist or is intentionally unavailable — e.g. deleting media,
|
||||
cancelling requests, banning users), this skill teaches the LLM to:
|
||||
|
||||
1. Politely explain that the action requires a human operator.
|
||||
2. Offer to submit a ticket instead.
|
||||
3. Use the seerr_submit_issue tool (if available) to create the ticket.
|
||||
"""
|
||||
|
||||
from agents.skills import Skill, register
|
||||
|
||||
# This skill has no tools of its own — it guides the LLM's behavior.
|
||||
# The actual ticket submission is handled by seerr_submit_issue.
|
||||
|
||||
triage_skill = Skill(
|
||||
name="triage",
|
||||
description="Fallback for unsupported actions — explains limitations "
|
||||
"and offers to create a ticket instead.",
|
||||
prompt_fragment="""## Triage & Fallback Rules
|
||||
|
||||
You are a helpful media assistant, but you have limited capabilities. Follow these
|
||||
rules when a user asks for something you **cannot** do:
|
||||
|
||||
### Actions you CANNOT perform (human-operator-only):
|
||||
- Deleting media, requests, or users
|
||||
- Cancelling existing requests
|
||||
- Modifying library settings
|
||||
- Changing user permissions
|
||||
- Any destructive or administrative action
|
||||
|
||||
### When the user asks for an unsupported action:
|
||||
1. **Politely explain** that this action requires a human operator.
|
||||
2. **Offer to submit a ticket** via the seerr_submit_issue tool with a clear
|
||||
description of what the user wants.
|
||||
3. Never say "I don't know how to do that" without also offering the ticket
|
||||
alternative.
|
||||
|
||||
### Example response template:
|
||||
"I can't perform [action] directly — that requires a human operator for safety.
|
||||
But I'd be happy to **submit a ticket** for you with all the details. Would you
|
||||
like me to do that?"
|
||||
|
||||
Always lean toward being helpful rather than just saying no.""",
|
||||
tools=[], # no tools — this is a prompt-only skill
|
||||
execute=None,
|
||||
)
|
||||
|
||||
register(triage_skill)
|
||||
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
Watch History skill — fetch the user's Jellyfin watch history via JellyStat API.
|
||||
|
||||
Requires the user to have linked Jellyfin via `/login jellyfin` in Discord.
|
||||
The auth gate (`requires_auth=["jellyfin"]`) is already active — users who
|
||||
haven't linked Jellyfin will be prompted to /login first.
|
||||
|
||||
Architecture
|
||||
------------
|
||||
This skill calls the JellyStat REST API (same FastAPI process, via HTTP)
|
||||
rather than accessing the PostgreSQL database directly. This keeps the
|
||||
bot isolated from database credentials.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
|
||||
from agents.skills import Skill, register, ToolResult
|
||||
from src import auth_store
|
||||
from src.config import get_config
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
BASE_URL = (get_config("BASE_URL") or "http://localhost:8000").rstrip("/")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool definitions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "watch_history",
|
||||
"description": (
|
||||
"Get the user's Jellyfin watch history — titles grouped by total "
|
||||
"watch time in a configurable time window. Use this when a user "
|
||||
"asks what they've watched, what they've been watching recently, "
|
||||
"or wants to see their viewing activity."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "How many titles to return (default 10, max 20).",
|
||||
},
|
||||
"minutes": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Time window in minutes. Default 10080 (7 days). "
|
||||
"Use a large number like 525600 for 'all time' (1 year)."
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "watch_genres",
|
||||
"description": (
|
||||
"Get the user's most-watched genres from Jellyfin, ranked by "
|
||||
"total watch time. Use this when a user asks what kinds of "
|
||||
"content they watch most, their favourite genres, or what "
|
||||
"categories dominate their viewing."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"minutes": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Time window in minutes. Default 10080 (7 days). "
|
||||
"Use a large number like 525600 for 'all time'."
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "watch_summary",
|
||||
"description": (
|
||||
"Get an all-time Jellyfin watch summary — total watch time, "
|
||||
"most-watched series, most-watched movie, 30-day and 7-day "
|
||||
"activity, and top 3 genres. Use this when a user asks for "
|
||||
"their overall stats, a dashboard, or 'how much have I watched?'."
|
||||
),
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_jellyfin_id(args: dict) -> str | None:
|
||||
"""Extract the Jellyfin user ID from auth_store using the injected Discord ID."""
|
||||
discord_user_id = args.pop("_discord_user_id", None)
|
||||
if discord_user_id is None:
|
||||
return None # not called from Discord — shouldn't happen with auth gate
|
||||
|
||||
auth = auth_store.get_auth(discord_user_id, "jellyfin")
|
||||
if auth is None or not auth.get("external_user_id"):
|
||||
return None
|
||||
|
||||
return auth["external_user_id"]
|
||||
|
||||
|
||||
async def _fetch_json(url: str) -> dict:
|
||||
"""GET *url* and return the parsed JSON body, or {} on failure."""
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.get(url)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _format_seconds(total: float) -> str:
|
||||
"""Convert seconds to a human-friendly string."""
|
||||
total = max(total, 0)
|
||||
hours = int(total // 3600)
|
||||
minutes = int((total % 3600) // 60)
|
||||
if hours and minutes:
|
||||
return f"{hours}h {minutes}m"
|
||||
if hours:
|
||||
return f"{hours}h"
|
||||
if minutes:
|
||||
return f"{minutes}m"
|
||||
return f"{int(total)}s"
|
||||
|
||||
|
||||
def _format_history(data: dict, limit: int) -> ToolResult:
|
||||
"""Format a watch-history API response for the LLM."""
|
||||
items = data.get("items", [])[:limit]
|
||||
if not items:
|
||||
return ToolResult.ok("You haven't watched anything in this time window.")
|
||||
|
||||
lines = [f"**Watch History** (last {data.get('window_minutes', '?')} minutes):"]
|
||||
for i, item in enumerate(items, 1):
|
||||
duration = _format_seconds(item["watch_time_sec"])
|
||||
icon = "📺" if item["media_type"] == "series" else "🎬"
|
||||
lines.append(f"{i}. {icon} **{item['title']}** — {duration}")
|
||||
|
||||
return ToolResult.ok("\n".join(lines))
|
||||
|
||||
|
||||
def _format_genres(data: dict) -> ToolResult:
|
||||
"""Format a genre-summary API response for the LLM."""
|
||||
genres = data.get("genres", [])
|
||||
if not genres:
|
||||
return ToolResult.ok("No genre data available for this time window.")
|
||||
|
||||
lines = [f"**Top Genres** (last {data.get('window_minutes', '?')} minutes):"]
|
||||
for i, g in enumerate(genres, 1):
|
||||
duration = _format_seconds(g["watch_time_sec"])
|
||||
lines.append(f"{i}. **{g['genre']}** — {duration}")
|
||||
|
||||
return ToolResult.ok("\n".join(lines))
|
||||
|
||||
|
||||
def _format_summary(data: dict) -> ToolResult:
|
||||
"""Format a user-summary API response for the LLM."""
|
||||
total = _format_seconds(data.get("total_watch_time_sec", 0))
|
||||
last_30 = _format_seconds(data.get("total_last_30d_sec", 0))
|
||||
last_7 = _format_seconds(data.get("total_last_7d_sec", 0))
|
||||
|
||||
top_series = data.get("most_watched_series") or "—"
|
||||
top_movie = data.get("most_watched_movie") or "—"
|
||||
top_genres = data.get("top_genres", [])
|
||||
genres_str = ", ".join(top_genres) if top_genres else "—"
|
||||
|
||||
lines = [
|
||||
"**Your Jellyfin Summary** (all time):",
|
||||
f"⏱️ Total watch time: **{total}**",
|
||||
f"📺 Most-watched series: **{top_series}**",
|
||||
f"🎬 Most-watched movie: **{top_movie}**",
|
||||
f"📅 Last 30 days: **{last_30}**",
|
||||
f"📅 Last 7 days: **{last_7}**",
|
||||
f"🏷️ Top genres: {genres_str}",
|
||||
]
|
||||
return ToolResult.ok("\n".join(lines))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Executor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _execute(tool_name: str, args: dict) -> ToolResult:
|
||||
# 1. Resolve Jellyfin user ID
|
||||
jellyfin_id = _resolve_jellyfin_id(args)
|
||||
if jellyfin_id is None:
|
||||
return ToolResult.fail(
|
||||
"Your Jellyfin account is not linked. Use `/login jellyfin` in a DM to connect."
|
||||
)
|
||||
|
||||
# 2. Route to the right JellyStat endpoint
|
||||
try:
|
||||
match tool_name:
|
||||
case "watch_history":
|
||||
limit = args.get("limit", 10)
|
||||
minutes = args.get("minutes", 10080)
|
||||
url = f"{BASE_URL}/jellystat/history/{jellyfin_id}?minutes={minutes}"
|
||||
data = await _fetch_json(url)
|
||||
return _format_history(data, limit)
|
||||
|
||||
case "watch_genres":
|
||||
minutes = args.get("minutes", 10080)
|
||||
url = f"{BASE_URL}/jellystat/genres/{jellyfin_id}?minutes={minutes}"
|
||||
data = await _fetch_json(url)
|
||||
return _format_genres(data)
|
||||
|
||||
case "watch_summary":
|
||||
url = f"{BASE_URL}/jellystat/summary/{jellyfin_id}"
|
||||
data = await _fetch_json(url)
|
||||
return _format_summary(data)
|
||||
|
||||
case _:
|
||||
return ToolResult.fail(f"Unknown tool: {tool_name}")
|
||||
|
||||
except httpx.HTTPError:
|
||||
return ToolResult.fail(
|
||||
"Could not reach the watch-history service right now. "
|
||||
"Please try again in a moment."
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Skill registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PROMPT = (
|
||||
"## Watch History\n"
|
||||
"\n"
|
||||
"You have THREE tools to answer questions about the user's Jellyfin watch activity:\n"
|
||||
"\n"
|
||||
"1. **`watch_history`** — per-title watch time in a time window (default: 7 days).\n"
|
||||
" Use when a user asks what they've watched, to show their history,\n"
|
||||
" or what they watched this week or yesterday.\n"
|
||||
"\n"
|
||||
"2. **`watch_genres`** — watch time broken down by genre.\n"
|
||||
" Use when a user asks what genres they watch, whether they watch more\n"
|
||||
" comedy than drama, or what their most-watched genre is.\n"
|
||||
"\n"
|
||||
"3. **`watch_summary`** — all-time dashboard: total watch time, most-watched\n"
|
||||
" series and movie, 30-day and 7-day activity, and top 3 genres.\n"
|
||||
" Use when a user asks for their stats, how much they've watched in\n"
|
||||
" total, or what their favourites are.\n"
|
||||
"\n"
|
||||
"Always call the appropriate tool before answering — NEVER guess at watch data.\n"
|
||||
"Format watch times in a human-readable way (hours and minutes), but keep the\n"
|
||||
"raw data visible too."
|
||||
)
|
||||
|
||||
watch_history_skill = Skill(
|
||||
name="watch_history",
|
||||
description="User's Jellyfin watch history, genres, and summary stats",
|
||||
requires_auth=["jellyfin"],
|
||||
prompt_fragment=_PROMPT,
|
||||
tools=TOOLS,
|
||||
execute=_execute,
|
||||
)
|
||||
|
||||
register(watch_history_skill)
|
||||
@@ -0,0 +1,13 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /Agents
|
||||
|
||||
COPY requirements.txt /Agents/requirements.txt
|
||||
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY . /Agents
|
||||
|
||||
ENV PYTHONPATH=/Agents
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
@@ -0,0 +1,75 @@
|
||||
# Gateway Architecture — Agent + Skill + Graph Pipeline
|
||||
|
||||
This is the **interface layer** of the Agents project. Everything that connects
|
||||
the outside world to the agent system lives here — REST APIs, Discord bot,
|
||||
and authentication.
|
||||
|
||||
---
|
||||
|
||||
## Directory Map
|
||||
|
||||
| Path | Description | Docs |
|
||||
|---|---|---|
|
||||
| `gateway/v1/` | REST API endpoints — chat, agent listing, OpenAI-compatible completions | [v1.md](v1/v1.md) |
|
||||
| `gateway/discord/` | Discord bot connector — in-process DM handler with LangGraph integration | [discord.md](discord/discord.md) |
|
||||
| `gateway/auth/` | Auth service registry + Jellyfin Quick Connect implementation | [auth.md](auth/auth.md) |
|
||||
|
||||
---
|
||||
|
||||
## Supporting Modules
|
||||
|
||||
| Path | Purpose |
|
||||
|---|---|
|
||||
| `gateway/dependencies.py` | FastAPI `Depends` providers — `get_llm_client()`, `get_agent_graph()` |
|
||||
| `src/config.py` | `.env` loader and config accessor |
|
||||
| `src/llm.py` | OpenAI-compatible client factory (DeepSeek) |
|
||||
| `src/state.py` | LangGraph `AgentState` TypedDict |
|
||||
| `src/graph.py` | LangGraph StateGraph factory — agent_node, tool_node, routing |
|
||||
| `src/tools_adapter.py` | Wraps skill tools as LangChain `@tool` functions |
|
||||
| `src/auth_store.py` | SQLite persistence for Discord → service auth linking |
|
||||
| `agents/` | Agent definitions (dataclass + registry) |
|
||||
| `agents/skills/` | Skill definitions — prompt fragments, tool schemas, executors |
|
||||
|
||||
---
|
||||
|
||||
## High-Level Request Flow
|
||||
|
||||
```
|
||||
┌──────────────────────────────┐
|
||||
│ Client (OpenWebUI / HTTP) │
|
||||
└──────────────┬───────────────┘
|
||||
│ POST /v1/chat/completions
|
||||
▼
|
||||
┌──────────────────────────────┐
|
||||
│ gateway/v1/chat.py │ ← resolves agent, invokes graph
|
||||
└──────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────┐
|
||||
│ LangGraph StateGraph │ ← src/graph.py
|
||||
│ ┌──────────┐ ┌──────────┐│
|
||||
│ │agent_node│──▶│tool_node ││
|
||||
│ │(LLM call)│◀──│(skills) ││
|
||||
│ └──────────┘ └──────────┘│
|
||||
└──────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────┐
|
||||
│ agents/skills/ │ ← Seerr API, Jellyfin API, etc.
|
||||
└──────────────────────────────┘
|
||||
```
|
||||
|
||||
For a detailed step-by-step walkthrough of the graph execution (including
|
||||
multi-turn context and tool-calling loops), see [v1.md](v1/v1.md).
|
||||
|
||||
---
|
||||
|
||||
## Startup
|
||||
|
||||
`main.py` is the entry point. At startup it:
|
||||
|
||||
1. Loads `.env` → creates the LLM client (DeepSeek) → stores on `app.state.llm_client`
|
||||
2. Calls `load_all_agents()` → imports every agent and skill module (they self-register)
|
||||
3. Imports `gateway.auth.jellyfin` → self-registers the Jellyfin auth service
|
||||
4. Mounts routers: `/v1/*` (chat endpoints) and `/api/v1/auth/*` (auth endpoints)
|
||||
5. Starts the Discord bot as a background asyncio task (lifespan)
|
||||
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
Auth Service registry — generic, pluggable authentication for any service.
|
||||
|
||||
Add a new service (Plex, Seerr, etc.) by:
|
||||
1. Subclassing AuthService
|
||||
2. Dropping the module in this package
|
||||
3. Calling register_auth_service() at import time
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AuthResult — returned by AuthService.authenticate()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class AuthResult:
|
||||
"""Outcome of a credential validation attempt."""
|
||||
success: bool
|
||||
external_user_id: Optional[str] = None
|
||||
external_name: Optional[str] = None
|
||||
credentials: Optional[dict] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AuthService — abstract base class
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AuthService(ABC):
|
||||
"""A service that users can authenticate against (Jellyfin, Seerr, Plex, etc.)
|
||||
|
||||
Subclasses must implement:
|
||||
- name : unique identifier used in URLs and DB keys
|
||||
- display_name : human-readable label shown in Discord
|
||||
- render_login_form(token, discord_id) → HTML string
|
||||
- authenticate(form_data) → AuthResult
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Unique service name: "jellyfin", "seerr", etc."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def display_name(self) -> str:
|
||||
"""Human-readable: "Jellyfin", "Seerr", "Plex" """
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def render_login_form(self, token: str, discord_id: int) -> str:
|
||||
"""Return HTML string with a login form for this service.
|
||||
|
||||
The form MUST include these hidden fields:
|
||||
<input type="hidden" name="token" value="{token}">
|
||||
<input type="hidden" name="discord_id" value="{discord_id}">
|
||||
<input type="hidden" name="service" value="{self.name}">
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def authenticate(self, form_data: dict) -> AuthResult:
|
||||
"""Validate credentials against the service. Return AuthResult."""
|
||||
...
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Global registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_registry: dict[str, AuthService] = {}
|
||||
|
||||
|
||||
def register_auth_service(svc: AuthService) -> None:
|
||||
"""Register an AuthService so it can be looked up by name."""
|
||||
_registry[svc.name] = svc
|
||||
|
||||
|
||||
def get_auth_service(name: str) -> AuthService | None:
|
||||
"""Look up a registered AuthService by name."""
|
||||
return _registry.get(name)
|
||||
|
||||
|
||||
def list_auth_services() -> list[str]:
|
||||
"""Return names of all registered auth services."""
|
||||
return list(_registry.keys())
|
||||
@@ -0,0 +1,152 @@
|
||||
# Auth — Service Registry & Persistence
|
||||
|
||||
The authentication system lets Discord users link their accounts to external
|
||||
services (currently **Jellyfin**) so the agent can perform actions on their
|
||||
behalf (e.g. checking watch history).
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
gateway/auth/ gateway/v1/auth.py
|
||||
┌──────────────────────┐ ┌──────────────────────────────┐
|
||||
│ AuthService (ABC) │ │ GET /api/v1/auth/login │
|
||||
│ ├─ JellyfinAuth │◀─────────│ POST /api/v1/auth/login │
|
||||
│ └─ (Plex, Seerr…) │ │ GET /api/v1/auth/status │
|
||||
│ │ │ GET /api/v1/auth/reset │
|
||||
└─────────┬────────────┘ └──────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
src/auth_store.py
|
||||
┌──────────────────────┐
|
||||
│ SQLite │
|
||||
│ ├─ link_tokens │ one-time tokens sent via Discord DM
|
||||
│ └─ user_auth │ per-user, per-service credentials
|
||||
└──────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Files
|
||||
|
||||
| File | Purpose |
|
||||
|---|---|
|
||||
| `gateway/auth/__init__.py` | Abstract `AuthService` base class + global registry |
|
||||
| `gateway/auth/jellyfin.py` | Jellyfin implementation — Quick Connect + username/password |
|
||||
| `gateway/v1/auth.py` | REST endpoints for the web-based login flow |
|
||||
| `src/auth_store.py` | SQLite persistence for link tokens and stored credentials |
|
||||
|
||||
---
|
||||
|
||||
## Flow: Discord User Links Jellyfin
|
||||
|
||||
```
|
||||
Discord DM Web Browser Jellyfin Server
|
||||
│ │ │
|
||||
│ 1. /login jellyfin │ │
|
||||
│ ──────────────────────────────▶│ │
|
||||
│ Bot creates link token in │ │
|
||||
│ SQLite, DMs the user a URL │ │
|
||||
│ │ │
|
||||
│ 2. User clicks link │ │
|
||||
│ ◀─────────────────────────────▶│ │
|
||||
│ │ GET /api/v1/auth/login │
|
||||
│ │ ?service=jellyfin │
|
||||
│ │ &token=xxx&discord_id=123 │
|
||||
│ │ │
|
||||
│ │ 3. Serve Quick Connect form │
|
||||
│ │ ◀──────────────────────────── │
|
||||
│ │ │
|
||||
│ │ 4. Initiate Quick Connect │
|
||||
│ │ ─────────────────────────────▶│
|
||||
│ │ POST /QuickConnect/Initiate │
|
||||
│ │ ◀── { Code: "ABC123" } │
|
||||
│ │ │
|
||||
│ 5. User enters code in │ │
|
||||
│ Jellyfin app │ │
|
||||
│ │ │
|
||||
│ │ 6. Poll: is it authorized? │
|
||||
│ │ ─────────────────────────────▶│
|
||||
│ │ GET /QuickConnect/Connect │
|
||||
│ │ ◀── Authenticated + Token │
|
||||
│ │ │
|
||||
│ 7. auth_store saves: │ │
|
||||
│ (discord_id, jellyfin, │ │
|
||||
│ AccessToken, username) │ │
|
||||
│ │ │
|
||||
│ 8. "✅ Linked to Jellyfin!" │ │
|
||||
│ ◀───────────────────────────── │ │
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## AuthService Base Class
|
||||
|
||||
```python
|
||||
class AuthService(ABC):
|
||||
name: str # "jellyfin"
|
||||
display_name: str # "Jellyfin"
|
||||
|
||||
def render_login_form(token, discord_id) -> str: ...
|
||||
async def authenticate(form_data) -> AuthResult: ...
|
||||
```
|
||||
|
||||
Add a new service (e.g. Plex, Seerr) by subclassing `AuthService`, dropping
|
||||
the module in `gateway/auth/`, and calling `register_auth_service()` at import
|
||||
time. The REST endpoints and auth store work generically — no changes needed.
|
||||
|
||||
---
|
||||
|
||||
## Current Implementation: Jellyfin
|
||||
|
||||
`gateway/auth/jellyfin.py` supports two flows:
|
||||
|
||||
| Method | How it works |
|
||||
|---|---|
|
||||
| **Quick Connect** (primary) | Calls Jellyfin's `/QuickConnect/Initiate` → polls `/QuickConnect/Connect` → stores the `AccessToken` |
|
||||
| **Username/Password** (fallback) | Renders an HTML form → user submits credentials → calls `/Users/AuthenticateByName` → stores the `AccessToken` |
|
||||
|
||||
The stored credentials include:
|
||||
- `external_user_id` — Jellyfin user ID
|
||||
- `external_name` — Jellyfin username
|
||||
- `credentials` dict — `{"AccessToken": "...", "ServerURL": "..."}`
|
||||
|
||||
---
|
||||
|
||||
## Auth Store (SQLite)
|
||||
|
||||
Two tables in `data/auth.db`:
|
||||
|
||||
```sql
|
||||
-- One-time tokens for the web login flow (expire after 10 min)
|
||||
CREATE TABLE link_tokens (
|
||||
token TEXT PRIMARY KEY,
|
||||
discord_id INTEGER NOT NULL,
|
||||
service TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
used INTEGER DEFAULT 0
|
||||
);
|
||||
|
||||
-- Per-user, per-service stored credentials
|
||||
CREATE TABLE user_auth (
|
||||
discord_id INTEGER NOT NULL,
|
||||
service TEXT NOT NULL,
|
||||
external_user_id TEXT,
|
||||
external_name TEXT,
|
||||
credentials TEXT, -- JSON
|
||||
created_at TEXT NOT NULL,
|
||||
PRIMARY KEY (discord_id, service)
|
||||
);
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Skill-Level Auth Gating
|
||||
|
||||
Skills can declare `requires_auth=["jellyfin"]`. When a tool is executed,
|
||||
the skill system checks the auth store. If the user isn't linked:
|
||||
|
||||
1. The tool returns `ToolResult.fail("Please login first using /login jellyfin")`
|
||||
2. The LLM relays this message to the user in Discord
|
||||
3. The user types `/login jellyfin` → Quick Connect flow → re-linked → try again
|
||||
@@ -0,0 +1,401 @@
|
||||
"""
|
||||
Jellyfin AuthService — validates Jellyfin credentials and stores the session token.
|
||||
|
||||
Two authentication flows:
|
||||
1. Quick Connect (primary): user enters a short code on their Jellyfin app.
|
||||
- initiate_quick_connect() → {code, secret}
|
||||
- poll_quick_connect(secret) → "Active" | "Authorized" | "Expired"
|
||||
- authenticate_quick_connect(secret) → AuthResult with token
|
||||
|
||||
2. Username/password (legacy): renders an HTML form, called via the REST API.
|
||||
- render_login_form(token, discord_id) → HTML string
|
||||
- authenticate(form_data) → AuthResult
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from gateway.auth import AuthService, AuthResult, register_auth_service
|
||||
from src.config import get_config
|
||||
|
||||
logger = logging.getLogger("auth.jellyfin")
|
||||
|
||||
# Emby-style authorization header required by Jellyfin's AuthenticateByName
|
||||
_EMBY_HEADER = (
|
||||
'MediaBrowser Client="AgentBot",'
|
||||
'Device="DiscordBot",'
|
||||
'DeviceId="agent-bot",'
|
||||
'Version="1.0"'
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuickConnectResult:
|
||||
"""Result of a Quick Connect initiation."""
|
||||
secret: str
|
||||
code: str
|
||||
device_id: str
|
||||
device_name: str
|
||||
|
||||
|
||||
class JellyfinAuth(AuthService):
|
||||
name = "jellyfin"
|
||||
display_name = "Jellyfin"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Quick Connect helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _qc_headers(self) -> dict[str, str]:
|
||||
"""Return headers used by all Quick Connect API calls."""
|
||||
return {
|
||||
"X-Emby-Authorization": (
|
||||
'MediaBrowser Client="AgentBot",'
|
||||
'Device="DiscordBot",'
|
||||
'DeviceId="agent-bot-qc",'
|
||||
'Version="1.0"'
|
||||
)
|
||||
}
|
||||
|
||||
async def _resolve_url(self) -> str | None:
|
||||
"""
|
||||
Resolve the Jellyfin server URL.
|
||||
1. Check JELLYFIN_URL env var (used in deployment).
|
||||
2. Check if user already has a stored auth with a URL (from legacy login).
|
||||
Returns None if no URL is configured.
|
||||
"""
|
||||
# First: explicit env var
|
||||
env_url = get_config("JELLYFIN_URL")
|
||||
if env_url:
|
||||
return env_url.strip().rstrip("/")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 1a: initiate Quick Connect
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def initiate_quick_connect(self, url: str | None = None) -> QuickConnectResult | None:
|
||||
"""
|
||||
Call Jellyfin's POST /QuickConnect/Initiate.
|
||||
Returns a QuickConnectResult with {secret, code} or None on failure.
|
||||
|
||||
The *code* is what the user enters on their Jellyfin page.
|
||||
The *secret* is used internally to poll/authenticate.
|
||||
"""
|
||||
base_url = url or await self._resolve_url()
|
||||
if not base_url:
|
||||
logger.error("QuickConnect failed — no JELLYFIN_URL configured.")
|
||||
return None
|
||||
|
||||
logger.info("Initiating Quick Connect on %s", base_url)
|
||||
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{base_url}/QuickConnect/Initiate",
|
||||
headers=self._qc_headers(),
|
||||
json={},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
"QuickConnect init failed: HTTP %s — %s",
|
||||
resp.status_code, resp.text[:200],
|
||||
)
|
||||
return None
|
||||
|
||||
data = resp.json()
|
||||
secret = data.get("Secret", "")
|
||||
code = data.get("Code", "")
|
||||
device_id = data.get("DeviceId", "")
|
||||
device_name = data.get("DeviceName", "")
|
||||
|
||||
if not secret or not code:
|
||||
logger.warning("QuickConnect init returned unexpected payload: %s", data)
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
"Quick Connect initiated: code=%s device=%s",
|
||||
code, device_name,
|
||||
)
|
||||
return QuickConnectResult(
|
||||
secret=secret,
|
||||
code=code,
|
||||
device_id=device_id,
|
||||
device_name=device_name,
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("QuickConnect init timed out reaching %s", base_url)
|
||||
return None
|
||||
except httpx.ConnectError:
|
||||
logger.error("QuickConnect init — cannot connect to %s", base_url)
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Unexpected error during QuickConnect init")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 1b: poll Quick Connect status
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def poll_quick_connect(self, secret: str, url: str | None = None) -> str:
|
||||
"""
|
||||
Call Jellyfin's GET /QuickConnect/Connect?secret=<secret>.
|
||||
Returns one of:
|
||||
- "Active" → user hasn't entered the code yet
|
||||
- "Authorized" → user entered code AND approved
|
||||
- "Expired" → code expired / unknown secret
|
||||
- "Error" → network or unexpected failure
|
||||
"""
|
||||
base_url = url or await self._resolve_url()
|
||||
if not base_url:
|
||||
logger.error("QuickConnect poll failed — no JELLYFIN_URL configured.")
|
||||
return "Error"
|
||||
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
try:
|
||||
resp = await client.get(
|
||||
f"{base_url}/QuickConnect/Connect",
|
||||
params={"secret": secret},
|
||||
headers=self._qc_headers(),
|
||||
)
|
||||
if resp.status_code == 404:
|
||||
return "Expired"
|
||||
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
# Jellyfin returns "Authenticated" (not "Authorized")
|
||||
if data.get("Authenticated") is True:
|
||||
return "Authorized"
|
||||
# "Authenticated" is false, missing, or null → still active
|
||||
return "Active"
|
||||
|
||||
logger.warning(
|
||||
"QuickConnect poll unexpected: HTTP %s — %s",
|
||||
resp.status_code, resp.text[:200],
|
||||
)
|
||||
return "Error"
|
||||
|
||||
except (httpx.TimeoutException, httpx.ConnectError):
|
||||
logger.warning("QuickConnect poll network error")
|
||||
return "Error"
|
||||
except Exception:
|
||||
logger.exception("Unexpected error during QuickConnect poll")
|
||||
return "Error"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 1c: exchange secret for token
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def authenticate_quick_connect(
|
||||
self, secret: str, url: str | None = None
|
||||
) -> AuthResult:
|
||||
"""
|
||||
After poll_quick_connect returns "Authorized", call
|
||||
POST /Users/AuthenticateWithQuickConnect to exchange the secret
|
||||
for a real access token.
|
||||
|
||||
Returns AuthResult with token, user_id, username on success.
|
||||
"""
|
||||
base_url = url or await self._resolve_url()
|
||||
if not base_url:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message="No Jellyfin server URL configured.",
|
||||
)
|
||||
|
||||
logger.info("Exchanging QuickConnect secret for token on %s", base_url)
|
||||
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{base_url}/Users/AuthenticateWithQuickConnect",
|
||||
json={"Secret": secret},
|
||||
headers=self._qc_headers(),
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
"QuickConnect auth exchange failed: HTTP %s",
|
||||
resp.status_code,
|
||||
)
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message="Quick Connect authentication failed. The code may have expired.",
|
||||
)
|
||||
|
||||
data = resp.json()
|
||||
user = data.get("User", {})
|
||||
token = data.get("AccessToken", "")
|
||||
|
||||
if not token:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message="Jellyfin returned an unexpected response.",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"QuickConnect linked: user=%s (%s)",
|
||||
user.get("Name", "?"),
|
||||
user.get("Id", "?"),
|
||||
)
|
||||
|
||||
return AuthResult(
|
||||
success=True,
|
||||
external_user_id=user.get("Id", ""),
|
||||
external_name=user.get("Name", "?"),
|
||||
credentials={
|
||||
"token": token,
|
||||
"url": base_url,
|
||||
"user_id": user.get("Id", ""),
|
||||
},
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message=f"Could not reach {base_url} — connection timed out.",
|
||||
)
|
||||
except httpx.ConnectError:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message=f"Could not connect to {base_url}. Is the server running?",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Unexpected error during QuickConnect auth exchange")
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message="An unexpected error occurred during authentication.",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Login form (legacy — used by the REST API)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def render_login_form(self, token: str, discord_id: int) -> str:
|
||||
return f"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>Link Jellyfin</title>
|
||||
<style>
|
||||
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; }}
|
||||
h2 {{ margin-bottom: 4px; }}
|
||||
.sub {{ color: #666; margin-bottom: 24px; }}
|
||||
label {{ display: block; margin-top: 16px; font-weight: 600; }}
|
||||
input {{ width: 100%; padding: 10px; margin-top: 4px; border: 1px solid #ccc; border-radius: 6px; box-sizing: border-box; }}
|
||||
button {{ margin-top: 24px; width: 100%; padding: 12px; background: #aa5cc3; color: #fff; border: none; border-radius: 6px; font-size: 16px; cursor: pointer; }}
|
||||
button:hover {{ background: #9448b0; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h2>🔗 Link Jellyfin to Discord</h2>
|
||||
<p class="sub">Enter your Jellyfin server URL and credentials to link your account.</p>
|
||||
|
||||
<form method="POST" action="/api/v1/auth/login">
|
||||
<input type="hidden" name="token" value="{token}">
|
||||
<input type="hidden" name="discord_id" value="{discord_id}">
|
||||
<input type="hidden" name="service" value="jellyfin">
|
||||
|
||||
<label for="jellyfin_url">Jellyfin Server URL</label>
|
||||
<input id="jellyfin_url" name="jellyfin_url" type="url"
|
||||
placeholder="https://jellyfin.example.com" required>
|
||||
|
||||
<label for="username">Username</label>
|
||||
<input id="username" name="username" type="text"
|
||||
placeholder="Your Jellyfin username" required autofocus>
|
||||
|
||||
<label for="password">Password</label>
|
||||
<input id="password" name="password" type="password"
|
||||
placeholder="Your Jellyfin password" required>
|
||||
|
||||
<button type="submit">Link Account</button>
|
||||
</form>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Authentication
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def authenticate(self, form_data: dict) -> AuthResult:
|
||||
url = form_data.get("jellyfin_url", "").strip().rstrip("/")
|
||||
username = form_data.get("username", "").strip()
|
||||
password = form_data.get("password", "").strip()
|
||||
|
||||
if not url or not username or not password:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message="All fields are required (URL, username, password).",
|
||||
)
|
||||
|
||||
logger.info("Attempting Jellyfin login for '%s' on %s", username, url)
|
||||
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{url}/Users/AuthenticateByName",
|
||||
json={"Username": username, "Pw": password},
|
||||
headers={"X-Emby-Authorization": _EMBY_HEADER},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
"Jellyfin login failed for '%s': HTTP %s", username, resp.status_code
|
||||
)
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message=f"Login failed — check your server URL and credentials.",
|
||||
)
|
||||
|
||||
data = resp.json()
|
||||
user = data.get("User", {})
|
||||
token = data.get("AccessToken", "")
|
||||
|
||||
if not token:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message="Jellyfin returned an unexpected response.",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Jellyfin login OK: user=%s (%s)",
|
||||
user.get("Name", "?"),
|
||||
user.get("Id", "?"),
|
||||
)
|
||||
|
||||
return AuthResult(
|
||||
success=True,
|
||||
external_user_id=user.get("Id", ""),
|
||||
external_name=user.get("Name", username),
|
||||
credentials={
|
||||
"token": token,
|
||||
"url": url,
|
||||
"user_id": user.get("Id", ""),
|
||||
},
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message=f"Could not reach {url} — connection timed out. Check the URL.",
|
||||
)
|
||||
except httpx.ConnectError:
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message=f"Could not connect to {url}. Is the server running?",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Unexpected error during Jellyfin login")
|
||||
return AuthResult(
|
||||
success=False,
|
||||
error_message=f"An unexpected error occurred. Please try again.",
|
||||
)
|
||||
|
||||
|
||||
# Self-register at import time
|
||||
register_auth_service(JellyfinAuth())
|
||||
@@ -0,0 +1,36 @@
|
||||
from fastapi import Request
|
||||
from openai import OpenAI
|
||||
|
||||
from src.graph import create_agent_graph
|
||||
|
||||
|
||||
def get_llm_client(request: Request) -> OpenAI:
|
||||
"""FastAPI dependency — returns the singleton OpenAI client from app.state."""
|
||||
return request.app.state.llm_client
|
||||
|
||||
|
||||
def get_agent_graph(agent_id: str, request: Request):
|
||||
"""
|
||||
FastAPI dependency — returns the compiled LangGraph graph for *agent_id*.
|
||||
|
||||
Graphs are lazily compiled on first use and cached on app.state so each
|
||||
agent's graph is only built once per process lifetime.
|
||||
"""
|
||||
cache: dict = request.app.state.agent_graphs
|
||||
|
||||
if agent_id not in cache:
|
||||
from agents import get as get_agent
|
||||
|
||||
agent = get_agent(agent_id)
|
||||
if agent is None:
|
||||
# Fall back to the naked agent if the requested one doesn't exist
|
||||
agent_id = "naked"
|
||||
agent = get_agent(agent_id)
|
||||
|
||||
cache[agent_id] = create_agent_graph(
|
||||
client=request.app.state.llm_client,
|
||||
agent_skills=agent.skills,
|
||||
system_prompt=agent.build_system_prompt(),
|
||||
)
|
||||
|
||||
return cache[agent_id]
|
||||
@@ -0,0 +1 @@
|
||||
# Discord bot package
|
||||
@@ -0,0 +1,378 @@
|
||||
"""
|
||||
Discord bot that connects users to the LangGraph agent via private messages.
|
||||
|
||||
Architecture
|
||||
------------
|
||||
- The bot runs in-process alongside FastAPI (on a background asyncio task).
|
||||
- Private messages (DMs) are routed through the same LangGraph graphs that
|
||||
power the REST API — no HTTP loopback needed.
|
||||
- Per-user conversation history is maintained so the LLM has context.
|
||||
|
||||
Environment
|
||||
-----------
|
||||
DISCORD_BOT_TOKEN – the bot token from the Discord Developer Portal
|
||||
DISCORD_MAX_HISTORY – how many past messages to keep per user (default 7)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
import discord
|
||||
|
||||
from agents import list_all as list_all_agents
|
||||
from gateway.discord.conversation import ConversationStore
|
||||
from src.config import DEEPSEEK_API_KEY, get_config
|
||||
from src.graph import create_agent_graph
|
||||
from src.llm import create_client
|
||||
from src import auth_store
|
||||
from gateway.auth import list_auth_services, get_auth_service
|
||||
|
||||
logger = logging.getLogger("bot.discord")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
DISCORD_BOT_TOKEN = get_config("DISCORD_BOT_TOKEN") or ""
|
||||
DISCORD_MAX_HISTORY = int(get_config("DISCORD_MAX_HISTORY", "7"))
|
||||
DISCORD_DEFAULT_AGENT = get_config("DISCORD_DEFAULT_AGENT", "media-agent")
|
||||
BASE_URL = get_config("BASE_URL", "http://localhost:8000").rstrip("/")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM client shared by all agents (same as the REST API uses)
|
||||
# ---------------------------------------------------------------------------
|
||||
_llm_client = create_client(DEEPSEEK_API_KEY)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Conversation store — one per process
|
||||
# ---------------------------------------------------------------------------
|
||||
_conversations = ConversationStore(max_history=DISCORD_MAX_HISTORY)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Graph cache — lazy-compiled per agent, same pattern as api/dependencies.py
|
||||
# ---------------------------------------------------------------------------
|
||||
_agent_graphs: dict[str, object] = {}
|
||||
|
||||
|
||||
def _get_graph(agent_id: str):
|
||||
"""Return a compiled LangGraph for *agent_id*, building it on first use."""
|
||||
if agent_id not in _agent_graphs:
|
||||
agents = list_all_agents()
|
||||
agent = agents.get(agent_id, agents.get("naked"))
|
||||
_agent_graphs[agent_id] = create_agent_graph(
|
||||
client=_llm_client,
|
||||
agent_skills=agent.skills if agent else [],
|
||||
system_prompt=agent.build_system_prompt() if agent else (
|
||||
"You are a helpful, general-purpose assistant."
|
||||
),
|
||||
)
|
||||
return _agent_graphs[agent_id]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discord client
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AgentBot(discord.Client):
|
||||
"""A discord.py Client that connects users to the LangGraph agent."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# message_content lets us read DM text.
|
||||
# guilds is required so that mutual_guilds is populated — without it
|
||||
# every DM is silently ignored.
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
intents.guilds = True
|
||||
super().__init__(intents=intents)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def on_ready(self) -> None:
|
||||
logger.info("Bot logged in as %s (ID %s)", self.user, self.user.id)
|
||||
# Print a ready banner so the dev knows it's alive
|
||||
print(f"\n🤖 Discord bot ready — logged in as {self.user}\n")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Shared-guild helper — uses the REST API, no privileged intents
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _shares_guild(self, user: discord.User) -> bool:
|
||||
"""Return True if *user* and the bot share at least one guild."""
|
||||
for guild in self.guilds:
|
||||
try:
|
||||
member = await guild.fetch_member(user.id)
|
||||
if member is not None:
|
||||
return True
|
||||
except (discord.NotFound, discord.Forbidden, discord.HTTPException):
|
||||
continue
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Message handler — DMs only
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def on_message(self, message: discord.Message) -> None:
|
||||
# Never reply to ourselves
|
||||
if message.author == self.user:
|
||||
return
|
||||
|
||||
# |-- DM channel only for now ----------------------------------|
|
||||
if not isinstance(message.channel, discord.DMChannel):
|
||||
logger.debug("Ignoring message from #%s (not a DM)", message.channel)
|
||||
return
|
||||
# |--------------------------------------------------------------|
|
||||
|
||||
# |-- Shared-server gate — only users who share at least one --|
|
||||
# | guild with the bot can interact via DM. --|
|
||||
# | We use fetch_member (REST API) instead of --|
|
||||
# | User.mutual_guilds because the latter requires the --|
|
||||
# | privileged "members" intent. This way no privileged --|
|
||||
# | intents are needed. --|
|
||||
if not await self._shares_guild(message.author):
|
||||
logger.warning(
|
||||
"Blocking DM from %s — no mutual guilds.",
|
||||
message.author.name,
|
||||
)
|
||||
return
|
||||
# |--------------------------------------------------------------|
|
||||
|
||||
user_id = message.author.id
|
||||
content = message.content.strip()
|
||||
|
||||
# |-- Bot commands — handled directly, never sent to the LLM --|
|
||||
if await self._handle_command(message, user_id, content):
|
||||
return
|
||||
# |--------------------------------------------------------------|
|
||||
|
||||
# Show typing indicator while the graph runs
|
||||
async with message.channel.typing():
|
||||
try:
|
||||
reply = await self._run_agent(
|
||||
user_id=user_id,
|
||||
user_msg=message.content,
|
||||
)
|
||||
await message.channel.send(reply)
|
||||
except Exception:
|
||||
logger.exception("Agent run failed for user %s", user_id)
|
||||
await message.channel.send(
|
||||
"Sorry, something went wrong processing your request. "
|
||||
"Please try again in a moment."
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Bot commands
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _handle_command(
|
||||
self, message: discord.Message, user_id: int, content: str
|
||||
) -> bool:
|
||||
"""Handle bot commands (/login, /logout). Returns True if handled."""
|
||||
lower = content.lower()
|
||||
|
||||
# --- /login [service] ---
|
||||
if lower.startswith("/login"):
|
||||
parts = content.split()
|
||||
service = parts[1].lower() if len(parts) > 1 else None
|
||||
|
||||
available = list_auth_services()
|
||||
if not available:
|
||||
await message.channel.send("No auth services are configured yet.")
|
||||
return True
|
||||
|
||||
if service is None:
|
||||
svc_list = ", ".join(available)
|
||||
await message.channel.send(
|
||||
f"Available services to link: **{svc_list}**\n"
|
||||
f"Use `/login <service>` — e.g. `/login jellyfin`"
|
||||
)
|
||||
return True
|
||||
|
||||
if service not in available:
|
||||
await message.channel.send(
|
||||
f"Unknown service '{service}'. Available: {', '.join(available)}"
|
||||
)
|
||||
return True
|
||||
|
||||
if auth_store.is_authenticated(user_id, service):
|
||||
svc_display = (get_auth_service(service) and get_auth_service(service).display_name) or service
|
||||
await message.channel.send(
|
||||
f"You're already linked to **{svc_display}**! "
|
||||
f"Use `/logout {service}` to unlink."
|
||||
)
|
||||
return True
|
||||
|
||||
# --- Quick Connect flow ---
|
||||
svc = get_auth_service(service)
|
||||
if svc is None:
|
||||
await message.channel.send(f"Unknown service: {service}")
|
||||
return True
|
||||
|
||||
await message.channel.send(f"🔑 Starting **{svc.display_name}** Quick Connect…")
|
||||
|
||||
qc_result = await svc.initiate_quick_connect()
|
||||
if qc_result is None:
|
||||
await message.channel.send(
|
||||
f"❌ Could not start Quick Connect for **{svc.display_name}**.\n"
|
||||
"Check that `JELLYFIN_URL` is configured and the server is reachable."
|
||||
)
|
||||
return True
|
||||
|
||||
await message.channel.send(
|
||||
f"Open **{svc.display_name}** → **Quick Connect** and enter this code:\n\n"
|
||||
f"**`{qc_result.code}`**\n\n"
|
||||
f"⏳ Waiting for you to approve…"
|
||||
)
|
||||
|
||||
# Poll for authorization
|
||||
async with message.channel.typing():
|
||||
for attempt in range(24): # 24 × 5s = 2 minutes
|
||||
await asyncio.sleep(5)
|
||||
status = await svc.poll_quick_connect(qc_result.secret)
|
||||
|
||||
if status == "Authorized":
|
||||
auth_result = await svc.authenticate_quick_connect(qc_result.secret)
|
||||
if auth_result.success:
|
||||
auth_store.store_auth(
|
||||
discord_user_id=user_id,
|
||||
service=service,
|
||||
external_user_id=auth_result.external_user_id or "",
|
||||
external_name=auth_result.external_name or "",
|
||||
credentials=auth_result.credentials,
|
||||
)
|
||||
await message.channel.send(
|
||||
f"✅ Linked to **{svc.display_name}** as "
|
||||
f"**{auth_result.external_name}**!"
|
||||
)
|
||||
else:
|
||||
await message.channel.send(
|
||||
f"❌ Authentication failed: "
|
||||
f"{auth_result.error_message or 'Unknown error'}"
|
||||
)
|
||||
return True
|
||||
|
||||
elif status == "Expired":
|
||||
await message.channel.send(
|
||||
"⌛ The Quick Connect code expired. "
|
||||
f"Use `/login {service}` to try again."
|
||||
)
|
||||
return True
|
||||
|
||||
# else: still "Active" — keep polling
|
||||
|
||||
await message.channel.send(
|
||||
"⌛ Timed out waiting for Quick Connect approval. "
|
||||
f"Use `/login {service}` to try again."
|
||||
)
|
||||
return True
|
||||
|
||||
# --- /logout [service] ---
|
||||
if lower.startswith("/logout"):
|
||||
parts = content.split()
|
||||
service = parts[1].lower() if len(parts) > 1 else None
|
||||
|
||||
if service is None:
|
||||
linked = auth_store.list_services(user_id)
|
||||
if not linked:
|
||||
await message.channel.send("You don't have any linked services.")
|
||||
else:
|
||||
svc_list = ", ".join(linked)
|
||||
await message.channel.send(
|
||||
f"Linked services: **{svc_list}**\n"
|
||||
f"Use `/logout <service>` to unlink."
|
||||
)
|
||||
return True
|
||||
|
||||
if not auth_store.is_authenticated(user_id, service):
|
||||
await message.channel.send(f"You're not linked to **{service}**.")
|
||||
return True
|
||||
|
||||
auth_store.revoke(user_id, service)
|
||||
svc_display = (get_auth_service(service) and get_auth_service(service).display_name) or service
|
||||
await message.channel.send(f"Unlinked from **{svc_display}**. Use `/login {service}` to re-link.")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Agent invocation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _run_agent(self, *, user_id: int, user_msg: str) -> str:
|
||||
"""Build the message list from history, invoke the graph, store the
|
||||
reply, and return the assistant's final text."""
|
||||
|
||||
# 1. Pick agent — defaults to DISCORD_DEFAULT_AGENT env var.
|
||||
agent_id = DISCORD_DEFAULT_AGENT
|
||||
|
||||
# 2. Build message list from stored history + new user message
|
||||
history = _conversations.get_history(user_id)
|
||||
messages = [*history, {"role": "user", "content": user_msg}]
|
||||
|
||||
# 3. Run the LangGraph (tools execute inline if needed)
|
||||
graph = _get_graph(agent_id)
|
||||
state = {"messages": messages, "discord_user_id": user_id}
|
||||
result = await graph.ainvoke(state)
|
||||
|
||||
last_msg = result["messages"][-1]
|
||||
reply = last_msg.content or ""
|
||||
|
||||
# 4. Persist the conversation
|
||||
_conversations.append(user_id, user_msg, reply)
|
||||
|
||||
return reply
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bootstrap helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _start_bot_sync(token: str) -> None:
|
||||
"""Synchronous entry-point that runs the bot in a new asyncio event loop.
|
||||
|
||||
Called from a background thread so the main thread can keep running the
|
||||
FastAPI / uvicorn server.
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
async def _run() -> None:
|
||||
bot = AgentBot()
|
||||
try:
|
||||
await bot.start(token)
|
||||
except discord.LoginFailure:
|
||||
logger.error(
|
||||
"Discord login failed — check DISCORD_BOT_TOKEN in your .env file."
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Unhandled exception in bot event loop.")
|
||||
|
||||
loop.run_until_complete(_run())
|
||||
|
||||
|
||||
def start_in_background(token: str | None = None) -> None:
|
||||
"""Launch the Discord bot in a daemon thread.
|
||||
|
||||
Pass *token* explicitly if you already have it; otherwise it is read
|
||||
from the DISCORD_BOT_TOKEN env variable.
|
||||
"""
|
||||
token = token or DISCORD_BOT_TOKEN
|
||||
if not token:
|
||||
logger.warning(
|
||||
"DISCORD_BOT_TOKEN is not set — Discord bot will NOT start."
|
||||
)
|
||||
return
|
||||
|
||||
import threading
|
||||
|
||||
t = threading.Thread(
|
||||
target=_start_bot_sync,
|
||||
args=(token,),
|
||||
daemon=True,
|
||||
name="discord-bot",
|
||||
)
|
||||
t.start()
|
||||
logger.info("Discord bot thread started.")
|
||||
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Per-user conversation history store.
|
||||
|
||||
Each Discord user gets their own isolated message list. Only the last
|
||||
`max_history` messages are kept — older ones are silently dropped so the
|
||||
LLM context stays small.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
|
||||
logger = logging.getLogger("bot.conversation")
|
||||
|
||||
# role we assign to user messages inside the OpenAI-style message list
|
||||
_USER_ROLE = "user"
|
||||
# role we assign to bot responses
|
||||
_ASSISTANT_ROLE = "assistant"
|
||||
|
||||
|
||||
class ConversationStore:
|
||||
"""Thread-safe-ish in-memory store keyed by Discord user ID (int)."""
|
||||
|
||||
def __init__(self, max_history: int = 7) -> None:
|
||||
self._max = max_history
|
||||
self._store: Dict[int, List[dict]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_history(self, user_id: int) -> list[dict]:
|
||||
"""Return the last *max_history* messages for *user_id*."""
|
||||
return list(self._store.get(user_id, []))
|
||||
|
||||
def append(self, user_id: int, user_msg: str, assistant_reply: str) -> None:
|
||||
"""Store the user message + assistant reply, then trim to max."""
|
||||
if user_id not in self._store:
|
||||
self._store[user_id] = []
|
||||
|
||||
history = self._store[user_id]
|
||||
history.append({"role": _USER_ROLE, "content": user_msg})
|
||||
history.append({"role": _ASSISTANT_ROLE, "content": assistant_reply})
|
||||
|
||||
# Trim oldest messages if we exceeded the limit
|
||||
while len(history) > self._max:
|
||||
history.pop(0)
|
||||
|
||||
def clear(self, user_id: int) -> None:
|
||||
"""Wipe the conversation for a user."""
|
||||
self._store.pop(user_id, None)
|
||||
logger.info("Cleared conversation for user %s", user_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# debug / introspection
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def user_count(self) -> int:
|
||||
return len(self._store)
|
||||
@@ -0,0 +1,73 @@
|
||||
# Discord — Connector
|
||||
|
||||
The Discord module embeds a Discord bot **in-process** alongside FastAPI.
|
||||
It uses the same LangGraph graphs and LLM client as the REST API — there is
|
||||
no HTTP loopback, no separate process, and no code duplication.
|
||||
|
||||
---
|
||||
|
||||
## Files
|
||||
|
||||
| File | Purpose |
|
||||
|---|---|
|
||||
| `bot.py` | Discord `Client` subclass (`AgentBot`) — DM handler, command parser, graph invoker, Quick Connect orchestrator |
|
||||
| `conversation.py` | In-memory conversation history store, keyed by Discord user ID |
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Discord Gateway (websocket)
|
||||
│ DM: "What's trending?"
|
||||
▼
|
||||
discord.Client.on_message()
|
||||
│ 1. Check: is this a DM? shares a guild? not a command?
|
||||
│ 2. Build message history from ConversationStore
|
||||
│ 3. Append user message
|
||||
▼
|
||||
_create_agent_graph(agent_id="media-agent")
|
||||
│ Uses the exact same create_agent_graph() from src/graph.py
|
||||
│ as the REST API — same LLM client, same tools, same cache.
|
||||
▼
|
||||
graph.ainvoke({"messages": [...]})
|
||||
│ LangGraph runs agent_node → tool_node → agent_node → END
|
||||
▼
|
||||
Response text → split into ≤2000-char Discord messages → sent to user
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Commands
|
||||
|
||||
Commands are DMs that start with `/`. The bot parses them before hitting the
|
||||
LLM:
|
||||
|
||||
| Command | Action |
|
||||
|---|---|
|
||||
| `/login <service>` | Generate a one-time auth link, DM it to the user |
|
||||
| `/jellyfin login` | Alias for `/login jellyfin` |
|
||||
| `/help` | Show available agents and commands |
|
||||
| `/<agent_id>` | Switch to a different agent for future messages |
|
||||
|
||||
---
|
||||
|
||||
## Auth Flow (Quick Connect)
|
||||
|
||||
When a user types `/login jellyfin`:
|
||||
|
||||
1. Bot generates a one-time token via `auth_store`
|
||||
2. Bot calls `auth_store.create_link_token(discord_id, "jellyfin")`
|
||||
3. Bot DMs the user: `https://<BASE_URL>/api/v1/auth/login?service=jellyfin&token=...&discord_id=...`
|
||||
4. User clicks the link → FastAPI serves the Jellyfin login form (or Quick Connect prompt)
|
||||
5. User authenticates → credentials stored in `auth_store`
|
||||
6. Future tool calls (e.g. `watch_history`) automatically use the stored Jellyfin session
|
||||
|
||||
---
|
||||
|
||||
## Conversation Persistence
|
||||
|
||||
- Per-user history stored in `ConversationStore` (in-memory dict)
|
||||
- Max history length configurable via `DISCORD_MAX_HISTORY` env var (default: 7)
|
||||
- Oldest messages are silently dropped when the limit is exceeded
|
||||
- History is NOT persisted across restarts (future: could use SQLite)
|
||||
@@ -0,0 +1,106 @@
|
||||
"""JellyStat REST API — watch history, genre summary, and user summary."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncpg
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from gateway.jellystat.db import get_pool
|
||||
from gateway.jellystat.models import (
|
||||
GenreSummaryResponse,
|
||||
UserSummaryResponse,
|
||||
WatchHistoryResponse,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/jellystat", tags=["jellystat"])
|
||||
|
||||
DEFAULT_WINDOW_MINUTES = 10080 # 7 days
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /jellystat/history/{user_id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/history/{user_id}", response_model=WatchHistoryResponse)
|
||||
async def get_watch_history(
|
||||
user_id: str,
|
||||
minutes: int = Query(
|
||||
default=DEFAULT_WINDOW_MINUTES, ge=1, description="Time window in minutes"
|
||||
),
|
||||
pool: asyncpg.Pool = Depends(get_pool),
|
||||
):
|
||||
"""Return watch history grouped by title, ordered by most-watched first."""
|
||||
rows = await pool.fetch(
|
||||
"SELECT * FROM fn_user_watch_history($1, $2)", user_id, minutes
|
||||
)
|
||||
return WatchHistoryResponse(
|
||||
user_id=user_id,
|
||||
window_minutes=minutes,
|
||||
items=[
|
||||
{
|
||||
"title": r["title"],
|
||||
"watch_time_sec": float(r["watch_time_sec"]),
|
||||
"media_type": r["media_type"],
|
||||
}
|
||||
for r in rows
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /jellystat/genres/{user_id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/genres/{user_id}", response_model=GenreSummaryResponse)
|
||||
async def get_genre_summary(
|
||||
user_id: str,
|
||||
minutes: int = Query(
|
||||
default=DEFAULT_WINDOW_MINUTES, ge=1, description="Time window in minutes"
|
||||
),
|
||||
pool: asyncpg.Pool = Depends(get_pool),
|
||||
):
|
||||
"""Return total watch time per genre, ordered by most-watched first."""
|
||||
rows = await pool.fetch(
|
||||
"SELECT * FROM fn_user_genre_summary($1, $2)", user_id, minutes
|
||||
)
|
||||
return GenreSummaryResponse(
|
||||
user_id=user_id,
|
||||
window_minutes=minutes,
|
||||
genres=[
|
||||
{"genre": r["genre"], "watch_time_sec": float(r["watch_time_sec"])}
|
||||
for r in rows
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /jellystat/summary/{user_id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/summary/{user_id}", response_model=UserSummaryResponse)
|
||||
async def get_user_summary(
|
||||
user_id: str,
|
||||
pool: asyncpg.Pool = Depends(get_pool),
|
||||
):
|
||||
"""Return all-time summary: total watch time, most-watched titles, top genres."""
|
||||
rows = await pool.fetch("SELECT * FROM fn_user_summary($1)", user_id)
|
||||
|
||||
# fn_user_summary returns key-value rows — build a dict
|
||||
# asyncpg already deserialises JSONB → Python objects
|
||||
metrics: dict[str, object] = {r["metric"]: r["value"] for r in rows}
|
||||
|
||||
top_genres_raw = metrics.get("top_genres", [])
|
||||
top_genres: list[str] = top_genres_raw if isinstance(top_genres_raw, list) else []
|
||||
|
||||
return UserSummaryResponse(
|
||||
user_id=user_id,
|
||||
total_watch_time_sec=float(metrics.get("total_watch_time", 0)),
|
||||
most_watched_series=metrics.get("most_watched_series"),
|
||||
most_watched_movie=metrics.get("most_watched_movie"),
|
||||
total_last_30d_sec=float(metrics.get("total_last_30d", 0)),
|
||||
total_last_7d_sec=float(metrics.get("total_last_7d", 0)),
|
||||
top_genres=top_genres,
|
||||
)
|
||||
@@ -0,0 +1,130 @@
|
||||
"""PostgreSQL connection pool for the JellyStat database."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import asyncpg
|
||||
from fastapi import FastAPI, Request
|
||||
|
||||
from src.config import get_config
|
||||
|
||||
logger = logging.getLogger("gateway.jellystat")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DSN builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_dsn() -> str:
|
||||
"""Build a PostgreSQL DSN from individual environment variables."""
|
||||
host = get_config("JELLYSTAT_DB_HOST", "localhost")
|
||||
port = get_config("JELLYSTAT_DB_PORT", "5432")
|
||||
user = get_config("JELLYSTAT_DB_USER", "postgres")
|
||||
password = get_config("JELLYSTAT_DB_PASSWORD", "")
|
||||
dbname = get_config("JELLYSTAT_DB_NAME", "jfstat")
|
||||
return f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pool lifecycle (called from main.py lifespan)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def init_pool(app: FastAPI) -> None:
|
||||
"""Create the connection pool and store it on app.state."""
|
||||
dsn = _build_dsn()
|
||||
safe = dsn.split("@")[1] if "@" in dsn else dsn
|
||||
logger.info("Connecting to JellyStat database at %s", safe)
|
||||
|
||||
pool = await asyncpg.create_pool(dsn, min_size=1, max_size=5)
|
||||
app.state.jellystat_pool = pool
|
||||
|
||||
# Deploy functions on every startup (CREATE OR REPLACE is idempotent)
|
||||
await _ensure_functions(pool)
|
||||
|
||||
|
||||
async def close_pool(app: FastAPI) -> None:
|
||||
"""Close the pool on shutdown."""
|
||||
pool: asyncpg.Pool | None = getattr(app.state, "jellystat_pool", None)
|
||||
if pool:
|
||||
await pool.close()
|
||||
logger.info("JellyStat pool closed")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FastAPI dependency
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def get_pool(request: Request) -> asyncpg.Pool:
|
||||
"""Return the JellyStat connection pool from app state."""
|
||||
return request.app.state.jellystat_pool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Function deployment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _ensure_functions(pool: asyncpg.Pool) -> None:
|
||||
"""Run startup-functions.sql to create or replace all JellyStat functions."""
|
||||
sql_path = Path(__file__).parent / "startup-functions.sql"
|
||||
if not sql_path.exists():
|
||||
logger.warning("startup-functions.sql not found — skipping function deployment")
|
||||
return
|
||||
|
||||
sql = sql_path.read_text()
|
||||
statements = _split_sql(sql)
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
for stmt in statements:
|
||||
try:
|
||||
await conn.execute(stmt)
|
||||
except Exception:
|
||||
# Log but don't crash — functions might already exist
|
||||
logger.exception("Failed to deploy SQL statement — continuing")
|
||||
|
||||
logger.info("JellyStat functions deployed (%d statements)", len(statements))
|
||||
|
||||
|
||||
def _split_sql(sql: str) -> list[str]:
|
||||
"""
|
||||
Split a multi-statement SQL string into individual statements.
|
||||
|
||||
Respects $$ dollar-quoting so that semicolons inside function bodies
|
||||
don't cause premature splits. Pure comment lines (starting with ``--``)
|
||||
outside dollar-quoted blocks are stripped.
|
||||
"""
|
||||
statements: list[str] = []
|
||||
current: list[str] = []
|
||||
in_dollar_quote = False
|
||||
|
||||
for line in sql.split("\n"):
|
||||
stripped = line.strip()
|
||||
|
||||
# Skip pure comment lines outside of dollar-quoted blocks
|
||||
if not in_dollar_quote and stripped.startswith("--"):
|
||||
continue
|
||||
|
||||
# Toggle dollar-quote state whenever we see $$
|
||||
if "$$" in line:
|
||||
in_dollar_quote = not in_dollar_quote
|
||||
|
||||
current.append(line)
|
||||
|
||||
# Statement terminator: semicolon at end of line, outside $$ block
|
||||
if not in_dollar_quote and line.rstrip().endswith(";"):
|
||||
stmt = "\n".join(current).strip()
|
||||
if stmt:
|
||||
statements.append(stmt)
|
||||
current = []
|
||||
|
||||
# Catch any trailing statement that wasn't terminated by a semicolon
|
||||
if current:
|
||||
stmt = "\n".join(current).strip()
|
||||
if stmt:
|
||||
statements.append(stmt)
|
||||
|
||||
return statements
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Pydantic response models for the JellyStat API."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class WatchHistoryItem(BaseModel):
|
||||
title: str
|
||||
watch_time_sec: float
|
||||
media_type: str
|
||||
|
||||
|
||||
class WatchHistoryResponse(BaseModel):
|
||||
user_id: str
|
||||
window_minutes: int
|
||||
items: list[WatchHistoryItem]
|
||||
|
||||
|
||||
class GenreSummaryItem(BaseModel):
|
||||
genre: str
|
||||
watch_time_sec: float
|
||||
|
||||
|
||||
class GenreSummaryResponse(BaseModel):
|
||||
user_id: str
|
||||
window_minutes: int
|
||||
genres: list[GenreSummaryItem]
|
||||
|
||||
|
||||
class UserSummaryResponse(BaseModel):
|
||||
user_id: str
|
||||
total_watch_time_sec: float
|
||||
most_watched_series: str | None
|
||||
most_watched_movie: str | None
|
||||
total_last_30d_sec: float
|
||||
total_last_7d_sec: float
|
||||
top_genres: list[str]
|
||||
@@ -0,0 +1,224 @@
|
||||
-- ============================================================================
|
||||
-- JellyStat API Functions
|
||||
-- Parameterized database functions callable by the API layer as:
|
||||
-- SELECT * FROM fn_user_watch_history('user_id_here', 10080);
|
||||
-- SELECT * FROM fn_user_genre_summary('user_id_here', 10080);
|
||||
-- SELECT * FROM fn_user_summary('user_id_here');
|
||||
-- ============================================================================
|
||||
|
||||
-- ----------------------------------------------------------------------------
|
||||
-- 1. User Watch History
|
||||
-- Returns every distinct title watched in the last N minutes,
|
||||
-- grouped and summed by title, ordered by most-watched first.
|
||||
-- ----------------------------------------------------------------------------
|
||||
CREATE OR REPLACE FUNCTION public.fn_user_watch_history(
|
||||
p_user_id TEXT,
|
||||
p_minutes INTEGER DEFAULT 10080 -- 7 days in minutes
|
||||
)
|
||||
RETURNS TABLE(
|
||||
title TEXT,
|
||||
watch_time_sec NUMERIC,
|
||||
media_type TEXT
|
||||
)
|
||||
LANGUAGE sql
|
||||
STABLE
|
||||
AS $$
|
||||
SELECT
|
||||
COALESCE(a."SeriesName", a."NowPlayingItemName") AS title,
|
||||
SUM(a."PlaybackDuration")::NUMERIC AS watch_time_sec,
|
||||
CASE
|
||||
WHEN a."SeriesName" IS NOT NULL THEN 'series'
|
||||
ELSE 'movie'
|
||||
END AS media_type
|
||||
FROM jf_playback_activity a
|
||||
WHERE a."UserId" = p_user_id
|
||||
AND a."ActivityDateInserted"
|
||||
>= NOW() - (p_minutes * INTERVAL '1 minute')
|
||||
GROUP BY
|
||||
COALESCE(a."SeriesName", a."NowPlayingItemName"),
|
||||
CASE WHEN a."SeriesName" IS NOT NULL THEN 'series' ELSE 'movie' END
|
||||
ORDER BY watch_time_sec DESC;
|
||||
$$;
|
||||
|
||||
-- ----------------------------------------------------------------------------
|
||||
-- 2. Genre Summary
|
||||
-- Returns total watch time per genre for a user over the last N minutes.
|
||||
-- Resolves genres for both movies (directly on the item) and series
|
||||
-- episodes (via jf_library_episodes → jf_library_items chain).
|
||||
-- ----------------------------------------------------------------------------
|
||||
CREATE OR REPLACE FUNCTION public.fn_user_genre_summary(
|
||||
p_user_id TEXT,
|
||||
p_minutes INTEGER DEFAULT 10080
|
||||
)
|
||||
RETURNS TABLE(
|
||||
genre TEXT,
|
||||
watch_time_sec NUMERIC
|
||||
)
|
||||
LANGUAGE sql
|
||||
STABLE
|
||||
AS $$
|
||||
WITH movie_genres AS (
|
||||
-- Movies: join playback directly to library_items on NowPlayingItemId
|
||||
SELECT
|
||||
genre_item.value AS genre,
|
||||
SUM(a."PlaybackDuration") AS watch_time_sec
|
||||
FROM jf_playback_activity a
|
||||
JOIN jf_library_items i
|
||||
ON i."Id" = a."NowPlayingItemId"
|
||||
CROSS JOIN LATERAL jsonb_array_elements_text(i."Genres") AS genre_item(value)
|
||||
WHERE a."UserId" = p_user_id
|
||||
AND a."SeriesName" IS NULL -- movies only
|
||||
AND a."ActivityDateInserted"
|
||||
>= NOW() - (p_minutes * INTERVAL '1 minute')
|
||||
AND i."Genres" IS NOT NULL
|
||||
AND jsonb_array_length(i."Genres") > 0
|
||||
GROUP BY genre_item.value
|
||||
),
|
||||
series_genres AS (
|
||||
-- Series: playback → episodes → series item → genres
|
||||
SELECT
|
||||
genre_item.value AS genre,
|
||||
SUM(a."PlaybackDuration") AS watch_time_sec
|
||||
FROM jf_playback_activity a
|
||||
JOIN jf_library_episodes e
|
||||
ON e."EpisodeId" = a."EpisodeId"
|
||||
JOIN jf_library_items i
|
||||
ON i."Id" = e."SeriesId"
|
||||
CROSS JOIN LATERAL jsonb_array_elements_text(i."Genres") AS genre_item(value)
|
||||
WHERE a."UserId" = p_user_id
|
||||
AND a."SeriesName" IS NOT NULL -- TV episodes only
|
||||
AND a."ActivityDateInserted"
|
||||
>= NOW() - (p_minutes * INTERVAL '1 minute')
|
||||
AND i."Genres" IS NOT NULL
|
||||
AND jsonb_array_length(i."Genres") > 0
|
||||
GROUP BY genre_item.value
|
||||
),
|
||||
combined AS (
|
||||
SELECT genre, watch_time_sec FROM movie_genres
|
||||
UNION ALL
|
||||
SELECT genre, watch_time_sec FROM series_genres
|
||||
)
|
||||
SELECT
|
||||
genre,
|
||||
SUM(watch_time_sec)::NUMERIC AS watch_time_sec
|
||||
FROM combined
|
||||
GROUP BY genre
|
||||
ORDER BY watch_time_sec DESC;
|
||||
$$;
|
||||
|
||||
-- ----------------------------------------------------------------------------
|
||||
-- 3. User Summary
|
||||
-- One-shot dashboard: all-time stats + recent windows + top genres.
|
||||
-- Returns key-value rows that the API trivially converts to a JSON object
|
||||
-- with Object.fromEntries() or similar.
|
||||
-- ----------------------------------------------------------------------------
|
||||
CREATE OR REPLACE FUNCTION public.fn_user_summary(
|
||||
p_user_id TEXT
|
||||
)
|
||||
RETURNS TABLE(
|
||||
metric TEXT,
|
||||
value JSONB
|
||||
)
|
||||
LANGUAGE sql
|
||||
STABLE
|
||||
AS $$
|
||||
-- total_watch_time (all time)
|
||||
SELECT 'total_watch_time'::TEXT AS metric,
|
||||
to_jsonb(COALESCE(SUM("PlaybackDuration"), 0)::NUMERIC) AS value
|
||||
FROM jf_playback_activity
|
||||
WHERE "UserId" = p_user_id
|
||||
|
||||
UNION ALL
|
||||
|
||||
-- most_watched_series (by total watch time)
|
||||
SELECT 'most_watched_series'::TEXT AS metric,
|
||||
COALESCE(
|
||||
(SELECT to_jsonb("SeriesName")
|
||||
FROM jf_playback_activity
|
||||
WHERE "UserId" = p_user_id
|
||||
AND "SeriesName" IS NOT NULL
|
||||
GROUP BY "SeriesName"
|
||||
ORDER BY SUM("PlaybackDuration") DESC
|
||||
LIMIT 1),
|
||||
'null'::JSONB
|
||||
) AS value
|
||||
|
||||
UNION ALL
|
||||
|
||||
-- most_watched_movie (by total watch time)
|
||||
SELECT 'most_watched_movie'::TEXT AS metric,
|
||||
COALESCE(
|
||||
(SELECT to_jsonb("NowPlayingItemName")
|
||||
FROM jf_playback_activity
|
||||
WHERE "UserId" = p_user_id
|
||||
AND "SeriesName" IS NULL
|
||||
GROUP BY "NowPlayingItemName"
|
||||
ORDER BY SUM("PlaybackDuration") DESC
|
||||
LIMIT 1),
|
||||
'null'::JSONB
|
||||
) AS value
|
||||
|
||||
UNION ALL
|
||||
|
||||
-- total_watch_time_last_month (last 30 days)
|
||||
SELECT 'total_last_30d'::TEXT AS metric,
|
||||
to_jsonb(COALESCE(SUM("PlaybackDuration"), 0)::NUMERIC) AS value
|
||||
FROM jf_playback_activity
|
||||
WHERE "UserId" = p_user_id
|
||||
AND "ActivityDateInserted" >= NOW() - INTERVAL '30 days'
|
||||
|
||||
UNION ALL
|
||||
|
||||
-- total_watch_time_last_week (last 7 days)
|
||||
SELECT 'total_last_7d'::TEXT AS metric,
|
||||
to_jsonb(COALESCE(SUM("PlaybackDuration"), 0)::NUMERIC) AS value
|
||||
FROM jf_playback_activity
|
||||
WHERE "UserId" = p_user_id
|
||||
AND "ActivityDateInserted" >= NOW() - INTERVAL '7 days'
|
||||
|
||||
UNION ALL
|
||||
|
||||
-- top_genres (top 3 all-time, as a JSON array)
|
||||
SELECT 'top_genres'::TEXT AS metric,
|
||||
COALESCE(
|
||||
(SELECT jsonb_agg(genre ORDER BY watch_time_sec DESC)
|
||||
FROM (
|
||||
SELECT genre, SUM(watch_time_sec) AS watch_time_sec
|
||||
FROM (
|
||||
-- movies
|
||||
SELECT
|
||||
genre_item.value AS genre,
|
||||
SUM(a."PlaybackDuration") AS watch_time_sec
|
||||
FROM jf_playback_activity a
|
||||
JOIN jf_library_items i ON i."Id" = a."NowPlayingItemId"
|
||||
CROSS JOIN LATERAL jsonb_array_elements_text(i."Genres") AS genre_item(value)
|
||||
WHERE a."UserId" = p_user_id
|
||||
AND a."SeriesName" IS NULL
|
||||
AND i."Genres" IS NOT NULL
|
||||
AND jsonb_array_length(i."Genres") > 0
|
||||
GROUP BY genre_item.value
|
||||
|
||||
UNION ALL
|
||||
|
||||
-- series
|
||||
SELECT
|
||||
genre_item.value AS genre,
|
||||
SUM(a."PlaybackDuration") AS watch_time_sec
|
||||
FROM jf_playback_activity a
|
||||
JOIN jf_library_episodes e ON e."EpisodeId" = a."EpisodeId"
|
||||
JOIN jf_library_items i ON i."Id" = e."SeriesId"
|
||||
CROSS JOIN LATERAL jsonb_array_elements_text(i."Genres") AS genre_item(value)
|
||||
WHERE a."UserId" = p_user_id
|
||||
AND a."SeriesName" IS NOT NULL
|
||||
AND i."Genres" IS NOT NULL
|
||||
AND jsonb_array_length(i."Genres") > 0
|
||||
GROUP BY genre_item.value
|
||||
) combined
|
||||
GROUP BY genre
|
||||
ORDER BY SUM(watch_time_sec) DESC
|
||||
LIMIT 3
|
||||
) top3
|
||||
),
|
||||
'[]'::JSONB
|
||||
) AS value;
|
||||
$$;
|
||||
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
Auth API — generic endpoints for linking Discord users to external services.
|
||||
|
||||
GET /api/v1/auth/login?service=X&token=Y&discord_id=Z
|
||||
Validates the link token and serves a service-specific login form.
|
||||
|
||||
POST /api/v1/auth/login
|
||||
Accepts the form submission, validates credentials against the service,
|
||||
stores the session, and returns a result page.
|
||||
|
||||
GET /api/v1/auth/status?discord_id=Z
|
||||
Returns which services are linked for this Discord user.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Form, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from gateway.auth import get_auth_service, list_auth_services
|
||||
from src import auth_store
|
||||
|
||||
logger = logging.getLogger("gateway.auth")
|
||||
|
||||
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /auth/login — serve the login form
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/login")
|
||||
async def login_form(
|
||||
service: str,
|
||||
token: str,
|
||||
discord_id: int,
|
||||
):
|
||||
"""Validate the one-time link token and return a service-specific login form."""
|
||||
|
||||
# Validate the token WITHOUT consuming it (the POST will consume it)
|
||||
result = auth_store.validate_token(token)
|
||||
if result is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired link token.")
|
||||
|
||||
uid, svc = result
|
||||
if uid != discord_id or svc != service:
|
||||
raise HTTPException(status_code=400, detail="Token does not match the request.")
|
||||
|
||||
# Look up the AuthService
|
||||
svc_obj = get_auth_service(service)
|
||||
if svc_obj is None:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown service: {service}")
|
||||
|
||||
logger.info("Serving login form: user=%s service=%s", discord_id, service)
|
||||
return HTMLResponse(svc_obj.render_login_form(token, discord_id))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /auth/login — handle form submission
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/login")
|
||||
async def login_submit(request: Request):
|
||||
"""Handle the login form POST: validate credentials, store auth, show result."""
|
||||
|
||||
# Parse form data
|
||||
form = await request.form()
|
||||
token = form.get("token", "")
|
||||
discord_id_str = form.get("discord_id", "")
|
||||
service = form.get("service", "")
|
||||
|
||||
if not token or not discord_id_str or not service:
|
||||
raise HTTPException(status_code=400, detail="Missing required fields.")
|
||||
|
||||
try:
|
||||
discord_id = int(discord_id_str)
|
||||
except (ValueError, TypeError):
|
||||
raise HTTPException(status_code=400, detail="Invalid discord_id.")
|
||||
|
||||
# Consume the token on POST (the GET only validated, didn't consume)
|
||||
result = auth_store.consume_token(token)
|
||||
if result is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired link token.")
|
||||
|
||||
# Look up the AuthService
|
||||
svc_obj = get_auth_service(service)
|
||||
if svc_obj is None:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown service: {service}")
|
||||
|
||||
# Collect service-specific form fields (everything except token, discord_id, service)
|
||||
form_data: dict[str, str] = {}
|
||||
for key, value in form.items():
|
||||
if key not in ("token", "discord_id", "service"):
|
||||
form_data[key] = str(value)
|
||||
|
||||
# Authenticate against the service
|
||||
auth_result = await svc_obj.authenticate(form_data)
|
||||
|
||||
if not auth_result.success:
|
||||
return HTMLResponse(
|
||||
status_code=401,
|
||||
content=f"""<!DOCTYPE html>
|
||||
<html><head><meta charset="utf-8"><title>Login Failed</title>
|
||||
<style>
|
||||
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; }}
|
||||
h2 {{ color: #d32f2f; }}
|
||||
a {{ color: #aa5cc3; }}
|
||||
</style></head><body>
|
||||
<h2>❌ Login Failed</h2>
|
||||
<p>{auth_result.error_message or "Authentication failed. Please try again."}</p>
|
||||
<p><a href="javascript:history.back()">← Go back and try again</a></p>
|
||||
</body></html>""",
|
||||
)
|
||||
|
||||
# Store the successful auth
|
||||
auth_store.store_auth(
|
||||
discord_user_id=discord_id,
|
||||
service=service,
|
||||
external_user_id=auth_result.external_user_id or "",
|
||||
external_name=auth_result.external_name or "",
|
||||
credentials=auth_result.credentials,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Auth linked: discord=%s → %s (%s)",
|
||||
discord_id,
|
||||
service,
|
||||
auth_result.external_name,
|
||||
)
|
||||
|
||||
return HTMLResponse(f"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>Account Linked</title>
|
||||
<style>
|
||||
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; text-align: center; }}
|
||||
h1 {{ color: #388e3c; }}
|
||||
.name {{ font-weight: bold; color: #aa5cc3; }}
|
||||
p {{ color: #666; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>✅ Account Linked!</h1>
|
||||
<p>Logged in as <span class="name">{auth_result.external_name}</span> on <strong>{svc_obj.display_name}</strong>.</p>
|
||||
<p>You can close this page and return to Discord.</p>
|
||||
</body>
|
||||
</html>""")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /auth/status — get all linked services for a Discord user
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/Discord/status")
|
||||
async def auth_status(discord_id: int):
|
||||
"""
|
||||
Return all services linked to this Discord user with full details.
|
||||
|
||||
Response:
|
||||
{
|
||||
"discord_id": 123456789,
|
||||
"linked_services": {
|
||||
"jellyfin": {
|
||||
"external_user_id": "abc123",
|
||||
"external_name": "Tim",
|
||||
"linked_at": "2026-05-25T10:00:00",
|
||||
"credentials": {
|
||||
"token": "jwt...",
|
||||
"url": "http://jellyfin:8096",
|
||||
"user_id": "abc123"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
auths = auth_store.get_all_auths(discord_id)
|
||||
|
||||
linked_services: dict[str, dict] = {}
|
||||
for auth in auths:
|
||||
svc_name = auth["service"]
|
||||
linked_services[svc_name] = {
|
||||
"external_user_id": auth["external_user_id"],
|
||||
"external_name": auth["external_name"],
|
||||
"linked_at": auth["linked_at"],
|
||||
"credentials": auth["credentials"],
|
||||
}
|
||||
|
||||
return {
|
||||
"discord_id": discord_id,
|
||||
"linked_services": linked_services,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /auth/reset — wipe auth store (DEV ONLY)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from src.config import get_config # noqa: E402
|
||||
|
||||
@router.post("/reset")
|
||||
async def reset_auth():
|
||||
"""
|
||||
Reset the entire auth store — clears all link tokens and user auth records.
|
||||
|
||||
Only enabled when ALLOW_AUTH_RESET=true in the environment.
|
||||
Returns 403 in production.
|
||||
"""
|
||||
if get_config("ALLOW_AUTH_RESET", "false").lower() != "true":
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Auth reset is disabled. Set ALLOW_AUTH_RESET=true to enable (dev only).",
|
||||
)
|
||||
|
||||
auth_store.reset_all()
|
||||
logger.warning("Auth store reset via API endpoint.")
|
||||
return {"status": "ok", "message": "Auth store cleared — all tokens and auth records removed."}
|
||||
@@ -0,0 +1,241 @@
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
|
||||
from gateway.dependencies import get_llm_client, get_agent_graph
|
||||
from agents import get as get_agent, list_all as list_all_agents
|
||||
from src.state import AgentState
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
session_id: str | None = None
|
||||
agent_id: str | None = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
messages: list[dict]
|
||||
stream: bool = False
|
||||
model: str = "deepseek-chat"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _resolve_agent(agent_id: str | None = None, model: str | None = None):
|
||||
"""
|
||||
1. explicit agent_id
|
||||
2. model field (OpenWebUI sends this — maps to agent_id if registered)
|
||||
3. fallback to "naked"
|
||||
"""
|
||||
lookup = agent_id or model
|
||||
if lookup is None:
|
||||
return get_agent("naked")
|
||||
agent = get_agent(lookup)
|
||||
return agent if agent else get_agent("naked")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LangGraph helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _invoke_graph(graph, messages: list[dict]) -> str:
|
||||
"""Run the graph synchronously (non-streaming) and return the final text."""
|
||||
state: AgentState = {"messages": messages}
|
||||
result = await graph.ainvoke(state)
|
||||
last_msg = result["messages"][-1]
|
||||
return last_msg.content or ""
|
||||
|
||||
|
||||
async def _stream_graph(graph, messages: list[dict]):
|
||||
"""
|
||||
Run the graph and stream the final response token-by-token.
|
||||
|
||||
LangGraph's astream_events would require langchain-openai's ChatOpenAI
|
||||
to intercept LLM chunks. Instead we run the graph to completion (tools
|
||||
execute silently) and then stream the final text content character by
|
||||
character — this gives the client a real SSE stream without adding new
|
||||
dependencies.
|
||||
"""
|
||||
state: AgentState = {"messages": messages}
|
||||
result = await graph.ainvoke(state)
|
||||
content = result["messages"][-1].content or ""
|
||||
# Yield token-by-token so the SSE client sees incremental output
|
||||
for token in content:
|
||||
yield token
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-streaming run (kept for /chat/sync and sync completions)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def run_agent_with_tools(
|
||||
request: Request,
|
||||
messages: list[dict],
|
||||
agent_id: str | None = None,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""Send messages through the agent's LangGraph. Non-streaming."""
|
||||
agent = _resolve_agent(agent_id, model)
|
||||
graph = get_agent_graph(agent.agent_id, request)
|
||||
return await _invoke_graph(graph, messages)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Streaming generator (kept for /chat and stream completions)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def run_agent_stream(
|
||||
request: Request,
|
||||
messages: list[dict],
|
||||
agent_id: str | None = None,
|
||||
model: str | None = None,
|
||||
):
|
||||
"""Async generator — yields tokens via the agent's LangGraph."""
|
||||
agent = _resolve_agent(agent_id, model)
|
||||
graph = get_agent_graph(agent.agent_id, request)
|
||||
async for token in _stream_graph(graph, messages):
|
||||
yield token
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/")
|
||||
def root():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
async def chat(
|
||||
req: ChatRequest,
|
||||
request: Request,
|
||||
client: OpenAI = Depends(get_llm_client),
|
||||
):
|
||||
"""Streaming chat — single message, no history."""
|
||||
messages = [{"role": "user", "content": req.message}]
|
||||
|
||||
async def event_stream():
|
||||
async for token in run_agent_stream(request, messages, req.agent_id):
|
||||
payload = json.dumps({"token": token, "session_id": req.session_id})
|
||||
yield f"data: {payload}\n\n"
|
||||
yield f"data: {json.dumps({'done': True, 'session_id': req.session_id})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat/sync")
|
||||
async def chat_sync(
|
||||
req: ChatRequest,
|
||||
request: Request,
|
||||
client: OpenAI = Depends(get_llm_client),
|
||||
):
|
||||
"""Non-streaming chat — single message."""
|
||||
messages = [{"role": "user", "content": req.message}]
|
||||
response = await run_agent_with_tools(request, messages, req.agent_id)
|
||||
return {"response": response, "session_id": req.session_id}
|
||||
|
||||
|
||||
@router.get("/agents")
|
||||
def list_agents():
|
||||
"""Return all registered agents."""
|
||||
return {
|
||||
"agents": [
|
||||
{
|
||||
"agent_id": a.agent_id,
|
||||
"description": a.description,
|
||||
"skills": a.skills,
|
||||
}
|
||||
for a in list_all_agents().values()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
def list_models():
|
||||
"""Return agents as selectable models for OpenWebUI."""
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": a.agent_id,
|
||||
"object": "model",
|
||||
"created": 0,
|
||||
"owned_by": "local-agent",
|
||||
}
|
||||
for a in list_all_agents().values()
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.post("/chat/completions")
|
||||
async def chat_completions(
|
||||
req: ChatCompletionRequest,
|
||||
request: Request,
|
||||
client: OpenAI = Depends(get_llm_client),
|
||||
):
|
||||
"""OpenAI-compatible /chat/completions — supports stream=True.
|
||||
Multi-turn: req.messages contains the FULL conversation history.
|
||||
Agent resolved from the model field (OpenWebUI sends this).
|
||||
"""
|
||||
agent = _resolve_agent(model=req.model)
|
||||
|
||||
if req.stream:
|
||||
async def sse_stream():
|
||||
async for token in run_agent_stream(
|
||||
request, req.messages, agent_id=agent.agent_id,
|
||||
):
|
||||
chunk = {
|
||||
"id": "chatcmpl-local",
|
||||
"object": "chat.completion.chunk",
|
||||
"choices": [
|
||||
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
||||
],
|
||||
}
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
final_chunk = {
|
||||
"id": "chatcmpl-local",
|
||||
"object": "chat.completion.chunk",
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}
|
||||
yield f"data: {json.dumps(final_chunk)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
sse_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
|
||||
# Non-streaming — full history, LangGraph agent
|
||||
response = await run_agent_with_tools(
|
||||
request, req.messages, agent_id=agent.agent_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": "chatcmpl-local",
|
||||
"object": "chat.completion",
|
||||
"created": 0,
|
||||
"model": req.model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": response},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
# V1 — Chat & Agent API Endpoints
|
||||
|
||||
This is the primary HTTP API surface for the chatbot agent system. It exposes
|
||||
both a custom streaming chat endpoint and an OpenAI-compatible
|
||||
`/chat/completions` endpoint so it works as a drop-in backend for OpenWebUI,
|
||||
LibreChat, or any OpenAI-compatible client.
|
||||
|
||||
---
|
||||
|
||||
## Endpoints
|
||||
|
||||
| Method | Path | Description |
|
||||
|---|---|---|
|
||||
| `GET ` | `/v1/` | Health check — returns `{"status": "ok"}` |
|
||||
| `GET ` | `/v1/agents` | List all registered agents (id + description) |
|
||||
| `GET ` | `/v1/models` | OpenAI-compatible model list (one entry per agent) |
|
||||
| `POST` | `/v1/chat` | Chat with an agent — streaming (SSE) |
|
||||
| `POST` | `/v1/chat/sync` | Chat with an agent — non-streaming |
|
||||
| `POST` | `/v1/chat/completions` | OpenAI-compatible chat completions (supports `stream: true`) |
|
||||
|
||||
All `/v1/*` endpoints are mounted by `main.py` via:
|
||||
|
||||
```python
|
||||
app.include_router(v1_router, prefix="/v1")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Agent Resolution
|
||||
|
||||
Each request can target a specific agent. The resolution order is:
|
||||
|
||||
1. **Explicit `agent_id`** field in the request body
|
||||
2. **OpenAI `model` field** (OpenWebUI sends this — mapped to `agent_id` if a matching agent is registered)
|
||||
3. **Fallback** to the `"naked"` agent (a plain LLM with no tools)
|
||||
|
||||
This means an OpenWebUI client can simply set `model: "media-agent"` and get
|
||||
the full Media Agent with Seerr tools.
|
||||
|
||||
---
|
||||
|
||||
## Request Flow
|
||||
|
||||
```
|
||||
Client (OpenWebUI / HTTP)
|
||||
│ POST /v1/chat/completions
|
||||
│ { model: "media-agent", messages: [...], stream: true/false }
|
||||
▼
|
||||
chat_completions()
|
||||
│ 1. _resolve_agent(req.model) → Agent(id="media-agent", skills=[...])
|
||||
│ 2. get_agent_graph("media-agent", request)
|
||||
│ → lazy-compiled LangGraph StateGraph, cached on app.state
|
||||
│ 3. stream=True → _stream_graph(graph, messages) → SSE token stream
|
||||
│ stream=False → _invoke_graph(graph, messages) → plain response
|
||||
▼
|
||||
LangGraph StateGraph (src/graph.py)
|
||||
│
|
||||
├── agent_node: calls LLM with system prompt + tool definitions
|
||||
│ └── LLM returns text OR tool_calls
|
||||
│
|
||||
├── _should_continue: if tool_calls → tool_node, else → END
|
||||
│
|
||||
└── tool_node: executes tool via agents/skills system → ToolMessage
|
||||
└── loops back to agent_node with the result
|
||||
```
|
||||
|
||||
For a detailed walkthrough, see [api.md](../api.md).
|
||||
|
||||
---
|
||||
|
||||
## Streaming
|
||||
|
||||
Two streaming modes exist:
|
||||
|
||||
### SSE (Server-Sent Events) — `/v1/chat`
|
||||
```
|
||||
data: {"token": "Here"}
|
||||
data: {"token": " are"}
|
||||
data: {"token": " the"}
|
||||
...
|
||||
data: [DONE]
|
||||
```
|
||||
|
||||
The graph runs to completion (tools execute silently), then the final text is
|
||||
yielded token-by-token as SSE events.
|
||||
|
||||
### OpenAI-compatible — `/v1/chat/completions` with `stream: true`
|
||||
```
|
||||
data: {"id":"...","object":"chat.completion.chunk","choices":[{"delta":{"content":"Hello"}}]}
|
||||
data: {"id":"...","object":"chat.completion.chunk","choices":[{"delta":{"content":"!"}}]}
|
||||
data: [DONE]
|
||||
```
|
||||
|
||||
> **Future improvement:** true token-level streaming (tokens appear as the LLM
|
||||
> generates them) would require using `langchain-openai`'s `ChatOpenAI` in
|
||||
> place of the raw `openai` client. The current approach avoids adding that
|
||||
> dependency.
|
||||
|
||||
---
|
||||
|
||||
## Dependencies
|
||||
|
||||
Endpoints receive shared singletons via FastAPI `Depends`:
|
||||
|
||||
- **`get_llm_client(request)`** → returns `request.app.state.llm_client` (OpenAI client singleton, created once in `main.py`)
|
||||
- **`get_agent_graph(agent_id, request)`** → returns a lazy-compiled LangGraph from `request.app.state.agent_graphs`
|
||||
@@ -0,0 +1,73 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from gateway.v1.auth import router as auth_router
|
||||
from gateway.v1.chat import router as v1_router
|
||||
from gateway.jellystat.api import router as jellystat_router
|
||||
from src.config import DEEPSEEK_API_KEY, get_config
|
||||
from src.llm import create_client
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Logging — tool calls will appear in the uvicorn console
|
||||
# ---------------------------------------------------------------------------
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Load all agents, skills, AND auth services so they self-register at startup
|
||||
# ---------------------------------------------------------------------------
|
||||
from agents import load_all_agents # noqa: E402
|
||||
|
||||
load_all_agents()
|
||||
|
||||
import gateway.auth.jellyfin # noqa: E402 — self-registers JellyfinAuth
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lifespan
|
||||
# ---------------------------------------------------------------------------
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
from gateway.discord.bot import start_in_background # noqa: E402
|
||||
from gateway.jellystat.db import init_pool, close_pool # noqa: E402
|
||||
|
||||
await init_pool(app)
|
||||
start_in_background()
|
||||
|
||||
yield
|
||||
|
||||
await close_pool(app)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# App
|
||||
# ---------------------------------------------------------------------------
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Singletons (stored on app.state so every module can reach them via Depends)
|
||||
# ---------------------------------------------------------------------------
|
||||
app.state.llm_client = create_client(DEEPSEEK_API_KEY)
|
||||
|
||||
# Lazy-compiled LangGraph graphs — populated on first use per agent
|
||||
app.state.agent_graphs: dict = {}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routers
|
||||
# ---------------------------------------------------------------------------
|
||||
app.include_router(v1_router, prefix="/v1")
|
||||
app.include_router(auth_router)
|
||||
app.include_router(jellystat_router)
|
||||
@@ -0,0 +1,10 @@
|
||||
fastapi
|
||||
openai
|
||||
uvicorn
|
||||
python-dotenv
|
||||
httpx
|
||||
langgraph
|
||||
langgraph-checkpoint
|
||||
discord.py
|
||||
python-multipart
|
||||
asyncpg
|
||||
@@ -0,0 +1,358 @@
|
||||
"""
|
||||
Auth Store — SQLite-backed persistence for Discord-to-service authentication.
|
||||
|
||||
Two tables:
|
||||
- link_tokens : one-time tokens sent via Discord DM to initiate login
|
||||
- user_auth : per-user, per-service credentials (Jellyfin token, etc.)
|
||||
|
||||
Thread-safe via WAL mode and a shared lock. No passwords are ever stored
|
||||
— only opaque service tokens (e.g. Jellyfin AccessToken).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from src.config import get_config
|
||||
|
||||
logger = logging.getLogger("auth_store")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
AUTH_DB_PATH = get_config("AUTH_DB_PATH", "data/auth.db")
|
||||
TOKEN_EXPIRY_MINUTES = int(get_config("AUTH_TOKEN_EXPIRY", "10"))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Singleton handle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_db_path: Path | None = None
|
||||
_db_lock = threading.Lock()
|
||||
|
||||
|
||||
def _resolve_path() -> Path:
|
||||
"""Turn AUTH_DB_PATH into an absolute path, creating parent dirs."""
|
||||
global _db_path
|
||||
if _db_path is not None:
|
||||
return _db_path
|
||||
p = Path(AUTH_DB_PATH)
|
||||
if not p.is_absolute():
|
||||
# Relative to the project root (two levels above this file)
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
p = project_root / p
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
_db_path = p
|
||||
return p
|
||||
|
||||
|
||||
def _get_conn() -> sqlite3.Connection:
|
||||
"""Return a thread-local connection to the auth database."""
|
||||
import sqlite3
|
||||
|
||||
conn = sqlite3.connect(str(_resolve_path()), check_same_thread=False)
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS link_tokens (
|
||||
token TEXT PRIMARY KEY,
|
||||
discord_user_id INTEGER NOT NULL,
|
||||
service TEXT NOT NULL,
|
||||
expires_at TEXT NOT NULL,
|
||||
used INTEGER DEFAULT 0,
|
||||
created_at TEXT DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS user_auth (
|
||||
discord_user_id INTEGER NOT NULL,
|
||||
service TEXT NOT NULL,
|
||||
external_user_id TEXT,
|
||||
external_name TEXT,
|
||||
credentials TEXT,
|
||||
linked_at TEXT DEFAULT (datetime('now')),
|
||||
is_active INTEGER DEFAULT 1,
|
||||
PRIMARY KEY (discord_user_id, service)
|
||||
);
|
||||
"""
|
||||
|
||||
_initialized = False
|
||||
|
||||
|
||||
def _ensure_schema() -> None:
|
||||
global _initialized
|
||||
if _initialized:
|
||||
return
|
||||
with _db_lock:
|
||||
if _initialized:
|
||||
return
|
||||
conn = _get_conn()
|
||||
conn.executescript(_SCHEMA)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
_initialized = True
|
||||
logger.info("Auth store initialized at %s", _resolve_path())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API — Link Tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_token(discord_user_id: int, service: str) -> str:
|
||||
"""Generate a one-time link token. Expires after TOKEN_EXPIRY_MINUTES."""
|
||||
_ensure_schema()
|
||||
token = secrets.token_urlsafe(32)
|
||||
expires = (datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRY_MINUTES)).isoformat()
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"INSERT INTO link_tokens (token, discord_user_id, service, expires_at) VALUES (?, ?, ?, ?)",
|
||||
(token, discord_user_id, service, expires),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
logger.info("Created link token for user %s / service %s", discord_user_id, service)
|
||||
return token
|
||||
|
||||
|
||||
def validate_token(token: str) -> tuple[int, str] | None:
|
||||
"""Read-only validation — does NOT consume the token.
|
||||
|
||||
Returns (discord_user_id, service) if the token exists, is unused,
|
||||
and has not expired. Returns None otherwise.
|
||||
"""
|
||||
_ensure_schema()
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?",
|
||||
(token,),
|
||||
).fetchone()
|
||||
conn.close()
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
if row["used"]:
|
||||
return None
|
||||
|
||||
expires = datetime.fromisoformat(row["expires_at"])
|
||||
if datetime.now(timezone.utc) > expires:
|
||||
return None
|
||||
|
||||
return (row["discord_user_id"], row["service"])
|
||||
|
||||
|
||||
def consume_token(token: str) -> tuple[int, str] | None:
|
||||
"""Validate and consume a link token. Returns (discord_user_id, service) or None.
|
||||
|
||||
A token is valid if:
|
||||
- It exists
|
||||
- It has not been used
|
||||
- It has not expired
|
||||
"""
|
||||
_ensure_schema()
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?",
|
||||
(token,),
|
||||
).fetchone()
|
||||
|
||||
if row is None:
|
||||
conn.close()
|
||||
return None
|
||||
|
||||
if row["used"]:
|
||||
conn.close()
|
||||
logger.warning("Token already used: %s…", token[:8])
|
||||
return None
|
||||
|
||||
expires = datetime.fromisoformat(row["expires_at"])
|
||||
if datetime.now(timezone.utc) > expires:
|
||||
conn.close()
|
||||
logger.warning("Token expired: %s…", token[:8])
|
||||
return None
|
||||
|
||||
conn.execute("UPDATE link_tokens SET used = 1 WHERE token = ?", (token,))
|
||||
conn.commit()
|
||||
result = (row["discord_user_id"], row["service"])
|
||||
conn.close()
|
||||
logger.info("Token consumed: %s… → user=%s service=%s", token[:8], result[0], result[1])
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API — User Auth
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def store_auth(
|
||||
discord_user_id: int,
|
||||
service: str,
|
||||
*,
|
||||
external_user_id: str = "",
|
||||
external_name: str = "",
|
||||
credentials: dict | None = None,
|
||||
) -> None:
|
||||
"""Store or update authentication for a user on a service."""
|
||||
_ensure_schema()
|
||||
import json
|
||||
|
||||
creds_json = json.dumps(credentials) if credentials else "{}"
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"""INSERT INTO user_auth (discord_user_id, service, external_user_id, external_name, credentials, linked_at)
|
||||
VALUES (?, ?, ?, ?, ?, datetime('now'))
|
||||
ON CONFLICT(discord_user_id, service) DO UPDATE SET
|
||||
external_user_id = excluded.external_user_id,
|
||||
external_name = excluded.external_name,
|
||||
credentials = excluded.credentials,
|
||||
linked_at = datetime('now'),
|
||||
is_active = 1""",
|
||||
(discord_user_id, service, external_user_id, external_name, creds_json),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
logger.info("Stored auth for user %s on %s as %s", discord_user_id, service, external_name)
|
||||
|
||||
|
||||
def get_auth(discord_user_id: int, service: str) -> dict | None:
|
||||
"""Retrieve stored auth for a user on a service. Returns None if not linked."""
|
||||
_ensure_schema()
|
||||
import json
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT * FROM user_auth WHERE discord_user_id = ? AND service = ? AND is_active = 1",
|
||||
(discord_user_id, service),
|
||||
).fetchone()
|
||||
conn.close()
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
credentials = json.loads(row["credentials"]) if row["credentials"] else {}
|
||||
return {
|
||||
"discord_user_id": row["discord_user_id"],
|
||||
"service": row["service"],
|
||||
"external_user_id": row["external_user_id"],
|
||||
"external_name": row["external_name"],
|
||||
"credentials": credentials,
|
||||
"linked_at": row["linked_at"],
|
||||
}
|
||||
|
||||
|
||||
def is_authenticated(discord_user_id: int, service: str) -> bool:
|
||||
"""Quick check: is this user linked to this service?"""
|
||||
return get_auth(discord_user_id, service) is not None
|
||||
|
||||
|
||||
def list_services(discord_user_id: int) -> list[str]:
|
||||
"""Return list of service names this user has linked."""
|
||||
_ensure_schema()
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT service FROM user_auth WHERE discord_user_id = ? AND is_active = 1",
|
||||
(discord_user_id,),
|
||||
).fetchall()
|
||||
conn.close()
|
||||
|
||||
return [r["service"] for r in rows]
|
||||
|
||||
|
||||
def revoke(discord_user_id: int, service: str) -> None:
|
||||
"""Unlink a user from a service."""
|
||||
_ensure_schema()
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"UPDATE user_auth SET is_active = 0 WHERE discord_user_id = ? AND service = ?",
|
||||
(discord_user_id, service),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
logger.info("Revoked auth for user %s on %s", discord_user_id, service)
|
||||
|
||||
|
||||
def get_all_auths(discord_user_id: int) -> list[dict]:
|
||||
"""
|
||||
Return all active auth records for a Discord user.
|
||||
Each record includes service name, external user id, external name,
|
||||
linked_at timestamp, and the raw credentials (e.g. Jellyfin token + URL).
|
||||
|
||||
Used by the /api/v1/auth/status endpoint so other services can discover
|
||||
linked accounts for a given Discord ID.
|
||||
"""
|
||||
_ensure_schema()
|
||||
import json
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
rows = conn.execute(
|
||||
"""SELECT service, external_user_id, external_name, credentials, linked_at
|
||||
FROM user_auth
|
||||
WHERE discord_user_id = ? AND is_active = 1
|
||||
ORDER BY linked_at DESC""",
|
||||
(discord_user_id,),
|
||||
).fetchall()
|
||||
conn.close()
|
||||
|
||||
results: list[dict] = []
|
||||
for row in rows:
|
||||
creds = {}
|
||||
if row["credentials"]:
|
||||
try:
|
||||
creds = json.loads(row["credentials"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
creds = {}
|
||||
results.append({
|
||||
"service": row["service"],
|
||||
"external_user_id": row["external_user_id"] or "",
|
||||
"external_name": row["external_name"] or "",
|
||||
"linked_at": row["linked_at"] or "",
|
||||
"credentials": creds,
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dev / testing — reset the entire store
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def reset_all() -> None:
|
||||
"""Truncate all auth tables — for development and testing only."""
|
||||
_ensure_schema()
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
conn.execute("DELETE FROM link_tokens")
|
||||
conn.execute("DELETE FROM user_auth")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
logger.warning("Auth store RESET — all tokens and auth records cleared.")
|
||||
@@ -0,0 +1,31 @@
|
||||
from dotenv import load_dotenv
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Load .env from the project root (one level above core/)
|
||||
# ---------------------------------------------------------------------------
|
||||
_env_path = Path(__file__).resolve().parent.parent / ".env"
|
||||
load_dotenv(_env_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# General-purpose config accessor — every skill uses this
|
||||
# ---------------------------------------------------------------------------
|
||||
def get_config(key: str, default: str | None = None) -> str | None:
|
||||
"""Read a value from the environment (loaded from .env)."""
|
||||
return os.getenv(key, default)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM
|
||||
# ---------------------------------------------------------------------------
|
||||
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Seerr (Overseerr / Jellyseerr)
|
||||
# ---------------------------------------------------------------------------
|
||||
SEERR_URL = os.getenv("SEERR_URL", "")
|
||||
SEERR_API_KEY = os.getenv("SEERR_API_KEY", "")
|
||||
SEERR_TIMEOUT = int(os.getenv("SEERR_TIMEOUT", "30"))
|
||||
+245
@@ -0,0 +1,245 @@
|
||||
"""
|
||||
LangGraph agent graph factory.
|
||||
|
||||
Builds a StateGraph with two nodes:
|
||||
- agent_node : calls the LLM (with system prompt + tool definitions)
|
||||
- tool_node : executes tool calls via the existing skill system
|
||||
|
||||
A conditional edge routes tool_calls back to the agent, or ends the run.
|
||||
When a tool fails due to missing authentication, the failure message is
|
||||
relayed to the LLM, which tells the user to use /login.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langgraph.graph import END, StateGraph
|
||||
from openai import OpenAI
|
||||
|
||||
from src.state import AgentState
|
||||
from agents.skills import get_all_tools, execute_tool
|
||||
|
||||
logger = logging.getLogger("graph")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper — map LangChain message type → OpenAI role
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _lc_role_to_openai(msg_type: str) -> str:
|
||||
"""Convert a LangChain message type string to an OpenAI role."""
|
||||
mapping = {"human": "user", "ai": "assistant", "tool": "tool", "system": "system"}
|
||||
return mapping.get(msg_type, "user")
|
||||
|
||||
|
||||
def _langchain_tc_to_openai(tool_calls: list) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert LangChain-format tool_calls (with `name`/`args` at top level)
|
||||
back to OpenAI format (with a nested `function` sub-object).
|
||||
"""
|
||||
result: list[dict[str, Any]] = []
|
||||
for tc in tool_calls:
|
||||
if isinstance(tc, dict):
|
||||
if "function" in tc:
|
||||
result.append(tc)
|
||||
else:
|
||||
# LangChain format: {"name": ..., "args": ..., "id": ...}
|
||||
result.append({
|
||||
"id": tc.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.get("name", ""),
|
||||
"arguments": json.dumps(tc.get("args", {})),
|
||||
},
|
||||
})
|
||||
else:
|
||||
# Pydantic model — dump to dict
|
||||
d = tc.model_dump() if hasattr(tc, "model_dump") else {}
|
||||
if "function" in d:
|
||||
result.append(d)
|
||||
else:
|
||||
result.append({
|
||||
"id": d.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": d.get("name", ""),
|
||||
"arguments": json.dumps(d.get("args", {})),
|
||||
},
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent node — calls the LLM
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_agent_node(
|
||||
client: OpenAI,
|
||||
system_prompt: str,
|
||||
tool_defs: list[dict[str, Any]],
|
||||
model_name: str = "deepseek-chat",
|
||||
):
|
||||
"""
|
||||
Return a callable suitable as a LangGraph node.
|
||||
|
||||
The node reads the current message list from state, prepends the system
|
||||
prompt, and calls the LLM. If tool_defs is non-empty the LLM may return
|
||||
tool_calls; ToolNode (or our custom tool node) will handle them.
|
||||
"""
|
||||
|
||||
def agent_node(state: AgentState) -> dict[str, list]:
|
||||
messages = state["messages"]
|
||||
|
||||
# Convert LangChain message objects to plain dicts for the OpenAI client.
|
||||
full: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
|
||||
for m in messages:
|
||||
if isinstance(m, dict):
|
||||
d = dict(m)
|
||||
tc = d.get("tool_calls")
|
||||
if tc and isinstance(tc, list) and tc and isinstance(tc[0], dict) and "function" not in tc[0]:
|
||||
d["tool_calls"] = _langchain_tc_to_openai(tc)
|
||||
full.append(d)
|
||||
else:
|
||||
role = _lc_role_to_openai(getattr(m, "type", "user"))
|
||||
d: dict[str, Any] = {"role": role, "content": getattr(m, "content", "")}
|
||||
tc = getattr(m, "tool_calls", None)
|
||||
if tc:
|
||||
d["tool_calls"] = _langchain_tc_to_openai(tc)
|
||||
tc_id = getattr(m, "tool_call_id", None)
|
||||
if tc_id:
|
||||
d["tool_call_id"] = tc_id
|
||||
full.append(d)
|
||||
|
||||
resp = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=full,
|
||||
tools=tool_defs if tool_defs else None,
|
||||
tool_choice="auto" if tool_defs else None,
|
||||
)
|
||||
choice = resp.choices[0]
|
||||
|
||||
raw_tool_calls = list(choice.message.tool_calls) if choice.message.tool_calls else []
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
for tc in raw_tool_calls:
|
||||
fn = tc.function
|
||||
tool_calls.append({
|
||||
"name": fn.name,
|
||||
"args": json.loads(fn.arguments),
|
||||
"id": tc.id,
|
||||
})
|
||||
ai_msg = AIMessage(
|
||||
content=choice.message.content or "",
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
id=getattr(choice.message, "id", None),
|
||||
)
|
||||
return {"messages": [ai_msg]}
|
||||
|
||||
return agent_node
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool node — executes tools via the existing skill system
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_tool_node(skill_names: list[str]):
|
||||
"""
|
||||
Return a callable that executes tool_calls from the last AI message.
|
||||
|
||||
If a tool fails because the user isn't authenticated, the failure
|
||||
message (which tells the user to /login) is returned to the LLM.
|
||||
The LLM naturally relays the instructions to the user.
|
||||
"""
|
||||
|
||||
async def tool_node(state: AgentState) -> dict[str, list]:
|
||||
last_msg = state["messages"][-1]
|
||||
tool_calls = getattr(last_msg, "tool_calls", None)
|
||||
if not tool_calls:
|
||||
return {"messages": []}
|
||||
|
||||
discord_user_id = state.get("discord_user_id")
|
||||
|
||||
results: list[ToolMessage] = []
|
||||
for tc in tool_calls:
|
||||
if isinstance(tc, dict):
|
||||
if "function" in tc:
|
||||
fn = tc["function"]
|
||||
fn_name = fn.get("name", "")
|
||||
fn_args_raw = fn.get("arguments", "{}")
|
||||
else:
|
||||
fn_name = tc.get("name", "")
|
||||
fn_args_raw = tc.get("args", {})
|
||||
tc_id = tc.get("id", "")
|
||||
else:
|
||||
fn_name = getattr(tc, "name", "")
|
||||
fn_args_raw = getattr(tc, "args", {})
|
||||
tc_id = getattr(tc, "id", "")
|
||||
|
||||
if isinstance(fn_args_raw, str):
|
||||
fn_args = json.loads(fn_args_raw)
|
||||
else:
|
||||
fn_args = fn_args_raw
|
||||
|
||||
tr = await execute_tool(
|
||||
skill_names, fn_name, fn_args,
|
||||
discord_user_id=discord_user_id,
|
||||
)
|
||||
content = tr.content if tr else f"Tool '{fn_name}' is not available."
|
||||
results.append(ToolMessage(content=content, tool_call_id=tc_id))
|
||||
|
||||
return {"messages": results}
|
||||
|
||||
return tool_node
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Router — decides whether to continue tool-calling or stop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _should_continue(state: AgentState) -> Literal["tool_node", END]:
|
||||
"""If the last message contains tool_calls → execute them, else finish."""
|
||||
last_msg = state["messages"][-1]
|
||||
if getattr(last_msg, "tool_calls", None):
|
||||
return "tool_node"
|
||||
return END
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Graph factory — the public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_agent_graph(
|
||||
*,
|
||||
client: OpenAI,
|
||||
agent_skills: list[str],
|
||||
system_prompt: str,
|
||||
model_name: str = "deepseek-chat",
|
||||
) -> StateGraph:
|
||||
"""
|
||||
Build and compile a LangGraph StateGraph for a single agent.
|
||||
"""
|
||||
tool_defs = get_all_tools(agent_skills)
|
||||
|
||||
graph = StateGraph(AgentState)
|
||||
|
||||
graph.add_node(
|
||||
"agent_node",
|
||||
_make_agent_node(client, system_prompt, tool_defs, model_name),
|
||||
)
|
||||
|
||||
if tool_defs:
|
||||
graph.add_node("tool_node", _make_tool_node(agent_skills))
|
||||
graph.add_conditional_edges("agent_node", _should_continue, {
|
||||
"tool_node": "tool_node",
|
||||
END: END,
|
||||
})
|
||||
graph.add_edge("tool_node", "agent_node")
|
||||
else:
|
||||
graph.add_edge("agent_node", END)
|
||||
|
||||
graph.set_entry_point("agent_node")
|
||||
|
||||
return graph.compile()
|
||||
@@ -0,0 +1,9 @@
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
def create_client(api_key: str) -> OpenAI:
|
||||
"""Factory for an OpenAI-compatible client pointed at DeepSeek."""
|
||||
return OpenAI(
|
||||
api_key=api_key,
|
||||
base_url="https://api.deepseek.com",
|
||||
)
|
||||
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
LangGraph agent state — defines the shape of the state object that flows
|
||||
through every node in the agent graph.
|
||||
"""
|
||||
|
||||
from typing import Annotated, TypedDict
|
||||
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
|
||||
class AgentState(TypedDict):
|
||||
"""
|
||||
The single source of truth that travels through every node in the graph.
|
||||
|
||||
`messages` uses LangGraph's `add_messages` reducer, which:
|
||||
- Appends new messages to the list.
|
||||
- Replaces messages with the same ID (useful for tool-call results).
|
||||
"""
|
||||
|
||||
messages: Annotated[list, add_messages]
|
||||
discord_user_id: int | None # set by the Discord bot, None for REST API calls
|
||||
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Tools adapter — bridges the existing skill/tool system with LangGraph's ToolNode.
|
||||
|
||||
LangGraph's ToolNode expects callable tools (typically @tool-decorated functions).
|
||||
This module wraps our skill-based tool definitions and async executors so
|
||||
ToolNode can invoke them without any changes to the skills/ layer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from agents.skills import get_all_tools, execute_tool
|
||||
|
||||
|
||||
def build_langgraph_tools(skill_names: list[str]) -> list:
|
||||
"""
|
||||
Convert the registered skill tool definitions into LangChain-compatible
|
||||
@tool-decorated functions that ToolNode can call.
|
||||
|
||||
Each tool wraps the existing `execute_tool()` pipeline, so the skill
|
||||
system's ToolResult + httpx session handling is fully preserved.
|
||||
"""
|
||||
tool_defs = get_all_tools(skill_names)
|
||||
wrapped: list = []
|
||||
|
||||
for td in tool_defs:
|
||||
fn_def = td.get("function", {})
|
||||
fn_name = fn_def.get("name", "")
|
||||
fn_desc = fn_def.get("description", "")
|
||||
|
||||
# Create a unique factory so each closure captures the right fn_name
|
||||
def _make_tool(name: str, desc: str, skills: list[str]):
|
||||
@tool(name, description=desc)
|
||||
async def _wrapped(**kwargs: Any) -> str:
|
||||
"""Execute the tool via the skill system and return its content."""
|
||||
result = await execute_tool(skills, name, kwargs)
|
||||
if result is None:
|
||||
return f"Tool '{name}' is not available."
|
||||
return result.content
|
||||
|
||||
# Stash the original OpenAI schema so LangGraph can use it
|
||||
_wrapped.metadata = fn_def
|
||||
return _wrapped
|
||||
|
||||
wrapped.append(_make_tool(fn_name, fn_desc, skill_names))
|
||||
|
||||
return wrapped
|
||||
@@ -1,19 +0,0 @@
|
||||
name: Gitea Actions Demo
|
||||
run-name: ${{ gitea.actor }} is testing out Gitea Actions 🚀
|
||||
on: [push]
|
||||
|
||||
jobs:
|
||||
Explore-Gitea-Actions:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- run: echo "🎉 The job was automatically triggered by a ${{ gitea.event_name }} event."
|
||||
- run: echo "🐧 This job is now running on a ${{ runner.os }} server hosted by Gitea!"
|
||||
- run: echo "🔎 The name of your branch is ${{ gitea.ref }} and your repository is ${{ gitea.repository }}."
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
- run: echo "💡 The ${{ gitea.repository }} repository has been cloned to the runner."
|
||||
- run: echo "🖥️ The workflow is now ready to test your code on the runner."
|
||||
- name: List files in the repository
|
||||
run: |
|
||||
ls ${{ gitea.workspace }}
|
||||
- run: echo "🍏 This job's status is ${{ job.status }}."
|
||||
Reference in New Issue
Block a user