246 lines
8.4 KiB
Python
246 lines
8.4 KiB
Python
"""
|
|
LangGraph agent graph factory.
|
|
|
|
Builds a StateGraph with two nodes:
|
|
- agent_node : calls the LLM (with system prompt + tool definitions)
|
|
- tool_node : executes tool calls via the existing skill system
|
|
|
|
A conditional edge routes tool_calls back to the agent, or ends the run.
|
|
When a tool fails due to missing authentication, the failure message is
|
|
relayed to the LLM, which tells the user to use /login.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from typing import Any, Literal
|
|
|
|
from langchain_core.messages import AIMessage, ToolMessage
|
|
from langgraph.graph import END, StateGraph
|
|
from openai import OpenAI
|
|
|
|
from core.state import AgentState
|
|
from skills import get_all_tools, execute_tool
|
|
|
|
logger = logging.getLogger("graph")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helper — map LangChain message type → OpenAI role
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _lc_role_to_openai(msg_type: str) -> str:
|
|
"""Convert a LangChain message type string to an OpenAI role."""
|
|
mapping = {"human": "user", "ai": "assistant", "tool": "tool", "system": "system"}
|
|
return mapping.get(msg_type, "user")
|
|
|
|
|
|
def _langchain_tc_to_openai(tool_calls: list) -> list[dict[str, Any]]:
|
|
"""
|
|
Convert LangChain-format tool_calls (with `name`/`args` at top level)
|
|
back to OpenAI format (with a nested `function` sub-object).
|
|
"""
|
|
result: list[dict[str, Any]] = []
|
|
for tc in tool_calls:
|
|
if isinstance(tc, dict):
|
|
if "function" in tc:
|
|
result.append(tc)
|
|
else:
|
|
# LangChain format: {"name": ..., "args": ..., "id": ...}
|
|
result.append({
|
|
"id": tc.get("id", ""),
|
|
"type": "function",
|
|
"function": {
|
|
"name": tc.get("name", ""),
|
|
"arguments": json.dumps(tc.get("args", {})),
|
|
},
|
|
})
|
|
else:
|
|
# Pydantic model — dump to dict
|
|
d = tc.model_dump() if hasattr(tc, "model_dump") else {}
|
|
if "function" in d:
|
|
result.append(d)
|
|
else:
|
|
result.append({
|
|
"id": d.get("id", ""),
|
|
"type": "function",
|
|
"function": {
|
|
"name": d.get("name", ""),
|
|
"arguments": json.dumps(d.get("args", {})),
|
|
},
|
|
})
|
|
return result
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Agent node — calls the LLM
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _make_agent_node(
|
|
client: OpenAI,
|
|
system_prompt: str,
|
|
tool_defs: list[dict[str, Any]],
|
|
model_name: str = "deepseek-chat",
|
|
):
|
|
"""
|
|
Return a callable suitable as a LangGraph node.
|
|
|
|
The node reads the current message list from state, prepends the system
|
|
prompt, and calls the LLM. If tool_defs is non-empty the LLM may return
|
|
tool_calls; ToolNode (or our custom tool node) will handle them.
|
|
"""
|
|
|
|
def agent_node(state: AgentState) -> dict[str, list]:
|
|
messages = state["messages"]
|
|
|
|
# Convert LangChain message objects to plain dicts for the OpenAI client.
|
|
full: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
|
|
for m in messages:
|
|
if isinstance(m, dict):
|
|
d = dict(m)
|
|
tc = d.get("tool_calls")
|
|
if tc and isinstance(tc, list) and tc and isinstance(tc[0], dict) and "function" not in tc[0]:
|
|
d["tool_calls"] = _langchain_tc_to_openai(tc)
|
|
full.append(d)
|
|
else:
|
|
role = _lc_role_to_openai(getattr(m, "type", "user"))
|
|
d: dict[str, Any] = {"role": role, "content": getattr(m, "content", "")}
|
|
tc = getattr(m, "tool_calls", None)
|
|
if tc:
|
|
d["tool_calls"] = _langchain_tc_to_openai(tc)
|
|
tc_id = getattr(m, "tool_call_id", None)
|
|
if tc_id:
|
|
d["tool_call_id"] = tc_id
|
|
full.append(d)
|
|
|
|
resp = client.chat.completions.create(
|
|
model=model_name,
|
|
messages=full,
|
|
tools=tool_defs if tool_defs else None,
|
|
tool_choice="auto" if tool_defs else None,
|
|
)
|
|
choice = resp.choices[0]
|
|
|
|
raw_tool_calls = list(choice.message.tool_calls) if choice.message.tool_calls else []
|
|
tool_calls: list[dict[str, Any]] = []
|
|
for tc in raw_tool_calls:
|
|
fn = tc.function
|
|
tool_calls.append({
|
|
"name": fn.name,
|
|
"args": json.loads(fn.arguments),
|
|
"id": tc.id,
|
|
})
|
|
ai_msg = AIMessage(
|
|
content=choice.message.content or "",
|
|
tool_calls=tool_calls if tool_calls else [],
|
|
id=getattr(choice.message, "id", None),
|
|
)
|
|
return {"messages": [ai_msg]}
|
|
|
|
return agent_node
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tool node — executes tools via the existing skill system
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _make_tool_node(skill_names: list[str]):
|
|
"""
|
|
Return a callable that executes tool_calls from the last AI message.
|
|
|
|
If a tool fails because the user isn't authenticated, the failure
|
|
message (which tells the user to /login) is returned to the LLM.
|
|
The LLM naturally relays the instructions to the user.
|
|
"""
|
|
|
|
async def tool_node(state: AgentState) -> dict[str, list]:
|
|
last_msg = state["messages"][-1]
|
|
tool_calls = getattr(last_msg, "tool_calls", None)
|
|
if not tool_calls:
|
|
return {"messages": []}
|
|
|
|
discord_user_id = state.get("discord_user_id")
|
|
|
|
results: list[ToolMessage] = []
|
|
for tc in tool_calls:
|
|
if isinstance(tc, dict):
|
|
if "function" in tc:
|
|
fn = tc["function"]
|
|
fn_name = fn.get("name", "")
|
|
fn_args_raw = fn.get("arguments", "{}")
|
|
else:
|
|
fn_name = tc.get("name", "")
|
|
fn_args_raw = tc.get("args", {})
|
|
tc_id = tc.get("id", "")
|
|
else:
|
|
fn_name = getattr(tc, "name", "")
|
|
fn_args_raw = getattr(tc, "args", {})
|
|
tc_id = getattr(tc, "id", "")
|
|
|
|
if isinstance(fn_args_raw, str):
|
|
fn_args = json.loads(fn_args_raw)
|
|
else:
|
|
fn_args = fn_args_raw
|
|
|
|
tr = await execute_tool(
|
|
skill_names, fn_name, fn_args,
|
|
discord_user_id=discord_user_id,
|
|
)
|
|
content = tr.content if tr else f"Tool '{fn_name}' is not available."
|
|
results.append(ToolMessage(content=content, tool_call_id=tc_id))
|
|
|
|
return {"messages": results}
|
|
|
|
return tool_node
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Router — decides whether to continue tool-calling or stop
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _should_continue(state: AgentState) -> Literal["tool_node", END]:
|
|
"""If the last message contains tool_calls → execute them, else finish."""
|
|
last_msg = state["messages"][-1]
|
|
if getattr(last_msg, "tool_calls", None):
|
|
return "tool_node"
|
|
return END
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Graph factory — the public API
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def create_agent_graph(
|
|
*,
|
|
client: OpenAI,
|
|
agent_skills: list[str],
|
|
system_prompt: str,
|
|
model_name: str = "deepseek-chat",
|
|
) -> StateGraph:
|
|
"""
|
|
Build and compile a LangGraph StateGraph for a single agent.
|
|
"""
|
|
tool_defs = get_all_tools(agent_skills)
|
|
|
|
graph = StateGraph(AgentState)
|
|
|
|
graph.add_node(
|
|
"agent_node",
|
|
_make_agent_node(client, system_prompt, tool_defs, model_name),
|
|
)
|
|
|
|
if tool_defs:
|
|
graph.add_node("tool_node", _make_tool_node(agent_skills))
|
|
graph.add_conditional_edges("agent_node", _should_continue, {
|
|
"tool_node": "tool_node",
|
|
END: END,
|
|
})
|
|
graph.add_edge("tool_node", "agent_node")
|
|
else:
|
|
graph.add_edge("agent_node", END)
|
|
|
|
graph.set_entry_point("agent_node")
|
|
|
|
return graph.compile()
|