Retries and simplification

This commit is contained in:
Andras Schmelczer 2026-04-26 14:22:33 +01:00
parent d513b17f93
commit de65fbee9f
6 changed files with 391 additions and 345 deletions

View file

@ -30,18 +30,20 @@ python3 display.py --saturation 1.5 --contrast 1.1 --gamma 0.85
4. Sends to e-ink display driver
**`src/lib/immich.py`** — Immich API client. Key behaviors:
- `PhotoHistory` tracks displayed photos in `photo_history.json` to avoid repeats (resets after 7 days)
- `_pick_weighted_random()` biases selection: 50% chance favorites, 50% chance recent (last 7 days), otherwise random
- Filters photos by orientation (portrait/landscape) based on EXIF data including rotation tags
- `PhotoHistory` tracks displayed photos in `photo_history.json` to avoid repeats (resets after 7 days). Asset is only marked displayed after a successful download.
- `_pick_weighted_random()` biases selection: 20% favorites, 50% recently-added (last 30 days, by Immich `createdAt`), otherwise uniform random
- Filters photos by orientation (portrait/landscape) based on EXIF data including rotation tags. Raises if nothing matches the requested orientation.
- Downloads preview-size thumbnails, not originals
- Asset lists (people-search and album) are cached on disk in `/tmp/frame_cache/` for 1 hour
- `urlopen` calls retry transient failures twice (3s, 10s backoff)
**`src/lib/homeassistant.py`** — Simple Home Assistant REST client for presence detection.
**`src/lib/waveshare_epd/epd7in3e.py`** — Modified Waveshare driver. The `getbuffer()` method handles the full image pipeline:
- Center-crops to 800x480 (or 480x800)
- Enhances saturation/contrast/gamma for e-ink (defaults: saturation=1.4, contrast=1.2, gamma=0.9)
- Atkinson dithering to 6-color palette using numba JIT
- Packs into 4-bit-per-pixel buffer (two pixels per byte)
- Enhances saturation/contrast/gamma for e-ink (caller passes values; CLI defaults live in `display.py`: saturation=1.3, contrast=1.05, gamma=0.90)
- Atkinson dithering to 6-color palette using numba JIT; produces palette indices directly (no Pillow quantize round-trip)
- Packs into 4-bit-per-pixel buffer (two pixels per byte) via numpy
**`src/lib/waveshare_epd/epdconfig.py`** — GPIO/SPI hardware config. **Critical: PWR pin is BCM 27** (not default 18).
@ -55,4 +57,5 @@ python3 display.py --saturation 1.5 --contrast 1.1 --gamma 0.85
- **Dependencies on Pi**: `python3-pil python3-opencv python3-numba python3-smbus spidev gpiozero`
- **Config via environment variables**: `IMMICH_URL`, `IMMICH_API_KEY`, `HA_URL`, `HA_TOKEN` (with hardcoded defaults in display.py)
- **Uses only stdlib `urllib`** — no requests library; the Immich client uses `urllib.request` directly
- **Single-instance lock** at `/tmp/frame.lock` (fcntl) — overlapping cron runs exit cleanly
- `sys.path.append` is used to add `lib/` to the path from display.py

View file

@ -1,5 +1,6 @@
#!/usr/bin/env python3
import argparse
import fcntl
import os
import sys
from datetime import datetime
@ -12,6 +13,8 @@ from waveshare_epd import epd7in3e
from immich import ImmichClient, get_random_photo_of_people, get_random_photo_from_album
from homeassistant import HomeAssistantClient
LOCK_FILE = "/tmp/frame.lock"
IMMICH_URL = os.environ.get("IMMICH_URL", "https://immich.example.com")
IMMICH_API_KEY = os.environ.get("IMMICH_API_KEY", "REDACTED_IMMICH_API_KEY")
@ -51,6 +54,13 @@ def main() -> None:
parser.add_argument("--no-enhance", action="store_true")
args = parser.parse_args()
lock_fd = open(LOCK_FILE, "w")
try:
fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
except BlockingIOError:
print("Another instance running, skipping")
sys.exit(0)
now = datetime.now()
print(f"Time: {now.strftime('%H:%M')}")

View file

