Fix caching
This commit is contained in:
parent
b31fa39ca4
commit
09aceae9d4
2 changed files with 26 additions and 16 deletions
|
|
@ -2,7 +2,7 @@ from pathlib import Path
|
|||
|
||||
DATA = sorted(Path("/mnt/wsl/PHYSICALDRIVE1/data/unsplash").glob("*.jpg"))
|
||||
|
||||
CACHE_PATH = Path("/mnt/wsl/PHYSICALDRIVE1/data/cache2")
|
||||
CACHE_PATH = Path("/mnt/wsl/PHYSICALDRIVE1/data/cache")
|
||||
CACHE_PATH.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
MODELS_PATH = Path("/home/andras/projects/bipolaroid/models")
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from editor.utils import compute_histogram
|
|||
from .random_edit import random_edit
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -18,7 +19,7 @@ class HistogramDataset(Dataset):
|
|||
paths: List[Path],
|
||||
edit_count: int = 5,
|
||||
bin_count: int = 32,
|
||||
target_size=(480, 480),
|
||||
target_size=(240, 240),
|
||||
delete_corrupt_images: bool = False,
|
||||
cache_path: Optional[Path] = None,
|
||||
):
|
||||
|
|
@ -27,6 +28,11 @@ class HistogramDataset(Dataset):
|
|||
self._bin_count = bin_count
|
||||
self._target_size = target_size
|
||||
self._cache_path = cache_path
|
||||
if self._cache_path:
|
||||
self._cache_path = (
|
||||
self._cache_path
|
||||
/ f"{self._bin_count}bins_{self._target_size[0]}x{self._target_size[1]}px"
|
||||
)
|
||||
|
||||
if delete_corrupt_images:
|
||||
self._delete_corrupt_images()
|
||||
|
|
@ -37,10 +43,10 @@ class HistogramDataset(Dataset):
|
|||
try:
|
||||
Image.open(path)
|
||||
except:
|
||||
print(f"Failed to open {path}, deleting...")
|
||||
logging.warning(f"Failed to open {path}, deleting...")
|
||||
deleted_count += 1
|
||||
path.unlink()
|
||||
print(f"Deleted {deleted_count} corrupt images")
|
||||
logging.info(f"Deleted {deleted_count} corrupt images")
|
||||
|
||||
def __len__(self):
|
||||
return len(self._paths) * self._edit_count
|
||||
|
|
@ -55,25 +61,29 @@ class HistogramDataset(Dataset):
|
|||
|
||||
def get_edited_image(self, original_idx: int, edit_idx: int) -> Image.Image:
|
||||
original_image = self.get_original_image(original_idx)
|
||||
return random_edit(original_image, seed=edit_idx)
|
||||
return random_edit(original_image, seed=original_idx * 7919 + edit_idx)
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self._cache_path is not None:
|
||||
self._cached_data_path = self._cache_path / f"{idx}.pt"
|
||||
if self._cached_data_path.exists():
|
||||
try:
|
||||
return torch.load(self._cached_data_path)
|
||||
except:
|
||||
print(f"Failed to load {self._cached_data_path}, regenerating...")
|
||||
|
||||
original_idx = idx // self._edit_count
|
||||
original = self.get_original_image(original_idx)
|
||||
edited = random_edit(original, seed=idx)
|
||||
edit_idx = idx % self._edit_count
|
||||
|
||||
if self._cache_path is not None:
|
||||
_cached_data_path = self._cache_path / str(original_idx) / f"{edit_idx}.pt"
|
||||
_cached_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if _cached_data_path.exists():
|
||||
try:
|
||||
return torch.load(_cached_data_path)
|
||||
except:
|
||||
logging.warning(
|
||||
f"Failed to load {_cached_data_path}, regenerating..."
|
||||
)
|
||||
|
||||
edited = self.get_edited_image(original_idx, edit_idx)
|
||||
edited_histogram = compute_histogram(
|
||||
edited, bins=self._bin_count, normalize=True
|
||||
)
|
||||
|
||||
original = self.get_original_image(original_idx)
|
||||
original_histogram = compute_histogram(
|
||||
original, bins=self._bin_count, normalize=True
|
||||
)
|
||||
|
|
@ -84,6 +94,6 @@ class HistogramDataset(Dataset):
|
|||
)
|
||||
|
||||
if self._cache_path is not None:
|
||||
torch.save(result, self._cached_data_path)
|
||||
torch.save(result, _cached_data_path)
|
||||
|
||||
return result
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue