diff --git a/src/config.py b/src/config.py index a844c3c..441635a 100644 --- a/src/config.py +++ b/src/config.py @@ -2,7 +2,7 @@ from pathlib import Path DATA = sorted(Path("/mnt/wsl/PHYSICALDRIVE1/data/unsplash").glob("*.jpg")) -CACHE_PATH = Path("/mnt/wsl/PHYSICALDRIVE1/data/cache2") +CACHE_PATH = Path("/mnt/wsl/PHYSICALDRIVE1/data/cache") CACHE_PATH.mkdir(exist_ok=True, parents=True) MODELS_PATH = Path("/home/andras/projects/bipolaroid/models") diff --git a/src/editor/training/histogram_dataset.py b/src/editor/training/histogram_dataset.py index baf549c..fe39deb 100644 --- a/src/editor/training/histogram_dataset.py +++ b/src/editor/training/histogram_dataset.py @@ -4,6 +4,7 @@ from editor.utils import compute_histogram from .random_edit import random_edit from PIL import Image from tqdm import tqdm +import logging import torch from pathlib import Path @@ -18,7 +19,7 @@ class HistogramDataset(Dataset): paths: List[Path], edit_count: int = 5, bin_count: int = 32, - target_size=(480, 480), + target_size=(240, 240), delete_corrupt_images: bool = False, cache_path: Optional[Path] = None, ): @@ -27,6 +28,11 @@ class HistogramDataset(Dataset): self._bin_count = bin_count self._target_size = target_size self._cache_path = cache_path + if self._cache_path: + self._cache_path = ( + self._cache_path + / f"{self._bin_count}bins_{self._target_size[0]}x{self._target_size[1]}px" + ) if delete_corrupt_images: self._delete_corrupt_images() @@ -37,10 +43,10 @@ class HistogramDataset(Dataset): try: Image.open(path) except: - print(f"Failed to open {path}, deleting...") + logging.warning(f"Failed to open {path}, deleting...") deleted_count += 1 path.unlink() - print(f"Deleted {deleted_count} corrupt images") + logging.info(f"Deleted {deleted_count} corrupt images") def __len__(self): return len(self._paths) * self._edit_count @@ -55,25 +61,29 @@ class HistogramDataset(Dataset): def get_edited_image(self, original_idx: int, edit_idx: int) -> Image.Image: original_image = self.get_original_image(original_idx) - return random_edit(original_image, seed=edit_idx) + return random_edit(original_image, seed=original_idx * 7919 + edit_idx) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: - if self._cache_path is not None: - self._cached_data_path = self._cache_path / f"{idx}.pt" - if self._cached_data_path.exists(): - try: - return torch.load(self._cached_data_path) - except: - print(f"Failed to load {self._cached_data_path}, regenerating...") - original_idx = idx // self._edit_count - original = self.get_original_image(original_idx) - edited = random_edit(original, seed=idx) + edit_idx = idx % self._edit_count + if self._cache_path is not None: + _cached_data_path = self._cache_path / str(original_idx) / f"{edit_idx}.pt" + _cached_data_path.parent.mkdir(parents=True, exist_ok=True) + if _cached_data_path.exists(): + try: + return torch.load(_cached_data_path) + except: + logging.warning( + f"Failed to load {_cached_data_path}, regenerating..." + ) + + edited = self.get_edited_image(original_idx, edit_idx) edited_histogram = compute_histogram( edited, bins=self._bin_count, normalize=True ) + original = self.get_original_image(original_idx) original_histogram = compute_histogram( original, bins=self._bin_count, normalize=True ) @@ -84,6 +94,6 @@ class HistogramDataset(Dataset): ) if self._cache_path is not None: - torch.save(result, self._cached_data_path) + torch.save(result, _cached_data_path) return result