317 lines
9.6 KiB
Python
317 lines
9.6 KiB
Python
"""
|
|
Auth Store — SQLite-backed persistence for Discord-to-service authentication.
|
|
|
|
Two tables:
|
|
- link_tokens : one-time tokens sent via Discord DM to initiate login
|
|
- user_auth : per-user, per-service credentials (Jellyfin token, etc.)
|
|
|
|
Thread-safe via WAL mode and a shared lock. No passwords are ever stored
|
|
— only opaque service tokens (e.g. Jellyfin AccessToken).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import secrets
|
|
import sqlite3
|
|
import threading
|
|
from datetime import datetime, timedelta, timezone
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from core.config import get_config
|
|
|
|
logger = logging.getLogger("auth_store")
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config
|
|
# ---------------------------------------------------------------------------
|
|
AUTH_DB_PATH = get_config("AUTH_DB_PATH", "data/auth.db")
|
|
TOKEN_EXPIRY_MINUTES = int(get_config("AUTH_TOKEN_EXPIRY", "10"))
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Singleton handle
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_db_path: Path | None = None
|
|
_db_lock = threading.Lock()
|
|
|
|
|
|
def _resolve_path() -> Path:
|
|
"""Turn AUTH_DB_PATH into an absolute path, creating parent dirs."""
|
|
global _db_path
|
|
if _db_path is not None:
|
|
return _db_path
|
|
p = Path(AUTH_DB_PATH)
|
|
if not p.is_absolute():
|
|
# Relative to the project root (two levels above this file)
|
|
project_root = Path(__file__).resolve().parent.parent
|
|
p = project_root / p
|
|
p.parent.mkdir(parents=True, exist_ok=True)
|
|
_db_path = p
|
|
return p
|
|
|
|
|
|
def _get_conn() -> sqlite3.Connection:
|
|
"""Return a thread-local connection to the auth database."""
|
|
import sqlite3
|
|
|
|
conn = sqlite3.connect(str(_resolve_path()), check_same_thread=False)
|
|
conn.execute("PRAGMA journal_mode=WAL")
|
|
conn.execute("PRAGMA foreign_keys=ON")
|
|
conn.row_factory = sqlite3.Row
|
|
return conn
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Schema
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_SCHEMA = """
|
|
CREATE TABLE IF NOT EXISTS link_tokens (
|
|
token TEXT PRIMARY KEY,
|
|
discord_user_id INTEGER NOT NULL,
|
|
service TEXT NOT NULL,
|
|
expires_at TEXT NOT NULL,
|
|
used INTEGER DEFAULT 0,
|
|
created_at TEXT DEFAULT (datetime('now'))
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS user_auth (
|
|
discord_user_id INTEGER NOT NULL,
|
|
service TEXT NOT NULL,
|
|
external_user_id TEXT,
|
|
external_name TEXT,
|
|
credentials TEXT,
|
|
linked_at TEXT DEFAULT (datetime('now')),
|
|
is_active INTEGER DEFAULT 1,
|
|
PRIMARY KEY (discord_user_id, service)
|
|
);
|
|
"""
|
|
|
|
_initialized = False
|
|
|
|
|
|
def _ensure_schema() -> None:
|
|
global _initialized
|
|
if _initialized:
|
|
return
|
|
with _db_lock:
|
|
if _initialized:
|
|
return
|
|
conn = _get_conn()
|
|
conn.executescript(_SCHEMA)
|
|
conn.commit()
|
|
conn.close()
|
|
_initialized = True
|
|
logger.info("Auth store initialized at %s", _resolve_path())
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Public API — Link Tokens
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def create_token(discord_user_id: int, service: str) -> str:
|
|
"""Generate a one-time link token. Expires after TOKEN_EXPIRY_MINUTES."""
|
|
_ensure_schema()
|
|
token = secrets.token_urlsafe(32)
|
|
expires = (datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRY_MINUTES)).isoformat()
|
|
|
|
with _db_lock:
|
|
conn = _get_conn()
|
|
conn.execute(
|
|
"INSERT INTO link_tokens (token, discord_user_id, service, expires_at) VALUES (?, ?, ?, ?)",
|
|
(token, discord_user_id, service, expires),
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
logger.info("Created link token for user %s / service %s", discord_user_id, service)
|
|
return token
|
|
|
|
|
|
def validate_token(token: str) -> tuple[int, str] | None:
|
|
"""Read-only validation — does NOT consume the token.
|
|
|
|
Returns (discord_user_id, service) if the token exists, is unused,
|
|
and has not expired. Returns None otherwise.
|
|
"""
|
|
_ensure_schema()
|
|
|
|
with _db_lock:
|
|
conn = _get_conn()
|
|
row = conn.execute(
|
|
"SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?",
|
|
(token,),
|
|
).fetchone()
|
|
conn.close()
|
|
|
|
if row is None:
|
|
return None
|
|
if row["used"]:
|
|
return None
|
|
|
|
expires = datetime.fromisoformat(row["expires_at"])
|
|
if datetime.now(timezone.utc) > expires:
|
|
return None
|
|
|
|
return (row["discord_user_id"], row["service"])
|
|
|
|
|
|
def consume_token(token: str) -> tuple[int, str] | None:
|
|
"""Validate and consume a link token. Returns (discord_user_id, service) or None.
|
|
|
|
A token is valid if:
|
|
- It exists
|
|
- It has not been used
|
|
- It has not expired
|
|
"""
|
|
_ensure_schema()
|
|
|
|
with _db_lock:
|
|
conn = _get_conn()
|
|
row = conn.execute(
|
|
"SELECT discord_user_id, service, used, expires_at FROM link_tokens WHERE token = ?",
|
|
(token,),
|
|
).fetchone()
|
|
|
|
if row is None:
|
|
conn.close()
|
|
return None
|
|
|
|
if row["used"]:
|
|
conn.close()
|
|
logger.warning("Token already used: %s…", token[:8])
|
|
return None
|
|
|
|
expires = datetime.fromisoformat(row["expires_at"])
|
|
if datetime.now(timezone.utc) > expires:
|
|
conn.close()
|
|
logger.warning("Token expired: %s…", token[:8])
|
|
return None
|
|
|
|
conn.execute("UPDATE link_tokens SET used = 1 WHERE token = ?", (token,))
|
|
conn.commit()
|
|
result = (row["discord_user_id"], row["service"])
|
|
conn.close()
|
|
logger.info("Token consumed: %s… → user=%s service=%s", token[:8], result[0], result[1])
|
|
return result
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Public API — User Auth
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def store_auth(
|
|
discord_user_id: int,
|
|
service: str,
|
|
*,
|
|
external_user_id: str = "",
|
|
external_name: str = "",
|
|
credentials: dict | None = None,
|
|
) -> None:
|
|
"""Store or update authentication for a user on a service."""
|
|
_ensure_schema()
|
|
import json
|
|
|
|
creds_json = json.dumps(credentials) if credentials else "{}"
|
|
|
|
with _db_lock:
|
|
conn = _get_conn()
|
|
conn.execute(
|
|
"""INSERT INTO user_auth (discord_user_id, service, external_user_id, external_name, credentials, linked_at)
|
|
VALUES (?, ?, ?, ?, ?, datetime('now'))
|
|
ON CONFLICT(discord_user_id, service) DO UPDATE SET
|
|
external_user_id = excluded.external_user_id,
|
|
external_name = excluded.external_name,
|
|
credentials = excluded.credentials,
|
|
linked_at = datetime('now'),
|
|
is_active = 1""",
|
|
(discord_user_id, service, external_user_id, external_name, creds_json),
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
logger.info("Stored auth for user %s on %s as %s", discord_user_id, service, external_name)
|
|
|
|
|
|
def get_auth(discord_user_id: int, service: str) -> dict | None:
|
|
"""Retrieve stored auth for a user on a service. Returns None if not linked."""
|
|
_ensure_schema()
|
|
import json
|
|
|
|
with _db_lock:
|
|
conn = _get_conn()
|
|
row = conn.execute(
|
|
"SELECT * FROM user_auth WHERE discord_user_id = ? AND service = ? AND is_active = 1",
|
|
(discord_user_id, service),
|
|
).fetchone()
|
|
conn.close()
|
|
|
|
if row is None:
|
|
return None
|
|
|
|
credentials = json.loads(row["credentials"]) if row["credentials"] else {}
|
|
return {
|
|
"discord_user_id": row["discord_user_id"],
|
|
"service": row["service"],
|
|
"external_user_id": row["external_user_id"],
|
|
"external_name": row["external_name"],
|
|
"credentials": credentials,
|
|
"linked_at": row["linked_at"],
|
|
}
|
|
|
|
|
|
def is_authenticated(discord_user_id: int, service: str) -> bool:
|
|
"""Quick check: is this user linked to this service?"""
|
|
return get_auth(discord_user_id, service) is not None
|
|
|
|
|
|
def list_services(discord_user_id: int) -> list[str]:
|
|
"""Return list of service names this user has linked."""
|
|
_ensure_schema()
|
|
|
|
with _db_lock:
|
|
conn = _get_conn()
|
|
rows = conn.execute(
|
|
"SELECT service FROM user_auth WHERE discord_user_id = ? AND is_active = 1",
|
|
(discord_user_id,),
|
|
).fetchall()
|
|
conn.close()
|
|
|
|
return [r["service"] for r in rows]
|
|
|
|
|
|
def revoke(discord_user_id: int, service: str) -> None:
|
|
"""Unlink a user from a service."""
|
|
_ensure_schema()
|
|
|
|
with _db_lock:
|
|
conn = _get_conn()
|
|
conn.execute(
|
|
"UPDATE user_auth SET is_active = 0 WHERE discord_user_id = ? AND service = ?",
|
|
(discord_user_id, service),
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
logger.info("Revoked auth for user %s on %s", discord_user_id, service)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Dev / testing — reset the entire store
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def reset_all() -> None:
|
|
"""Truncate all auth tables — for development and testing only."""
|
|
_ensure_schema()
|
|
|
|
with _db_lock:
|
|
conn = _get_conn()
|
|
conn.execute("DELETE FROM link_tokens")
|
|
conn.execute("DELETE FROM user_auth")
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
logger.warning("Auth store RESET — all tokens and auth records cleared.")
|