diff --git a/editor/training/histogram_dataset.py b/editor/training/histogram_dataset.py index 2693d20..7d44e4e 100644 --- a/editor/training/histogram_dataset.py +++ b/editor/training/histogram_dataset.py @@ -1,5 +1,5 @@ from torch.utils.data import Dataset -from typing import List +from typing import List, Optional from editor.utils import compute_histogram from .random_edit import random_edit from PIL import Image @@ -20,11 +20,13 @@ class HistogramDataset(Dataset): bin_count: int = 32, target_size=(480, 480), delete_corrupt_images: bool = False, + cache_path: Optional[Path] = None, ): self._paths = sorted(paths) self._edit_count = edit_count self._bin_count = bin_count self._target_size = target_size + self._cache_path = cache_path if delete_corrupt_images: self._delete_corrupt_images() @@ -44,6 +46,14 @@ class HistogramDataset(Dataset): return len(self._paths) * self._edit_count def __getitem__(self, idx): + if self._cache_path is not None: + self._cached_data_path = self._cache_path / f"{idx}.pt" + if self._cached_data_path.exists(): + try: + return torch.load(self._cached_data_path) + except: + print(f"Failed to load {self._cached_data_path}, regenerating...") + original_idx = idx // self._edit_count original_path = self._paths[original_idx] original = Image.open(original_path) @@ -58,7 +68,12 @@ class HistogramDataset(Dataset): edited, bins=self._bin_count, normalize=True ) - return ( + result = ( torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0), torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0), ) + + if self._cache_path is not None: + torch.save(result, self._cached_data_path) + + return result