small refactor of the structure
This commit is contained in:
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
Auth API — generic endpoints for linking Discord users to external services.
|
||||
|
||||
GET /api/v1/auth/login?service=X&token=Y&discord_id=Z
|
||||
Validates the link token and serves a service-specific login form.
|
||||
|
||||
POST /api/v1/auth/login
|
||||
Accepts the form submission, validates credentials against the service,
|
||||
stores the session, and returns a result page.
|
||||
|
||||
GET /api/v1/auth/status?discord_id=Z
|
||||
Returns which services are linked for this Discord user.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Form, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from gateway.auth import get_auth_service, list_auth_services
|
||||
from src import auth_store
|
||||
|
||||
logger = logging.getLogger("gateway.auth")
|
||||
|
||||
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /auth/login — serve the login form
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/login")
|
||||
async def login_form(
|
||||
service: str,
|
||||
token: str,
|
||||
discord_id: int,
|
||||
):
|
||||
"""Validate the one-time link token and return a service-specific login form."""
|
||||
|
||||
# Validate the token WITHOUT consuming it (the POST will consume it)
|
||||
result = auth_store.validate_token(token)
|
||||
if result is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired link token.")
|
||||
|
||||
uid, svc = result
|
||||
if uid != discord_id or svc != service:
|
||||
raise HTTPException(status_code=400, detail="Token does not match the request.")
|
||||
|
||||
# Look up the AuthService
|
||||
svc_obj = get_auth_service(service)
|
||||
if svc_obj is None:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown service: {service}")
|
||||
|
||||
logger.info("Serving login form: user=%s service=%s", discord_id, service)
|
||||
return HTMLResponse(svc_obj.render_login_form(token, discord_id))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /auth/login — handle form submission
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/login")
|
||||
async def login_submit(request: Request):
|
||||
"""Handle the login form POST: validate credentials, store auth, show result."""
|
||||
|
||||
# Parse form data
|
||||
form = await request.form()
|
||||
token = form.get("token", "")
|
||||
discord_id_str = form.get("discord_id", "")
|
||||
service = form.get("service", "")
|
||||
|
||||
if not token or not discord_id_str or not service:
|
||||
raise HTTPException(status_code=400, detail="Missing required fields.")
|
||||
|
||||
try:
|
||||
discord_id = int(discord_id_str)
|
||||
except (ValueError, TypeError):
|
||||
raise HTTPException(status_code=400, detail="Invalid discord_id.")
|
||||
|
||||
# Consume the token on POST (the GET only validated, didn't consume)
|
||||
result = auth_store.consume_token(token)
|
||||
if result is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired link token.")
|
||||
|
||||
# Look up the AuthService
|
||||
svc_obj = get_auth_service(service)
|
||||
if svc_obj is None:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown service: {service}")
|
||||
|
||||
# Collect service-specific form fields (everything except token, discord_id, service)
|
||||
form_data: dict[str, str] = {}
|
||||
for key, value in form.items():
|
||||
if key not in ("token", "discord_id", "service"):
|
||||
form_data[key] = str(value)
|
||||
|
||||
# Authenticate against the service
|
||||
auth_result = await svc_obj.authenticate(form_data)
|
||||
|
||||
if not auth_result.success:
|
||||
return HTMLResponse(
|
||||
status_code=401,
|
||||
content=f"""<!DOCTYPE html>
|
||||
<html><head><meta charset="utf-8"><title>Login Failed</title>
|
||||
<style>
|
||||
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; }}
|
||||
h2 {{ color: #d32f2f; }}
|
||||
a {{ color: #aa5cc3; }}
|
||||
</style></head><body>
|
||||
<h2>❌ Login Failed</h2>
|
||||
<p>{auth_result.error_message or "Authentication failed. Please try again."}</p>
|
||||
<p><a href="javascript:history.back()">← Go back and try again</a></p>
|
||||
</body></html>""",
|
||||
)
|
||||
|
||||
# Store the successful auth
|
||||
auth_store.store_auth(
|
||||
discord_user_id=discord_id,
|
||||
service=service,
|
||||
external_user_id=auth_result.external_user_id or "",
|
||||
external_name=auth_result.external_name or "",
|
||||
credentials=auth_result.credentials,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Auth linked: discord=%s → %s (%s)",
|
||||
discord_id,
|
||||
service,
|
||||
auth_result.external_name,
|
||||
)
|
||||
|
||||
return HTMLResponse(f"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>Account Linked</title>
|
||||
<style>
|
||||
body {{ font-family: system-ui, sans-serif; max-width: 420px; margin: 60px auto; padding: 0 20px; text-align: center; }}
|
||||
h1 {{ color: #388e3c; }}
|
||||
.name {{ font-weight: bold; color: #aa5cc3; }}
|
||||
p {{ color: #666; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>✅ Account Linked!</h1>
|
||||
<p>Logged in as <span class="name">{auth_result.external_name}</span> on <strong>{svc_obj.display_name}</strong>.</p>
|
||||
<p>You can close this page and return to Discord.</p>
|
||||
</body>
|
||||
</html>""")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /auth/status — get all linked services for a Discord user
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/Discord/status")
|
||||
async def auth_status(discord_id: int):
|
||||
"""
|
||||
Return all services linked to this Discord user with full details.
|
||||
|
||||
Response:
|
||||
{
|
||||
"discord_id": 123456789,
|
||||
"linked_services": {
|
||||
"jellyfin": {
|
||||
"external_user_id": "abc123",
|
||||
"external_name": "Tim",
|
||||
"linked_at": "2026-05-25T10:00:00",
|
||||
"credentials": {
|
||||
"token": "jwt...",
|
||||
"url": "http://jellyfin:8096",
|
||||
"user_id": "abc123"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
auths = auth_store.get_all_auths(discord_id)
|
||||
|
||||
linked_services: dict[str, dict] = {}
|
||||
for auth in auths:
|
||||
svc_name = auth["service"]
|
||||
linked_services[svc_name] = {
|
||||
"external_user_id": auth["external_user_id"],
|
||||
"external_name": auth["external_name"],
|
||||
"linked_at": auth["linked_at"],
|
||||
"credentials": auth["credentials"],
|
||||
}
|
||||
|
||||
return {
|
||||
"discord_id": discord_id,
|
||||
"linked_services": linked_services,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /auth/reset — wipe auth store (DEV ONLY)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from src.config import get_config # noqa: E402
|
||||
|
||||
@router.post("/reset")
|
||||
async def reset_auth():
|
||||
"""
|
||||
Reset the entire auth store — clears all link tokens and user auth records.
|
||||
|
||||
Only enabled when ALLOW_AUTH_RESET=true in the environment.
|
||||
Returns 403 in production.
|
||||
"""
|
||||
if get_config("ALLOW_AUTH_RESET", "false").lower() != "true":
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Auth reset is disabled. Set ALLOW_AUTH_RESET=true to enable (dev only).",
|
||||
)
|
||||
|
||||
auth_store.reset_all()
|
||||
logger.warning("Auth store reset via API endpoint.")
|
||||
return {"status": "ok", "message": "Auth store cleared — all tokens and auth records removed."}
|
||||
@@ -0,0 +1,241 @@
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
|
||||
from gateway.dependencies import get_llm_client, get_agent_graph
|
||||
from agents import get as get_agent, list_all as list_all_agents
|
||||
from src.state import AgentState
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
session_id: str | None = None
|
||||
agent_id: str | None = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
messages: list[dict]
|
||||
stream: bool = False
|
||||
model: str = "deepseek-chat"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _resolve_agent(agent_id: str | None = None, model: str | None = None):
|
||||
"""
|
||||
1. explicit agent_id
|
||||
2. model field (OpenWebUI sends this — maps to agent_id if registered)
|
||||
3. fallback to "naked"
|
||||
"""
|
||||
lookup = agent_id or model
|
||||
if lookup is None:
|
||||
return get_agent("naked")
|
||||
agent = get_agent(lookup)
|
||||
return agent if agent else get_agent("naked")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LangGraph helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _invoke_graph(graph, messages: list[dict]) -> str:
|
||||
"""Run the graph synchronously (non-streaming) and return the final text."""
|
||||
state: AgentState = {"messages": messages}
|
||||
result = await graph.ainvoke(state)
|
||||
last_msg = result["messages"][-1]
|
||||
return last_msg.content or ""
|
||||
|
||||
|
||||
async def _stream_graph(graph, messages: list[dict]):
|
||||
"""
|
||||
Run the graph and stream the final response token-by-token.
|
||||
|
||||
LangGraph's astream_events would require langchain-openai's ChatOpenAI
|
||||
to intercept LLM chunks. Instead we run the graph to completion (tools
|
||||
execute silently) and then stream the final text content character by
|
||||
character — this gives the client a real SSE stream without adding new
|
||||
dependencies.
|
||||
"""
|
||||
state: AgentState = {"messages": messages}
|
||||
result = await graph.ainvoke(state)
|
||||
content = result["messages"][-1].content or ""
|
||||
# Yield token-by-token so the SSE client sees incremental output
|
||||
for token in content:
|
||||
yield token
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-streaming run (kept for /chat/sync and sync completions)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def run_agent_with_tools(
|
||||
request: Request,
|
||||
messages: list[dict],
|
||||
agent_id: str | None = None,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""Send messages through the agent's LangGraph. Non-streaming."""
|
||||
agent = _resolve_agent(agent_id, model)
|
||||
graph = get_agent_graph(agent.agent_id, request)
|
||||
return await _invoke_graph(graph, messages)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Streaming generator (kept for /chat and stream completions)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def run_agent_stream(
|
||||
request: Request,
|
||||
messages: list[dict],
|
||||
agent_id: str | None = None,
|
||||
model: str | None = None,
|
||||
):
|
||||
"""Async generator — yields tokens via the agent's LangGraph."""
|
||||
agent = _resolve_agent(agent_id, model)
|
||||
graph = get_agent_graph(agent.agent_id, request)
|
||||
async for token in _stream_graph(graph, messages):
|
||||
yield token
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/")
|
||||
def root():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
async def chat(
|
||||
req: ChatRequest,
|
||||
request: Request,
|
||||
client: OpenAI = Depends(get_llm_client),
|
||||
):
|
||||
"""Streaming chat — single message, no history."""
|
||||
messages = [{"role": "user", "content": req.message}]
|
||||
|
||||
async def event_stream():
|
||||
async for token in run_agent_stream(request, messages, req.agent_id):
|
||||
payload = json.dumps({"token": token, "session_id": req.session_id})
|
||||
yield f"data: {payload}\n\n"
|
||||
yield f"data: {json.dumps({'done': True, 'session_id': req.session_id})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat/sync")
|
||||
async def chat_sync(
|
||||
req: ChatRequest,
|
||||
request: Request,
|
||||
client: OpenAI = Depends(get_llm_client),
|
||||
):
|
||||
"""Non-streaming chat — single message."""
|
||||
messages = [{"role": "user", "content": req.message}]
|
||||
response = await run_agent_with_tools(request, messages, req.agent_id)
|
||||
return {"response": response, "session_id": req.session_id}
|
||||
|
||||
|
||||
@router.get("/agents")
|
||||
def list_agents():
|
||||
"""Return all registered agents."""
|
||||
return {
|
||||
"agents": [
|
||||
{
|
||||
"agent_id": a.agent_id,
|
||||
"description": a.description,
|
||||
"skills": a.skills,
|
||||
}
|
||||
for a in list_all_agents().values()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
def list_models():
|
||||
"""Return agents as selectable models for OpenWebUI."""
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": a.agent_id,
|
||||
"object": "model",
|
||||
"created": 0,
|
||||
"owned_by": "local-agent",
|
||||
}
|
||||
for a in list_all_agents().values()
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.post("/chat/completions")
|
||||
async def chat_completions(
|
||||
req: ChatCompletionRequest,
|
||||
request: Request,
|
||||
client: OpenAI = Depends(get_llm_client),
|
||||
):
|
||||
"""OpenAI-compatible /chat/completions — supports stream=True.
|
||||
Multi-turn: req.messages contains the FULL conversation history.
|
||||
Agent resolved from the model field (OpenWebUI sends this).
|
||||
"""
|
||||
agent = _resolve_agent(model=req.model)
|
||||
|
||||
if req.stream:
|
||||
async def sse_stream():
|
||||
async for token in run_agent_stream(
|
||||
request, req.messages, agent_id=agent.agent_id,
|
||||
):
|
||||
chunk = {
|
||||
"id": "chatcmpl-local",
|
||||
"object": "chat.completion.chunk",
|
||||
"choices": [
|
||||
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
||||
],
|
||||
}
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
final_chunk = {
|
||||
"id": "chatcmpl-local",
|
||||
"object": "chat.completion.chunk",
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}
|
||||
yield f"data: {json.dumps(final_chunk)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
sse_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
|
||||
# Non-streaming — full history, LangGraph agent
|
||||
response = await run_agent_with_tools(
|
||||
request, req.messages, agent_id=agent.agent_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": "chatcmpl-local",
|
||||
"object": "chat.completion",
|
||||
"created": 0,
|
||||
"model": req.model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": response},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
Reference in New Issue
Block a user