This commit is contained in:
Andras Schmelczer 2024-04-28 12:19:19 +01:00
parent eec9ee0275
commit 294f2fab12
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
9 changed files with 62140 additions and 11540 deletions

View file

@ -1,5 +1,5 @@
from torch.utils.data import Dataset
from typing import List, Optional
from typing import List, Optional, Tuple
from editor.utils import compute_histogram
from .random_edit import random_edit
from PIL import Image
@ -45,7 +45,19 @@ class HistogramDataset(Dataset):
def __len__(self):
return len(self._paths) * self._edit_count
def __getitem__(self, idx):
def get_original_image(self, original_idx: int) -> Image.Image:
original_path = self._paths[original_idx]
original = Image.open(original_path)
original.thumbnail(
self._target_size, Image.Resampling.LANCZOS
) # size will be at most target_size, the aspect ratio is preserved
return original
def get_edited_image(self, original_idx: int, edit_idx: int) -> Image.Image:
original_image = self.get_original_image(original_idx)
return random_edit(original_image, seed=edit_idx)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
if self._cache_path is not None:
self._cached_data_path = self._cache_path / f"{idx}.pt"
if self._cached_data_path.exists():
@ -55,10 +67,7 @@ class HistogramDataset(Dataset):
print(f"Failed to load {self._cached_data_path}, regenerating...")
original_idx = idx // self._edit_count
original_path = self._paths[original_idx]
original = Image.open(original_path)
original.thumbnail(self._target_size, Image.Resampling.LANCZOS)
original = self.get_original_image(original_idx)
edited = random_edit(original, seed=idx)
edited_histogram = compute_histogram(