From cd70c4cdcabcebe81f77394f64bf4d12ae29490d Mon Sep 17 00:00:00 2001 From: Andras Schmelczer Date: Sun, 26 Apr 2026 14:22:33 +0100 Subject: [PATCH] Retries and simplification --- CLAUDE.md | 15 +- src/display.py | 10 + src/lib/homeassistant.py | 62 ++-- src/lib/immich.py | 567 ++++++++++++++++-------------- src/lib/progress.py | 5 - src/lib/waveshare_epd/epd7in3e.py | 77 ++-- 6 files changed, 391 insertions(+), 345 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 8674010..46734c4 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -30,18 +30,20 @@ python3 display.py --saturation 1.5 --contrast 1.1 --gamma 0.85 4. Sends to e-ink display driver **`src/lib/immich.py`** — Immich API client. Key behaviors: -- `PhotoHistory` tracks displayed photos in `photo_history.json` to avoid repeats (resets after 7 days) -- `_pick_weighted_random()` biases selection: 50% chance favorites, 50% chance recent (last 7 days), otherwise random -- Filters photos by orientation (portrait/landscape) based on EXIF data including rotation tags +- `PhotoHistory` tracks displayed photos in `photo_history.json` to avoid repeats (resets after 7 days). Asset is only marked displayed after a successful download. +- `_pick_weighted_random()` biases selection: 20% favorites, 50% recently-added (last 30 days, by Immich `createdAt`), otherwise uniform random +- Filters photos by orientation (portrait/landscape) based on EXIF data including rotation tags. Raises if nothing matches the requested orientation. - Downloads preview-size thumbnails, not originals +- Asset lists (people-search and album) are cached on disk in `/tmp/frame_cache/` for 1 hour +- `urlopen` calls retry transient failures twice (3s, 10s backoff) **`src/lib/homeassistant.py`** — Simple Home Assistant REST client for presence detection. **`src/lib/waveshare_epd/epd7in3e.py`** — Modified Waveshare driver. The `getbuffer()` method handles the full image pipeline: - Center-crops to 800x480 (or 480x800) -- Enhances saturation/contrast/gamma for e-ink (defaults: saturation=1.4, contrast=1.2, gamma=0.9) -- Atkinson dithering to 6-color palette using numba JIT -- Packs into 4-bit-per-pixel buffer (two pixels per byte) +- Enhances saturation/contrast/gamma for e-ink (caller passes values; CLI defaults live in `display.py`: saturation=1.3, contrast=1.05, gamma=0.90) +- Atkinson dithering to 6-color palette using numba JIT; produces palette indices directly (no Pillow quantize round-trip) +- Packs into 4-bit-per-pixel buffer (two pixels per byte) via numpy **`src/lib/waveshare_epd/epdconfig.py`** — GPIO/SPI hardware config. **Critical: PWR pin is BCM 27** (not default 18). @@ -55,4 +57,5 @@ python3 display.py --saturation 1.5 --contrast 1.1 --gamma 0.85 - **Dependencies on Pi**: `python3-pil python3-opencv python3-numba python3-smbus spidev gpiozero` - **Config via environment variables**: `IMMICH_URL`, `IMMICH_API_KEY`, `HA_URL`, `HA_TOKEN` (with hardcoded defaults in display.py) - **Uses only stdlib `urllib`** — no requests library; the Immich client uses `urllib.request` directly +- **Single-instance lock** at `/tmp/frame.lock` (fcntl) — overlapping cron runs exit cleanly - `sys.path.append` is used to add `lib/` to the path from display.py diff --git a/src/display.py b/src/display.py index 14216e4..fc2a3fe 100644 --- a/src/display.py +++ b/src/display.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import argparse +import fcntl import os import sys from datetime import datetime @@ -12,6 +13,8 @@ from waveshare_epd import epd7in3e from immich import ImmichClient, get_random_photo_of_people, get_random_photo_from_album from homeassistant import HomeAssistantClient +LOCK_FILE = "/tmp/frame.lock" + IMMICH_URL = os.environ.get("IMMICH_URL", "https://immich.schmelczer.dev") IMMICH_API_KEY = os.environ.get("IMMICH_API_KEY", "6crxVS1JLTJxsfGlzVhN2kefdL4EP7HPkkoMk9L6ZOE") @@ -51,6 +54,13 @@ def main() -> None: parser.add_argument("--no-enhance", action="store_true") args = parser.parse_args() + lock_fd = open(LOCK_FILE, "w") + try: + fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB) + except BlockingIOError: + print("Another instance running, skipping") + sys.exit(0) + now = datetime.now() print(f"Time: {now.strftime('%H:%M')}") diff --git a/src/lib/homeassistant.py b/src/lib/homeassistant.py index fe7458e..30f3fec 100644 --- a/src/lib/homeassistant.py +++ b/src/lib/homeassistant.py @@ -1,25 +1,37 @@ -#!/usr/bin/env python3 -import json -from urllib.request import Request, urlopen - - -class HomeAssistantClient: - def __init__(self, base_url: str, token: str): - self.base_url = base_url.rstrip("/") - self.token = token - - def get_state(self, entity_id: str) -> dict: - url = f"{self.base_url}/api/states/{entity_id}" - req = Request(url, headers={ - "Authorization": f"Bearer {self.token}", - "Content-Type": "application/json", - }) - with urlopen(req, timeout=30) as resp: - return json.loads(resp.read().decode()) - - def is_person_home(self, entity_id: str) -> bool: - try: - return self.get_state(entity_id).get("state") == "home" - except Exception as e: - print(f"Failed to check {entity_id}: {e}") - return False +#!/usr/bin/env python3 +import json +import time +from urllib.error import URLError +from urllib.request import Request, urlopen + +RETRY_DELAYS = (3, 10) + + +class HomeAssistantClient: + def __init__(self, base_url: str, token: str): + self.base_url = base_url.rstrip("/") + self.token = token + + def get_state(self, entity_id: str) -> dict: + url = f"{self.base_url}/api/states/{entity_id}" + req = Request(url, headers={ + "Authorization": f"Bearer {self.token}", + "Content-Type": "application/json", + }) + last_err: Exception | None = None + for attempt in range(len(RETRY_DELAYS) + 1): + try: + with urlopen(req, timeout=30) as resp: + return json.loads(resp.read().decode()) + except (URLError, TimeoutError) as e: + last_err = e + if attempt < len(RETRY_DELAYS): + time.sleep(RETRY_DELAYS[attempt]) + raise last_err + + def is_person_home(self, entity_id: str) -> bool: + try: + return self.get_state(entity_id).get("state") == "home" + except Exception as e: + print(f"Failed to check {entity_id}: {e}") + return False diff --git a/src/lib/immich.py b/src/lib/immich.py index 0bb2028..08f3680 100644 --- a/src/lib/immich.py +++ b/src/lib/immich.py @@ -1,260 +1,307 @@ -#!/usr/bin/env python3 -import json -import random -import tempfile -from dataclasses import dataclass -from datetime import datetime, timedelta, timezone -from pathlib import Path -from urllib.request import Request, urlopen - -from progress import ProgressBar - -HISTORY_FILE = Path(__file__).parent.parent / "photo_history.json" -HISTORY_MAX_AGE_DAYS = 7 - - -class PhotoHistory: - """Track displayed photos to avoid repeats. Clears after 7 days.""" - - def __init__(self, path: Path = HISTORY_FILE): - self.path = path - self.displayed: set[str] = set() - self.created_at: datetime | None = None - self._load() - - def _load(self) -> None: - if not self.path.exists(): - self._reset() - return - try: - data = json.loads(self.path.read_text()) - self.created_at = datetime.fromisoformat(data.get("created_at", "")) - if self.created_at.tzinfo is None: - self.created_at = self.created_at.replace(tzinfo=timezone.utc) - if datetime.now(timezone.utc) - self.created_at > timedelta(days=HISTORY_MAX_AGE_DAYS): - print(f"Photo history expired (>{HISTORY_MAX_AGE_DAYS} days), clearing...") - self._reset() - else: - self.displayed = set(data.get("displayed", [])) - except (json.JSONDecodeError, ValueError, KeyError): - self._reset() - - def _reset(self) -> None: - self.displayed = set() - self.created_at = datetime.now(timezone.utc) - self._save() - - def _save(self) -> None: - self.path.write_text(json.dumps({ - "created_at": self.created_at.isoformat(), - "displayed": list(self.displayed), - }, indent=2)) - - def mark_displayed(self, asset_id: str) -> None: - self.displayed.add(asset_id) - self._save() - - def filter_new(self, assets: list[dict]) -> list[dict]: - return [a for a in assets if a.get("id") not in self.displayed] - - -_history: PhotoHistory | None = None -_people_cache: dict[str, str] = {} # name -> id cache - - -def get_history() -> PhotoHistory: - global _history - if _history is None: - _history = PhotoHistory() - return _history - - -@dataclass -class ImmichClient: - base_url: str - api_key: str - - def _request(self, method: str, endpoint: str, data: dict | None = None, - show_progress: bool = False, progress_desc: str = "Fetching") -> dict: - url = f"{self.base_url.rstrip('/')}/api/{endpoint.lstrip('/')}" - headers = {"x-api-key": self.api_key} - body = None - if data is not None: - headers["Content-Type"] = "application/json" - body = json.dumps(data).encode() - - req = Request(url, data=body, headers=headers, method=method) - with urlopen(req, timeout=30) as resp: - total_size = resp.headers.get('Content-Length') - if total_size and show_progress: - total_size = int(total_size) - progress = ProgressBar(total_size, desc=progress_desc) - chunks = bytearray() - while chunk := resp.read(8192): - chunks.extend(chunk) - progress.update(len(chunk)) - progress.finish() - return json.loads(chunks.decode()) - return json.loads(resp.read().decode()) - - def get_people(self) -> list[dict]: - return self._request("GET", "/people")["people"] - - def get_person_id(self, name: str) -> str | None: - for person in self.get_people(): - if person["name"].lower() == name.lower(): - return person["id"] - return None - - def search_assets_by_people(self, person_ids: list[str]) -> list[dict]: - items = [] - page = 1 - while True: - result = self._request("POST", "/search/metadata", { - "personIds": person_ids, - "size": 250, - "page": page, - "type": "IMAGE", - "withExif": True, - }) - batch = result.get("assets", {}).get("items", []) - items.extend(batch) - if not batch or not result.get("assets", {}).get("nextPage"): - break - page += 1 - return items - - def download_asset(self, asset_id: str, dest: Path, show_progress: bool = True) -> Path: - url = f"{self.base_url.rstrip('/')}/api/assets/{asset_id}/thumbnail?size=preview" - req = Request(url, headers={"x-api-key": self.api_key}) - with urlopen(req, timeout=30) as resp: - total_size = resp.headers.get('Content-Length') - if total_size and show_progress: - total_size = int(total_size) - progress = ProgressBar(total_size, desc="Downloading") - data = bytearray() - while chunk := resp.read(8192): - data.extend(chunk) - progress.update(len(chunk)) - progress.finish() - dest.write_bytes(bytes(data)) - else: - dest.write_bytes(resp.read()) - return dest - - def get_album_id(self, name: str) -> str | None: - for album in self._request("GET", "/albums"): - if album["albumName"].lower() == name.lower(): - return album["id"] - return None - - def get_album_assets(self, album_id: str, show_progress: bool = False) -> list[dict]: - album = self._request("GET", f"/albums/{album_id}", - show_progress=show_progress, progress_desc="Fetching album") - return album.get("assets", []) - - -def _is_portrait(asset: dict) -> bool | None: - """Check if asset displays as portrait, accounting for EXIF orientation.""" - exif = asset.get("exifInfo") or {} - width = exif.get("exifImageWidth") or 0 - height = exif.get("exifImageHeight") or 0 - if not (width and height): - return None - # EXIF orientation 6 and 8 mean 90° rotation (swap dimensions) - orientation = str(exif.get("orientation") or "1") - if orientation in ("6", "8"): - width, height = height, width - return height > width - - -def _filter_by_orientation(assets: list[dict], portrait: bool) -> list[dict]: - """Filter assets by orientation, accounting for EXIF rotation.""" - filtered = [] - no_dimensions = 0 - for asset in assets: - is_portrait = _is_portrait(asset) - if is_portrait is not None: - if is_portrait == portrait: - filtered.append(asset) - else: - no_dimensions += 1 - if no_dimensions: - print(f"Note: {no_dimensions}/{len(assets)} photos missing dimension data") - return filtered - - -def _pick_weighted_random(assets: list[dict]) -> dict: - """Pick random asset, biased towards favorites (20%) and recently added photos (50%).""" - if not assets: - raise ValueError("No assets to choose from") - - cutoff = datetime.now(timezone.utc) - timedelta(days=30) - favorites = [a for a in assets if a.get("isFavorite")] - recent = [] - for asset in assets: - date_str = asset.get("createdAt", "") - try: - if datetime.fromisoformat(date_str.replace("Z", "+00:00")) >= cutoff: - recent.append(asset) - except (ValueError, AttributeError): - pass - - if favorites and random.random() < 0.2: - return random.choice(favorites) - if recent and random.random() < 0.5: - return random.choice(recent) - return random.choice(assets) - - -def _download_random_asset(client: ImmichClient, assets: list[dict]) -> Path: - history = get_history() - new_assets = history.filter_new(assets) - - if new_assets: - print(f"Photos: {len(new_assets)} new / {len(assets)} total") - asset = _pick_weighted_random(new_assets) - else: - print(f"All {len(assets)} photos shown, picking from full list") - asset = _pick_weighted_random(assets) - - history.mark_displayed(asset["id"]) - dest = Path(tempfile.gettempdir()) / "immich_photo.jpg" - return client.download_asset(asset["id"], dest) - - -def get_random_photo_of_people(client: ImmichClient, names: list[str], orientation: int = 0) -> Path: - person_ids = [pid for name in names if (pid := client.get_person_id(name))] - if not person_ids: - raise ValueError(f"No people found: {names}") - - assets = client.search_assets_by_people(person_ids) - - if not assets: - raise ValueError(f"No photos found for: {names}") - - portrait = orientation in (90, 270) - filtered = _filter_by_orientation(assets, portrait) - if filtered: - assets = filtered - else: - print(f"No {'portrait' if portrait else 'landscape'} photos, using any orientation") - return _download_random_asset(client, assets) - - -def get_random_photo_from_album(client: ImmichClient, album_name: str, orientation: int = 0) -> Path: - album_id = client.get_album_id(album_name) - if not album_id: - raise ValueError(f"Album not found: {album_name}") - - assets = [a for a in client.get_album_assets(album_id) if a.get("type") == "IMAGE"] - if not assets: - raise ValueError(f"No photos in album: {album_name}") - - portrait = orientation in (90, 270) - filtered = _filter_by_orientation(assets, portrait) - if filtered: - assets = filtered - else: - print(f"No {'portrait' if portrait else 'landscape'} photos, using any orientation") - return _download_random_asset(client, assets) +#!/usr/bin/env python3 +import hashlib +import json +import random +import tempfile +import time +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from pathlib import Path +from urllib.error import URLError +from urllib.request import Request, urlopen + +from progress import ProgressBar + +HISTORY_FILE = Path(__file__).parent.parent / "photo_history.json" +HISTORY_MAX_AGE_DAYS = 7 + +CACHE_DIR = Path(tempfile.gettempdir()) / "frame_cache" +CACHE_TTL_SECONDS = 3600 + +RETRY_DELAYS = (3, 10) + + +def _urlopen_with_retry(req: Request, timeout: int = 30): + """urlopen wrapper that retries transient network failures.""" + last_err: Exception | None = None + for attempt in range(len(RETRY_DELAYS) + 1): + try: + return urlopen(req, timeout=timeout) + except (URLError, TimeoutError) as e: + last_err = e + if attempt < len(RETRY_DELAYS): + time.sleep(RETRY_DELAYS[attempt]) + raise last_err + + +def _cache_get(key: str) -> list[dict] | None: + path = CACHE_DIR / f"{key}.json" + if not path.exists(): + return None + if time.time() - path.stat().st_mtime > CACHE_TTL_SECONDS: + return None + try: + return json.loads(path.read_text()) + except (json.JSONDecodeError, OSError): + return None + + +def _cache_set(key: str, value: list[dict]) -> None: + CACHE_DIR.mkdir(exist_ok=True) + (CACHE_DIR / f"{key}.json").write_text(json.dumps(value)) + + +class PhotoHistory: + """Track displayed photos to avoid repeats. Clears after 7 days.""" + + def __init__(self, path: Path = HISTORY_FILE): + self.path = path + self.displayed: set[str] = set() + self.created_at: datetime | None = None + self._load() + + def _load(self) -> None: + if not self.path.exists(): + self._reset() + return + try: + data = json.loads(self.path.read_text()) + self.created_at = datetime.fromisoformat(data.get("created_at", "")) + if self.created_at.tzinfo is None: + self.created_at = self.created_at.replace(tzinfo=timezone.utc) + if datetime.now(timezone.utc) - self.created_at > timedelta(days=HISTORY_MAX_AGE_DAYS): + print(f"Photo history expired (>{HISTORY_MAX_AGE_DAYS} days), clearing...") + self._reset() + else: + self.displayed = set(data.get("displayed", [])) + except (json.JSONDecodeError, ValueError, KeyError): + self._reset() + + def _reset(self) -> None: + self.displayed = set() + self.created_at = datetime.now(timezone.utc) + self._save() + + def _save(self) -> None: + self.path.write_text(json.dumps({ + "created_at": self.created_at.isoformat(), + "displayed": list(self.displayed), + }, indent=2)) + + def mark_displayed(self, asset_id: str) -> None: + self.displayed.add(asset_id) + self._save() + + def filter_new(self, assets: list[dict]) -> list[dict]: + return [a for a in assets if a.get("id") not in self.displayed] + + +_history: PhotoHistory | None = None + + +def get_history() -> PhotoHistory: + global _history + if _history is None: + _history = PhotoHistory() + return _history + + +@dataclass +class ImmichClient: + base_url: str + api_key: str + + def _request(self, method: str, endpoint: str, data: dict | None = None, + show_progress: bool = False, progress_desc: str = "Fetching") -> dict: + url = f"{self.base_url.rstrip('/')}/api/{endpoint.lstrip('/')}" + headers = {"x-api-key": self.api_key} + body = None + if data is not None: + headers["Content-Type"] = "application/json" + body = json.dumps(data).encode() + + req = Request(url, data=body, headers=headers, method=method) + with _urlopen_with_retry(req, timeout=30) as resp: + total_size = resp.headers.get('Content-Length') + if total_size and show_progress: + total_size = int(total_size) + progress = ProgressBar(total_size, desc=progress_desc) + chunks = bytearray() + while chunk := resp.read(8192): + chunks.extend(chunk) + progress.update(len(chunk)) + progress.finish() + return json.loads(chunks.decode()) + return json.loads(resp.read().decode()) + + def get_people(self) -> list[dict]: + return self._request("GET", "/people")["people"] + + def get_person_id(self, name: str) -> str | None: + for person in self.get_people(): + if person["name"].lower() == name.lower(): + return person["id"] + return None + + def search_assets_by_people(self, person_ids: list[str]) -> list[dict]: + key = "people_" + hashlib.md5("_".join(sorted(person_ids)).encode()).hexdigest() + cached = _cache_get(key) + if cached is not None: + return cached + + items = [] + page = 1 + while True: + result = self._request("POST", "/search/metadata", { + "personIds": person_ids, + "size": 250, + "page": page, + "type": "IMAGE", + "withExif": True, + }) + batch = result.get("assets", {}).get("items", []) + items.extend(batch) + if not batch or not result.get("assets", {}).get("nextPage"): + break + page += 1 + _cache_set(key, items) + return items + + def download_asset(self, asset_id: str, dest: Path, show_progress: bool = True) -> Path: + url = f"{self.base_url.rstrip('/')}/api/assets/{asset_id}/thumbnail?size=preview" + req = Request(url, headers={"x-api-key": self.api_key}) + with _urlopen_with_retry(req, timeout=30) as resp: + total_size = resp.headers.get('Content-Length') + if total_size and show_progress: + total_size = int(total_size) + progress = ProgressBar(total_size, desc="Downloading") + data = bytearray() + while chunk := resp.read(8192): + data.extend(chunk) + progress.update(len(chunk)) + progress.finish() + dest.write_bytes(bytes(data)) + else: + dest.write_bytes(resp.read()) + return dest + + def get_album_id(self, name: str) -> str | None: + for album in self._request("GET", "/albums"): + if album["albumName"].lower() == name.lower(): + return album["id"] + return None + + def get_album_assets(self, album_id: str, show_progress: bool = False) -> list[dict]: + key = f"album_{album_id}" + cached = _cache_get(key) + if cached is not None: + return cached + + album = self._request("GET", f"/albums/{album_id}", + show_progress=show_progress, progress_desc="Fetching album") + assets = album.get("assets", []) + _cache_set(key, assets) + return assets + + +def _is_portrait(asset: dict) -> bool | None: + """Check if asset displays as portrait, accounting for EXIF orientation.""" + exif = asset.get("exifInfo") or {} + width = exif.get("exifImageWidth") or 0 + height = exif.get("exifImageHeight") or 0 + if not (width and height): + return None + # EXIF orientation 6 and 8 mean 90° rotation (swap dimensions) + orientation = str(exif.get("orientation") or "1") + if orientation in ("6", "8"): + width, height = height, width + return height > width + + +def _filter_by_orientation(assets: list[dict], portrait: bool) -> list[dict]: + """Filter assets by orientation, accounting for EXIF rotation.""" + filtered = [] + no_dimensions = 0 + for asset in assets: + is_portrait = _is_portrait(asset) + if is_portrait is not None: + if is_portrait == portrait: + filtered.append(asset) + else: + no_dimensions += 1 + if no_dimensions: + print(f"Note: {no_dimensions}/{len(assets)} photos missing dimension data") + return filtered + + +def _pick_weighted_random(assets: list[dict]) -> dict: + """Pick random asset, biased towards favorites (20%) and recently added photos (50%).""" + if not assets: + raise ValueError("No assets to choose from") + + cutoff = datetime.now(timezone.utc) - timedelta(days=30) + favorites = [a for a in assets if a.get("isFavorite")] + recent = [] + for asset in assets: + date_str = asset.get("createdAt", "") + try: + if datetime.fromisoformat(date_str.replace("Z", "+00:00")) >= cutoff: + recent.append(asset) + except (ValueError, AttributeError): + pass + + if favorites and random.random() < 0.2: + return random.choice(favorites) + if recent and random.random() < 0.5: + return random.choice(recent) + return random.choice(assets) + + +def _download_random_asset(client: ImmichClient, assets: list[dict]) -> Path: + history = get_history() + new_assets = history.filter_new(assets) + + if new_assets: + print(f"Photos: {len(new_assets)} new / {len(assets)} total") + asset = _pick_weighted_random(new_assets) + else: + print(f"All {len(assets)} photos shown, picking from full list") + asset = _pick_weighted_random(assets) + + dest = Path(tempfile.gettempdir()) / "immich_photo.jpg" + path = client.download_asset(asset["id"], dest) + history.mark_displayed(asset["id"]) + return path + + +def get_random_photo_of_people(client: ImmichClient, names: list[str], orientation: int = 0) -> Path: + person_ids = [pid for name in names if (pid := client.get_person_id(name))] + if not person_ids: + raise ValueError(f"No people found: {names}") + + assets = client.search_assets_by_people(person_ids) + + if not assets: + raise ValueError(f"No photos found for: {names}") + + portrait = orientation in (90, 270) + filtered = _filter_by_orientation(assets, portrait) + if not filtered: + raise ValueError(f"No {'portrait' if portrait else 'landscape'} photos available") + return _download_random_asset(client, filtered) + + +def get_random_photo_from_album(client: ImmichClient, album_name: str, orientation: int = 0) -> Path: + album_id = client.get_album_id(album_name) + if not album_id: + raise ValueError(f"Album not found: {album_name}") + + assets = [a for a in client.get_album_assets(album_id) if a.get("type") == "IMAGE"] + if not assets: + raise ValueError(f"No photos in album: {album_name}") + + portrait = orientation in (90, 270) + filtered = _filter_by_orientation(assets, portrait) + if not filtered: + raise ValueError(f"No {'portrait' if portrait else 'landscape'} photos in album: {album_name}") + return _download_random_asset(client, filtered) diff --git a/src/lib/progress.py b/src/lib/progress.py index 335544e..2240e9c 100644 --- a/src/lib/progress.py +++ b/src/lib/progress.py @@ -47,8 +47,3 @@ class ProgressBar: """Complete the progress bar.""" self.current = self.total self._render() - - -def print_status(msg: str) -> None: - """Print a status message.""" - print(f" {msg}") diff --git a/src/lib/waveshare_epd/epd7in3e.py b/src/lib/waveshare_epd/epd7in3e.py index 7114130..3c2f235 100644 --- a/src/lib/waveshare_epd/epd7in3e.py +++ b/src/lib/waveshare_epd/epd7in3e.py @@ -2,38 +2,33 @@ # Waveshare 7.3" 6-color e-Paper driver (modified) # Original: Waveshare team, 2022-10-20 -import sys import numpy as np import cv2 from PIL import Image, ImageEnhance from numba import jit +from progress import ProgressBar from . import epdconfig EPD_WIDTH = 800 EPD_HEIGHT = 480 -DEFAULT_SATURATION = 1.4 -DEFAULT_CONTRAST = 1.2 -DEFAULT_GAMMA = 0.9 - +# 6-color e-ink encoding: indices 0,1,2,3,5,6 are wire-format colors; +# 4 is reserved/unused (filled with BLACK so nearest-color never picks it). PALETTE_RGB = np.array([ - [0, 0, 0], # BLACK - [255, 255, 255], # WHITE - [255, 255, 0], # YELLOW - [255, 0, 0], # RED - [0, 0, 255], # BLUE - [0, 255, 0], # GREEN + [0, 0, 0], # 0: BLACK + [255, 255, 255], # 1: WHITE + [255, 255, 0], # 2: YELLOW + [255, 0, 0], # 3: RED + [0, 0, 0], # 4: unused + [0, 0, 255], # 5: BLUE + [0, 255, 0], # 6: GREEN ], dtype=np.float64) PERCEPTUAL_WEIGHTS = np.array([0.299, 0.587, 0.114], dtype=np.float64) -def _enhance_for_eink(image: Image.Image, saturation: float = None, - contrast: float = None, gamma: float = None) -> Image.Image: - saturation = saturation or DEFAULT_SATURATION - contrast = contrast or DEFAULT_CONTRAST - gamma = gamma or DEFAULT_GAMMA - +def _enhance_for_eink(image: Image.Image, saturation: float, + contrast: float, gamma: float) -> Image.Image: img = image.convert('RGB') if saturation != 1.0: img = ImageEnhance.Color(img).enhance(saturation) @@ -66,18 +61,6 @@ def _crop_center(image: Image.Image, target_w: int, target_h: int, return Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) -def _render_progress(desc: str, current: int, total: int, width: int = 30) -> None: - if total == 0: - return - percent = int(100 * current / total) - filled = int(width * current / total) - bar = "█" * filled + "░" * (width - filled) - sys.stdout.write(f"\r{desc}: |{bar}| {percent:3d}%") - sys.stdout.flush() - if current >= total: - print() - - @jit(nopython=True, cache=True) def _find_nearest_color(r, g, b, palette, weights): best_idx, best_dist = 0, 1e10 @@ -92,14 +75,14 @@ def _find_nearest_color(r, g, b, palette, weights): @jit(nopython=True, cache=True) -def _atkinson_dither_rows(img, palette, weights, start_row, end_row): +def _atkinson_dither_rows(img, palette, weights, indices, start_row, end_row): height, width = img.shape[:2] for y in range(start_row, end_row): for x in range(width): old_r, old_g, old_b = img[y, x, 0], img[y, x, 1], img[y, x, 2] idx = _find_nearest_color(old_r, old_g, old_b, palette, weights) + indices[y, x] = idx new_r, new_g, new_b = palette[idx, 0], palette[idx, 1], palette[idx, 2] - img[y, x, 0], img[y, x, 1], img[y, x, 2] = new_r, new_g, new_b err_r, err_g, err_b = (old_r - new_r) / 8.0, (old_g - new_g) / 8.0, (old_b - new_b) / 8.0 @@ -127,23 +110,25 @@ def _atkinson_dither_rows(img, palette, weights, start_row, end_row): img[y + 2, x, 0] += err_r img[y + 2, x, 1] += err_g img[y + 2, x, 2] += err_b - return img -def _dither_atkinson(image: Image.Image, show_progress: bool = True) -> Image.Image: +def _dither_atkinson(image: Image.Image, show_progress: bool = True) -> np.ndarray: + """Atkinson-dither to the e-ink palette and return a uint8 array of palette indices.""" img = np.array(image.convert('RGB'), dtype=np.float64) - height = img.shape[0] + height, width = img.shape[:2] + indices = np.zeros((height, width), dtype=np.uint8) if show_progress: print("Dithering...") + progress = ProgressBar(height, desc="Dithering") chunk_size = 48 for i in range((height + chunk_size - 1) // chunk_size): start, end = i * chunk_size, min((i + 1) * chunk_size, height) - img = _atkinson_dither_rows(img, PALETTE_RGB, PERCEPTUAL_WEIGHTS, start, end) + _atkinson_dither_rows(img, PALETTE_RGB, PERCEPTUAL_WEIGHTS, indices, start, end) if show_progress: - _render_progress("Dithering", end, height) + progress.set(end) - return Image.fromarray(np.clip(img, 0, 255).astype(np.uint8), 'RGB') + return indices class EPD: @@ -253,11 +238,8 @@ class EPD: self.wait_busy() return 0 - def getbuffer(self, image, saturation=None, contrast=None, gamma=None, - enhance=True, show_progress=True): - pal_image = Image.new("P", (1, 1)) - pal_image.putpalette((0,0,0, 255,255,255, 255,255,0, 255,0,0, 0,0,0, 0,0,255, 0,255,0) + (0,0,0)*249) - + def getbuffer(self, image, saturation: float, contrast: float, gamma: float, + enhance: bool = True, show_progress: bool = True): image = image.convert('RGB') imwidth, imheight = image.size @@ -271,16 +253,13 @@ class EPD: print("Enhancing...") image = _enhance_for_eink(image, saturation, contrast, gamma) - image = _dither_atkinson(image, show_progress) + indices = _dither_atkinson(image, show_progress) if show_progress: print("Packing buffer...") - image_6color = image.quantize(palette=pal_image, dither=Image.Dither.NONE) - buf_6color = bytearray(image_6color.tobytes('raw')) - - buf = [0x00] * (self.width * self.height // 2) - for i in range(0, len(buf_6color), 2): - buf[i // 2] = (buf_6color[i] << 4) + buf_6color[i + 1] + flat = indices.reshape(-1) + packed = (flat[0::2].astype(np.uint8) << 4) | flat[1::2].astype(np.uint8) + buf = packed.tolist() if show_progress: print("Ready")