diff --git a/backend/src/life_towers/api.py b/backend/src/life_towers/api.py index 3a8fd4a..8fbcc7f 100644 --- a/backend/src/life_towers/api.py +++ b/backend/src/life_towers/api.py @@ -1,8 +1,6 @@ -"""APIRouter with all Life Towers endpoints.""" - from __future__ import annotations -import json +import sqlite3 import time from typing import Annotated @@ -12,6 +10,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request from .auth import get_current_user from .db import db_connection from .limits import limiter +from .logging import token_log_id from .models import ( BlockOut, DataIn, @@ -53,7 +52,7 @@ async def register(request: Request, body: RegisterRequest) -> RegisterResponse: (now, token), ) conn.commit() - logger.info("user_registered", user_id=token, new=existing is None) + logger.info("user_registered", user_id=token_log_id(token), new=existing is None) return RegisterResponse(user_id=token) @@ -94,7 +93,7 @@ async def get_data( block_rows = conn.execute( """ - SELECT id, tag, description, is_done, created_at + SELECT id, tag, description, is_done, difficulty, created_at FROM blocks WHERE tower_id = ? ORDER BY position @@ -108,6 +107,7 @@ async def get_data( tag=b["tag"], description=b["description"], is_done=bool(b["is_done"]), + difficulty=b["difficulty"], created_at=b["created_at"], ) for b in block_rows @@ -213,8 +213,9 @@ async def put_data( """ INSERT INTO blocks (id, tower_id, user_id, position, tag, - description, is_done, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + description, is_done, difficulty, + created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( block.id, @@ -224,6 +225,7 @@ async def put_data( block.tag, block.description, 1 if block.is_done else 0, + block.difficulty, created_at, now, ), @@ -236,8 +238,14 @@ async def put_data( ) conn.commit() + except sqlite3.IntegrityError as exc: + conn.rollback() + raise HTTPException( + status_code=409, + detail="Submitted IDs conflict with existing data", + ) from exc except Exception: conn.rollback() raise - logger.info("data_replaced", user_id=user_id, pages=len(body.pages)) + logger.info("data_replaced", user_id=token_log_id(user_id), pages=len(body.pages)) diff --git a/backend/src/life_towers/auth.py b/backend/src/life_towers/auth.py index 4142948..45f396a 100644 --- a/backend/src/life_towers/auth.py +++ b/backend/src/life_towers/auth.py @@ -2,11 +2,10 @@ from __future__ import annotations -import uuid - from fastapi import HTTPException, Request from .db import db_connection +from .models import _canonical_uuidv4 # Single generic detail used for ALL 401 responses. Per spec, the response # must not distinguish between missing / malformed / unknown tokens — that @@ -21,25 +20,28 @@ def _unauthorized() -> HTTPException: ) -def get_current_user(request: Request) -> str: - """Dependency that extracts and validates a Bearer token, returns user_id.""" +def extract_bearer_token(request: Request) -> str | None: + """Return the raw Bearer token from the Authorization header, or None.""" auth_header = request.headers.get("Authorization") or request.headers.get( "authorization" ) if not auth_header: - raise _unauthorized() - + return None parts = auth_header.split() - if len(parts) != 2 or parts[0].lower() != "bearer": - raise _unauthorized() + if len(parts) == 2 and parts[0].lower() == "bearer": + return parts[1] + return None - token = parts[1] + +def get_current_user(request: Request) -> str: + """Dependency that extracts and validates a Bearer token, returns user_id.""" + token = extract_bearer_token(request) + if token is None: + raise _unauthorized() try: - u = uuid.UUID(token) - if u.version != 4: - raise ValueError("Not v4") - except (ValueError, AttributeError): + token = _canonical_uuidv4(token) + except ValueError: raise _unauthorized() with db_connection() as conn: diff --git a/backend/src/life_towers/limits.py b/backend/src/life_towers/limits.py index b566d84..bc6f90d 100644 --- a/backend/src/life_towers/limits.py +++ b/backend/src/life_towers/limits.py @@ -8,6 +8,8 @@ from fastapi import Request, Response from slowapi import Limiter from slowapi.util import get_remote_address +from .auth import extract_bearer_token + PAYLOAD_LIMIT_BYTES = 2 * 1024 * 1024 # 2 MiB _TOO_LARGE_BODY = json.dumps( @@ -20,12 +22,7 @@ _TOO_LARGE_BODY = json.dumps( def _get_token_or_ip(request: Request) -> str: """Key function for rate limiting: use Bearer token if present, else IP.""" - auth = request.headers.get("Authorization") or request.headers.get("authorization") - if auth: - parts = auth.split() - if len(parts) == 2 and parts[0].lower() == "bearer": - return parts[1] - return get_remote_address(request) + return extract_bearer_token(request) or get_remote_address(request) limiter = Limiter(key_func=_get_token_or_ip, default_limits=[]) diff --git a/backend/src/life_towers/logging.py b/backend/src/life_towers/logging.py index fea46c0..539b01b 100644 --- a/backend/src/life_towers/logging.py +++ b/backend/src/life_towers/logging.py @@ -4,10 +4,13 @@ from __future__ import annotations import time import uuid as _uuid_mod +from hashlib import sha256 import structlog from fastapi import Request, Response +from .auth import extract_bearer_token + def configure_logging() -> None: """Configure structlog for JSON output.""" @@ -26,19 +29,19 @@ def configure_logging() -> None: ) +def token_log_id(token: str) -> str: + return sha256(token.encode("utf-8")).hexdigest()[:12] + + async def request_logging_middleware(request: Request, call_next) -> Response: - """Log each request with method, path, status, duration_ms, user_id, request_id.""" + """Log each request without writing bearer credentials to the log stream.""" request_id = str(_uuid_mod.uuid4()) structlog.contextvars.clear_contextvars() structlog.contextvars.bind_contextvars(request_id=request_id) # Extract user_id from Authorization header for logging (no DB call here) - auth = request.headers.get("Authorization") or request.headers.get("authorization") - user_id: str | None = None - if auth: - parts = auth.split() - if len(parts) == 2 and parts[0].lower() == "bearer": - user_id = parts[1] + token = extract_bearer_token(request) + user_id = token_log_id(token) if token else None start = time.monotonic() response = await call_next(request) diff --git a/backend/src/life_towers/main.py b/backend/src/life_towers/main.py index b97eba7..b9be419 100644 --- a/backend/src/life_towers/main.py +++ b/backend/src/life_towers/main.py @@ -1,18 +1,17 @@ -"""ASGI app, lifespan, static files mount, route registration.""" - from __future__ import annotations -import json import os from contextlib import asynccontextmanager +from html import escape from pathlib import Path from typing import AsyncGenerator +from urllib.parse import urlsplit, urlunsplit import structlog from fastapi import FastAPI, HTTPException, Request, Response from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import FileResponse, JSONResponse +from fastapi.responses import FileResponse, HTMLResponse, JSONResponse from slowapi import _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded from starlette.middleware.base import BaseHTTPMiddleware @@ -35,7 +34,6 @@ STATUS_CODE_MAP: dict[int, str] = { 413: "payload_too_large", 422: "bad_request", 429: "rate_limited", - 507: "quota_exceeded", 500: "server_error", } @@ -92,10 +90,15 @@ def create_app() -> FastAPI: request: Request, exc: RequestValidationError ) -> JSONResponse: fields = sorted( - {".".join(str(loc) for loc in e.get("loc", ()) if loc != "body") for e in exc.errors()} + field + for field in { + ".".join(str(loc) for loc in e.get("loc", ()) if loc != "body") + for e in exc.errors() + } + if field ) if fields: - detail_str = "Validation failed for: " + ", ".join(f for f in fields if f) + detail_str = "Validation failed for: " + ", ".join(fields) else: detail_str = "Validation failed" return JSONResponse( @@ -113,7 +116,7 @@ def create_app() -> FastAPI: code = STATUS_CODE_MAP.get(exc.status_code, "server_error") detail = {"error": code, "detail": str(exc.detail)} - headers = getattr(exc, "headers", None) or {} + headers = exc.headers or {} return JSONResponse(status_code=exc.status_code, content=detail, headers=headers) # Generic 500 handler @@ -151,7 +154,44 @@ def _mount_static(app: FastAPI, static_dir: Path) -> None: r"[-.][A-Za-z0-9]{8,}\.(?:js|css|woff2?|png|jpe?g|svg|ico|map)$" ) - def _serve_file(file_path: Path) -> FileResponse: + def _absolute_meta_urls(request: Request) -> tuple[str, str]: + configured_public_url = os.environ.get("LIFE_TOWERS_PUBLIC_URL", "").strip() + if configured_public_url: + public_root = configured_public_url.rstrip("/") + "/" + return public_root, f"{public_root}og-image.png" + + parts = urlsplit(str(request.url)) + canonical_url = urlunsplit((parts.scheme, parts.netloc, parts.path or "/", "", "")) + + root_path = str(request.scope.get("root_path") or "").strip("/") + og_image_path = f"/{root_path}/og-image.png" if root_path else "/og-image.png" + og_image_url = urlunsplit((parts.scheme, parts.netloc, og_image_path, "", "")) + return canonical_url, og_image_url + + def _serve_index(file_path: Path, request: Request) -> HTMLResponse: + canonical_url, og_image_url = _absolute_meta_urls(request) + html = file_path.read_text(encoding="utf-8") + html = html.replace( + 'href="/" data-dynamic-url="canonical"', + f'href="{escape(canonical_url, quote=True)}" data-dynamic-url="canonical"', + ) + html = html.replace( + 'content="/" data-dynamic-url="canonical"', + f'content="{escape(canonical_url, quote=True)}" data-dynamic-url="canonical"', + ) + html = html.replace( + 'content="/og-image.png" data-dynamic-url="og-image"', + f'content="{escape(og_image_url, quote=True)}" data-dynamic-url="og-image"', + ) + + resp = HTMLResponse(html) + resp.headers["Cache-Control"] = "no-cache" + return resp + + def _serve_file(file_path: Path, request: Request) -> Response: + if file_path.name == "index.html": + return _serve_index(file_path, request) + resp = FileResponse(str(file_path)) if HASHED_PATTERN.search(file_path.name): resp.headers["Cache-Control"] = "public, max-age=31536000, immutable" @@ -160,7 +200,7 @@ def _mount_static(app: FastAPI, static_dir: Path) -> None: return resp @app.get("/{full_path:path}", include_in_schema=False) - async def spa_fallback(full_path: str) -> Response: + async def spa_fallback(request: Request, full_path: str) -> Response: # API routes are handled by the API router (registered before this); # if execution reaches here for an /api/* path, it really is unknown. if full_path.startswith("api/"): @@ -174,12 +214,12 @@ def _mount_static(app: FastAPI, static_dir: Path) -> None: except ValueError: raise HTTPException(status_code=404, detail="Not found") if candidate.is_file(): - return _serve_file(candidate) + return _serve_file(candidate, request) # SPA fallback to index.html. index = static_dir / "index.html" if index.is_file(): - return _serve_file(index) + return _serve_file(index, request) raise HTTPException(status_code=404, detail="Not found") diff --git a/backend/src/life_towers/migrations/001_initial.sql b/backend/src/life_towers/migrations/001_initial.sql index 14ab170..372c02f 100644 --- a/backend/src/life_towers/migrations/001_initial.sql +++ b/backend/src/life_towers/migrations/001_initial.sql @@ -1,10 +1,8 @@ -- Life Towers v4 initial schema. --- SQLite with WAL mode and foreign keys enabled at connection time. +-- WAL mode, foreign keys, and busy_timeout are applied per-connection in +-- db._apply_pragmas(), so they are not (and need not be) set here. -- All timestamps are unix epoch seconds (INTEGER). -PRAGMA journal_mode = WAL; -PRAGMA foreign_keys = ON; - CREATE TABLE IF NOT EXISTS users ( id TEXT PRIMARY KEY, created_at INTEGER NOT NULL, @@ -17,6 +15,7 @@ CREATE TABLE IF NOT EXISTS pages ( position INTEGER NOT NULL, name TEXT NOT NULL, hide_create_tower_button INTEGER NOT NULL DEFAULT 0 CHECK (hide_create_tower_button IN (0, 1)), + keep_tasks_open INTEGER NOT NULL DEFAULT 0 CHECK (keep_tasks_open IN (0, 1)), default_date_from INTEGER, default_date_to INTEGER, created_at INTEGER NOT NULL, @@ -47,6 +46,7 @@ CREATE TABLE IF NOT EXISTS blocks ( tag TEXT NOT NULL DEFAULT '', description TEXT NOT NULL DEFAULT '', is_done INTEGER NOT NULL DEFAULT 0 CHECK (is_done IN (0, 1)), + difficulty INTEGER NOT NULL DEFAULT 1 CHECK (difficulty >= 1), created_at INTEGER NOT NULL, updated_at INTEGER NOT NULL ) STRICT; diff --git a/backend/src/life_towers/migrations/002_keep_tasks_open.sql b/backend/src/life_towers/migrations/002_keep_tasks_open.sql deleted file mode 100644 index f9417ad..0000000 --- a/backend/src/life_towers/migrations/002_keep_tasks_open.sql +++ /dev/null @@ -1,2 +0,0 @@ -ALTER TABLE pages ADD COLUMN keep_tasks_open INTEGER NOT NULL DEFAULT 0 - CHECK (keep_tasks_open IN (0, 1)); diff --git a/backend/src/life_towers/models.py b/backend/src/life_towers/models.py index d848ad4..b40d5e3 100644 --- a/backend/src/life_towers/models.py +++ b/backend/src/life_towers/models.py @@ -8,12 +8,14 @@ from pydantic import BaseModel, Field, field_validator, model_validator import uuid as _uuid_mod -def _is_uuidv4(value: str) -> bool: +def _canonical_uuidv4(value: str) -> str: try: u = _uuid_mod.UUID(value) - return u.version == 4 + if u.version == 4: + return str(u) except (ValueError, AttributeError): - return False + pass + raise ValueError("must be a UUIDv4") class HslColor(BaseModel): @@ -27,14 +29,13 @@ class BlockIn(BaseModel): tag: str = Field(max_length=200) description: str = Field(max_length=10_000) is_done: bool + difficulty: int = Field(default=1, ge=1, le=100) created_at: Optional[int] = None @field_validator("id") @classmethod def validate_id(cls, v: str) -> str: - if not _is_uuidv4(v): - raise ValueError(f"id must be a UUIDv4, got: {v!r}") - return v + return _canonical_uuidv4(v) class BlockOut(BaseModel): @@ -42,6 +43,7 @@ class BlockOut(BaseModel): tag: str description: str is_done: bool + difficulty: int created_at: int @@ -54,9 +56,7 @@ class TowerIn(BaseModel): @field_validator("id") @classmethod def validate_id(cls, v: str) -> str: - if not _is_uuidv4(v): - raise ValueError(f"id must be a UUIDv4, got: {v!r}") - return v + return _canonical_uuidv4(v) class TowerOut(BaseModel): @@ -78,9 +78,7 @@ class PageIn(BaseModel): @field_validator("id") @classmethod def validate_id(cls, v: str) -> str: - if not _is_uuidv4(v): - raise ValueError(f"id must be a UUIDv4, got: {v!r}") - return v + return _canonical_uuidv4(v) class PageOut(BaseModel): @@ -137,9 +135,7 @@ class RegisterRequest(BaseModel): @field_validator("token") @classmethod def validate_token(cls, v: str) -> str: - if not _is_uuidv4(v): - raise ValueError("token must be a UUIDv4") - return v + return _canonical_uuidv4(v) class RegisterResponse(BaseModel): diff --git a/backend/tests/test_api.py b/backend/tests/test_api.py index 868d344..a4f8538 100644 --- a/backend/tests/test_api.py +++ b/backend/tests/test_api.py @@ -51,15 +51,71 @@ async def client(tmp_path: Path) -> AsyncGenerator[AsyncClient, None]: db_module._DB_PATH = None -# --------------------------------------------------------------------------- -# Health -# --------------------------------------------------------------------------- - @pytest.mark.asyncio -async def test_health(client: AsyncClient) -> None: - resp = await client.get("/api/v1/health") +async def test_spa_index_injects_absolute_open_graph_urls( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + static_dir = tmp_path / "static" + static_dir.mkdir() + (static_dir / "index.html").write_text( + """ + + +
+ + + + + + +""", + encoding="utf-8", + ) + (static_dir / "og-image.png").write_bytes(b"fake png") + monkeypatch.setenv("LIFE_TOWERS_STATIC_DIR", str(static_dir)) + + app = create_app() + async with AsyncClient( + transport=ASGITransport(app=app), + base_url="https://towers.example", + ) as c: + resp = await c.get("/tasks?utm_source=test") + assert resp.status_code == 200 - assert resp.json() == {"status": "ok"} + assert ( + '' + ) in resp.text + assert ( + '' + ) in resp.text + assert ( + '' + ) in resp.text + assert ( + '' + ) in resp.text + + monkeypatch.setenv("LIFE_TOWERS_PUBLIC_URL", "https://public.example/towers") + app = create_app() + async with AsyncClient( + transport=ASGITransport(app=app), + base_url="https://internal.example", + ) as c: + resp = await c.get("/") + + assert resp.status_code == 200 + assert ( + '' + ) in resp.text + assert ( + '' + ) in resp.text # --------------------------------------------------------------------------- @@ -99,6 +155,21 @@ async def test_register_non_uuidv4_token(client: AsyncClient) -> None: assert resp.status_code == 400 +@pytest.mark.asyncio +async def test_uuid_inputs_are_canonicalized(client: AsyncClient) -> None: + token = make_uuidv4() + upper_token = token.upper() + register_resp = await client.post("/api/v1/register", json={"token": upper_token}) + assert register_resp.status_code == 200 + assert register_resp.json() == {"user_id": token} + + data_resp = await client.get( + "/api/v1/data", + headers={"Authorization": f"Bearer {upper_token}"}, + ) + assert data_resp.status_code == 200 + + # --------------------------------------------------------------------------- # Auth / GET /data # --------------------------------------------------------------------------- @@ -148,6 +219,7 @@ def _make_tree() -> dict: "tag": f"tag-{pi}-{ti}-{bi}", "description": f"desc-{pi}-{ti}-{bi}", "is_done": False, + "difficulty": bi + 1, "created_at": 1700000000 + bi, } ) @@ -199,6 +271,42 @@ async def test_round_trip(client: AsyncClient) -> None: for bi, block in enumerate(tower["blocks"]): assert block["id"] == tree["pages"][pi]["towers"][ti]["blocks"][bi]["id"] assert block["tag"] == tree["pages"][pi]["towers"][ti]["blocks"][bi]["tag"] + assert ( + block["difficulty"] + == tree["pages"][pi]["towers"][ti]["blocks"][bi]["difficulty"] + ) + + +@pytest.mark.asyncio +async def test_difficulty_defaults_to_one_when_omitted(client: AsyncClient) -> None: + """A block sent without `difficulty` is stored and returned as 1.""" + token = make_uuidv4() + await client.post("/api/v1/register", json={"token": token}) + headers = {"Authorization": f"Bearer {token}"} + + tree = _make_tree() + # Drop difficulty from the very first block. + del tree["pages"][0]["towers"][0]["blocks"][0]["difficulty"] + + put_resp = await client.put("/api/v1/data", json=tree, headers=headers) + assert put_resp.status_code == 204 + + data = (await client.get("/api/v1/data", headers=headers)).json() + assert data["pages"][0]["towers"][0]["blocks"][0]["difficulty"] == 1 + + +@pytest.mark.asyncio +async def test_difficulty_must_be_positive(client: AsyncClient) -> None: + """difficulty < 1 is rejected by validation.""" + token = make_uuidv4() + await client.post("/api/v1/register", json={"token": token}) + headers = {"Authorization": f"Bearer {token}"} + + tree = _make_tree() + tree["pages"][0]["towers"][0]["blocks"][0]["difficulty"] = 0 + + resp = await client.put("/api/v1/data", json=tree, headers=headers) + assert resp.status_code == 400 # --------------------------------------------------------------------------- @@ -218,6 +326,34 @@ async def test_put_duplicate_page_id(client: AsyncClient) -> None: resp = await client.put("/api/v1/data", json=tree, headers=headers) assert resp.status_code == 400 # pydantic validation error → 400 bad_request per spec + assert resp.json() == {"error": "bad_request", "detail": "Validation failed"} + + +@pytest.mark.asyncio +async def test_put_cross_user_id_conflict_returns_409(client: AsyncClient) -> None: + first_token = make_uuidv4() + second_token = make_uuidv4() + await client.post("/api/v1/register", json={"token": first_token}) + await client.post("/api/v1/register", json={"token": second_token}) + + tree = _make_tree() + first_resp = await client.put( + "/api/v1/data", + json=tree, + headers={"Authorization": f"Bearer {first_token}"}, + ) + assert first_resp.status_code == 204 + + second_resp = await client.put( + "/api/v1/data", + json=tree, + headers={"Authorization": f"Bearer {second_token}"}, + ) + assert second_resp.status_code == 409 + assert second_resp.json() == { + "error": "conflict", + "detail": "Submitted IDs conflict with existing data", + } @pytest.mark.asyncio