@ -1,25 +1,37 @@
#!/usr/bin/env python3
import json
from urllib.request import Request, urlopen
class HomeAssistantClient:
def __init__(self, base_url: str, token: str):
self.base_url = base_url.rstrip("/")
self.token = token
def get_state(self, entity_id: str) -> dict:
url = f"{self.base_url}/api/states/{entity_id}"
req = Request(url, headers={
"Authorization": f"Bearer {self.token}",
"Content-Type": "application/json",
})
with urlopen(req, timeout=30) as resp:
return json.loads(resp.read().decode())
def is_person_home(self, entity_id: str) -> bool:
try:
return self.get_state(entity_id).get("state") == "home"
except Exception as e:
print(f"Failed to check {entity_id}: {e}")
return False
#!/usr/bin/env python3
import json
import time
from urllib.error import URLError
from urllib.request import Request, urlopen
RETRY_DELAYS = (3, 10)
class HomeAssistantClient:
def __init__(self, base_url: str, token: str):
self.base_url = base_url.rstrip("/")
self.token = token
def get_state(self, entity_id: str) -> dict:
url = f"{self.base_url}/api/states/{entity_id}"
req = Request(url, headers={
"Authorization": f"Bearer {self.token}",
"Content-Type": "application/json",
})
last_err: Exception | None = None
for attempt in range(len(RETRY_DELAYS) + 1):
try:
with urlopen(req, timeout=30) as resp:
return json.loads(resp.read().decode())
except (URLError, TimeoutError) as e:
last_err = e
if attempt < len(RETRY_DELAYS):
time.sleep(RETRY_DELAYS[attempt])
raise last_err
def is_person_home(self, entity_id: str) -> bool:
try:
return self.get_state(entity_id).get("state") == "home"
except Exception as e:
print(f"Failed to check {entity_id}: {e}")
return False

View file

