158 lines
6.0 KiB
Python
158 lines
6.0 KiB
Python
"""
|
|
Skill system — each skill is a piece of domain knowledge or a capability
|
|
that can be attached to an agent to shape its behavior and system prompt.
|
|
|
|
A Skill is a lightweight object with:
|
|
- name : short identifier (e.g. "media_info")
|
|
- description : human-readable summary
|
|
- prompt_fragment : extra text injected into the agent's system prompt
|
|
- tools : OpenAI function-calling tool definitions (list of dicts)
|
|
- execute : async callable to run a tool → ToolResult
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
|
from src.config import get_config # re-export so every skill can use it
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ToolResult — every skill executor must return this
|
|
# ---------------------------------------------------------------------------
|
|
@dataclass
|
|
class ToolResult:
|
|
"""Result of executing a tool.
|
|
- success: True if the API returned 2xx and the action completed.
|
|
- content: The message to feed back to the LLM (will be shown to the user).
|
|
"""
|
|
content: str
|
|
success: bool = True
|
|
|
|
@classmethod
|
|
def ok(cls, content: str) -> "ToolResult":
|
|
return cls(content=content, success=True)
|
|
|
|
@classmethod
|
|
def fail(cls, content: str) -> "ToolResult":
|
|
return cls(content=content, success=False)
|
|
|
|
|
|
# Type alias for a tool executor
|
|
ToolExecutor = Callable[[str, dict], Awaitable[ToolResult]]
|
|
|
|
|
|
@dataclass
|
|
class Skill:
|
|
name: str
|
|
description: str
|
|
prompt_fragment: str = ""
|
|
tools: List[Dict[str, Any]] = field(default_factory=list)
|
|
execute: Optional[ToolExecutor] = None
|
|
requires_auth: List[str] = field(default_factory=list)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Global skill registry — populated at startup / import time
|
|
# ---------------------------------------------------------------------------
|
|
_skill_registry: Dict[str, Skill] = {}
|
|
|
|
|
|
def register(skill: Skill) -> None:
|
|
"""Register a skill so agents can look it up by name."""
|
|
_skill_registry[skill.name] = skill
|
|
|
|
|
|
def get(name: str) -> Skill | None:
|
|
"""Return a registered skill by name, or None."""
|
|
return _skill_registry.get(name)
|
|
|
|
|
|
def list_all() -> Dict[str, Skill]:
|
|
"""Return a shallow copy of the registry."""
|
|
return dict(_skill_registry)
|
|
|
|
|
|
def get_combined_prompt(skill_names: list[str], base_prompt: str = "") -> str:
|
|
"""Build a system prompt from a base prompt + requested skill fragments."""
|
|
parts = [base_prompt] if base_prompt else []
|
|
for name in skill_names:
|
|
s = get(name)
|
|
if s and s.prompt_fragment:
|
|
parts.append(s.prompt_fragment)
|
|
return "\n\n".join(parts)
|
|
|
|
|
|
def get_all_tools(skill_names: list[str]) -> List[Dict[str, Any]]:
|
|
"""Collect all OpenAI tool definitions across the requested skills."""
|
|
tools: List[Dict[str, Any]] = []
|
|
seen: set[str] = set()
|
|
for name in skill_names:
|
|
s = get(name)
|
|
if s:
|
|
for t in s.tools:
|
|
fn_name = t.get("function", {}).get("name", "")
|
|
if fn_name and fn_name not in seen:
|
|
seen.add(fn_name)
|
|
tools.append(t)
|
|
return tools
|
|
|
|
|
|
async def execute_tool(
|
|
skill_names: list[str], tool_name: str, args: dict,
|
|
discord_user_id: int | None = None,
|
|
) -> ToolResult | None:
|
|
"""Find the skill that owns *tool_name* and run its executor.
|
|
|
|
If *discord_user_id* is provided, also checks whether the owning skill
|
|
requires authentication for any services. If auth is missing, returns
|
|
a friendly ToolResult.fail(...) telling the user how to log in.
|
|
|
|
Only logs failures to the console — successful calls are silent.
|
|
"""
|
|
import logging
|
|
logger = logging.getLogger("skills")
|
|
|
|
for name in skill_names:
|
|
s = get(name)
|
|
if s and s.execute:
|
|
for t in s.tools:
|
|
if t.get("function", {}).get("name") == tool_name:
|
|
# --- Auth gate ---
|
|
if s.requires_auth and discord_user_id is not None:
|
|
from src import auth_store
|
|
from gateway.auth import get_auth_service
|
|
missing: list[str] = []
|
|
for svc in s.requires_auth:
|
|
if not auth_store.is_authenticated(discord_user_id, svc):
|
|
missing.append(svc)
|
|
if missing:
|
|
svc_displays = ", ".join(
|
|
(get_auth_service(m) and get_auth_service(m).display_name) or m
|
|
for m in missing
|
|
)
|
|
return ToolResult.fail(
|
|
f"You need to log in to {svc_displays} first. "
|
|
+ " ".join(f"Send `/login {m}` in a DM to get started." for m in missing)
|
|
)
|
|
# --- End auth gate ---
|
|
# Inject discord_user_id so skills can resolve external user IDs
|
|
if discord_user_id is not None:
|
|
args = {**args, "_discord_user_id": discord_user_id}
|
|
try:
|
|
result = await s.execute(tool_name, args)
|
|
if not result.success:
|
|
logger.warning(
|
|
"⚠️ TOOL FAILED: %s | args=%s → %s",
|
|
tool_name, args, result.content[:300],
|
|
)
|
|
return result
|
|
except Exception as exc:
|
|
logger.exception(
|
|
"💥 TOOL CRASH: %s | args=%s", tool_name, args
|
|
)
|
|
return ToolResult.fail(
|
|
f"Tool '{tool_name}' crashed unexpectedly: {exc}"
|
|
)
|
|
|
|
logger.warning("⚠️ TOOL NOT FOUND: %s (skills=%s)", tool_name, skill_names)
|
|
return None
|