small refactor of the structure

This commit is contained in:
2026-05-25 12:16:24 +02:00
parent 51e099acdd
commit b0f10b6bb1
26 changed files with 37 additions and 37 deletions
View File
+358
View File
@@ -0,0 +1,358 @@
"""
Auth Store — SQLite-backed persistence for Discord-to-service authentication.
Two tables:
- link_tokens : one-time tokens sent via Discord DM to initiate login
- user_auth : per-user, per-service credentials (Jellyfin token, etc.)
Thread-safe via WAL mode and a shared lock. No passwords are ever stored
— only opaque service tokens (e.g. Jellyfin AccessToken).
"""
from __future__ import annotations
import logging
import secrets
import sqlite3
import threading
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Optional
from src.config import get_config
logger = logging.getLogger("auth_store")
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
AUTH_DB_PATH = get_config("AUTH_DB_PATH", "data/auth.db")
TOKEN_EXPIRY_MINUTES = int(get_config("AUTH_TOKEN_EXPIRY", "10"))
# ---------------------------------------------------------------------------
# Singleton handle
# ---------------------------------------------------------------------------
_db_path: Path | None = None
_db_lock = threading.Lock()
def _resolve_path() -> Path:
"""Turn AUTH_DB_PATH into an absolute path, creating parent dirs."""
global _db_path
if _db_path is not None:
return _db_path
p = Path(AUTH_DB_PATH)
if not p.is_absolute():
# Relative to the project root (two levels above this file)
project_root = Path(__file__).resolve().parent.parent
p = project_root / p
p.parent.mkdir(parents=True, exist_ok=True)
_db_path = p
return p
def _get_conn() -> sqlite3.Connection:
"""Return a thread-local connection to the auth database."""
import sqlite3
conn = sqlite3.connect(str(_resolve_path()), check_same_thread=False)
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
conn.row_factory = sqlite3.Row
return conn
# ---------------------------------------------------------------------------
# Schema
# ---------------------------------------------------------------------------
_SCHEMA = """
CREATE TABLE IF NOT EXISTS link_tokens (
token TEXT PRIMARY KEY,
discord_user_id INTEGER NOT NULL,
service TEXT NOT NULL,
expires_at TEXT NOT NULL,
used INTEGER DEFAULT 0,
created_at TEXT DEFAULT (datetime('now'))
);
CREATE TABLE IF NOT EXISTS user_auth (
discord_user_id INTEGER NOT NULL,
service TEXT NOT NULL,
external_user_id TEXT,
external_name TEXT,
credentials TEXT,
linked_at TEXT DEFAULT (datetime('now')),
is_active INTEGER DEFAULT 1,
PRIMARY KEY (discord_user_id, service)
);
"""
_initialized = False
def _ensure_schema() -> None:
global _initialized
if _initialized:
return
with _db_lock:
if _initialized:
return
conn = _get_conn()
conn.executescript(_SCHEMA)
conn.commit()
conn.close()
_initialized = True
logger.info("Auth store initialized at %s", _resolve_path())
# ---------------------------------------------------------------------------
# Public API — Link Tokens
# ---------------------------------------------------------------------------
def create_token(discord_user_id: int, service: str) -> str:
"""Generate a one-time link token. Expires after TOKEN_EXPIRY_MINUTES."""
_ensure_schema()
token = secrets.token_urlsafe(32)
expires = (datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRY_MINUTES)).isoformat()
with _db_lock:
conn = _get_conn()
conn.execute(
"INSERT INTO link_tokens (token, discord_user_id, service, expires_at) VALUES (?, ?, ?, ?)",
(token, discord_user_id, service, expires),
)
conn.commit()
conn.close()
logger.info("Created link token for user %s / service %s", discord_user_id, service)
return token
def validate_token(token: str) -> tuple[int, str] | None:
"""Read-only validation — does NOT consume the token.
Returns (discord_user_id, service) if the token exists, is unused,
and has not expired. Returns None otherwise.
"""
_ensure_schema()
with _db_lock:
conn = _get_conn()
row = conn.execute(
"SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?",
(token,),
).fetchone()
conn.close()
if row is None:
return None
if row["used"]:
return None
expires = datetime.fromisoformat(row["expires_at"])
if datetime.now(timezone.utc) > expires:
return None
return (row["discord_user_id"], row["service"])
def consume_token(token: str) -> tuple[int, str] | None:
"""Validate and consume a link token. Returns (discord_user_id, service) or None.
A token is valid if:
- It exists
- It has not been used
- It has not expired
"""
_ensure_schema()
with _db_lock:
conn = _get_conn()
row = conn.execute(
"SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?",
(token,),
).fetchone()
if row is None:
conn.close()
return None
if row["used"]:
conn.close()
logger.warning("Token already used: %s", token[:8])
return None
expires = datetime.fromisoformat(row["expires_at"])
if datetime.now(timezone.utc) > expires:
conn.close()
logger.warning("Token expired: %s", token[:8])
return None
conn.execute("UPDATE link_tokens SET used = 1 WHERE token = ?", (token,))
conn.commit()
result = (row["discord_user_id"], row["service"])
conn.close()
logger.info("Token consumed: %s… → user=%s service=%s", token[:8], result[0], result[1])
return result
# ---------------------------------------------------------------------------
# Public API — User Auth
# ---------------------------------------------------------------------------
def store_auth(
discord_user_id: int,
service: str,
*,
external_user_id: str = "",
external_name: str = "",
credentials: dict | None = None,
) -> None:
"""Store or update authentication for a user on a service."""
_ensure_schema()
import json
creds_json = json.dumps(credentials) if credentials else "{}"
with _db_lock:
conn = _get_conn()
conn.execute(
"""INSERT INTO user_auth (discord_user_id, service, external_user_id, external_name, credentials, linked_at)
VALUES (?, ?, ?, ?, ?, datetime('now'))
ON CONFLICT(discord_user_id, service) DO UPDATE SET
external_user_id = excluded.external_user_id,
external_name = excluded.external_name,
credentials = excluded.credentials,
linked_at = datetime('now'),
is_active = 1""",
(discord_user_id, service, external_user_id, external_name, creds_json),
)
conn.commit()
conn.close()
logger.info("Stored auth for user %s on %s as %s", discord_user_id, service, external_name)
def get_auth(discord_user_id: int, service: str) -> dict | None:
"""Retrieve stored auth for a user on a service. Returns None if not linked."""
_ensure_schema()
import json
with _db_lock:
conn = _get_conn()
row = conn.execute(
"SELECT * FROM user_auth WHERE discord_user_id = ? AND service = ? AND is_active = 1",
(discord_user_id, service),
).fetchone()
conn.close()
if row is None:
return None
credentials = json.loads(row["credentials"]) if row["credentials"] else {}
return {
"discord_user_id": row["discord_user_id"],
"service": row["service"],
"external_user_id": row["external_user_id"],
"external_name": row["external_name"],
"credentials": credentials,
"linked_at": row["linked_at"],
}
def is_authenticated(discord_user_id: int, service: str) -> bool:
"""Quick check: is this user linked to this service?"""
return get_auth(discord_user_id, service) is not None
def list_services(discord_user_id: int) -> list[str]:
"""Return list of service names this user has linked."""
_ensure_schema()
with _db_lock:
conn = _get_conn()
rows = conn.execute(
"SELECT service FROM user_auth WHERE discord_user_id = ? AND is_active = 1",
(discord_user_id,),
).fetchall()
conn.close()
return [r["service"] for r in rows]
def revoke(discord_user_id: int, service: str) -> None:
"""Unlink a user from a service."""
_ensure_schema()
with _db_lock:
conn = _get_conn()
conn.execute(
"UPDATE user_auth SET is_active = 0 WHERE discord_user_id = ? AND service = ?",
(discord_user_id, service),
)
conn.commit()
conn.close()
logger.info("Revoked auth for user %s on %s", discord_user_id, service)
def get_all_auths(discord_user_id: int) -> list[dict]:
"""
Return all active auth records for a Discord user.
Each record includes service name, external user id, external name,
linked_at timestamp, and the raw credentials (e.g. Jellyfin token + URL).
Used by the /api/v1/auth/status endpoint so other services can discover
linked accounts for a given Discord ID.
"""
_ensure_schema()
import json
with _db_lock:
conn = _get_conn()
rows = conn.execute(
"""SELECT service, external_user_id, external_name, credentials, linked_at
FROM user_auth
WHERE discord_user_id = ? AND is_active = 1
ORDER BY linked_at DESC""",
(discord_user_id,),
).fetchall()
conn.close()
results: list[dict] = []
for row in rows:
creds = {}
if row["credentials"]:
try:
creds = json.loads(row["credentials"])
except (json.JSONDecodeError, TypeError):
creds = {}
results.append({
"service": row["service"],
"external_user_id": row["external_user_id"] or "",
"external_name": row["external_name"] or "",
"linked_at": row["linked_at"] or "",
"credentials": creds,
})
return results
# ---------------------------------------------------------------------------
# Dev / testing — reset the entire store
# ---------------------------------------------------------------------------
def reset_all() -> None:
"""Truncate all auth tables — for development and testing only."""
_ensure_schema()
with _db_lock:
conn = _get_conn()
conn.execute("DELETE FROM link_tokens")
conn.execute("DELETE FROM user_auth")
conn.commit()
conn.close()
logger.warning("Auth store RESET — all tokens and auth records cleared.")
+31
View File
@@ -0,0 +1,31 @@
from dotenv import load_dotenv
from pathlib import Path
import os
# ---------------------------------------------------------------------------
# Load .env from the project root (one level above core/)
# ---------------------------------------------------------------------------
_env_path = Path(__file__).resolve().parent.parent / ".env"
load_dotenv(_env_path)
# ---------------------------------------------------------------------------
# General-purpose config accessor — every skill uses this
# ---------------------------------------------------------------------------
def get_config(key: str, default: str | None = None) -> str | None:
"""Read a value from the environment (loaded from .env)."""
return os.getenv(key, default)
# ---------------------------------------------------------------------------
# LLM
# ---------------------------------------------------------------------------
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
# ---------------------------------------------------------------------------
# Seerr (Overseerr / Jellyseerr)
# ---------------------------------------------------------------------------
SEERR_URL = os.getenv("SEERR_URL", "")
SEERR_API_KEY = os.getenv("SEERR_API_KEY", "")
SEERR_TIMEOUT = int(os.getenv("SEERR_TIMEOUT", "30"))
+245
View File
@@ -0,0 +1,245 @@
"""
LangGraph agent graph factory.
Builds a StateGraph with two nodes:
- agent_node : calls the LLM (with system prompt + tool definitions)
- tool_node : executes tool calls via the existing skill system
A conditional edge routes tool_calls back to the agent, or ends the run.
When a tool fails due to missing authentication, the failure message is
relayed to the LLM, which tells the user to use /login.
"""
from __future__ import annotations
import json
import logging
from typing import Any, Literal
from langchain_core.messages import AIMessage, ToolMessage
from langgraph.graph import END, StateGraph
from openai import OpenAI
from src.state import AgentState
from agents.skills import get_all_tools, execute_tool
logger = logging.getLogger("graph")
# ---------------------------------------------------------------------------
# Helper — map LangChain message type → OpenAI role
# ---------------------------------------------------------------------------
def _lc_role_to_openai(msg_type: str) -> str:
"""Convert a LangChain message type string to an OpenAI role."""
mapping = {"human": "user", "ai": "assistant", "tool": "tool", "system": "system"}
return mapping.get(msg_type, "user")
def _langchain_tc_to_openai(tool_calls: list) -> list[dict[str, Any]]:
"""
Convert LangChain-format tool_calls (with `name`/`args` at top level)
back to OpenAI format (with a nested `function` sub-object).
"""
result: list[dict[str, Any]] = []
for tc in tool_calls:
if isinstance(tc, dict):
if "function" in tc:
result.append(tc)
else:
# LangChain format: {"name": ..., "args": ..., "id": ...}
result.append({
"id": tc.get("id", ""),
"type": "function",
"function": {
"name": tc.get("name", ""),
"arguments": json.dumps(tc.get("args", {})),
},
})
else:
# Pydantic model — dump to dict
d = tc.model_dump() if hasattr(tc, "model_dump") else {}
if "function" in d:
result.append(d)
else:
result.append({
"id": d.get("id", ""),
"type": "function",
"function": {
"name": d.get("name", ""),
"arguments": json.dumps(d.get("args", {})),
},
})
return result
# ---------------------------------------------------------------------------
# Agent node — calls the LLM
# ---------------------------------------------------------------------------
def _make_agent_node(
client: OpenAI,
system_prompt: str,
tool_defs: list[dict[str, Any]],
model_name: str = "deepseek-chat",
):
"""
Return a callable suitable as a LangGraph node.
The node reads the current message list from state, prepends the system
prompt, and calls the LLM. If tool_defs is non-empty the LLM may return
tool_calls; ToolNode (or our custom tool node) will handle them.
"""
def agent_node(state: AgentState) -> dict[str, list]:
messages = state["messages"]
# Convert LangChain message objects to plain dicts for the OpenAI client.
full: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
for m in messages:
if isinstance(m, dict):
d = dict(m)
tc = d.get("tool_calls")
if tc and isinstance(tc, list) and tc and isinstance(tc[0], dict) and "function" not in tc[0]:
d["tool_calls"] = _langchain_tc_to_openai(tc)
full.append(d)
else:
role = _lc_role_to_openai(getattr(m, "type", "user"))
d: dict[str, Any] = {"role": role, "content": getattr(m, "content", "")}
tc = getattr(m, "tool_calls", None)
if tc:
d["tool_calls"] = _langchain_tc_to_openai(tc)
tc_id = getattr(m, "tool_call_id", None)
if tc_id:
d["tool_call_id"] = tc_id
full.append(d)
resp = client.chat.completions.create(
model=model_name,
messages=full,
tools=tool_defs if tool_defs else None,
tool_choice="auto" if tool_defs else None,
)
choice = resp.choices[0]
raw_tool_calls = list(choice.message.tool_calls) if choice.message.tool_calls else []
tool_calls: list[dict[str, Any]] = []
for tc in raw_tool_calls:
fn = tc.function
tool_calls.append({
"name": fn.name,
"args": json.loads(fn.arguments),
"id": tc.id,
})
ai_msg = AIMessage(
content=choice.message.content or "",
tool_calls=tool_calls if tool_calls else [],
id=getattr(choice.message, "id", None),
)
return {"messages": [ai_msg]}
return agent_node
# ---------------------------------------------------------------------------
# Tool node — executes tools via the existing skill system
# ---------------------------------------------------------------------------
def _make_tool_node(skill_names: list[str]):
"""
Return a callable that executes tool_calls from the last AI message.
If a tool fails because the user isn't authenticated, the failure
message (which tells the user to /login) is returned to the LLM.
The LLM naturally relays the instructions to the user.
"""
async def tool_node(state: AgentState) -> dict[str, list]:
last_msg = state["messages"][-1]
tool_calls = getattr(last_msg, "tool_calls", None)
if not tool_calls:
return {"messages": []}
discord_user_id = state.get("discord_user_id")
results: list[ToolMessage] = []
for tc in tool_calls:
if isinstance(tc, dict):
if "function" in tc:
fn = tc["function"]
fn_name = fn.get("name", "")
fn_args_raw = fn.get("arguments", "{}")
else:
fn_name = tc.get("name", "")
fn_args_raw = tc.get("args", {})
tc_id = tc.get("id", "")
else:
fn_name = getattr(tc, "name", "")
fn_args_raw = getattr(tc, "args", {})
tc_id = getattr(tc, "id", "")
if isinstance(fn_args_raw, str):
fn_args = json.loads(fn_args_raw)
else:
fn_args = fn_args_raw
tr = await execute_tool(
skill_names, fn_name, fn_args,
discord_user_id=discord_user_id,
)
content = tr.content if tr else f"Tool '{fn_name}' is not available."
results.append(ToolMessage(content=content, tool_call_id=tc_id))
return {"messages": results}
return tool_node
# ---------------------------------------------------------------------------
# Router — decides whether to continue tool-calling or stop
# ---------------------------------------------------------------------------
def _should_continue(state: AgentState) -> Literal["tool_node", END]:
"""If the last message contains tool_calls → execute them, else finish."""
last_msg = state["messages"][-1]
if getattr(last_msg, "tool_calls", None):
return "tool_node"
return END
# ---------------------------------------------------------------------------
# Graph factory — the public API
# ---------------------------------------------------------------------------
def create_agent_graph(
*,
client: OpenAI,
agent_skills: list[str],
system_prompt: str,
model_name: str = "deepseek-chat",
) -> StateGraph:
"""
Build and compile a LangGraph StateGraph for a single agent.
"""
tool_defs = get_all_tools(agent_skills)
graph = StateGraph(AgentState)
graph.add_node(
"agent_node",
_make_agent_node(client, system_prompt, tool_defs, model_name),
)
if tool_defs:
graph.add_node("tool_node", _make_tool_node(agent_skills))
graph.add_conditional_edges("agent_node", _should_continue, {
"tool_node": "tool_node",
END: END,
})
graph.add_edge("tool_node", "agent_node")
else:
graph.add_edge("agent_node", END)
graph.set_entry_point("agent_node")
return graph.compile()
+9
View File
@@ -0,0 +1,9 @@
from openai import OpenAI
def create_client(api_key: str) -> OpenAI:
"""Factory for an OpenAI-compatible client pointed at DeepSeek."""
return OpenAI(
api_key=api_key,
base_url="https://api.deepseek.com",
)
+21
View File
@@ -0,0 +1,21 @@
"""
LangGraph agent state — defines the shape of the state object that flows
through every node in the agent graph.
"""
from typing import Annotated, TypedDict
from langgraph.graph.message import add_messages
class AgentState(TypedDict):
"""
The single source of truth that travels through every node in the graph.
`messages` uses LangGraph's `add_messages` reducer, which:
- Appends new messages to the list.
- Replaces messages with the same ID (useful for tool-call results).
"""
messages: Annotated[list, add_messages]
discord_user_id: int | None # set by the Discord bot, None for REST API calls
+51
View File
@@ -0,0 +1,51 @@
"""
Tools adapter — bridges the existing skill/tool system with LangGraph's ToolNode.
LangGraph's ToolNode expects callable tools (typically @tool-decorated functions).
This module wraps our skill-based tool definitions and async executors so
ToolNode can invoke them without any changes to the skills/ layer.
"""
from __future__ import annotations
import json
from typing import Any
from langchain_core.tools import tool
from agents.skills import get_all_tools, execute_tool
def build_langgraph_tools(skill_names: list[str]) -> list:
"""
Convert the registered skill tool definitions into LangChain-compatible
@tool-decorated functions that ToolNode can call.
Each tool wraps the existing `execute_tool()` pipeline, so the skill
system's ToolResult + httpx session handling is fully preserved.
"""
tool_defs = get_all_tools(skill_names)
wrapped: list = []
for td in tool_defs:
fn_def = td.get("function", {})
fn_name = fn_def.get("name", "")
fn_desc = fn_def.get("description", "")
# Create a unique factory so each closure captures the right fn_name
def _make_tool(name: str, desc: str, skills: list[str]):
@tool(name, description=desc)
async def _wrapped(**kwargs: Any) -> str:
"""Execute the tool via the skill system and return its content."""
result = await execute_tool(skills, name, kwargs)
if result is None:
return f"Tool '{name}' is not available."
return result.content
# Stash the original OpenAI schema so LangGraph can use it
_wrapped.metadata = fn_def
return _wrapped
wrapped.append(_make_tool(fn_name, fn_desc, skill_names))
return wrapped