@ -1,260 +1,307 @@
#!/usr/bin/env python3
import json
import random
import tempfile
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from pathlib import Path
from urllib.request import Request, urlopen
from progress import ProgressBar
HISTORY_FILE = Path(__file__).parent.parent / "photo_history.json"
HISTORY_MAX_AGE_DAYS = 7
class PhotoHistory:
"""Track displayed photos to avoid repeats. Clears after 7 days."""
def __init__(self, path: Path = HISTORY_FILE):
self.path = path
self.displayed: set[str] = set()
self.created_at: datetime | None = None
self._load()
def _load(self) -> None:
if not self.path.exists():
self._reset()
return
try:
data = json.loads(self.path.read_text())
self.created_at = datetime.fromisoformat(data.get("created_at", ""))
if self.created_at.tzinfo is None:
self.created_at = self.created_at.replace(tzinfo=timezone.utc)
if datetime.now(timezone.utc) - self.created_at > timedelta(days=HISTORY_MAX_AGE_DAYS):
print(f"Photo history expired (>{HISTORY_MAX_AGE_DAYS} days), clearing...")
self._reset()
else:
self.displayed = set(data.get("displayed", []))
except (json.JSONDecodeError, ValueError, KeyError):
self._reset()
def _reset(self) -> None:
self.displayed = set()
self.created_at = datetime.now(timezone.utc)
self._save()
def _save(self) -> None:
self.path.write_text(json.dumps({
"created_at": self.created_at.isoformat(),
"displayed": list(self.displayed),
}, indent=2))
def mark_displayed(self, asset_id: str) -> None:
self.displayed.add(asset_id)
self._save()
def filter_new(self, assets: list[dict]) -> list[dict]:
return [a for a in assets if a.get("id") not in self.displayed]
_history: PhotoHistory | None = None
_people_cache: dict[str, str] = {} # name -> id cache
def get_history() -> PhotoHistory:
global _history
if _history is None:
_history = PhotoHistory()
return _history
@dataclass
class ImmichClient:
base_url: str
api_key: str
def _request(self, method: str, endpoint: str, data: dict | None = None,
show_progress: bool = False, progress_desc: str = "Fetching") -> dict:
url = f"{self.base_url.rstrip('/')}/api/{endpoint.lstrip('/')}"
headers = {"x-api-key": self.api_key}
body = None
if data is not None:
headers["Content-Type"] = "application/json"
body = json.dumps(data).encode()
req = Request(url, data=body, headers=headers, method=method)
with urlopen(req, timeout=30) as resp:
total_size = resp.headers.get('Content-Length')
if total_size and show_progress:
total_size = int(total_size)
progress = ProgressBar(total_size, desc=progress_desc)
chunks = bytearray()
while chunk := resp.read(8192):
chunks.extend(chunk)
progress.update(len(chunk))
progress.finish()
return json.loads(chunks.decode())
return json.loads(resp.read().decode())
def get_people(self) -> list[dict]:
return self._request("GET", "/people")["people"]
def get_person_id(self, name: str) -> str | None:
for person in self.get_people():
if person["name"].lower() == name.lower():
return person["id"]
return None
def search_assets_by_people(self, person_ids: list[str]) -> list[dict]:
items = []
page = 1
while True:
result = self._request("POST", "/search/metadata", {
"personIds": person_ids,
"size": 250,
"page": page,
"type": "IMAGE",
"withExif": True,
})
batch = result.get("assets", {}).get("items", [])
items.extend(batch)
if not batch or not result.get("assets", {}).get("nextPage"):
break
page += 1
return items
def download_asset(self, asset_id: str, dest: Path, show_progress: bool = True) -> Path:
url = f"{self.base_url.rstrip('/')}/api/assets/{asset_id}/thumbnail?size=preview"
req = Request(url, headers={"x-api-key": self.api_key})
with urlopen(req, timeout=30) as resp:
total_size = resp.headers.get('Content-Length')
if total_size and show_progress:
total_size = int(total_size)
progress = ProgressBar(total_size, desc="Downloading")
data = bytearray()
while chunk := resp.read(8192):
data.extend(chunk)
progress.update(len(chunk))
progress.finish()
dest.write_bytes(bytes(data))
else:
dest.write_bytes(resp.read())
return dest
def get_album_id(self, name: str) -> str | None:
for album in self._request("GET", "/albums"):
if album["albumName"].lower() == name.lower():
return album["id"]
return None
def get_album_assets(self, album_id: str, show_progress: bool = False) -> list[dict]:
album = self._request("GET", f"/albums/{album_id}",
show_progress=show_progress, progress_desc="Fetching album")
return album.get("assets", [])
def _is_portrait(asset: dict) -> bool | None:
"""Check if asset displays as portrait, accounting for EXIF orientation."""
exif = asset.get("exifInfo") or {}
width = exif.get("exifImageWidth") or 0
height = exif.get("exifImageHeight") or 0
if not (width and height):
return None
# EXIF orientation 6 and 8 mean 90° rotation (swap dimensions)
orientation = str(exif.get("orientation") or "1")
if orientation in ("6", "8"):
width, height = height, width
return height > width
def _filter_by_orientation(assets: list[dict], portrait: bool) -> list[dict]:
"""Filter assets by orientation, accounting for EXIF rotation."""
filtered = []
no_dimensions = 0
for asset in assets:
is_portrait = _is_portrait(asset)
if is_portrait is not None:
if is_portrait == portrait:
filtered.append(asset)
else:
no_dimensions += 1
if no_dimensions:
print(f"Note: {no_dimensions}/{len(assets)} photos missing dimension data")
return filtered
def _pick_weighted_random(assets: list[dict]) -> dict:
"""Pick random asset, biased towards favorites (20%) and recently added photos (50%)."""
if not assets:
raise ValueError("No assets to choose from")
cutoff = datetime.now(timezone.utc) - timedelta(days=30)
favorites = [a for a in assets if a.get("isFavorite")]
recent = []
for asset in assets:
date_str = asset.get("createdAt", "")
try:
if datetime.fromisoformat(date_str.replace("Z", "+00:00")) >= cutoff:
recent.append(asset)
except (ValueError, AttributeError):
pass
if favorites and random.random() < 0.2:
return random.choice(favorites)
if recent and random.random() < 0.5:
return random.choice(recent)
return random.choice(assets)
def _download_random_asset(client: ImmichClient, assets: list[dict]) -> Path:
history = get_history()
new_assets = history.filter_new(assets)
if new_assets:
print(f"Photos: {len(new_assets)} new / {len(assets)} total")
asset = _pick_weighted_random(new_assets)
else:
print(f"All {len(assets)} photos shown, picking from full list")
asset = _pick_weighted_random(assets)
history.mark_displayed(asset["id"])
dest = Path(tempfile.gettempdir()) / "immich_photo.jpg"
return client.download_asset(asset["id"], dest)
def get_random_photo_of_people(client: ImmichClient, names: list[str], orientation: int = 0) -> Path:
person_ids = [pid for name in names if (pid := client.get_person_id(name))]
if not person_ids:
raise ValueError(f"No people found: {names}")
assets = client.search_assets_by_people(person_ids)
if not assets:
raise ValueError(f"No photos found for: {names}")
portrait = orientation in (90, 270)
filtered = _filter_by_orientation(assets, portrait)
if filtered:
assets = filtered
else:
print(f"No {'portrait' if portrait else 'landscape'} photos, using any orientation")
return _download_random_asset(client, assets)
def get_random_photo_from_album(client: ImmichClient, album_name: str, orientation: int = 0) -> Path:
album_id = client.get_album_id(album_name)
if not album_id:
raise ValueError(f"Album not found: {album_name}")
assets = [a for a in client.get_album_assets(album_id) if a.get("type") == "IMAGE"]
if not assets:
raise ValueError(f"No photos in album: {album_name}")
portrait = orientation in (90, 270)
filtered = _filter_by_orientation(assets, portrait)
if filtered:
assets = filtered
else:
print(f"No {'portrait' if portrait else 'landscape'} photos, using any orientation")
return _download_random_asset(client, assets)
#!/usr/bin/env python3
import hashlib
import json
import random
import tempfile
import time
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from pathlib import Path
from urllib.error import URLError
from urllib.request import Request, urlopen
from progress import ProgressBar
HISTORY_FILE = Path(__file__).parent.parent / "photo_history.json"
HISTORY_MAX_AGE_DAYS = 7
CACHE_DIR = Path(tempfile.gettempdir()) / "frame_cache"
CACHE_TTL_SECONDS = 3600
RETRY_DELAYS = (3, 10)
def _urlopen_with_retry(req: Request, timeout: int = 30):
"""urlopen wrapper that retries transient network failures."""
last_err: Exception | None = None
for attempt in range(len(RETRY_DELAYS) + 1):
try:
return urlopen(req, timeout=timeout)
except (URLError, TimeoutError) as e:
last_err = e
if attempt < len(RETRY_DELAYS):
time.sleep(RETRY_DELAYS[attempt])
raise last_err
def _cache_get(key: str) -> list[dict] | None:
path = CACHE_DIR / f"{key}.json"
if not path.exists():
return None
if time.time() - path.stat().st_mtime > CACHE_TTL_SECONDS:
return None
try:
return json.loads(path.read_text())
except (json.JSONDecodeError, OSError):
return None
def _cache_set(key: str, value: list[dict]) -> None:
CACHE_DIR.mkdir(exist_ok=True)
(CACHE_DIR / f"{key}.json").write_text(json.dumps(value))
class PhotoHistory:
"""Track displayed photos to avoid repeats. Clears after 7 days."""
def __init__(self, path: Path = HISTORY_FILE):
self.path = path
self.displayed: set[str] = set()
self.created_at: datetime | None = None
self._load()
def _load(self) -> None:
if not self.path.exists():
self._reset()
return
try:
data = json.loads(self.path.read_text())
self.created_at = datetime.fromisoformat(data.get("created_at", ""))
if self.created_at.tzinfo is None:
self.created_at = self.created_at.replace(tzinfo=timezone.utc)
if datetime.now(timezone.utc) - self.created_at > timedelta(days=HISTORY_MAX_AGE_DAYS):
print(f"Photo history expired (>{HISTORY_MAX_AGE_DAYS} days), clearing...")
self._reset()
else:
self.displayed = set(data.get("displayed", []))
except (json.JSONDecodeError, ValueError, KeyError):
self._reset()
def _reset(self) -> None:
self.displayed = set()
self.created_at = datetime.now(timezone.utc)
self._save()
def _save(self) -> None:
self.path.write_text(json.dumps({
"created_at": self.created_at.isoformat(),
"displayed": list(self.displayed),
}, indent=2))
def mark_displayed(self, asset_id: str) -> None:
self.displayed.add(asset_id)
self._save()
def filter_new(self, assets: list[dict]) -> list[dict]:
return [a for a in assets if a.get("id") not in self.displayed]
_history: PhotoHistory | None = None
def get_history() -> PhotoHistory:
global _history
if _history is None:
_history = PhotoHistory()
return _history
@dataclass
class ImmichClient:
base_url: str
api_key: str
def _request(self, method: str, endpoint: str, data: dict | None = None,
show_progress: bool = False, progress_desc: str = "Fetching") -> dict:
url = f"{self.base_url.rstrip('/')}/api/{endpoint.lstrip('/')}"
headers = {"x-api-key": self.api_key}
body = None
if data is not None:
headers["Content-Type"] = "application/json"
body = json.dumps(data).encode()
req = Request(url, data=body, headers=headers, method=method)
with _urlopen_with_retry(req, timeout=30) as resp:
total_size = resp.headers.get('Content-Length')
if total_size and show_progress:
total_size = int(total_size)
progress = ProgressBar(total_size, desc=progress_desc)
chunks = bytearray()
while chunk := resp.read(8192):
chunks.extend(chunk)
progress.update(len(chunk))
progress.finish()
return json.loads(chunks.decode())
return json.loads(resp.read().decode())
def get_people(self) -> list[dict]:
return self._request("GET", "/people")["people"]
def get_person_id(self, name: str) -> str | None:
for person in self.get_people():
if person["name"].lower() == name.lower():
return person["id"]
return None
def search_assets_by_people(self, person_ids: list[str]) -> list[dict]:
key = "people_" + hashlib.md5("_".join(sorted(person_ids)).encode()).hexdigest()
cached = _cache_get(key)
if cached is not None:
return cached
items = []
page = 1
while True:
result = self._request("POST", "/search/metadata", {
"personIds": person_ids,
"size": 250,
"page": page,
"type": "IMAGE",
"withExif": True,
})
batch = result.get("assets", {}).get("items", [])
items.extend(batch)
if not batch or not result.get("assets", {}).get("nextPage"):
break
page += 1
_cache_set(key, items)
return items
def download_asset(self, asset_id: str, dest: Path, show_progress: bool = True) -> Path:
url = f"{self.base_url.rstrip('/')}/api/assets/{asset_id}/thumbnail?size=preview"
req = Request(url, headers={"x-api-key": self.api_key})
with _urlopen_with_retry(req, timeout=30) as resp:
total_size = resp.headers.get('Content-Length')
if total_size and show_progress:
total_size = int(total_size)
progress = ProgressBar(total_size, desc="Downloading")
data = bytearray()
while chunk := resp.read(8192):
data.extend(chunk)
progress.update(len(chunk))
progress.finish()
dest.write_bytes(bytes(data))
else:
dest.write_bytes(resp.read())
return dest
def get_album_id(self, name: str) -> str | None:
for album in self._request("GET", "/albums"):
if album["albumName"].lower() == name.lower():
return album["id"]
return None
def get_album_assets(self, album_id: str, show_progress: bool = False) -> list[dict]:
key = f"album_{album_id}"
cached = _cache_get(key)
if cached is not None:
return cached
album = self._request("GET", f"/albums/{album_id}",
show_progress=show_progress, progress_desc="Fetching album")
assets = album.get("assets", [])
_cache_set(key, assets)
return assets
def _is_portrait(asset: dict) -> bool | None:
"""Check if asset displays as portrait, accounting for EXIF orientation."""
exif = asset.get("exifInfo") or {}
width = exif.get("exifImageWidth") or 0
height = exif.get("exifImageHeight") or 0
if not (width and height):
return None
# EXIF orientation 6 and 8 mean 90° rotation (swap dimensions)
orientation = str(exif.get("orientation") or "1")
if orientation in ("6", "8"):
width, height = height, width
return height > width
def _filter_by_orientation(assets: list[dict], portrait: bool) -> list[dict]:
"""Filter assets by orientation, accounting for EXIF rotation."""
filtered = []
no_dimensions = 0
for asset in assets:
is_portrait = _is_portrait(asset)
if is_portrait is not None:
if is_portrait == portrait:
filtered.append(asset)
else:
no_dimensions += 1
if no_dimensions:
print(f"Note: {no_dimensions}/{len(assets)} photos missing dimension data")
return filtered
def _pick_weighted_random(assets: list[dict]) -> dict:
"""Pick random asset, biased towards favorites (20%) and recently added photos (50%)."""
if not assets:
raise ValueError("No assets to choose from")
cutoff = datetime.now(timezone.utc) - timedelta(days=30)
favorites = [a for a in assets if a.get("isFavorite")]
recent = []
for asset in assets:
date_str = asset.get("createdAt", "")
try:
if datetime.fromisoformat(date_str.replace("Z", "+00:00")) >= cutoff:
recent.append(asset)
except (ValueError, AttributeError):
pass
if favorites and random.random() < 0.2:
return random.choice(favorites)
if recent and random.random() < 0.5:
return random.choice(recent)
return random.choice(assets)
def _download_random_asset(client: ImmichClient, assets: list[dict]) -> Path:
history = get_history()
new_assets = history.filter_new(assets)
if new_assets:
print(f"Photos: {len(new_assets)} new / {len(assets)} total")
asset = _pick_weighted_random(new_assets)
else:
print(f"All {len(assets)} photos shown, picking from full list")
asset = _pick_weighted_random(assets)
dest = Path(tempfile.gettempdir()) / "immich_photo.jpg"
path = client.download_asset(asset["id"], dest)
history.mark_displayed(asset["id"])
return path
def get_random_photo_of_people(client: ImmichClient, names: list[str], orientation: int = 0) -> Path:
person_ids = [pid for name in names if (pid := client.get_person_id(name))]
if not person_ids:
raise ValueError(f"No people found: {names}")
assets = client.search_assets_by_people(person_ids)
if not assets:
raise ValueError(f"No photos found for: {names}")
portrait = orientation in (90, 270)
filtered = _filter_by_orientation(assets, portrait)
if not filtered:
raise ValueError(f"No {'portrait' if portrait else 'landscape'} photos available")
return _download_random_asset(client, filtered)
def get_random_photo_from_album(client: ImmichClient, album_name: str, orientation: int = 0) -> Path:
album_id = client.get_album_id(album_name)
if not album_id:
raise ValueError(f"Album not found: {album_name}")
assets = [a for a in client.get_album_assets(album_id) if a.get("type") == "IMAGE"]
if not assets:
raise ValueError(f"No photos in album: {album_name}")
portrait = orientation in (90, 270)
filtered = _filter_by_orientation(assets, portrait)
if not filtered:
raise ValueError(f"No {'portrait' if portrait else 'landscape'} photos in album: {album_name}")
return _download_random_asset(client, filtered)

