diff --git a/.env.example b/.env.example index b072ac9..fc20b32 100644 --- a/.env.example +++ b/.env.example @@ -6,6 +6,13 @@ # LLM — DeepSeek (OpenAI-compatible) DEEPSEEK_API_KEY=sk-your-deepseek-api-key +# --------------------------------------------------------------------------- +# Discord Bot +# --------------------------------------------------------------------------- +DISCORD_BOT_TOKEN=your-discord-bot-token-here +# DISCORD_MAX_HISTORY=7 # optional, defaults to 7 (max past messages per user) +# DISCORD_DEFAULT_AGENT=media-agent # optional, which agent the DM bot uses + # --------------------------------------------------------------------------- # Seerr (Overseerr / Jellyseerr) # --------------------------------------------------------------------------- diff --git a/bot/__init__.py b/bot/__init__.py new file mode 100644 index 0000000..4872dd6 --- /dev/null +++ b/bot/__init__.py @@ -0,0 +1 @@ +# Discord bot package diff --git a/bot/conversation.py b/bot/conversation.py new file mode 100644 index 0000000..532dfb2 --- /dev/null +++ b/bot/conversation.py @@ -0,0 +1,61 @@ +""" +Per-user conversation history store. + +Each Discord user gets their own isolated message list. Only the last +`max_history` messages are kept — older ones are silently dropped so the +LLM context stays small. +""" + +from __future__ import annotations + +import logging +from typing import Dict, List + +logger = logging.getLogger("bot.conversation") + +# role we assign to user messages inside the OpenAI-style message list +_USER_ROLE = "user" +# role we assign to bot responses +_ASSISTANT_ROLE = "assistant" + + +class ConversationStore: + """Thread-safe-ish in-memory store keyed by Discord user ID (int).""" + + def __init__(self, max_history: int = 7) -> None: + self._max = max_history + self._store: Dict[int, List[dict]] = {} + + # ------------------------------------------------------------------ + # public API + # ------------------------------------------------------------------ + + def get_history(self, user_id: int) -> list[dict]: + """Return the last *max_history* messages for *user_id*.""" + return list(self._store.get(user_id, [])) + + def append(self, user_id: int, user_msg: str, assistant_reply: str) -> None: + """Store the user message + assistant reply, then trim to max.""" + if user_id not in self._store: + self._store[user_id] = [] + + history = self._store[user_id] + history.append({"role": _USER_ROLE, "content": user_msg}) + history.append({"role": _ASSISTANT_ROLE, "content": assistant_reply}) + + # Trim oldest messages if we exceeded the limit + while len(history) > self._max: + history.pop(0) + + def clear(self, user_id: int) -> None: + """Wipe the conversation for a user.""" + self._store.pop(user_id, None) + logger.info("Cleared conversation for user %s", user_id) + + # ------------------------------------------------------------------ + # debug / introspection + # ------------------------------------------------------------------ + + @property + def user_count(self) -> int: + return len(self._store) diff --git a/bot/discord_bot.py b/bot/discord_bot.py new file mode 100644 index 0000000..57b1d4b --- /dev/null +++ b/bot/discord_bot.py @@ -0,0 +1,236 @@ +""" +Discord bot that connects users to the LangGraph agent via private messages. + +Architecture +------------ +- The bot runs in-process alongside FastAPI (on a background asyncio task). +- Private messages (DMs) are routed through the same LangGraph graphs that + power the REST API — no HTTP loopback needed. +- Per-user conversation history is maintained so the LLM has context. + +Environment +----------- +DISCORD_BOT_TOKEN – the bot token from the Discord Developer Portal +DISCORD_MAX_HISTORY – how many past messages to keep per user (default 7) +""" + +from __future__ import annotations + +import asyncio +import logging +import os + +import discord + +from agents import list_all as list_all_agents +from bot.conversation import ConversationStore +from core.config import DEEPSEEK_API_KEY, get_config +from core.graph import create_agent_graph +from core.llm import create_client + +logger = logging.getLogger("bot.discord") + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- +DISCORD_BOT_TOKEN = get_config("DISCORD_BOT_TOKEN") or "" +DISCORD_MAX_HISTORY = int(get_config("DISCORD_MAX_HISTORY", "7")) +DISCORD_DEFAULT_AGENT = get_config("DISCORD_DEFAULT_AGENT", "media-agent") + +# --------------------------------------------------------------------------- +# LLM client shared by all agents (same as the REST API uses) +# --------------------------------------------------------------------------- +_llm_client = create_client(DEEPSEEK_API_KEY) + +# --------------------------------------------------------------------------- +# Conversation store — one per process +# --------------------------------------------------------------------------- +_conversations = ConversationStore(max_history=DISCORD_MAX_HISTORY) + +# --------------------------------------------------------------------------- +# Graph cache — lazy-compiled per agent, same pattern as api/dependencies.py +# --------------------------------------------------------------------------- +_agent_graphs: dict[str, object] = {} + + +def _get_graph(agent_id: str): + """Return a compiled LangGraph for *agent_id*, building it on first use.""" + if agent_id not in _agent_graphs: + agents = list_all_agents() + agent = agents.get(agent_id, agents.get("naked")) + _agent_graphs[agent_id] = create_agent_graph( + client=_llm_client, + agent_skills=agent.skills if agent else [], + system_prompt=agent.build_system_prompt() if agent else ( + "You are a helpful, general-purpose assistant." + ), + ) + return _agent_graphs[agent_id] + + +# --------------------------------------------------------------------------- +# Discord client +# --------------------------------------------------------------------------- + +class AgentBot(discord.Client): + """A discord.py Client that connects users to the LangGraph agent.""" + + def __init__(self) -> None: + # message_content lets us read DM text. + # guilds is required so that mutual_guilds is populated — without it + # every DM is silently ignored. + intents = discord.Intents.default() + intents.message_content = True + intents.guilds = True + super().__init__(intents=intents) + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + async def on_ready(self) -> None: + logger.info("Bot logged in as %s (ID %s)", self.user, self.user.id) + # Print a ready banner so the dev knows it's alive + print(f"\n🤖 Discord bot ready — logged in as {self.user}\n") + + # ------------------------------------------------------------------ + # Shared-guild helper — uses the REST API, no privileged intents + # ------------------------------------------------------------------ + + async def _shares_guild(self, user: discord.User) -> bool: + """Return True if *user* and the bot share at least one guild.""" + for guild in self.guilds: + try: + member = await guild.fetch_member(user.id) + if member is not None: + return True + except (discord.NotFound, discord.Forbidden, discord.HTTPException): + continue + return False + + # ------------------------------------------------------------------ + # Message handler — DMs only + # ------------------------------------------------------------------ + + async def on_message(self, message: discord.Message) -> None: + # Never reply to ourselves + if message.author == self.user: + return + + # |-- DM channel only for now ----------------------------------| + if not isinstance(message.channel, discord.DMChannel): + logger.debug("Ignoring message from #%s (not a DM)", message.channel) + return + # |--------------------------------------------------------------| + + # |-- Shared-server gate — only users who share at least one --| + # | guild with the bot can interact via DM. --| + # | We use fetch_member (REST API) instead of --| + # | User.mutual_guilds because the latter requires the --| + # | privileged "members" intent. This way no privileged --| + # | intents are needed. --| + if not await self._shares_guild(message.author): + logger.warning( + "Blocking DM from %s — no mutual guilds.", + message.author.name, + ) + return + # |--------------------------------------------------------------| + + user_id = message.author.id + + # Show typing indicator while the graph runs + async with message.channel.typing(): + try: + reply = await self._run_agent( + user_id=user_id, + user_msg=message.content, + ) + await message.channel.send(reply) + except Exception: + logger.exception("Agent run failed for user %s", user_id) + await message.channel.send( + "Sorry, something went wrong processing your request. " + "Please try again in a moment." + ) + + # ------------------------------------------------------------------ + # Agent invocation + # ------------------------------------------------------------------ + + async def _run_agent(self, *, user_id: int, user_msg: str) -> str: + """Build the message list from history, invoke the graph, store the + reply, and return the assistant's final text.""" + + # 1. Pick agent — defaults to DISCORD_DEFAULT_AGENT env var. + # Change DISCORD_DEFAULT_AGENT in .env to switch agents. + agent_id = DISCORD_DEFAULT_AGENT + + # 2. Build message list from stored history + new user message + history = _conversations.get_history(user_id) + messages = [*history, {"role": "user", "content": user_msg}] + + # 3. Run the LangGraph (tools execute inline if needed) + graph = _get_graph(agent_id) + state = {"messages": messages} + result = await graph.ainvoke(state) + + last_msg = result["messages"][-1] + reply = last_msg.content or "" + + # 4. Persist the conversation + _conversations.append(user_id, user_msg, reply) + + return reply + + +# --------------------------------------------------------------------------- +# Bootstrap helpers +# --------------------------------------------------------------------------- + +def _start_bot_sync(token: str) -> None: + """Synchronous entry-point that runs the bot in a new asyncio event loop. + + Called from a background thread so the main thread can keep running the + FastAPI / uvicorn server. + """ + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + async def _run() -> None: + bot = AgentBot() + try: + await bot.start(token) + except discord.LoginFailure: + logger.error( + "Discord login failed — check DISCORD_BOT_TOKEN in your .env file." + ) + except Exception: + logger.exception("Unhandled exception in bot event loop.") + + loop.run_until_complete(_run()) + + +def start_in_background(token: str | None = None) -> None: + """Launch the Discord bot in a daemon thread. + + Pass *token* explicitly if you already have it; otherwise it is read + from the DISCORD_BOT_TOKEN env variable. + """ + token = token or DISCORD_BOT_TOKEN + if not token: + logger.warning( + "DISCORD_BOT_TOKEN is not set — Discord bot will NOT start." + ) + return + + import threading + + t = threading.Thread( + target=_start_bot_sync, + args=(token,), + daemon=True, + name="discord-bot", + ) + t.start() + logger.info("Discord bot thread started.") diff --git a/main.py b/main.py index 1fedf96..ef42ea4 100644 --- a/main.py +++ b/main.py @@ -48,3 +48,13 @@ app.state.agent_graphs: dict = {} # Routers # --------------------------------------------------------------------------- app.include_router(v1_router, prefix="/v1") + +# --------------------------------------------------------------------------- +# Discord bot — launched once on app startup (not at import time, which +# would double-fire under uvicorn --reload). +# --------------------------------------------------------------------------- +@app.on_event("startup") +async def _start_discord_bot() -> None: + from bot.discord_bot import start_in_background # noqa: E402 + + start_in_background() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 5e700b8..d1b0c2f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ uvicorn python-dotenv httpx langgraph -langgraph-checkpoint \ No newline at end of file +langgraph-checkpoint +discord.py \ No newline at end of file diff --git a/skills/seerr.py b/skills/seerr.py index 2127ddd..7cdbc33 100644 --- a/skills/seerr.py +++ b/skills/seerr.py @@ -3,11 +3,11 @@ Seerr skill — connects to Overseerr / Jellyseerr API for media discovery, requests, and issue submission. .env variables: - SEERR_URL – base URL (e.g. https://seerr.example.com) - SEERR_USERNAME – login username (email) - SEERR_PASSWORD – login password - SEERR_API_KEY – fallback API key (used if username/password not set) - SEERR_TIMEOUT – optional request timeout in seconds (default 30) + SEERR_URL - base URL (e.g. https://seerr.example.com) + SEERR_USERNAME - login username (email) + SEERR_PASSWORD - login password + SEERR_API_KEY - fallback API key (used if username/password not set) + SEERR_TIMEOUT - optional request timeout in seconds (default 30) Auth flow: 1. If SEERR_USERNAME + SEERR_PASSWORD are set: @@ -39,54 +39,103 @@ SEERR_TIMEOUT = int(get_config("SEERR_TIMEOUT", "30")) # --------------------------------------------------------------------------- # Auth — cookie-based session (preferred) or API key fallback # --------------------------------------------------------------------------- -_seerr_session: httpx.AsyncClient | None = None +# +# IMPORTANT: httpx.AsyncClient binds internal asyncio primitives to the +# event loop that is current when the client is created. The Discord bot +# runs in a separate thread with its own event loop, so we must create a +# fresh AsyncClient *per event loop*. We cache one client per loop ID so +# each loop still reuses its own singleton (connection pooling works), but +# the bot and the REST API never fight over the same connection pool. +# --------------------------------------------------------------------------- + +import asyncio +import threading + +_seerr_sessions: dict[int, httpx.AsyncClient] = {} +_seerr_sessions_lock = threading.Lock() + +# Cached login cookies — obtained once at module load (sync) and reused +# for every event-loop-specific client. A threading.Event ensures that +# the first caller to trigger the login blocks all other callers until +# the login is complete — preventing a race where a second thread builds +# a client with empty cookies before the login finishes. +_seerr_cookies: dict = {} +_seerr_cookies_ready = threading.Event() +_seerr_cookies_lock = threading.Lock() -def _init_session() -> None: - """Initialise the Seerr session once at module load. - Uses httpx.Client (sync!) for the one-time login, then creates an - async client with the resulting cookies. No async event-loop tricks. +def _ensure_cookies() -> None: + """One-time sync login to get the connect.sid cookie. + + Thread-safe: only one thread performs the login; all others block + until it finishes, then reuse the result. """ - global _seerr_session - - if _seerr_session is not None: + if _seerr_cookies_ready.is_set(): return - cookies: dict = {} + with _seerr_cookies_lock: + # Double-check — another thread may have finished while we waited + if _seerr_cookies_ready.is_set(): + return - if SEERR_USERNAME.strip() and SEERR_PASSWORD.strip(): - # --- Cookie-based auth: login via sync client --- - sync_client = httpx.Client(base_url=SEERR_URL, timeout=SEERR_TIMEOUT) - try: - resp = sync_client.post("/api/v1/auth/jellyfin", json={ - "username": SEERR_USERNAME.strip(), - "password": SEERR_PASSWORD.strip(), - }) - resp.raise_for_status() - cookies = dict(sync_client.cookies) - except httpx.HTTPError: - pass # fall through to API key - finally: - sync_client.close() + if SEERR_USERNAME.strip() and SEERR_PASSWORD.strip(): + sync_client = httpx.Client(base_url=SEERR_URL, timeout=SEERR_TIMEOUT) + try: + resp = sync_client.post("/api/v1/auth/jellyfin", json={ + "username": SEERR_USERNAME.strip(), + "password": SEERR_PASSWORD.strip(), + }) + resp.raise_for_status() + _seerr_cookies.update(dict(sync_client.cookies)) + except httpx.HTTPError: + pass + finally: + sync_client.close() - # Build the async session - if cookies: - _seerr_session = httpx.AsyncClient( + # Signal completion — even if login failed (empty cookies) so we + # don't retry forever. + _seerr_cookies_ready.set() + + +def _build_client() -> httpx.AsyncClient: + """Create a new httpx.AsyncClient for the *current* event loop.""" + if _seerr_cookies: + return httpx.AsyncClient( base_url=SEERR_URL, - cookies=cookies, # ← cookie auth + cookies=_seerr_cookies, timeout=SEERR_TIMEOUT, ) - elif SEERR_API_KEY.strip(): - _seerr_session = httpx.AsyncClient( + if SEERR_API_KEY.strip(): + return httpx.AsyncClient( base_url=SEERR_URL, headers={"X-Api-Key": SEERR_API_KEY.strip()}, timeout=SEERR_TIMEOUT, ) - else: - _seerr_session = httpx.AsyncClient( - base_url=SEERR_URL, - timeout=SEERR_TIMEOUT, - ) + return httpx.AsyncClient( + base_url=SEERR_URL, + timeout=SEERR_TIMEOUT, + ) + + +def _get_session() -> httpx.AsyncClient: + """Return an AsyncClient valid for the currently-running event loop. + + On the very first call the sync login is performed (if credentials are + configured). After that every event loop gets its own cached client. + """ + _ensure_cookies() + + try: + loop_id = id(asyncio.get_running_loop()) + except RuntimeError: + # No event loop running (e.g. called during module import). + # Build a throw-away client — the first real call will recreate it. + loop_id = 0 + + with _seerr_sessions_lock: + if loop_id not in _seerr_sessions: + _seerr_sessions[loop_id] = _build_client() + return _seerr_sessions[loop_id] # --------------------------------------------------------------------------- @@ -95,11 +144,9 @@ def _init_session() -> None: class _SharedClient: - """Wraps the shared httpx.AsyncClient so that `async with` doesn't + """Wraps a per-loop httpx.AsyncClient so that `async with` doesn't close it when the context exits. All 11 call sites use: async with _client() as c: - Without this wrapper, httpx would close the shared session after - the first call, breaking every subsequent tool execution. """ def __init__(self, client: httpx.AsyncClient) -> None: @@ -113,13 +160,11 @@ class _SharedClient: def _client() -> _SharedClient: - """Return a context-manager wrapper around the shared httpx session.""" - assert _seerr_session is not None, "Seerr session not initialised" - return _SharedClient(_seerr_session) + """Return a context-manager wrapper around the current loop's session.""" + return _SharedClient(_get_session()) -# Initialise at import time -_init_session() +# Per-loop sessions are created lazily on first use — no eager init needed. def _fmt_items(items: list[dict], kind: str) -> str: