Make cached files 10x smaller

This commit is contained in:
Andras Schmelczer 2024-08-28 20:42:10 +01:00
parent 9d2830f8ff
commit 9e286264b8
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C

View file

@ -6,8 +6,10 @@ from PIL import Image
import logging
import torch
from pathlib import Path
import struct
import zlib
import PIL.Image
import numpy as np
PIL.Image.MAX_IMAGE_PIXELS = None
@ -54,33 +56,72 @@ class HistogramDataset(Dataset):
original_idx = idx // self._edit_count
edit_idx = idx % self._edit_count
edited_histogram = None
original_histogram = None
cached_data_path = None
if self._cache_path is not None:
_cached_data_path = self._cache_path / str(original_idx) / f"{edit_idx}.pt"
_cached_data_path.parent.mkdir(parents=True, exist_ok=True)
if _cached_data_path.exists():
try:
return torch.load(_cached_data_path)
except:
logging.warning(
f"Failed to load {_cached_data_path}, regenerating..."
)
cached_data_path = self._cache_path / str(original_idx) / f"{edit_idx}.bin"
cached_data_path.parent.mkdir(parents=True, exist_ok=True)
edited = self.get_edited_image(original_idx, edit_idx)
edited_histogram = compute_histogram(
edited, bins=self._bin_count, normalize=True
)
if cached_data_path and cached_data_path.exists():
try:
edited_histogram, original_histogram = self.read_2_histograms(
cached_data_path, self._bin_count
)
except:
logging.warning(f"Failed to load {cached_data_path}, regenerating...")
else:
edited = self.get_edited_image(original_idx, edit_idx)
edited_histogram = compute_histogram(
edited, bins=self._bin_count, normalize=True
)
original = self.get_original_image(original_idx)
original_histogram = compute_histogram(
original, bins=self._bin_count, normalize=True
)
original = self.get_original_image(original_idx)
original_histogram = compute_histogram(
original, bins=self._bin_count, normalize=True
)
if cached_data_path:
self.save_2_histograms(
edited_histogram,
original_histogram,
cached_data_path,
)
result = (
torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0),
torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0),
)
if self._cache_path is not None:
torch.save(result, _cached_data_path)
return result
@staticmethod
def save_2_histograms(tensor1: np.ndarray, tensor2: np.ndarray, path: Path):
flat_array1 = tensor1.flatten().astype(np.float32)
flat_array2 = tensor2.flatten().astype(np.float32)
assert len(flat_array1) == len(flat_array2)
format = f"{len(flat_array1)}f{len(flat_array2)}f"
packed_bytes = struct.pack(format, *flat_array1, *flat_array2)
compressed_bytes = zlib.compress(packed_bytes, level=9)
with open(path, "wb") as f:
f.write(compressed_bytes)
@staticmethod
def read_2_histograms(path: Path, bin_count: int) -> Tuple[np.ndarray, np.ndarray]:
length = bin_count**3
format = f"{length}f{length}f"
with open(path, "rb") as f:
packed_data = f.read()
unpacked_data = struct.unpack(format, zlib.decompress(packed_data))
return (
np.array(unpacked_data[:length], dtype=np.float32).reshape(
(bin_count, bin_count, bin_count)
),
np.array(unpacked_data[length:], dtype=np.float32).reshape(
(bin_count, bin_count, bin_count)
),
)