View file

@ -47,8 +47,3 @@ class ProgressBar:
"""Complete the progress bar."""
self.current = self.total
self._render()
def print_status(msg: str) -> None:
"""Print a status message."""
print(f" {msg}")

View file

@ -2,38 +2,33 @@
# Waveshare 7.3" 6-color e-Paper driver (modified)
# Original: Waveshare team, 2022-10-20
import sys
import numpy as np
import cv2
from PIL import Image, ImageEnhance
from numba import jit
from progress import ProgressBar
from . import epdconfig
EPD_WIDTH = 800
EPD_HEIGHT = 480
DEFAULT_SATURATION = 1.4
DEFAULT_CONTRAST = 1.2
DEFAULT_GAMMA = 0.9
# 6-color e-ink encoding: indices 0,1,2,3,5,6 are wire-format colors;
# 4 is reserved/unused (filled with BLACK so nearest-color never picks it).
PALETTE_RGB = np.array([
[0, 0, 0], # BLACK
[255, 255, 255], # WHITE
[255, 255, 0], # YELLOW
[255, 0, 0], # RED
[0, 0, 255], # BLUE
[0, 255, 0], # GREEN
[0, 0, 0], # 0: BLACK
[255, 255, 255], # 1: WHITE
[255, 255, 0], # 2: YELLOW
[255, 0, 0], # 3: RED
[0, 0, 0], # 4: unused
[0, 0, 255], # 5: BLUE
[0, 255, 0], # 6: GREEN
], dtype=np.float64)
PERCEPTUAL_WEIGHTS = np.array([0.299, 0.587, 0.114], dtype=np.float64)
def _enhance_for_eink(image: Image.Image, saturation: float = None,
contrast: float = None, gamma: float = None) -> Image.Image:
saturation = saturation or DEFAULT_SATURATION
contrast = contrast or DEFAULT_CONTRAST
gamma = gamma or DEFAULT_GAMMA
def _enhance_for_eink(image: Image.Image, saturation: float,
contrast: float, gamma: float) -> Image.Image:
img = image.convert('RGB')
if saturation != 1.0:
img = ImageEnhance.Color(img).enhance(saturation)
@ -66,18 +61,6 @@ def _crop_center(image: Image.Image, target_w: int, target_h: int,
return Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
def _render_progress(desc: str, current: int, total: int, width: int = 30) -> None:
if total == 0:
return
percent = int(100 * current / total)
filled = int(width * current / total)
bar = "" * filled + "" * (width - filled)
sys.stdout.write(f"\r{desc}: |{bar}| {percent:3d}%")
sys.stdout.flush()
if current >= total:
print()
@jit(nopython=True, cache=True)
def _find_nearest_color(r, g, b, palette, weights):
best_idx, best_dist = 0, 1e10
@ -92,14 +75,14 @@ def _find_nearest_color(r, g, b, palette, weights):
@jit(nopython=True, cache=True)
def _atkinson_dither_rows(img, palette, weights, start_row, end_row):
def _atkinson_dither_rows(img, palette, weights, indices, start_row, end_row):
height, width = img.shape[:2]
for y in range(start_row, end_row):
for x in range(width):
old_r, old_g, old_b = img[y, x, 0], img[y, x, 1], img[y, x, 2]
idx = _find_nearest_color(old_r, old_g, old_b, palette, weights)
indices[y, x] = idx
new_r, new_g, new_b = palette[idx, 0], palette[idx, 1], palette[idx, 2]
img[y, x, 0], img[y, x, 1], img[y, x, 2] = new_r, new_g, new_b
err_r, err_g, err_b = (old_r - new_r) / 8.0, (old_g - new_g) / 8.0, (old_b - new_b) / 8.0
@ -127,23 +110,25 @@ def _atkinson_dither_rows(img, palette, weights, start_row, end_row):
img[y + 2, x, 0] += err_r
img[y + 2, x, 1] += err_g
img[y + 2, x, 2] += err_b
return img
def _dither_atkinson(image: Image.Image, show_progress: bool = True) -> Image.Image:
def _dither_atkinson(image: Image.Image, show_progress: bool = True) -> np.ndarray:
"""Atkinson-dither to the e-ink palette and return a uint8 array of palette indices."""
img = np.array(image.convert('RGB'), dtype=np.float64)
height = img.shape[0]
height, width = img.shape[:2]
indices = np.zeros((height, width), dtype=np.uint8)
if show_progress:
print("Dithering...")
progress = ProgressBar(height, desc="Dithering")
chunk_size = 48
for i in range((height + chunk_size - 1) // chunk_size):
start, end = i * chunk_size, min((i + 1) * chunk_size, height)
img = _atkinson_dither_rows(img, PALETTE_RGB, PERCEPTUAL_WEIGHTS, start, end)
_atkinson_dither_rows(img, PALETTE_RGB, PERCEPTUAL_WEIGHTS, indices, start, end)
if show_progress:
_render_progress("Dithering", end, height)
progress.set(end)
return Image.fromarray(np.clip(img, 0, 255).astype(np.uint8), 'RGB')
return indices
class EPD:
@ -253,11 +238,8 @@ class EPD:
self.wait_busy()
return 0
def getbuffer(self, image, saturation=None, contrast=None, gamma=None,
enhance=True, show_progress=True):
pal_image = Image.new("P", (1, 1))
pal_image.putpalette((0,0,0, 255,255,255, 255,255,0, 255,0,0, 0,0,0, 0,0,255, 0,255,0) + (0,0,0)*249)
def getbuffer(self, image, saturation: float, contrast: float, gamma: float,
enhance: bool = True, show_progress: bool = True):
image = image.convert('RGB')
imwidth, imheight = image.size
@ -271,16 +253,13 @@ class EPD:
print("Enhancing...")
image = _enhance_for_eink(image, saturation, contrast, gamma)
image = _dither_atkinson(image, show_progress)
indices = _dither_atkinson(image, show_progress)
if show_progress:
print("Packing buffer...")
image_6color = image.quantize(palette=pal_image, dither=Image.Dither.NONE)
buf_6color = bytearray(image_6color.tobytes('raw'))
buf = [0x00] * (self.width * self.height // 2)
for i in range(0, len(buf_6color), 2):
buf[i // 2] = (buf_6color[i] << 4) + buf_6color[i + 1]
flat = indices.reshape(-1)
packed = (flat[0::2].astype(np.uint8) << 4) | flat[1::2].astype(np.uint8)
buf = packed.tolist()
if show_progress:
print("Ready")