Add caching

This commit is contained in:
Andras Schmelczer 2024-04-12 20:49:17 +01:00
parent 07d926161e
commit b87c1dd859
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C

View file

@ -1,5 +1,5 @@
from torch.utils.data import Dataset
from typing import List
from typing import List, Optional
from editor.utils import compute_histogram
from .random_edit import random_edit
from PIL import Image
@ -20,11 +20,13 @@ class HistogramDataset(Dataset):
bin_count: int = 32,
target_size=(480, 480),
delete_corrupt_images: bool = False,
cache_path: Optional[Path] = None,
):
self._paths = sorted(paths)
self._edit_count = edit_count
self._bin_count = bin_count
self._target_size = target_size
self._cache_path = cache_path
if delete_corrupt_images:
self._delete_corrupt_images()
@ -44,6 +46,14 @@ class HistogramDataset(Dataset):
return len(self._paths) * self._edit_count
def __getitem__(self, idx):
if self._cache_path is not None:
self._cached_data_path = self._cache_path / f"{idx}.pt"
if self._cached_data_path.exists():
try:
return torch.load(self._cached_data_path)
except:
print(f"Failed to load {self._cached_data_path}, regenerating...")
original_idx = idx // self._edit_count
original_path = self._paths[original_idx]
original = Image.open(original_path)
@ -58,7 +68,12 @@ class HistogramDataset(Dataset):
edited, bins=self._bin_count, normalize=True
)
return (
result = (
torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0),
torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0),
)
if self._cache_path is not None:
torch.save(result, self._cached_data_path)
return result