""" LangGraph agent graph factory. Builds a StateGraph with two nodes: - agent_node : calls the LLM (with system prompt + tool definitions) - tool_node : executes tool calls via the existing skill system A conditional edge routes tool_calls back to the agent, or ends the run. When a tool fails due to missing authentication, the failure message is relayed to the LLM, which tells the user to use /login. """ from __future__ import annotations import json import logging from typing import Any, Literal from langchain_core.messages import AIMessage, ToolMessage from langgraph.graph import END, StateGraph from openai import OpenAI from core.state import AgentState from skills import get_all_tools, execute_tool logger = logging.getLogger("graph") # --------------------------------------------------------------------------- # Helper — map LangChain message type → OpenAI role # --------------------------------------------------------------------------- def _lc_role_to_openai(msg_type: str) -> str: """Convert a LangChain message type string to an OpenAI role.""" mapping = {"human": "user", "ai": "assistant", "tool": "tool", "system": "system"} return mapping.get(msg_type, "user") def _langchain_tc_to_openai(tool_calls: list) -> list[dict[str, Any]]: """ Convert LangChain-format tool_calls (with `name`/`args` at top level) back to OpenAI format (with a nested `function` sub-object). """ result: list[dict[str, Any]] = [] for tc in tool_calls: if isinstance(tc, dict): if "function" in tc: result.append(tc) else: # LangChain format: {"name": ..., "args": ..., "id": ...} result.append({ "id": tc.get("id", ""), "type": "function", "function": { "name": tc.get("name", ""), "arguments": json.dumps(tc.get("args", {})), }, }) else: # Pydantic model — dump to dict d = tc.model_dump() if hasattr(tc, "model_dump") else {} if "function" in d: result.append(d) else: result.append({ "id": d.get("id", ""), "type": "function", "function": { "name": d.get("name", ""), "arguments": json.dumps(d.get("args", {})), }, }) return result # --------------------------------------------------------------------------- # Agent node — calls the LLM # --------------------------------------------------------------------------- def _make_agent_node( client: OpenAI, system_prompt: str, tool_defs: list[dict[str, Any]], model_name: str = "deepseek-chat", ): """ Return a callable suitable as a LangGraph node. The node reads the current message list from state, prepends the system prompt, and calls the LLM. If tool_defs is non-empty the LLM may return tool_calls; ToolNode (or our custom tool node) will handle them. """ def agent_node(state: AgentState) -> dict[str, list]: messages = state["messages"] # Convert LangChain message objects to plain dicts for the OpenAI client. full: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}] for m in messages: if isinstance(m, dict): d = dict(m) tc = d.get("tool_calls") if tc and isinstance(tc, list) and tc and isinstance(tc[0], dict) and "function" not in tc[0]: d["tool_calls"] = _langchain_tc_to_openai(tc) full.append(d) else: role = _lc_role_to_openai(getattr(m, "type", "user")) d: dict[str, Any] = {"role": role, "content": getattr(m, "content", "")} tc = getattr(m, "tool_calls", None) if tc: d["tool_calls"] = _langchain_tc_to_openai(tc) tc_id = getattr(m, "tool_call_id", None) if tc_id: d["tool_call_id"] = tc_id full.append(d) resp = client.chat.completions.create( model=model_name, messages=full, tools=tool_defs if tool_defs else None, tool_choice="auto" if tool_defs else None, ) choice = resp.choices[0] raw_tool_calls = list(choice.message.tool_calls) if choice.message.tool_calls else [] tool_calls: list[dict[str, Any]] = [] for tc in raw_tool_calls: fn = tc.function tool_calls.append({ "name": fn.name, "args": json.loads(fn.arguments), "id": tc.id, }) ai_msg = AIMessage( content=choice.message.content or "", tool_calls=tool_calls if tool_calls else [], id=getattr(choice.message, "id", None), ) return {"messages": [ai_msg]} return agent_node # --------------------------------------------------------------------------- # Tool node — executes tools via the existing skill system # --------------------------------------------------------------------------- def _make_tool_node(skill_names: list[str]): """ Return a callable that executes tool_calls from the last AI message. If a tool fails because the user isn't authenticated, the failure message (which tells the user to /login) is returned to the LLM. The LLM naturally relays the instructions to the user. """ async def tool_node(state: AgentState) -> dict[str, list]: last_msg = state["messages"][-1] tool_calls = getattr(last_msg, "tool_calls", None) if not tool_calls: return {"messages": []} discord_user_id = state.get("discord_user_id") results: list[ToolMessage] = [] for tc in tool_calls: if isinstance(tc, dict): if "function" in tc: fn = tc["function"] fn_name = fn.get("name", "") fn_args_raw = fn.get("arguments", "{}") else: fn_name = tc.get("name", "") fn_args_raw = tc.get("args", {}) tc_id = tc.get("id", "") else: fn_name = getattr(tc, "name", "") fn_args_raw = getattr(tc, "args", {}) tc_id = getattr(tc, "id", "") if isinstance(fn_args_raw, str): fn_args = json.loads(fn_args_raw) else: fn_args = fn_args_raw tr = await execute_tool( skill_names, fn_name, fn_args, discord_user_id=discord_user_id, ) content = tr.content if tr else f"Tool '{fn_name}' is not available." results.append(ToolMessage(content=content, tool_call_id=tc_id)) return {"messages": results} return tool_node # --------------------------------------------------------------------------- # Router — decides whether to continue tool-calling or stop # --------------------------------------------------------------------------- def _should_continue(state: AgentState) -> Literal["tool_node", END]: """If the last message contains tool_calls → execute them, else finish.""" last_msg = state["messages"][-1] if getattr(last_msg, "tool_calls", None): return "tool_node" return END # --------------------------------------------------------------------------- # Graph factory — the public API # --------------------------------------------------------------------------- def create_agent_graph( *, client: OpenAI, agent_skills: list[str], system_prompt: str, model_name: str = "deepseek-chat", ) -> StateGraph: """ Build and compile a LangGraph StateGraph for a single agent. """ tool_defs = get_all_tools(agent_skills) graph = StateGraph(AgentState) graph.add_node( "agent_node", _make_agent_node(client, system_prompt, tool_defs, model_name), ) if tool_defs: graph.add_node("tool_node", _make_tool_node(agent_skills)) graph.add_conditional_edges("agent_node", _should_continue, { "tool_node": "tool_node", END: END, }) graph.add_edge("tool_node", "agent_node") else: graph.add_edge("agent_node", END) graph.set_entry_point("agent_node") return graph.compile()