backend: tidy modules, consolidate schema migrations, expand API tests

This commit is contained in:
Andras Schmelczer 2026-05-31 10:49:26 +01:00
parent 4156d1d469
commit d9724a462d
9 changed files with 254 additions and 74 deletions

View file

@ -1,8 +1,6 @@
"""APIRouter with all Life Towers endpoints."""
from __future__ import annotations
import json
import sqlite3
import time
from typing import Annotated
@ -12,6 +10,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request
from .auth import get_current_user
from .db import db_connection
from .limits import limiter
from .logging import token_log_id
from .models import (
BlockOut,
DataIn,
@ -53,7 +52,7 @@ async def register(request: Request, body: RegisterRequest) -> RegisterResponse:
(now, token),
)
conn.commit()
logger.info("user_registered", user_id=token, new=existing is None)
logger.info("user_registered", user_id=token_log_id(token), new=existing is None)
return RegisterResponse(user_id=token)
@ -94,7 +93,7 @@ async def get_data(
block_rows = conn.execute(
"""
SELECT id, tag, description, is_done, created_at
SELECT id, tag, description, is_done, difficulty, created_at
FROM blocks
WHERE tower_id = ?
ORDER BY position
@ -108,6 +107,7 @@ async def get_data(
tag=b["tag"],
description=b["description"],
is_done=bool(b["is_done"]),
difficulty=b["difficulty"],
created_at=b["created_at"],
)
for b in block_rows
@ -213,8 +213,9 @@ async def put_data(
"""
INSERT INTO blocks
(id, tower_id, user_id, position, tag,
description, is_done, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
description, is_done, difficulty,
created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
block.id,
@ -224,6 +225,7 @@ async def put_data(
block.tag,
block.description,
1 if block.is_done else 0,
block.difficulty,
created_at,
now,
),
@ -236,8 +238,14 @@ async def put_data(
)
conn.commit()
except sqlite3.IntegrityError as exc:
conn.rollback()
raise HTTPException(
status_code=409,
detail="Submitted IDs conflict with existing data",
) from exc
except Exception:
conn.rollback()
raise
logger.info("data_replaced", user_id=user_id, pages=len(body.pages))
logger.info("data_replaced", user_id=token_log_id(user_id), pages=len(body.pages))

View file

@ -2,11 +2,10 @@
from __future__ import annotations
import uuid
from fastapi import HTTPException, Request
from .db import db_connection
from .models import _canonical_uuidv4
# Single generic detail used for ALL 401 responses. Per spec, the response
# must not distinguish between missing / malformed / unknown tokens — that
@ -21,25 +20,28 @@ def _unauthorized() -> HTTPException:
)
def get_current_user(request: Request) -> str:
"""Dependency that extracts and validates a Bearer token, returns user_id."""
def extract_bearer_token(request: Request) -> str | None:
"""Return the raw Bearer token from the Authorization header, or None."""
auth_header = request.headers.get("Authorization") or request.headers.get(
"authorization"
)
if not auth_header:
raise _unauthorized()
return None
parts = auth_header.split()
if len(parts) != 2 or parts[0].lower() != "bearer":
raise _unauthorized()
if len(parts) == 2 and parts[0].lower() == "bearer":
return parts[1]
return None
token = parts[1]
def get_current_user(request: Request) -> str:
"""Dependency that extracts and validates a Bearer token, returns user_id."""
token = extract_bearer_token(request)
if token is None:
raise _unauthorized()
try:
u = uuid.UUID(token)
if u.version != 4:
raise ValueError("Not v4")
except (ValueError, AttributeError):
token = _canonical_uuidv4(token)
except ValueError:
raise _unauthorized()
with db_connection() as conn:

View file

@ -8,6 +8,8 @@ from fastapi import Request, Response
from slowapi import Limiter
from slowapi.util import get_remote_address
from .auth import extract_bearer_token
PAYLOAD_LIMIT_BYTES = 2 * 1024 * 1024 # 2 MiB
_TOO_LARGE_BODY = json.dumps(
@ -20,12 +22,7 @@ _TOO_LARGE_BODY = json.dumps(
def _get_token_or_ip(request: Request) -> str:
"""Key function for rate limiting: use Bearer token if present, else IP."""
auth = request.headers.get("Authorization") or request.headers.get("authorization")
if auth:
parts = auth.split()
if len(parts) == 2 and parts[0].lower() == "bearer":
return parts[1]
return get_remote_address(request)
return extract_bearer_token(request) or get_remote_address(request)
limiter = Limiter(key_func=_get_token_or_ip, default_limits=[])

View file

@ -4,10 +4,13 @@ from __future__ import annotations
import time
import uuid as _uuid_mod
from hashlib import sha256
import structlog
from fastapi import Request, Response
from .auth import extract_bearer_token
def configure_logging() -> None:
"""Configure structlog for JSON output."""
@ -26,19 +29,19 @@ def configure_logging() -> None:
)
def token_log_id(token: str) -> str:
return sha256(token.encode("utf-8")).hexdigest()[:12]
async def request_logging_middleware(request: Request, call_next) -> Response:
"""Log each request with method, path, status, duration_ms, user_id, request_id."""
"""Log each request without writing bearer credentials to the log stream."""
request_id = str(_uuid_mod.uuid4())
structlog.contextvars.clear_contextvars()
structlog.contextvars.bind_contextvars(request_id=request_id)
# Extract user_id from Authorization header for logging (no DB call here)
auth = request.headers.get("Authorization") or request.headers.get("authorization")
user_id: str | None = None
if auth:
parts = auth.split()
if len(parts) == 2 and parts[0].lower() == "bearer":
user_id = parts[1]
token = extract_bearer_token(request)
user_id = token_log_id(token) if token else None
start = time.monotonic()
response = await call_next(request)

View file

@ -1,18 +1,17 @@
"""ASGI app, lifespan, static files mount, route registration."""
from __future__ import annotations
import json
import os
from contextlib import asynccontextmanager
from html import escape
from pathlib import Path
from typing import AsyncGenerator
from urllib.parse import urlsplit, urlunsplit
import structlog
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from starlette.middleware.base import BaseHTTPMiddleware
@ -35,7 +34,6 @@ STATUS_CODE_MAP: dict[int, str] = {
413: "payload_too_large",
422: "bad_request",
429: "rate_limited",
507: "quota_exceeded",
500: "server_error",
}
@ -92,10 +90,15 @@ def create_app() -> FastAPI:
request: Request, exc: RequestValidationError
) -> JSONResponse:
fields = sorted(
{".".join(str(loc) for loc in e.get("loc", ()) if loc != "body") for e in exc.errors()}
field
for field in {
".".join(str(loc) for loc in e.get("loc", ()) if loc != "body")
for e in exc.errors()
}
if field
)
if fields:
detail_str = "Validation failed for: " + ", ".join(f for f in fields if f)
detail_str = "Validation failed for: " + ", ".join(fields)
else:
detail_str = "Validation failed"
return JSONResponse(
@ -113,7 +116,7 @@ def create_app() -> FastAPI:
code = STATUS_CODE_MAP.get(exc.status_code, "server_error")
detail = {"error": code, "detail": str(exc.detail)}
headers = getattr(exc, "headers", None) or {}
headers = exc.headers or {}
return JSONResponse(status_code=exc.status_code, content=detail, headers=headers)
# Generic 500 handler
@ -151,7 +154,44 @@ def _mount_static(app: FastAPI, static_dir: Path) -> None:
r"[-.][A-Za-z0-9]{8,}\.(?:js|css|woff2?|png|jpe?g|svg|ico|map)$"
)
def _serve_file(file_path: Path) -> FileResponse:
def _absolute_meta_urls(request: Request) -> tuple[str, str]:
configured_public_url = os.environ.get("LIFE_TOWERS_PUBLIC_URL", "").strip()
if configured_public_url:
public_root = configured_public_url.rstrip("/") + "/"
return public_root, f"{public_root}og-image.png"
parts = urlsplit(str(request.url))
canonical_url = urlunsplit((parts.scheme, parts.netloc, parts.path or "/", "", ""))
root_path = str(request.scope.get("root_path") or "").strip("/")
og_image_path = f"/{root_path}/og-image.png" if root_path else "/og-image.png"
og_image_url = urlunsplit((parts.scheme, parts.netloc, og_image_path, "", ""))
return canonical_url, og_image_url
def _serve_index(file_path: Path, request: Request) -> HTMLResponse:
canonical_url, og_image_url = _absolute_meta_urls(request)
html = file_path.read_text(encoding="utf-8")
html = html.replace(
'href="/" data-dynamic-url="canonical"',
f'href="{escape(canonical_url, quote=True)}" data-dynamic-url="canonical"',
)
html = html.replace(
'content="/" data-dynamic-url="canonical"',
f'content="{escape(canonical_url, quote=True)}" data-dynamic-url="canonical"',
)
html = html.replace(
'content="/og-image.png" data-dynamic-url="og-image"',
f'content="{escape(og_image_url, quote=True)}" data-dynamic-url="og-image"',
)
resp = HTMLResponse(html)
resp.headers["Cache-Control"] = "no-cache"
return resp
def _serve_file(file_path: Path, request: Request) -> Response:
if file_path.name == "index.html":
return _serve_index(file_path, request)
resp = FileResponse(str(file_path))
if HASHED_PATTERN.search(file_path.name):
resp.headers["Cache-Control"] = "public, max-age=31536000, immutable"
@ -160,7 +200,7 @@ def _mount_static(app: FastAPI, static_dir: Path) -> None:
return resp
@app.get("/{full_path:path}", include_in_schema=False)
async def spa_fallback(full_path: str) -> Response:
async def spa_fallback(request: Request, full_path: str) -> Response:
# API routes are handled by the API router (registered before this);
# if execution reaches here for an /api/* path, it really is unknown.
if full_path.startswith("api/"):
@ -174,12 +214,12 @@ def _mount_static(app: FastAPI, static_dir: Path) -> None:
except ValueError:
raise HTTPException(status_code=404, detail="Not found")
if candidate.is_file():
return _serve_file(candidate)
return _serve_file(candidate, request)
# SPA fallback to index.html.
index = static_dir / "index.html"
if index.is_file():
return _serve_file(index)
return _serve_file(index, request)
raise HTTPException(status_code=404, detail="Not found")

View file

@ -1,10 +1,8 @@
-- Life Towers v4 initial schema.
-- SQLite with WAL mode and foreign keys enabled at connection time.
-- WAL mode, foreign keys, and busy_timeout are applied per-connection in
-- db._apply_pragmas(), so they are not (and need not be) set here.
-- All timestamps are unix epoch seconds (INTEGER).
PRAGMA journal_mode = WAL;
PRAGMA foreign_keys = ON;
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
created_at INTEGER NOT NULL,
@ -17,6 +15,7 @@ CREATE TABLE IF NOT EXISTS pages (
position INTEGER NOT NULL,
name TEXT NOT NULL,
hide_create_tower_button INTEGER NOT NULL DEFAULT 0 CHECK (hide_create_tower_button IN (0, 1)),
keep_tasks_open INTEGER NOT NULL DEFAULT 0 CHECK (keep_tasks_open IN (0, 1)),
default_date_from INTEGER,
default_date_to INTEGER,
created_at INTEGER NOT NULL,
@ -47,6 +46,7 @@ CREATE TABLE IF NOT EXISTS blocks (
tag TEXT NOT NULL DEFAULT '',
description TEXT NOT NULL DEFAULT '',
is_done INTEGER NOT NULL DEFAULT 0 CHECK (is_done IN (0, 1)),
difficulty INTEGER NOT NULL DEFAULT 1 CHECK (difficulty >= 1),
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL
) STRICT;

View file

@ -1,2 +0,0 @@
ALTER TABLE pages ADD COLUMN keep_tasks_open INTEGER NOT NULL DEFAULT 0
CHECK (keep_tasks_open IN (0, 1));

View file

@ -8,12 +8,14 @@ from pydantic import BaseModel, Field, field_validator, model_validator
import uuid as _uuid_mod
def _is_uuidv4(value: str) -> bool:
def _canonical_uuidv4(value: str) -> str:
try:
u = _uuid_mod.UUID(value)
return u.version == 4
if u.version == 4:
return str(u)
except (ValueError, AttributeError):
return False
pass
raise ValueError("must be a UUIDv4")
class HslColor(BaseModel):
@ -27,14 +29,13 @@ class BlockIn(BaseModel):
tag: str = Field(max_length=200)
description: str = Field(max_length=10_000)
is_done: bool
difficulty: int = Field(default=1, ge=1, le=100)
created_at: Optional[int] = None
@field_validator("id")
@classmethod
def validate_id(cls, v: str) -> str:
if not _is_uuidv4(v):
raise ValueError(f"id must be a UUIDv4, got: {v!r}")
return v
return _canonical_uuidv4(v)
class BlockOut(BaseModel):
@ -42,6 +43,7 @@ class BlockOut(BaseModel):
tag: str
description: str
is_done: bool
difficulty: int
created_at: int
@ -54,9 +56,7 @@ class TowerIn(BaseModel):
@field_validator("id")
@classmethod
def validate_id(cls, v: str) -> str:
if not _is_uuidv4(v):
raise ValueError(f"id must be a UUIDv4, got: {v!r}")
return v
return _canonical_uuidv4(v)
class TowerOut(BaseModel):
@ -78,9 +78,7 @@ class PageIn(BaseModel):
@field_validator("id")
@classmethod
def validate_id(cls, v: str) -> str:
if not _is_uuidv4(v):
raise ValueError(f"id must be a UUIDv4, got: {v!r}")
return v
return _canonical_uuidv4(v)
class PageOut(BaseModel):
@ -137,9 +135,7 @@ class RegisterRequest(BaseModel):
@field_validator("token")
@classmethod
def validate_token(cls, v: str) -> str:
if not _is_uuidv4(v):
raise ValueError("token must be a UUIDv4")
return v
return _canonical_uuidv4(v)
class RegisterResponse(BaseModel):

View file

@ -51,15 +51,71 @@ async def client(tmp_path: Path) -> AsyncGenerator[AsyncClient, None]:
db_module._DB_PATH = None
# ---------------------------------------------------------------------------
# Health
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_health(client: AsyncClient) -> None:
resp = await client.get("/api/v1/health")
async def test_spa_index_injects_absolute_open_graph_urls(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
static_dir = tmp_path / "static"
static_dir.mkdir()
(static_dir / "index.html").write_text(
"""
<!doctype html>
<html>
<head>
<link rel="canonical" href="/" data-dynamic-url="canonical" />
<meta property="og:url" content="/" data-dynamic-url="canonical" />
<meta property="og:image" content="/og-image.png" data-dynamic-url="og-image" />
<meta name="twitter:image" content="/og-image.png" data-dynamic-url="og-image" />
</head>
</html>
""",
encoding="utf-8",
)
(static_dir / "og-image.png").write_bytes(b"fake png")
monkeypatch.setenv("LIFE_TOWERS_STATIC_DIR", str(static_dir))
app = create_app()
async with AsyncClient(
transport=ASGITransport(app=app),
base_url="https://towers.example",
) as c:
resp = await c.get("/tasks?utm_source=test")
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}
assert (
'<link rel="canonical" href="https://towers.example/tasks" '
'data-dynamic-url="canonical" />'
) in resp.text
assert (
'<meta property="og:url" content="https://towers.example/tasks" '
'data-dynamic-url="canonical" />'
) in resp.text
assert (
'<meta property="og:image" content="https://towers.example/og-image.png" '
'data-dynamic-url="og-image" />'
) in resp.text
assert (
'<meta name="twitter:image" content="https://towers.example/og-image.png" '
'data-dynamic-url="og-image" />'
) in resp.text
monkeypatch.setenv("LIFE_TOWERS_PUBLIC_URL", "https://public.example/towers")
app = create_app()
async with AsyncClient(
transport=ASGITransport(app=app),
base_url="https://internal.example",
) as c:
resp = await c.get("/")
assert resp.status_code == 200
assert (
'<link rel="canonical" href="https://public.example/towers/" '
'data-dynamic-url="canonical" />'
) in resp.text
assert (
'<meta property="og:image" content="https://public.example/towers/og-image.png" '
'data-dynamic-url="og-image" />'
) in resp.text
# ---------------------------------------------------------------------------
@ -99,6 +155,21 @@ async def test_register_non_uuidv4_token(client: AsyncClient) -> None:
assert resp.status_code == 400
@pytest.mark.asyncio
async def test_uuid_inputs_are_canonicalized(client: AsyncClient) -> None:
token = make_uuidv4()
upper_token = token.upper()
register_resp = await client.post("/api/v1/register", json={"token": upper_token})
assert register_resp.status_code == 200
assert register_resp.json() == {"user_id": token}
data_resp = await client.get(
"/api/v1/data",
headers={"Authorization": f"Bearer {upper_token}"},
)
assert data_resp.status_code == 200
# ---------------------------------------------------------------------------
# Auth / GET /data
# ---------------------------------------------------------------------------
@ -148,6 +219,7 @@ def _make_tree() -> dict:
"tag": f"tag-{pi}-{ti}-{bi}",
"description": f"desc-{pi}-{ti}-{bi}",
"is_done": False,
"difficulty": bi + 1,
"created_at": 1700000000 + bi,
}
)
@ -199,6 +271,42 @@ async def test_round_trip(client: AsyncClient) -> None:
for bi, block in enumerate(tower["blocks"]):
assert block["id"] == tree["pages"][pi]["towers"][ti]["blocks"][bi]["id"]
assert block["tag"] == tree["pages"][pi]["towers"][ti]["blocks"][bi]["tag"]
assert (
block["difficulty"]
== tree["pages"][pi]["towers"][ti]["blocks"][bi]["difficulty"]
)
@pytest.mark.asyncio
async def test_difficulty_defaults_to_one_when_omitted(client: AsyncClient) -> None:
"""A block sent without `difficulty` is stored and returned as 1."""
token = make_uuidv4()
await client.post("/api/v1/register", json={"token": token})
headers = {"Authorization": f"Bearer {token}"}
tree = _make_tree()
# Drop difficulty from the very first block.
del tree["pages"][0]["towers"][0]["blocks"][0]["difficulty"]
put_resp = await client.put("/api/v1/data", json=tree, headers=headers)
assert put_resp.status_code == 204
data = (await client.get("/api/v1/data", headers=headers)).json()
assert data["pages"][0]["towers"][0]["blocks"][0]["difficulty"] == 1
@pytest.mark.asyncio
async def test_difficulty_must_be_positive(client: AsyncClient) -> None:
"""difficulty < 1 is rejected by validation."""
token = make_uuidv4()
await client.post("/api/v1/register", json={"token": token})
headers = {"Authorization": f"Bearer {token}"}
tree = _make_tree()
tree["pages"][0]["towers"][0]["blocks"][0]["difficulty"] = 0
resp = await client.put("/api/v1/data", json=tree, headers=headers)
assert resp.status_code == 400
# ---------------------------------------------------------------------------
@ -218,6 +326,34 @@ async def test_put_duplicate_page_id(client: AsyncClient) -> None:
resp = await client.put("/api/v1/data", json=tree, headers=headers)
assert resp.status_code == 400 # pydantic validation error → 400 bad_request per spec
assert resp.json() == {"error": "bad_request", "detail": "Validation failed"}
@pytest.mark.asyncio
async def test_put_cross_user_id_conflict_returns_409(client: AsyncClient) -> None:
first_token = make_uuidv4()
second_token = make_uuidv4()
await client.post("/api/v1/register", json={"token": first_token})
await client.post("/api/v1/register", json={"token": second_token})
tree = _make_tree()
first_resp = await client.put(
"/api/v1/data",
json=tree,
headers={"Authorization": f"Bearer {first_token}"},
)
assert first_resp.status_code == 204
second_resp = await client.put(
"/api/v1/data",
json=tree,
headers={"Authorization": f"Bearer {second_token}"},
)
assert second_resp.status_code == 409
assert second_resp.json() == {
"error": "conflict",
"detail": "Submitted IDs conflict with existing data",
}
@pytest.mark.asyncio