Refactor
This commit is contained in:
parent
eec9ee0275
commit
294f2fab12
9 changed files with 62140 additions and 11540 deletions
|
|
@ -1,5 +1,5 @@
|
|||
from torch.utils.data import Dataset
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
from editor.utils import compute_histogram
|
||||
from .random_edit import random_edit
|
||||
from PIL import Image
|
||||
|
|
@ -45,7 +45,19 @@ class HistogramDataset(Dataset):
|
|||
def __len__(self):
|
||||
return len(self._paths) * self._edit_count
|
||||
|
||||
def __getitem__(self, idx):
|
||||
def get_original_image(self, original_idx: int) -> Image.Image:
|
||||
original_path = self._paths[original_idx]
|
||||
original = Image.open(original_path)
|
||||
original.thumbnail(
|
||||
self._target_size, Image.Resampling.LANCZOS
|
||||
) # size will be at most target_size, the aspect ratio is preserved
|
||||
return original
|
||||
|
||||
def get_edited_image(self, original_idx: int, edit_idx: int) -> Image.Image:
|
||||
original_image = self.get_original_image(original_idx)
|
||||
return random_edit(original_image, seed=edit_idx)
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self._cache_path is not None:
|
||||
self._cached_data_path = self._cache_path / f"{idx}.pt"
|
||||
if self._cached_data_path.exists():
|
||||
|
|
@ -55,10 +67,7 @@ class HistogramDataset(Dataset):
|
|||
print(f"Failed to load {self._cached_data_path}, regenerating...")
|
||||
|
||||
original_idx = idx // self._edit_count
|
||||
original_path = self._paths[original_idx]
|
||||
original = Image.open(original_path)
|
||||
original.thumbnail(self._target_size, Image.Resampling.LANCZOS)
|
||||
|
||||
original = self.get_original_image(original_idx)
|
||||
edited = random_edit(original, seed=idx)
|
||||
|
||||
edited_histogram = compute_histogram(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue