From 9e286264b88668512e10dc6230125c59875f804d Mon Sep 17 00:00:00 2001 From: Andras Schmelczer Date: Wed, 28 Aug 2024 20:42:10 +0100 Subject: [PATCH] Make cached files 10x smaller --- src/training/histogram_dataset.py | 83 +++++++++++++++++++++++-------- 1 file changed, 62 insertions(+), 21 deletions(-) diff --git a/src/training/histogram_dataset.py b/src/training/histogram_dataset.py index d51b005..b7d4d9e 100644 --- a/src/training/histogram_dataset.py +++ b/src/training/histogram_dataset.py @@ -6,8 +6,10 @@ from PIL import Image import logging import torch from pathlib import Path - +import struct +import zlib import PIL.Image +import numpy as np PIL.Image.MAX_IMAGE_PIXELS = None @@ -54,33 +56,72 @@ class HistogramDataset(Dataset): original_idx = idx // self._edit_count edit_idx = idx % self._edit_count + edited_histogram = None + original_histogram = None + cached_data_path = None + if self._cache_path is not None: - _cached_data_path = self._cache_path / str(original_idx) / f"{edit_idx}.pt" - _cached_data_path.parent.mkdir(parents=True, exist_ok=True) - if _cached_data_path.exists(): - try: - return torch.load(_cached_data_path) - except: - logging.warning( - f"Failed to load {_cached_data_path}, regenerating..." - ) + cached_data_path = self._cache_path / str(original_idx) / f"{edit_idx}.bin" + cached_data_path.parent.mkdir(parents=True, exist_ok=True) - edited = self.get_edited_image(original_idx, edit_idx) - edited_histogram = compute_histogram( - edited, bins=self._bin_count, normalize=True - ) + if cached_data_path and cached_data_path.exists(): + try: + edited_histogram, original_histogram = self.read_2_histograms( + cached_data_path, self._bin_count + ) + except: + logging.warning(f"Failed to load {cached_data_path}, regenerating...") + else: + edited = self.get_edited_image(original_idx, edit_idx) + edited_histogram = compute_histogram( + edited, bins=self._bin_count, normalize=True + ) - original = self.get_original_image(original_idx) - original_histogram = compute_histogram( - original, bins=self._bin_count, normalize=True - ) + original = self.get_original_image(original_idx) + original_histogram = compute_histogram( + original, bins=self._bin_count, normalize=True + ) + + if cached_data_path: + self.save_2_histograms( + edited_histogram, + original_histogram, + cached_data_path, + ) result = ( torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0), torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0), ) - if self._cache_path is not None: - torch.save(result, _cached_data_path) - return result + + @staticmethod + def save_2_histograms(tensor1: np.ndarray, tensor2: np.ndarray, path: Path): + flat_array1 = tensor1.flatten().astype(np.float32) + flat_array2 = tensor2.flatten().astype(np.float32) + + assert len(flat_array1) == len(flat_array2) + + format = f"{len(flat_array1)}f{len(flat_array2)}f" + packed_bytes = struct.pack(format, *flat_array1, *flat_array2) + compressed_bytes = zlib.compress(packed_bytes, level=9) + with open(path, "wb") as f: + f.write(compressed_bytes) + + @staticmethod + def read_2_histograms(path: Path, bin_count: int) -> Tuple[np.ndarray, np.ndarray]: + length = bin_count**3 + format = f"{length}f{length}f" + with open(path, "rb") as f: + packed_data = f.read() + + unpacked_data = struct.unpack(format, zlib.decompress(packed_data)) + return ( + np.array(unpacked_data[:length], dtype=np.float32).reshape( + (bin_count, bin_count, bin_count) + ), + np.array(unpacked_data[length:], dtype=np.float32).reshape( + (bin_count, bin_count, bin_count) + ), + )