Make cached files 10x smaller
This commit is contained in:
parent
9d2830f8ff
commit
9e286264b8
1 changed files with 62 additions and 21 deletions
|
|
@ -6,8 +6,10 @@ from PIL import Image
|
|||
import logging
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
import struct
|
||||
import zlib
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
|
||||
PIL.Image.MAX_IMAGE_PIXELS = None
|
||||
|
||||
|
|
@ -54,33 +56,72 @@ class HistogramDataset(Dataset):
|
|||
original_idx = idx // self._edit_count
|
||||
edit_idx = idx % self._edit_count
|
||||
|
||||
edited_histogram = None
|
||||
original_histogram = None
|
||||
cached_data_path = None
|
||||
|
||||
if self._cache_path is not None:
|
||||
_cached_data_path = self._cache_path / str(original_idx) / f"{edit_idx}.pt"
|
||||
_cached_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if _cached_data_path.exists():
|
||||
try:
|
||||
return torch.load(_cached_data_path)
|
||||
except:
|
||||
logging.warning(
|
||||
f"Failed to load {_cached_data_path}, regenerating..."
|
||||
)
|
||||
cached_data_path = self._cache_path / str(original_idx) / f"{edit_idx}.bin"
|
||||
cached_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
edited = self.get_edited_image(original_idx, edit_idx)
|
||||
edited_histogram = compute_histogram(
|
||||
edited, bins=self._bin_count, normalize=True
|
||||
)
|
||||
if cached_data_path and cached_data_path.exists():
|
||||
try:
|
||||
edited_histogram, original_histogram = self.read_2_histograms(
|
||||
cached_data_path, self._bin_count
|
||||
)
|
||||
except:
|
||||
logging.warning(f"Failed to load {cached_data_path}, regenerating...")
|
||||
else:
|
||||
edited = self.get_edited_image(original_idx, edit_idx)
|
||||
edited_histogram = compute_histogram(
|
||||
edited, bins=self._bin_count, normalize=True
|
||||
)
|
||||
|
||||
original = self.get_original_image(original_idx)
|
||||
original_histogram = compute_histogram(
|
||||
original, bins=self._bin_count, normalize=True
|
||||
)
|
||||
original = self.get_original_image(original_idx)
|
||||
original_histogram = compute_histogram(
|
||||
original, bins=self._bin_count, normalize=True
|
||||
)
|
||||
|
||||
if cached_data_path:
|
||||
self.save_2_histograms(
|
||||
edited_histogram,
|
||||
original_histogram,
|
||||
cached_data_path,
|
||||
)
|
||||
|
||||
result = (
|
||||
torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0),
|
||||
torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0),
|
||||
)
|
||||
|
||||
if self._cache_path is not None:
|
||||
torch.save(result, _cached_data_path)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def save_2_histograms(tensor1: np.ndarray, tensor2: np.ndarray, path: Path):
|
||||
flat_array1 = tensor1.flatten().astype(np.float32)
|
||||
flat_array2 = tensor2.flatten().astype(np.float32)
|
||||
|
||||
assert len(flat_array1) == len(flat_array2)
|
||||
|
||||
format = f"{len(flat_array1)}f{len(flat_array2)}f"
|
||||
packed_bytes = struct.pack(format, *flat_array1, *flat_array2)
|
||||
compressed_bytes = zlib.compress(packed_bytes, level=9)
|
||||
with open(path, "wb") as f:
|
||||
f.write(compressed_bytes)
|
||||
|
||||
@staticmethod
|
||||
def read_2_histograms(path: Path, bin_count: int) -> Tuple[np.ndarray, np.ndarray]:
|
||||
length = bin_count**3
|
||||
format = f"{length}f{length}f"
|
||||
with open(path, "rb") as f:
|
||||
packed_data = f.read()
|
||||
|
||||
unpacked_data = struct.unpack(format, zlib.decompress(packed_data))
|
||||
return (
|
||||
np.array(unpacked_data[:length], dtype=np.float32).reshape(
|
||||
(bin_count, bin_count, bin_count)
|
||||
),
|
||||
np.array(unpacked_data[length:], dtype=np.float32).reshape(
|
||||
(bin_count, bin_count, bin_count)
|
||||
),
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue