Add caching
This commit is contained in:
parent
07d926161e
commit
b87c1dd859
1 changed files with 17 additions and 2 deletions
|
|
@ -1,5 +1,5 @@
|
|||
from torch.utils.data import Dataset
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from editor.utils import compute_histogram
|
||||
from .random_edit import random_edit
|
||||
from PIL import Image
|
||||
|
|
@ -20,11 +20,13 @@ class HistogramDataset(Dataset):
|
|||
bin_count: int = 32,
|
||||
target_size=(480, 480),
|
||||
delete_corrupt_images: bool = False,
|
||||
cache_path: Optional[Path] = None,
|
||||
):
|
||||
self._paths = sorted(paths)
|
||||
self._edit_count = edit_count
|
||||
self._bin_count = bin_count
|
||||
self._target_size = target_size
|
||||
self._cache_path = cache_path
|
||||
|
||||
if delete_corrupt_images:
|
||||
self._delete_corrupt_images()
|
||||
|
|
@ -44,6 +46,14 @@ class HistogramDataset(Dataset):
|
|||
return len(self._paths) * self._edit_count
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self._cache_path is not None:
|
||||
self._cached_data_path = self._cache_path / f"{idx}.pt"
|
||||
if self._cached_data_path.exists():
|
||||
try:
|
||||
return torch.load(self._cached_data_path)
|
||||
except:
|
||||
print(f"Failed to load {self._cached_data_path}, regenerating...")
|
||||
|
||||
original_idx = idx // self._edit_count
|
||||
original_path = self._paths[original_idx]
|
||||
original = Image.open(original_path)
|
||||
|
|
@ -58,7 +68,12 @@ class HistogramDataset(Dataset):
|
|||
edited, bins=self._bin_count, normalize=True
|
||||
)
|
||||
|
||||
return (
|
||||
result = (
|
||||
torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0),
|
||||
torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0),
|
||||
)
|
||||
|
||||
if self._cache_path is not None:
|
||||
torch.save(result, self._cached_data_path)
|
||||
|
||||
return result
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue