small refactor of the structure
This commit is contained in:
@@ -0,0 +1,358 @@
|
||||
"""
|
||||
Auth Store — SQLite-backed persistence for Discord-to-service authentication.
|
||||
|
||||
Two tables:
|
||||
- link_tokens : one-time tokens sent via Discord DM to initiate login
|
||||
- user_auth : per-user, per-service credentials (Jellyfin token, etc.)
|
||||
|
||||
Thread-safe via WAL mode and a shared lock. No passwords are ever stored
|
||||
— only opaque service tokens (e.g. Jellyfin AccessToken).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from src.config import get_config
|
||||
|
||||
logger = logging.getLogger("auth_store")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
AUTH_DB_PATH = get_config("AUTH_DB_PATH", "data/auth.db")
|
||||
TOKEN_EXPIRY_MINUTES = int(get_config("AUTH_TOKEN_EXPIRY", "10"))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Singleton handle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_db_path: Path | None = None
|
||||
_db_lock = threading.Lock()
|
||||
|
||||
|
||||
def _resolve_path() -> Path:
|
||||
"""Turn AUTH_DB_PATH into an absolute path, creating parent dirs."""
|
||||
global _db_path
|
||||
if _db_path is not None:
|
||||
return _db_path
|
||||
p = Path(AUTH_DB_PATH)
|
||||
if not p.is_absolute():
|
||||
# Relative to the project root (two levels above this file)
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
p = project_root / p
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
_db_path = p
|
||||
return p
|
||||
|
||||
|
||||
def _get_conn() -> sqlite3.Connection:
|
||||
"""Return a thread-local connection to the auth database."""
|
||||
import sqlite3
|
||||
|
||||
conn = sqlite3.connect(str(_resolve_path()), check_same_thread=False)
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS link_tokens (
|
||||
token TEXT PRIMARY KEY,
|
||||
discord_user_id INTEGER NOT NULL,
|
||||
service TEXT NOT NULL,
|
||||
expires_at TEXT NOT NULL,
|
||||
used INTEGER DEFAULT 0,
|
||||
created_at TEXT DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS user_auth (
|
||||
discord_user_id INTEGER NOT NULL,
|
||||
service TEXT NOT NULL,
|
||||
external_user_id TEXT,
|
||||
external_name TEXT,
|
||||
credentials TEXT,
|
||||
linked_at TEXT DEFAULT (datetime('now')),
|
||||
is_active INTEGER DEFAULT 1,
|
||||
PRIMARY KEY (discord_user_id, service)
|
||||
);
|
||||
"""
|
||||
|
||||
_initialized = False
|
||||
|
||||
|
||||
def _ensure_schema() -> None:
|
||||
global _initialized
|
||||
if _initialized:
|
||||
return
|
||||
with _db_lock:
|
||||
if _initialized:
|
||||
return
|
||||
conn = _get_conn()
|
||||
conn.executescript(_SCHEMA)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
_initialized = True
|
||||
logger.info("Auth store initialized at %s", _resolve_path())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API — Link Tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_token(discord_user_id: int, service: str) -> str:
|
||||
"""Generate a one-time link token. Expires after TOKEN_EXPIRY_MINUTES."""
|
||||
_ensure_schema()
|
||||
token = secrets.token_urlsafe(32)
|
||||
expires = (datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRY_MINUTES)).isoformat()
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"INSERT INTO link_tokens (token, discord_user_id, service, expires_at) VALUES (?, ?, ?, ?)",
|
||||
(token, discord_user_id, service, expires),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
logger.info("Created link token for user %s / service %s", discord_user_id, service)
|
||||
return token
|
||||
|
||||
|
||||
def validate_token(token: str) -> tuple[int, str] | None:
|
||||
"""Read-only validation — does NOT consume the token.
|
||||
|
||||
Returns (discord_user_id, service) if the token exists, is unused,
|
||||
and has not expired. Returns None otherwise.
|
||||
"""
|
||||
_ensure_schema()
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?",
|
||||
(token,),
|
||||
).fetchone()
|
||||
conn.close()
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
if row["used"]:
|
||||
return None
|
||||
|
||||
expires = datetime.fromisoformat(row["expires_at"])
|
||||
if datetime.now(timezone.utc) > expires:
|
||||
return None
|
||||
|
||||
return (row["discord_user_id"], row["service"])
|
||||
|
||||
|
||||
def consume_token(token: str) -> tuple[int, str] | None:
|
||||
"""Validate and consume a link token. Returns (discord_user_id, service) or None.
|
||||
|
||||
A token is valid if:
|
||||
- It exists
|
||||
- It has not been used
|
||||
- It has not expired
|
||||
"""
|
||||
_ensure_schema()
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?",
|
||||
(token,),
|
||||
).fetchone()
|
||||
|
||||
if row is None:
|
||||
conn.close()
|
||||
return None
|
||||
|
||||
if row["used"]:
|
||||
conn.close()
|
||||
logger.warning("Token already used: %s…", token[:8])
|
||||
return None
|
||||
|
||||
expires = datetime.fromisoformat(row["expires_at"])
|
||||
if datetime.now(timezone.utc) > expires:
|
||||
conn.close()
|
||||
logger.warning("Token expired: %s…", token[:8])
|
||||
return None
|
||||
|
||||
conn.execute("UPDATE link_tokens SET used = 1 WHERE token = ?", (token,))
|
||||
conn.commit()
|
||||
result = (row["discord_user_id"], row["service"])
|
||||
conn.close()
|
||||
logger.info("Token consumed: %s… → user=%s service=%s", token[:8], result[0], result[1])
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API — User Auth
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def store_auth(
|
||||
discord_user_id: int,
|
||||
service: str,
|
||||
*,
|
||||
external_user_id: str = "",
|
||||
external_name: str = "",
|
||||
credentials: dict | None = None,
|
||||
) -> None:
|
||||
"""Store or update authentication for a user on a service."""
|
||||
_ensure_schema()
|
||||
import json
|
||||
|
||||
creds_json = json.dumps(credentials) if credentials else "{}"
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"""INSERT INTO user_auth (discord_user_id, service, external_user_id, external_name, credentials, linked_at)
|
||||
VALUES (?, ?, ?, ?, ?, datetime('now'))
|
||||
ON CONFLICT(discord_user_id, service) DO UPDATE SET
|
||||
external_user_id = excluded.external_user_id,
|
||||
external_name = excluded.external_name,
|
||||
credentials = excluded.credentials,
|
||||
linked_at = datetime('now'),
|
||||
is_active = 1""",
|
||||
(discord_user_id, service, external_user_id, external_name, creds_json),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
logger.info("Stored auth for user %s on %s as %s", discord_user_id, service, external_name)
|
||||
|
||||
|
||||
def get_auth(discord_user_id: int, service: str) -> dict | None:
|
||||
"""Retrieve stored auth for a user on a service. Returns None if not linked."""
|
||||
_ensure_schema()
|
||||
import json
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
row = conn.execute(
|
||||
"SELECT * FROM user_auth WHERE discord_user_id = ? AND service = ? AND is_active = 1",
|
||||
(discord_user_id, service),
|
||||
).fetchone()
|
||||
conn.close()
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
credentials = json.loads(row["credentials"]) if row["credentials"] else {}
|
||||
return {
|
||||
"discord_user_id": row["discord_user_id"],
|
||||
"service": row["service"],
|
||||
"external_user_id": row["external_user_id"],
|
||||
"external_name": row["external_name"],
|
||||
"credentials": credentials,
|
||||
"linked_at": row["linked_at"],
|
||||
}
|
||||
|
||||
|
||||
def is_authenticated(discord_user_id: int, service: str) -> bool:
|
||||
"""Quick check: is this user linked to this service?"""
|
||||
return get_auth(discord_user_id, service) is not None
|
||||
|
||||
|
||||
def list_services(discord_user_id: int) -> list[str]:
|
||||
"""Return list of service names this user has linked."""
|
||||
_ensure_schema()
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT service FROM user_auth WHERE discord_user_id = ? AND is_active = 1",
|
||||
(discord_user_id,),
|
||||
).fetchall()
|
||||
conn.close()
|
||||
|
||||
return [r["service"] for r in rows]
|
||||
|
||||
|
||||
def revoke(discord_user_id: int, service: str) -> None:
|
||||
"""Unlink a user from a service."""
|
||||
_ensure_schema()
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"UPDATE user_auth SET is_active = 0 WHERE discord_user_id = ? AND service = ?",
|
||||
(discord_user_id, service),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
logger.info("Revoked auth for user %s on %s", discord_user_id, service)
|
||||
|
||||
|
||||
def get_all_auths(discord_user_id: int) -> list[dict]:
|
||||
"""
|
||||
Return all active auth records for a Discord user.
|
||||
Each record includes service name, external user id, external name,
|
||||
linked_at timestamp, and the raw credentials (e.g. Jellyfin token + URL).
|
||||
|
||||
Used by the /api/v1/auth/status endpoint so other services can discover
|
||||
linked accounts for a given Discord ID.
|
||||
"""
|
||||
_ensure_schema()
|
||||
import json
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
rows = conn.execute(
|
||||
"""SELECT service, external_user_id, external_name, credentials, linked_at
|
||||
FROM user_auth
|
||||
WHERE discord_user_id = ? AND is_active = 1
|
||||
ORDER BY linked_at DESC""",
|
||||
(discord_user_id,),
|
||||
).fetchall()
|
||||
conn.close()
|
||||
|
||||
results: list[dict] = []
|
||||
for row in rows:
|
||||
creds = {}
|
||||
if row["credentials"]:
|
||||
try:
|
||||
creds = json.loads(row["credentials"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
creds = {}
|
||||
results.append({
|
||||
"service": row["service"],
|
||||
"external_user_id": row["external_user_id"] or "",
|
||||
"external_name": row["external_name"] or "",
|
||||
"linked_at": row["linked_at"] or "",
|
||||
"credentials": creds,
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dev / testing — reset the entire store
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def reset_all() -> None:
|
||||
"""Truncate all auth tables — for development and testing only."""
|
||||
_ensure_schema()
|
||||
|
||||
with _db_lock:
|
||||
conn = _get_conn()
|
||||
conn.execute("DELETE FROM link_tokens")
|
||||
conn.execute("DELETE FROM user_auth")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
logger.warning("Auth store RESET — all tokens and auth records cleared.")
|
||||
@@ -0,0 +1,31 @@
|
||||
from dotenv import load_dotenv
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Load .env from the project root (one level above core/)
|
||||
# ---------------------------------------------------------------------------
|
||||
_env_path = Path(__file__).resolve().parent.parent / ".env"
|
||||
load_dotenv(_env_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# General-purpose config accessor — every skill uses this
|
||||
# ---------------------------------------------------------------------------
|
||||
def get_config(key: str, default: str | None = None) -> str | None:
|
||||
"""Read a value from the environment (loaded from .env)."""
|
||||
return os.getenv(key, default)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM
|
||||
# ---------------------------------------------------------------------------
|
||||
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Seerr (Overseerr / Jellyseerr)
|
||||
# ---------------------------------------------------------------------------
|
||||
SEERR_URL = os.getenv("SEERR_URL", "")
|
||||
SEERR_API_KEY = os.getenv("SEERR_API_KEY", "")
|
||||
SEERR_TIMEOUT = int(os.getenv("SEERR_TIMEOUT", "30"))
|
||||
+245
@@ -0,0 +1,245 @@
|
||||
"""
|
||||
LangGraph agent graph factory.
|
||||
|
||||
Builds a StateGraph with two nodes:
|
||||
- agent_node : calls the LLM (with system prompt + tool definitions)
|
||||
- tool_node : executes tool calls via the existing skill system
|
||||
|
||||
A conditional edge routes tool_calls back to the agent, or ends the run.
|
||||
When a tool fails due to missing authentication, the failure message is
|
||||
relayed to the LLM, which tells the user to use /login.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langgraph.graph import END, StateGraph
|
||||
from openai import OpenAI
|
||||
|
||||
from src.state import AgentState
|
||||
from agents.skills import get_all_tools, execute_tool
|
||||
|
||||
logger = logging.getLogger("graph")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper — map LangChain message type → OpenAI role
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _lc_role_to_openai(msg_type: str) -> str:
|
||||
"""Convert a LangChain message type string to an OpenAI role."""
|
||||
mapping = {"human": "user", "ai": "assistant", "tool": "tool", "system": "system"}
|
||||
return mapping.get(msg_type, "user")
|
||||
|
||||
|
||||
def _langchain_tc_to_openai(tool_calls: list) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert LangChain-format tool_calls (with `name`/`args` at top level)
|
||||
back to OpenAI format (with a nested `function` sub-object).
|
||||
"""
|
||||
result: list[dict[str, Any]] = []
|
||||
for tc in tool_calls:
|
||||
if isinstance(tc, dict):
|
||||
if "function" in tc:
|
||||
result.append(tc)
|
||||
else:
|
||||
# LangChain format: {"name": ..., "args": ..., "id": ...}
|
||||
result.append({
|
||||
"id": tc.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.get("name", ""),
|
||||
"arguments": json.dumps(tc.get("args", {})),
|
||||
},
|
||||
})
|
||||
else:
|
||||
# Pydantic model — dump to dict
|
||||
d = tc.model_dump() if hasattr(tc, "model_dump") else {}
|
||||
if "function" in d:
|
||||
result.append(d)
|
||||
else:
|
||||
result.append({
|
||||
"id": d.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": d.get("name", ""),
|
||||
"arguments": json.dumps(d.get("args", {})),
|
||||
},
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent node — calls the LLM
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_agent_node(
|
||||
client: OpenAI,
|
||||
system_prompt: str,
|
||||
tool_defs: list[dict[str, Any]],
|
||||
model_name: str = "deepseek-chat",
|
||||
):
|
||||
"""
|
||||
Return a callable suitable as a LangGraph node.
|
||||
|
||||
The node reads the current message list from state, prepends the system
|
||||
prompt, and calls the LLM. If tool_defs is non-empty the LLM may return
|
||||
tool_calls; ToolNode (or our custom tool node) will handle them.
|
||||
"""
|
||||
|
||||
def agent_node(state: AgentState) -> dict[str, list]:
|
||||
messages = state["messages"]
|
||||
|
||||
# Convert LangChain message objects to plain dicts for the OpenAI client.
|
||||
full: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
|
||||
for m in messages:
|
||||
if isinstance(m, dict):
|
||||
d = dict(m)
|
||||
tc = d.get("tool_calls")
|
||||
if tc and isinstance(tc, list) and tc and isinstance(tc[0], dict) and "function" not in tc[0]:
|
||||
d["tool_calls"] = _langchain_tc_to_openai(tc)
|
||||
full.append(d)
|
||||
else:
|
||||
role = _lc_role_to_openai(getattr(m, "type", "user"))
|
||||
d: dict[str, Any] = {"role": role, "content": getattr(m, "content", "")}
|
||||
tc = getattr(m, "tool_calls", None)
|
||||
if tc:
|
||||
d["tool_calls"] = _langchain_tc_to_openai(tc)
|
||||
tc_id = getattr(m, "tool_call_id", None)
|
||||
if tc_id:
|
||||
d["tool_call_id"] = tc_id
|
||||
full.append(d)
|
||||
|
||||
resp = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=full,
|
||||
tools=tool_defs if tool_defs else None,
|
||||
tool_choice="auto" if tool_defs else None,
|
||||
)
|
||||
choice = resp.choices[0]
|
||||
|
||||
raw_tool_calls = list(choice.message.tool_calls) if choice.message.tool_calls else []
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
for tc in raw_tool_calls:
|
||||
fn = tc.function
|
||||
tool_calls.append({
|
||||
"name": fn.name,
|
||||
"args": json.loads(fn.arguments),
|
||||
"id": tc.id,
|
||||
})
|
||||
ai_msg = AIMessage(
|
||||
content=choice.message.content or "",
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
id=getattr(choice.message, "id", None),
|
||||
)
|
||||
return {"messages": [ai_msg]}
|
||||
|
||||
return agent_node
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool node — executes tools via the existing skill system
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_tool_node(skill_names: list[str]):
|
||||
"""
|
||||
Return a callable that executes tool_calls from the last AI message.
|
||||
|
||||
If a tool fails because the user isn't authenticated, the failure
|
||||
message (which tells the user to /login) is returned to the LLM.
|
||||
The LLM naturally relays the instructions to the user.
|
||||
"""
|
||||
|
||||
async def tool_node(state: AgentState) -> dict[str, list]:
|
||||
last_msg = state["messages"][-1]
|
||||
tool_calls = getattr(last_msg, "tool_calls", None)
|
||||
if not tool_calls:
|
||||
return {"messages": []}
|
||||
|
||||
discord_user_id = state.get("discord_user_id")
|
||||
|
||||
results: list[ToolMessage] = []
|
||||
for tc in tool_calls:
|
||||
if isinstance(tc, dict):
|
||||
if "function" in tc:
|
||||
fn = tc["function"]
|
||||
fn_name = fn.get("name", "")
|
||||
fn_args_raw = fn.get("arguments", "{}")
|
||||
else:
|
||||
fn_name = tc.get("name", "")
|
||||
fn_args_raw = tc.get("args", {})
|
||||
tc_id = tc.get("id", "")
|
||||
else:
|
||||
fn_name = getattr(tc, "name", "")
|
||||
fn_args_raw = getattr(tc, "args", {})
|
||||
tc_id = getattr(tc, "id", "")
|
||||
|
||||
if isinstance(fn_args_raw, str):
|
||||
fn_args = json.loads(fn_args_raw)
|
||||
else:
|
||||
fn_args = fn_args_raw
|
||||
|
||||
tr = await execute_tool(
|
||||
skill_names, fn_name, fn_args,
|
||||
discord_user_id=discord_user_id,
|
||||
)
|
||||
content = tr.content if tr else f"Tool '{fn_name}' is not available."
|
||||
results.append(ToolMessage(content=content, tool_call_id=tc_id))
|
||||
|
||||
return {"messages": results}
|
||||
|
||||
return tool_node
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Router — decides whether to continue tool-calling or stop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _should_continue(state: AgentState) -> Literal["tool_node", END]:
|
||||
"""If the last message contains tool_calls → execute them, else finish."""
|
||||
last_msg = state["messages"][-1]
|
||||
if getattr(last_msg, "tool_calls", None):
|
||||
return "tool_node"
|
||||
return END
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Graph factory — the public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_agent_graph(
|
||||
*,
|
||||
client: OpenAI,
|
||||
agent_skills: list[str],
|
||||
system_prompt: str,
|
||||
model_name: str = "deepseek-chat",
|
||||
) -> StateGraph:
|
||||
"""
|
||||
Build and compile a LangGraph StateGraph for a single agent.
|
||||
"""
|
||||
tool_defs = get_all_tools(agent_skills)
|
||||
|
||||
graph = StateGraph(AgentState)
|
||||
|
||||
graph.add_node(
|
||||
"agent_node",
|
||||
_make_agent_node(client, system_prompt, tool_defs, model_name),
|
||||
)
|
||||
|
||||
if tool_defs:
|
||||
graph.add_node("tool_node", _make_tool_node(agent_skills))
|
||||
graph.add_conditional_edges("agent_node", _should_continue, {
|
||||
"tool_node": "tool_node",
|
||||
END: END,
|
||||
})
|
||||
graph.add_edge("tool_node", "agent_node")
|
||||
else:
|
||||
graph.add_edge("agent_node", END)
|
||||
|
||||
graph.set_entry_point("agent_node")
|
||||
|
||||
return graph.compile()
|
||||
@@ -0,0 +1,9 @@
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
def create_client(api_key: str) -> OpenAI:
|
||||
"""Factory for an OpenAI-compatible client pointed at DeepSeek."""
|
||||
return OpenAI(
|
||||
api_key=api_key,
|
||||
base_url="https://api.deepseek.com",
|
||||
)
|
||||
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
LangGraph agent state — defines the shape of the state object that flows
|
||||
through every node in the agent graph.
|
||||
"""
|
||||
|
||||
from typing import Annotated, TypedDict
|
||||
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
|
||||
class AgentState(TypedDict):
|
||||
"""
|
||||
The single source of truth that travels through every node in the graph.
|
||||
|
||||
`messages` uses LangGraph's `add_messages` reducer, which:
|
||||
- Appends new messages to the list.
|
||||
- Replaces messages with the same ID (useful for tool-call results).
|
||||
"""
|
||||
|
||||
messages: Annotated[list, add_messages]
|
||||
discord_user_id: int | None # set by the Discord bot, None for REST API calls
|
||||
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Tools adapter — bridges the existing skill/tool system with LangGraph's ToolNode.
|
||||
|
||||
LangGraph's ToolNode expects callable tools (typically @tool-decorated functions).
|
||||
This module wraps our skill-based tool definitions and async executors so
|
||||
ToolNode can invoke them without any changes to the skills/ layer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from agents.skills import get_all_tools, execute_tool
|
||||
|
||||
|
||||
def build_langgraph_tools(skill_names: list[str]) -> list:
|
||||
"""
|
||||
Convert the registered skill tool definitions into LangChain-compatible
|
||||
@tool-decorated functions that ToolNode can call.
|
||||
|
||||
Each tool wraps the existing `execute_tool()` pipeline, so the skill
|
||||
system's ToolResult + httpx session handling is fully preserved.
|
||||
"""
|
||||
tool_defs = get_all_tools(skill_names)
|
||||
wrapped: list = []
|
||||
|
||||
for td in tool_defs:
|
||||
fn_def = td.get("function", {})
|
||||
fn_name = fn_def.get("name", "")
|
||||
fn_desc = fn_def.get("description", "")
|
||||
|
||||
# Create a unique factory so each closure captures the right fn_name
|
||||
def _make_tool(name: str, desc: str, skills: list[str]):
|
||||
@tool(name, description=desc)
|
||||
async def _wrapped(**kwargs: Any) -> str:
|
||||
"""Execute the tool via the skill system and return its content."""
|
||||
result = await execute_tool(skills, name, kwargs)
|
||||
if result is None:
|
||||
return f"Tool '{name}' is not available."
|
||||
return result.content
|
||||
|
||||
# Stash the original OpenAI schema so LangGraph can use it
|
||||
_wrapped.metadata = fn_def
|
||||
return _wrapped
|
||||
|
||||
wrapped.append(_make_tool(fn_name, fn_desc, skill_names))
|
||||
|
||||
return wrapped
|
||||
Reference in New Issue
Block a user