"""PostgreSQL connection pool for the JellyStat database.""" from __future__ import annotations import logging from pathlib import Path import asyncpg from fastapi import FastAPI, Request from src.config import get_config logger = logging.getLogger("gateway.jellystat") # --------------------------------------------------------------------------- # DSN builder # --------------------------------------------------------------------------- def _build_dsn() -> str: """Build a PostgreSQL DSN from individual environment variables.""" host = get_config("JELLYSTAT_DB_HOST", "localhost") port = get_config("JELLYSTAT_DB_PORT", "5432") user = get_config("JELLYSTAT_DB_USER", "postgres") password = get_config("JELLYSTAT_DB_PASSWORD", "") dbname = get_config("JELLYSTAT_DB_NAME", "jfstat") return f"postgresql://{user}:{password}@{host}:{port}/{dbname}" # --------------------------------------------------------------------------- # Pool lifecycle (called from main.py lifespan) # --------------------------------------------------------------------------- async def init_pool(app: FastAPI) -> None: """Create the connection pool and store it on app.state.""" dsn = _build_dsn() safe = dsn.split("@")[1] if "@" in dsn else dsn logger.info("Connecting to JellyStat database at %s", safe) pool = await asyncpg.create_pool(dsn, min_size=1, max_size=5) app.state.jellystat_pool = pool # Deploy functions on every startup (CREATE OR REPLACE is idempotent) await _ensure_functions(pool) async def close_pool(app: FastAPI) -> None: """Close the pool on shutdown.""" pool: asyncpg.Pool | None = getattr(app.state, "jellystat_pool", None) if pool: await pool.close() logger.info("JellyStat pool closed") # --------------------------------------------------------------------------- # FastAPI dependency # --------------------------------------------------------------------------- async def get_pool(request: Request) -> asyncpg.Pool: """Return the JellyStat connection pool from app state.""" return request.app.state.jellystat_pool # --------------------------------------------------------------------------- # Function deployment # --------------------------------------------------------------------------- async def _ensure_functions(pool: asyncpg.Pool) -> None: """Run startup-functions.sql to create or replace all JellyStat functions.""" sql_path = Path(__file__).parent / "startup-functions.sql" if not sql_path.exists(): logger.warning("startup-functions.sql not found — skipping function deployment") return sql = sql_path.read_text() statements = _split_sql(sql) async with pool.acquire() as conn: for stmt in statements: try: await conn.execute(stmt) except Exception: # Log but don't crash — functions might already exist logger.exception("Failed to deploy SQL statement — continuing") logger.info("JellyStat functions deployed (%d statements)", len(statements)) def _split_sql(sql: str) -> list[str]: """ Split a multi-statement SQL string into individual statements. Respects $$ dollar-quoting so that semicolons inside function bodies don't cause premature splits. Pure comment lines (starting with ``--``) outside dollar-quoted blocks are stripped. """ statements: list[str] = [] current: list[str] = [] in_dollar_quote = False for line in sql.split("\n"): stripped = line.strip() # Skip pure comment lines outside of dollar-quoted blocks if not in_dollar_quote and stripped.startswith("--"): continue # Toggle dollar-quote state whenever we see $$ if "$$" in line: in_dollar_quote = not in_dollar_quote current.append(line) # Statement terminator: semicolon at end of line, outside $$ block if not in_dollar_quote and line.rstrip().endswith(";"): stmt = "\n".join(current).strip() if stmt: statements.append(stmt) current = [] # Catch any trailing statement that wasn't terminated by a semicolon if current: stmt = "\n".join(current).strip() if stmt: statements.append(stmt) return statements