backend: tidy modules, consolidate schema migrations, expand API tests
This commit is contained in:
parent
4156d1d469
commit
d9724a462d
9 changed files with 254 additions and 74 deletions
|
|
@ -1,8 +1,6 @@
|
|||
"""APIRouter with all Life Towers endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import time
|
||||
from typing import Annotated
|
||||
|
||||
|
|
@ -12,6 +10,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
|||
from .auth import get_current_user
|
||||
from .db import db_connection
|
||||
from .limits import limiter
|
||||
from .logging import token_log_id
|
||||
from .models import (
|
||||
BlockOut,
|
||||
DataIn,
|
||||
|
|
@ -53,7 +52,7 @@ async def register(request: Request, body: RegisterRequest) -> RegisterResponse:
|
|||
(now, token),
|
||||
)
|
||||
conn.commit()
|
||||
logger.info("user_registered", user_id=token, new=existing is None)
|
||||
logger.info("user_registered", user_id=token_log_id(token), new=existing is None)
|
||||
return RegisterResponse(user_id=token)
|
||||
|
||||
|
||||
|
|
@ -94,7 +93,7 @@ async def get_data(
|
|||
|
||||
block_rows = conn.execute(
|
||||
"""
|
||||
SELECT id, tag, description, is_done, created_at
|
||||
SELECT id, tag, description, is_done, difficulty, created_at
|
||||
FROM blocks
|
||||
WHERE tower_id = ?
|
||||
ORDER BY position
|
||||
|
|
@ -108,6 +107,7 @@ async def get_data(
|
|||
tag=b["tag"],
|
||||
description=b["description"],
|
||||
is_done=bool(b["is_done"]),
|
||||
difficulty=b["difficulty"],
|
||||
created_at=b["created_at"],
|
||||
)
|
||||
for b in block_rows
|
||||
|
|
@ -213,8 +213,9 @@ async def put_data(
|
|||
"""
|
||||
INSERT INTO blocks
|
||||
(id, tower_id, user_id, position, tag,
|
||||
description, is_done, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
description, is_done, difficulty,
|
||||
created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
block.id,
|
||||
|
|
@ -224,6 +225,7 @@ async def put_data(
|
|||
block.tag,
|
||||
block.description,
|
||||
1 if block.is_done else 0,
|
||||
block.difficulty,
|
||||
created_at,
|
||||
now,
|
||||
),
|
||||
|
|
@ -236,8 +238,14 @@ async def put_data(
|
|||
)
|
||||
|
||||
conn.commit()
|
||||
except sqlite3.IntegrityError as exc:
|
||||
conn.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Submitted IDs conflict with existing data",
|
||||
) from exc
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
logger.info("data_replaced", user_id=user_id, pages=len(body.pages))
|
||||
logger.info("data_replaced", user_id=token_log_id(user_id), pages=len(body.pages))
|
||||
|
|
|
|||
|
|
@ -2,11 +2,10 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from .db import db_connection
|
||||
from .models import _canonical_uuidv4
|
||||
|
||||
# Single generic detail used for ALL 401 responses. Per spec, the response
|
||||
# must not distinguish between missing / malformed / unknown tokens — that
|
||||
|
|
@ -21,25 +20,28 @@ def _unauthorized() -> HTTPException:
|
|||
)
|
||||
|
||||
|
||||
def get_current_user(request: Request) -> str:
|
||||
"""Dependency that extracts and validates a Bearer token, returns user_id."""
|
||||
def extract_bearer_token(request: Request) -> str | None:
|
||||
"""Return the raw Bearer token from the Authorization header, or None."""
|
||||
auth_header = request.headers.get("Authorization") or request.headers.get(
|
||||
"authorization"
|
||||
)
|
||||
if not auth_header:
|
||||
raise _unauthorized()
|
||||
|
||||
return None
|
||||
parts = auth_header.split()
|
||||
if len(parts) != 2 or parts[0].lower() != "bearer":
|
||||
raise _unauthorized()
|
||||
if len(parts) == 2 and parts[0].lower() == "bearer":
|
||||
return parts[1]
|
||||
return None
|
||||
|
||||
token = parts[1]
|
||||
|
||||
def get_current_user(request: Request) -> str:
|
||||
"""Dependency that extracts and validates a Bearer token, returns user_id."""
|
||||
token = extract_bearer_token(request)
|
||||
if token is None:
|
||||
raise _unauthorized()
|
||||
|
||||
try:
|
||||
u = uuid.UUID(token)
|
||||
if u.version != 4:
|
||||
raise ValueError("Not v4")
|
||||
except (ValueError, AttributeError):
|
||||
token = _canonical_uuidv4(token)
|
||||
except ValueError:
|
||||
raise _unauthorized()
|
||||
|
||||
with db_connection() as conn:
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ from fastapi import Request, Response
|
|||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from .auth import extract_bearer_token
|
||||
|
||||
PAYLOAD_LIMIT_BYTES = 2 * 1024 * 1024 # 2 MiB
|
||||
|
||||
_TOO_LARGE_BODY = json.dumps(
|
||||
|
|
@ -20,12 +22,7 @@ _TOO_LARGE_BODY = json.dumps(
|
|||
|
||||
def _get_token_or_ip(request: Request) -> str:
|
||||
"""Key function for rate limiting: use Bearer token if present, else IP."""
|
||||
auth = request.headers.get("Authorization") or request.headers.get("authorization")
|
||||
if auth:
|
||||
parts = auth.split()
|
||||
if len(parts) == 2 and parts[0].lower() == "bearer":
|
||||
return parts[1]
|
||||
return get_remote_address(request)
|
||||
return extract_bearer_token(request) or get_remote_address(request)
|
||||
|
||||
|
||||
limiter = Limiter(key_func=_get_token_or_ip, default_limits=[])
|
||||
|
|
|
|||
|
|
@ -4,10 +4,13 @@ from __future__ import annotations
|
|||
|
||||
import time
|
||||
import uuid as _uuid_mod
|
||||
from hashlib import sha256
|
||||
|
||||
import structlog
|
||||
from fastapi import Request, Response
|
||||
|
||||
from .auth import extract_bearer_token
|
||||
|
||||
|
||||
def configure_logging() -> None:
|
||||
"""Configure structlog for JSON output."""
|
||||
|
|
@ -26,19 +29,19 @@ def configure_logging() -> None:
|
|||
)
|
||||
|
||||
|
||||
def token_log_id(token: str) -> str:
|
||||
return sha256(token.encode("utf-8")).hexdigest()[:12]
|
||||
|
||||
|
||||
async def request_logging_middleware(request: Request, call_next) -> Response:
|
||||
"""Log each request with method, path, status, duration_ms, user_id, request_id."""
|
||||
"""Log each request without writing bearer credentials to the log stream."""
|
||||
request_id = str(_uuid_mod.uuid4())
|
||||
structlog.contextvars.clear_contextvars()
|
||||
structlog.contextvars.bind_contextvars(request_id=request_id)
|
||||
|
||||
# Extract user_id from Authorization header for logging (no DB call here)
|
||||
auth = request.headers.get("Authorization") or request.headers.get("authorization")
|
||||
user_id: str | None = None
|
||||
if auth:
|
||||
parts = auth.split()
|
||||
if len(parts) == 2 and parts[0].lower() == "bearer":
|
||||
user_id = parts[1]
|
||||
token = extract_bearer_token(request)
|
||||
user_id = token_log_id(token) if token else None
|
||||
|
||||
start = time.monotonic()
|
||||
response = await call_next(request)
|
||||
|
|
|
|||
|
|
@ -1,18 +1,17 @@
|
|||
"""ASGI app, lifespan, static files mount, route registration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from html import escape
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
from urllib.parse import urlsplit, urlunsplit
|
||||
|
||||
import structlog
|
||||
from fastapi import FastAPI, HTTPException, Request, Response
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
|
||||
from slowapi import _rate_limit_exceeded_handler
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
|
@ -35,7 +34,6 @@ STATUS_CODE_MAP: dict[int, str] = {
|
|||
413: "payload_too_large",
|
||||
422: "bad_request",
|
||||
429: "rate_limited",
|
||||
507: "quota_exceeded",
|
||||
500: "server_error",
|
||||
}
|
||||
|
||||
|
|
@ -92,10 +90,15 @@ def create_app() -> FastAPI:
|
|||
request: Request, exc: RequestValidationError
|
||||
) -> JSONResponse:
|
||||
fields = sorted(
|
||||
{".".join(str(loc) for loc in e.get("loc", ()) if loc != "body") for e in exc.errors()}
|
||||
field
|
||||
for field in {
|
||||
".".join(str(loc) for loc in e.get("loc", ()) if loc != "body")
|
||||
for e in exc.errors()
|
||||
}
|
||||
if field
|
||||
)
|
||||
if fields:
|
||||
detail_str = "Validation failed for: " + ", ".join(f for f in fields if f)
|
||||
detail_str = "Validation failed for: " + ", ".join(fields)
|
||||
else:
|
||||
detail_str = "Validation failed"
|
||||
return JSONResponse(
|
||||
|
|
@ -113,7 +116,7 @@ def create_app() -> FastAPI:
|
|||
code = STATUS_CODE_MAP.get(exc.status_code, "server_error")
|
||||
detail = {"error": code, "detail": str(exc.detail)}
|
||||
|
||||
headers = getattr(exc, "headers", None) or {}
|
||||
headers = exc.headers or {}
|
||||
return JSONResponse(status_code=exc.status_code, content=detail, headers=headers)
|
||||
|
||||
# Generic 500 handler
|
||||
|
|
@ -151,7 +154,44 @@ def _mount_static(app: FastAPI, static_dir: Path) -> None:
|
|||
r"[-.][A-Za-z0-9]{8,}\.(?:js|css|woff2?|png|jpe?g|svg|ico|map)$"
|
||||
)
|
||||
|
||||
def _serve_file(file_path: Path) -> FileResponse:
|
||||
def _absolute_meta_urls(request: Request) -> tuple[str, str]:
|
||||
configured_public_url = os.environ.get("LIFE_TOWERS_PUBLIC_URL", "").strip()
|
||||
if configured_public_url:
|
||||
public_root = configured_public_url.rstrip("/") + "/"
|
||||
return public_root, f"{public_root}og-image.png"
|
||||
|
||||
parts = urlsplit(str(request.url))
|
||||
canonical_url = urlunsplit((parts.scheme, parts.netloc, parts.path or "/", "", ""))
|
||||
|
||||
root_path = str(request.scope.get("root_path") or "").strip("/")
|
||||
og_image_path = f"/{root_path}/og-image.png" if root_path else "/og-image.png"
|
||||
og_image_url = urlunsplit((parts.scheme, parts.netloc, og_image_path, "", ""))
|
||||
return canonical_url, og_image_url
|
||||
|
||||
def _serve_index(file_path: Path, request: Request) -> HTMLResponse:
|
||||
canonical_url, og_image_url = _absolute_meta_urls(request)
|
||||
html = file_path.read_text(encoding="utf-8")
|
||||
html = html.replace(
|
||||
'href="/" data-dynamic-url="canonical"',
|
||||
f'href="{escape(canonical_url, quote=True)}" data-dynamic-url="canonical"',
|
||||
)
|
||||
html = html.replace(
|
||||
'content="/" data-dynamic-url="canonical"',
|
||||
f'content="{escape(canonical_url, quote=True)}" data-dynamic-url="canonical"',
|
||||
)
|
||||
html = html.replace(
|
||||
'content="/og-image.png" data-dynamic-url="og-image"',
|
||||
f'content="{escape(og_image_url, quote=True)}" data-dynamic-url="og-image"',
|
||||
)
|
||||
|
||||
resp = HTMLResponse(html)
|
||||
resp.headers["Cache-Control"] = "no-cache"
|
||||
return resp
|
||||
|
||||
def _serve_file(file_path: Path, request: Request) -> Response:
|
||||
if file_path.name == "index.html":
|
||||
return _serve_index(file_path, request)
|
||||
|
||||
resp = FileResponse(str(file_path))
|
||||
if HASHED_PATTERN.search(file_path.name):
|
||||
resp.headers["Cache-Control"] = "public, max-age=31536000, immutable"
|
||||
|
|
@ -160,7 +200,7 @@ def _mount_static(app: FastAPI, static_dir: Path) -> None:
|
|||
return resp
|
||||
|
||||
@app.get("/{full_path:path}", include_in_schema=False)
|
||||
async def spa_fallback(full_path: str) -> Response:
|
||||
async def spa_fallback(request: Request, full_path: str) -> Response:
|
||||
# API routes are handled by the API router (registered before this);
|
||||
# if execution reaches here for an /api/* path, it really is unknown.
|
||||
if full_path.startswith("api/"):
|
||||
|
|
@ -174,12 +214,12 @@ def _mount_static(app: FastAPI, static_dir: Path) -> None:
|
|||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
if candidate.is_file():
|
||||
return _serve_file(candidate)
|
||||
return _serve_file(candidate, request)
|
||||
|
||||
# SPA fallback to index.html.
|
||||
index = static_dir / "index.html"
|
||||
if index.is_file():
|
||||
return _serve_file(index)
|
||||
return _serve_file(index, request)
|
||||
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
-- Life Towers v4 initial schema.
|
||||
-- SQLite with WAL mode and foreign keys enabled at connection time.
|
||||
-- WAL mode, foreign keys, and busy_timeout are applied per-connection in
|
||||
-- db._apply_pragmas(), so they are not (and need not be) set here.
|
||||
-- All timestamps are unix epoch seconds (INTEGER).
|
||||
|
||||
PRAGMA journal_mode = WAL;
|
||||
PRAGMA foreign_keys = ON;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
created_at INTEGER NOT NULL,
|
||||
|
|
@ -17,6 +15,7 @@ CREATE TABLE IF NOT EXISTS pages (
|
|||
position INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
hide_create_tower_button INTEGER NOT NULL DEFAULT 0 CHECK (hide_create_tower_button IN (0, 1)),
|
||||
keep_tasks_open INTEGER NOT NULL DEFAULT 0 CHECK (keep_tasks_open IN (0, 1)),
|
||||
default_date_from INTEGER,
|
||||
default_date_to INTEGER,
|
||||
created_at INTEGER NOT NULL,
|
||||
|
|
@ -47,6 +46,7 @@ CREATE TABLE IF NOT EXISTS blocks (
|
|||
tag TEXT NOT NULL DEFAULT '',
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
is_done INTEGER NOT NULL DEFAULT 0 CHECK (is_done IN (0, 1)),
|
||||
difficulty INTEGER NOT NULL DEFAULT 1 CHECK (difficulty >= 1),
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL
|
||||
) STRICT;
|
||||
|
|
|
|||
|
|
@ -1,2 +0,0 @@
|
|||
ALTER TABLE pages ADD COLUMN keep_tasks_open INTEGER NOT NULL DEFAULT 0
|
||||
CHECK (keep_tasks_open IN (0, 1));
|
||||
|
|
@ -8,12 +8,14 @@ from pydantic import BaseModel, Field, field_validator, model_validator
|
|||
import uuid as _uuid_mod
|
||||
|
||||
|
||||
def _is_uuidv4(value: str) -> bool:
|
||||
def _canonical_uuidv4(value: str) -> str:
|
||||
try:
|
||||
u = _uuid_mod.UUID(value)
|
||||
return u.version == 4
|
||||
if u.version == 4:
|
||||
return str(u)
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
pass
|
||||
raise ValueError("must be a UUIDv4")
|
||||
|
||||
|
||||
class HslColor(BaseModel):
|
||||
|
|
@ -27,14 +29,13 @@ class BlockIn(BaseModel):
|
|||
tag: str = Field(max_length=200)
|
||||
description: str = Field(max_length=10_000)
|
||||
is_done: bool
|
||||
difficulty: int = Field(default=1, ge=1, le=100)
|
||||
created_at: Optional[int] = None
|
||||
|
||||
@field_validator("id")
|
||||
@classmethod
|
||||
def validate_id(cls, v: str) -> str:
|
||||
if not _is_uuidv4(v):
|
||||
raise ValueError(f"id must be a UUIDv4, got: {v!r}")
|
||||
return v
|
||||
return _canonical_uuidv4(v)
|
||||
|
||||
|
||||
class BlockOut(BaseModel):
|
||||
|
|
@ -42,6 +43,7 @@ class BlockOut(BaseModel):
|
|||
tag: str
|
||||
description: str
|
||||
is_done: bool
|
||||
difficulty: int
|
||||
created_at: int
|
||||
|
||||
|
||||
|
|
@ -54,9 +56,7 @@ class TowerIn(BaseModel):
|
|||
@field_validator("id")
|
||||
@classmethod
|
||||
def validate_id(cls, v: str) -> str:
|
||||
if not _is_uuidv4(v):
|
||||
raise ValueError(f"id must be a UUIDv4, got: {v!r}")
|
||||
return v
|
||||
return _canonical_uuidv4(v)
|
||||
|
||||
|
||||
class TowerOut(BaseModel):
|
||||
|
|
@ -78,9 +78,7 @@ class PageIn(BaseModel):
|
|||
@field_validator("id")
|
||||
@classmethod
|
||||
def validate_id(cls, v: str) -> str:
|
||||
if not _is_uuidv4(v):
|
||||
raise ValueError(f"id must be a UUIDv4, got: {v!r}")
|
||||
return v
|
||||
return _canonical_uuidv4(v)
|
||||
|
||||
|
||||
class PageOut(BaseModel):
|
||||
|
|
@ -137,9 +135,7 @@ class RegisterRequest(BaseModel):
|
|||
@field_validator("token")
|
||||
@classmethod
|
||||
def validate_token(cls, v: str) -> str:
|
||||
if not _is_uuidv4(v):
|
||||
raise ValueError("token must be a UUIDv4")
|
||||
return v
|
||||
return _canonical_uuidv4(v)
|
||||
|
||||
|
||||
class RegisterResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -51,15 +51,71 @@ async def client(tmp_path: Path) -> AsyncGenerator[AsyncClient, None]:
|
|||
db_module._DB_PATH = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health(client: AsyncClient) -> None:
|
||||
resp = await client.get("/api/v1/health")
|
||||
async def test_spa_index_injects_absolute_open_graph_urls(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
static_dir = tmp_path / "static"
|
||||
static_dir.mkdir()
|
||||
(static_dir / "index.html").write_text(
|
||||
"""
|
||||
<!doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<link rel="canonical" href="/" data-dynamic-url="canonical" />
|
||||
<meta property="og:url" content="/" data-dynamic-url="canonical" />
|
||||
<meta property="og:image" content="/og-image.png" data-dynamic-url="og-image" />
|
||||
<meta name="twitter:image" content="/og-image.png" data-dynamic-url="og-image" />
|
||||
</head>
|
||||
</html>
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(static_dir / "og-image.png").write_bytes(b"fake png")
|
||||
monkeypatch.setenv("LIFE_TOWERS_STATIC_DIR", str(static_dir))
|
||||
|
||||
app = create_app()
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app),
|
||||
base_url="https://towers.example",
|
||||
) as c:
|
||||
resp = await c.get("/tasks?utm_source=test")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"status": "ok"}
|
||||
assert (
|
||||
'<link rel="canonical" href="https://towers.example/tasks" '
|
||||
'data-dynamic-url="canonical" />'
|
||||
) in resp.text
|
||||
assert (
|
||||
'<meta property="og:url" content="https://towers.example/tasks" '
|
||||
'data-dynamic-url="canonical" />'
|
||||
) in resp.text
|
||||
assert (
|
||||
'<meta property="og:image" content="https://towers.example/og-image.png" '
|
||||
'data-dynamic-url="og-image" />'
|
||||
) in resp.text
|
||||
assert (
|
||||
'<meta name="twitter:image" content="https://towers.example/og-image.png" '
|
||||
'data-dynamic-url="og-image" />'
|
||||
) in resp.text
|
||||
|
||||
monkeypatch.setenv("LIFE_TOWERS_PUBLIC_URL", "https://public.example/towers")
|
||||
app = create_app()
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app),
|
||||
base_url="https://internal.example",
|
||||
) as c:
|
||||
resp = await c.get("/")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert (
|
||||
'<link rel="canonical" href="https://public.example/towers/" '
|
||||
'data-dynamic-url="canonical" />'
|
||||
) in resp.text
|
||||
assert (
|
||||
'<meta property="og:image" content="https://public.example/towers/og-image.png" '
|
||||
'data-dynamic-url="og-image" />'
|
||||
) in resp.text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -99,6 +155,21 @@ async def test_register_non_uuidv4_token(client: AsyncClient) -> None:
|
|||
assert resp.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uuid_inputs_are_canonicalized(client: AsyncClient) -> None:
|
||||
token = make_uuidv4()
|
||||
upper_token = token.upper()
|
||||
register_resp = await client.post("/api/v1/register", json={"token": upper_token})
|
||||
assert register_resp.status_code == 200
|
||||
assert register_resp.json() == {"user_id": token}
|
||||
|
||||
data_resp = await client.get(
|
||||
"/api/v1/data",
|
||||
headers={"Authorization": f"Bearer {upper_token}"},
|
||||
)
|
||||
assert data_resp.status_code == 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth / GET /data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -148,6 +219,7 @@ def _make_tree() -> dict:
|
|||
"tag": f"tag-{pi}-{ti}-{bi}",
|
||||
"description": f"desc-{pi}-{ti}-{bi}",
|
||||
"is_done": False,
|
||||
"difficulty": bi + 1,
|
||||
"created_at": 1700000000 + bi,
|
||||
}
|
||||
)
|
||||
|
|
@ -199,6 +271,42 @@ async def test_round_trip(client: AsyncClient) -> None:
|
|||
for bi, block in enumerate(tower["blocks"]):
|
||||
assert block["id"] == tree["pages"][pi]["towers"][ti]["blocks"][bi]["id"]
|
||||
assert block["tag"] == tree["pages"][pi]["towers"][ti]["blocks"][bi]["tag"]
|
||||
assert (
|
||||
block["difficulty"]
|
||||
== tree["pages"][pi]["towers"][ti]["blocks"][bi]["difficulty"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_difficulty_defaults_to_one_when_omitted(client: AsyncClient) -> None:
|
||||
"""A block sent without `difficulty` is stored and returned as 1."""
|
||||
token = make_uuidv4()
|
||||
await client.post("/api/v1/register", json={"token": token})
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
tree = _make_tree()
|
||||
# Drop difficulty from the very first block.
|
||||
del tree["pages"][0]["towers"][0]["blocks"][0]["difficulty"]
|
||||
|
||||
put_resp = await client.put("/api/v1/data", json=tree, headers=headers)
|
||||
assert put_resp.status_code == 204
|
||||
|
||||
data = (await client.get("/api/v1/data", headers=headers)).json()
|
||||
assert data["pages"][0]["towers"][0]["blocks"][0]["difficulty"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_difficulty_must_be_positive(client: AsyncClient) -> None:
|
||||
"""difficulty < 1 is rejected by validation."""
|
||||
token = make_uuidv4()
|
||||
await client.post("/api/v1/register", json={"token": token})
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
tree = _make_tree()
|
||||
tree["pages"][0]["towers"][0]["blocks"][0]["difficulty"] = 0
|
||||
|
||||
resp = await client.put("/api/v1/data", json=tree, headers=headers)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -218,6 +326,34 @@ async def test_put_duplicate_page_id(client: AsyncClient) -> None:
|
|||
|
||||
resp = await client.put("/api/v1/data", json=tree, headers=headers)
|
||||
assert resp.status_code == 400 # pydantic validation error → 400 bad_request per spec
|
||||
assert resp.json() == {"error": "bad_request", "detail": "Validation failed"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_put_cross_user_id_conflict_returns_409(client: AsyncClient) -> None:
|
||||
first_token = make_uuidv4()
|
||||
second_token = make_uuidv4()
|
||||
await client.post("/api/v1/register", json={"token": first_token})
|
||||
await client.post("/api/v1/register", json={"token": second_token})
|
||||
|
||||
tree = _make_tree()
|
||||
first_resp = await client.put(
|
||||
"/api/v1/data",
|
||||
json=tree,
|
||||
headers={"Authorization": f"Bearer {first_token}"},
|
||||
)
|
||||
assert first_resp.status_code == 204
|
||||
|
||||
second_resp = await client.put(
|
||||
"/api/v1/data",
|
||||
json=tree,
|
||||
headers={"Authorization": f"Bearer {second_token}"},
|
||||
)
|
||||
assert second_resp.status_code == 409
|
||||
assert second_resp.json() == {
|
||||
"error": "conflict",
|
||||
"detail": "Submitted IDs conflict with existing data",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue