376 lines
15 KiB
Python
376 lines
15 KiB
Python
"""
|
||
Discord bot that connects users to the LangGraph agent via private messages.
|
||
|
||
Architecture
|
||
------------
|
||
- The bot runs in-process alongside FastAPI (on a background asyncio task).
|
||
- Private messages (DMs) are routed through the same LangGraph graphs that
|
||
power the REST API — no HTTP loopback needed.
|
||
- Per-user conversation history is maintained so the LLM has context.
|
||
|
||
Environment
|
||
-----------
|
||
DISCORD_BOT_TOKEN – the bot token from the Discord Developer Portal
|
||
DISCORD_MAX_HISTORY – how many past messages to keep per user (default 7)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
import os
|
||
|
||
import discord
|
||
|
||
from agents import list_all as list_all_agents
|
||
from gateway.discord.conversation import ConversationStore
|
||
from src.config import DEEPSEEK_API_KEY, get_config
|
||
from src.graph import create_agent_graph
|
||
from src.llm import create_client
|
||
from src import auth_store
|
||
from gateway.auth import list_auth_services, get_auth_service
|
||
|
||
logger = logging.getLogger("bot.discord")
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Config
|
||
# ---------------------------------------------------------------------------
|
||
DISCORD_BOT_TOKEN = get_config("DISCORD_BOT_TOKEN") or ""
|
||
DISCORD_MAX_HISTORY = int(get_config("DISCORD_MAX_HISTORY", "7"))
|
||
DISCORD_DEFAULT_AGENT = get_config("DISCORD_DEFAULT_AGENT", "media-agent")
|
||
BASE_URL = get_config("BASE_URL", "http://localhost:8000").rstrip("/")
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# LLM client shared by all agents (same as the REST API uses)
|
||
# ---------------------------------------------------------------------------
|
||
_llm_client = create_client(DEEPSEEK_API_KEY)
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Conversation store — one per process
|
||
# ---------------------------------------------------------------------------
|
||
_conversations = ConversationStore(max_history=DISCORD_MAX_HISTORY)
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Graph cache — lazy-compiled per agent, same pattern as api/dependencies.py
|
||
# ---------------------------------------------------------------------------
|
||
_agent_graphs: dict[str, object] = {}
|
||
|
||
|
||
def _get_graph(agent_id: str):
|
||
"""Return a compiled LangGraph for *agent_id*, building it on first use."""
|
||
if agent_id not in _agent_graphs:
|
||
agents = list_all_agents()
|
||
agent = agents.get(agent_id, agents.get("naked"))
|
||
_agent_graphs[agent_id] = create_agent_graph(
|
||
client=_llm_client,
|
||
agent_skills=agent.skills if agent else [],
|
||
system_prompt=agent.build_system_prompt() if agent else (
|
||
"You are a helpful, general-purpose assistant."
|
||
),
|
||
)
|
||
return _agent_graphs[agent_id]
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Discord client
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class AgentBot(discord.Client):
|
||
"""A discord.py Client that connects users to the LangGraph agent."""
|
||
|
||
def __init__(self) -> None:
|
||
# message_content lets us read DM text.
|
||
# guilds is required so that mutual_guilds is populated — without it
|
||
# every DM is silently ignored.
|
||
intents = discord.Intents.default()
|
||
intents.message_content = True
|
||
intents.guilds = True
|
||
super().__init__(intents=intents)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Lifecycle
|
||
# ------------------------------------------------------------------
|
||
|
||
async def on_ready(self) -> None:
|
||
logger.info("Bot logged in as %s (ID %s)", self.user, self.user.id)
|
||
# Print a ready banner so the dev knows it's alive
|
||
print(f"\n🤖 Discord bot ready — logged in as {self.user}\n")
|
||
|
||
# ------------------------------------------------------------------
|
||
# Shared-guild helper — uses the REST API, no privileged intents
|
||
# ------------------------------------------------------------------
|
||
|
||
async def _shares_guild(self, user: discord.User) -> bool:
|
||
"""Return True if *user* and the bot share at least one guild."""
|
||
for guild in self.guilds:
|
||
try:
|
||
member = await guild.fetch_member(user.id)
|
||
if member is not None:
|
||
return True
|
||
except (discord.NotFound, discord.Forbidden, discord.HTTPException):
|
||
continue
|
||
return False
|
||
|
||
# ------------------------------------------------------------------
|
||
# Message handler — DMs only
|
||
# ------------------------------------------------------------------
|
||
|
||
async def on_message(self, message: discord.Message) -> None:
|
||
# Never reply to ourselves
|
||
if message.author == self.user:
|
||
return
|
||
|
||
# |-- DM channel only for now ----------------------------------|
|
||
if not isinstance(message.channel, discord.DMChannel):
|
||
logger.debug("Ignoring message from #%s (not a DM)", message.channel)
|
||
return
|
||
# |--------------------------------------------------------------|
|
||
|
||
# |-- Shared-server gate — only users who share at least one --|
|
||
# | guild with the bot can interact via DM. --|
|
||
# | We use fetch_member (REST API) instead of --|
|
||
# | User.mutual_guilds because the latter requires the --|
|
||
# | privileged "members" intent. This way no privileged --|
|
||
# | intents are needed. --|
|
||
if not await self._shares_guild(message.author):
|
||
logger.warning(
|
||
"Blocking DM from %s — no mutual guilds.",
|
||
message.author.name,
|
||
)
|
||
return
|
||
# |--------------------------------------------------------------|
|
||
|
||
user_id = message.author.id
|
||
content = message.content.strip()
|
||
|
||
# |-- Bot commands — handled directly, never sent to the LLM --|
|
||
if await self._handle_command(message, user_id, content):
|
||
return
|
||
# |--------------------------------------------------------------|
|
||
|
||
# Show typing indicator while the graph runs
|
||
async with message.channel.typing():
|
||
try:
|
||
reply = await self._run_agent(
|
||
user_id=user_id,
|
||
user_msg=message.content,
|
||
)
|
||
await message.channel.send(reply)
|
||
except Exception:
|
||
logger.exception("Agent run failed for user %s", user_id)
|
||
await message.channel.send(
|
||
"Sorry, something went wrong processing your request. "
|
||
"Please try again in a moment."
|
||
)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Bot commands
|
||
# ------------------------------------------------------------------
|
||
|
||
async def _handle_command(
|
||
self, message: discord.Message, user_id: int, content: str
|
||
) -> bool:
|
||
"""Handle bot commands (/login, /logout). Returns True if handled."""
|
||
lower = content.lower()
|
||
|
||
# --- /login [service] ---
|
||
if lower.startswith("/login"):
|
||
parts = content.split()
|
||
service = parts[1].lower() if len(parts) > 1 else None
|
||
|
||
available = list_auth_services()
|
||
if not available:
|
||
await message.channel.send("No auth services are configured yet.")
|
||
return True
|
||
|
||
if service is None:
|
||
svc_list = ", ".join(available)
|
||
await message.channel.send(
|
||
f"Available services to link: **{svc_list}**\n"
|
||
f"Use `/login <service>` — e.g. `/login jellyfin`"
|
||
)
|
||
return True
|
||
|
||
if service not in available:
|
||
await message.channel.send(
|
||
f"Unknown service '{service}'. Available: {', '.join(available)}"
|
||
)
|
||
return True
|
||
|
||
if auth_store.is_authenticated(user_id, service):
|
||
svc_display = (get_auth_service(service) and get_auth_service(service).display_name) or service
|
||
await message.channel.send(
|
||
f"You're already linked to **{svc_display}**! "
|
||
f"Use `/logout {service}` to unlink."
|
||
)
|
||
return True
|
||
|
||
# --- Quick Connect flow ---
|
||
svc = get_auth_service(service)
|
||
|
||
await message.channel.send(f"🔑 Starting **{svc.display_name}** Quick Connect…")
|
||
|
||
qc_result = await svc.initiate_quick_connect()
|
||
if qc_result is None:
|
||
await message.channel.send(
|
||
f"❌ Could not start Quick Connect for **{svc.display_name}**.\n"
|
||
"Check that `JELLYFIN_URL` is configured and the server is reachable."
|
||
)
|
||
return True
|
||
|
||
await message.channel.send(
|
||
f"Open **{svc.display_name}** → **Quick Connect** and enter this code:\n\n"
|
||
f"**`{qc_result.code}`**\n\n"
|
||
f"⏳ Waiting for you to approve…"
|
||
)
|
||
|
||
# Poll for authorization
|
||
async with message.channel.typing():
|
||
for attempt in range(24): # 24 × 5s = 2 minutes
|
||
await asyncio.sleep(5)
|
||
status = await svc.poll_quick_connect(qc_result.secret)
|
||
|
||
if status == "Authorized":
|
||
auth_result = await svc.authenticate_quick_connect(qc_result.secret)
|
||
if auth_result.success:
|
||
auth_store.store_auth(
|
||
discord_user_id=user_id,
|
||
service=service,
|
||
external_user_id=auth_result.external_user_id or "",
|
||
external_name=auth_result.external_name or "",
|
||
credentials=auth_result.credentials,
|
||
)
|
||
await message.channel.send(
|
||
f"✅ Linked to **{svc.display_name}** as "
|
||
f"**{auth_result.external_name}**!"
|
||
)
|
||
else:
|
||
await message.channel.send(
|
||
f"❌ Authentication failed: "
|
||
f"{auth_result.error_message or 'Unknown error'}"
|
||
)
|
||
return True
|
||
|
||
elif status == "Expired":
|
||
await message.channel.send(
|
||
"⌛ The Quick Connect code expired. "
|
||
f"Use `/login {service}` to try again."
|
||
)
|
||
return True
|
||
|
||
# else: still "Active" — keep polling
|
||
|
||
await message.channel.send(
|
||
"⌛ Timed out waiting for Quick Connect approval. "
|
||
f"Use `/login {service}` to try again."
|
||
)
|
||
return True
|
||
|
||
# --- /logout [service] ---
|
||
if lower.startswith("/logout"):
|
||
parts = content.split()
|
||
service = parts[1].lower() if len(parts) > 1 else None
|
||
|
||
if service is None:
|
||
linked = auth_store.list_services(user_id)
|
||
if not linked:
|
||
await message.channel.send("You don't have any linked services.")
|
||
else:
|
||
svc_list = ", ".join(linked)
|
||
await message.channel.send(
|
||
f"Linked services: **{svc_list}**\n"
|
||
f"Use `/logout <service>` to unlink."
|
||
)
|
||
return True
|
||
|
||
if not auth_store.is_authenticated(user_id, service):
|
||
await message.channel.send(f"You're not linked to **{service}**.")
|
||
return True
|
||
|
||
auth_store.revoke(user_id, service)
|
||
svc_display = (get_auth_service(service) and get_auth_service(service).display_name) or service
|
||
await message.channel.send(f"Unlinked from **{svc_display}**. Use `/login {service}` to re-link.")
|
||
return True
|
||
|
||
return False
|
||
|
||
# ------------------------------------------------------------------
|
||
# Agent invocation
|
||
# ------------------------------------------------------------------
|
||
|
||
async def _run_agent(self, *, user_id: int, user_msg: str) -> str:
|
||
"""Build the message list from history, invoke the graph, store the
|
||
reply, and return the assistant's final text."""
|
||
|
||
# 1. Pick agent — defaults to DISCORD_DEFAULT_AGENT env var.
|
||
agent_id = DISCORD_DEFAULT_AGENT
|
||
|
||
# 2. Build message list from stored history + new user message
|
||
history = _conversations.get_history(user_id)
|
||
messages = [*history, {"role": "user", "content": user_msg}]
|
||
|
||
# 3. Run the LangGraph (tools execute inline if needed)
|
||
graph = _get_graph(agent_id)
|
||
state = {"messages": messages, "discord_user_id": user_id}
|
||
result = await graph.ainvoke(state)
|
||
|
||
last_msg = result["messages"][-1]
|
||
reply = last_msg.content or ""
|
||
|
||
# 4. Persist the conversation
|
||
_conversations.append(user_id, user_msg, reply)
|
||
|
||
return reply
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Bootstrap helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _start_bot_sync(token: str) -> None:
|
||
"""Synchronous entry-point that runs the bot in a new asyncio event loop.
|
||
|
||
Called from a background thread so the main thread can keep running the
|
||
FastAPI / uvicorn server.
|
||
"""
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
|
||
async def _run() -> None:
|
||
bot = AgentBot()
|
||
try:
|
||
await bot.start(token)
|
||
except discord.LoginFailure:
|
||
logger.error(
|
||
"Discord login failed — check DISCORD_BOT_TOKEN in your .env file."
|
||
)
|
||
except Exception:
|
||
logger.exception("Unhandled exception in bot event loop.")
|
||
|
||
loop.run_until_complete(_run())
|
||
|
||
|
||
def start_in_background(token: str | None = None) -> None:
|
||
"""Launch the Discord bot in a daemon thread.
|
||
|
||
Pass *token* explicitly if you already have it; otherwise it is read
|
||
from the DISCORD_BOT_TOKEN env variable.
|
||
"""
|
||
token = token or DISCORD_BOT_TOKEN
|
||
if not token:
|
||
logger.warning(
|
||
"DISCORD_BOT_TOKEN is not set — Discord bot will NOT start."
|
||
)
|
||
return
|
||
|
||
import threading
|
||
|
||
t = threading.Thread(
|
||
target=_start_bot_sync,
|
||
args=(token,),
|
||
daemon=True,
|
||
name="discord-bot",
|
||
)
|
||
t.start()
|
||
logger.info("Discord bot thread started.")
|