Fix loading bug

This commit is contained in:
Andras Schmelczer 2024-09-08 10:32:36 +01:00
parent aecd7ec9cb
commit 6b1fafd586
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C

View file

@ -53,23 +53,23 @@ class HistogramDataset(Dataset):
return random_edit(original_image, seed=original_idx * 7919 + edit_idx) return random_edit(original_image, seed=original_idx * 7919 + edit_idx)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
assert self._cache_path is not None, "Cache path is not set"
original_idx = idx // self._edit_count original_idx = idx // self._edit_count
edit_idx = idx % self._edit_count edit_idx = idx % self._edit_count
cached_data_path = None
if self._cache_path is not None:
cached_data_path = self._cache_path / str(original_idx) / f"{edit_idx}.bin" cached_data_path = self._cache_path / str(original_idx) / f"{edit_idx}.bin"
cached_data_path.parent.mkdir(parents=True, exist_ok=True) cached_data_path.parent.mkdir(parents=True, exist_ok=True)
if cached_data_path and cached_data_path.exists(): if cached_data_path.exists():
try: try:
edited_histogram, original_histogram = self.read_2_histograms( edited_histogram, original_histogram = self.read_2_histograms(
cached_data_path, self._bin_count cached_data_path, self._bin_count
) )
logging.debug(f"Loaded {cached_data_path} from cache") logging.debug(f"Loaded {cached_data_path} from cache")
except: except:
cached_data_path.unlink(missing_ok=True)
logging.warning(f"Failed to load {cached_data_path}, regenerating...") logging.warning(f"Failed to load {cached_data_path}, regenerating...")
else: elif not cached_data_path.exists():
edited = self.get_edited_image(original_idx, edit_idx) edited = self.get_edited_image(original_idx, edit_idx)
edited_histogram = compute_histogram( edited_histogram = compute_histogram(
edited, bins=self._bin_count, normalize=True edited, bins=self._bin_count, normalize=True
@ -80,7 +80,6 @@ class HistogramDataset(Dataset):
original, bins=self._bin_count, normalize=True original, bins=self._bin_count, normalize=True
) )
if cached_data_path:
try: try:
self.save_2_histograms( self.save_2_histograms(
edited_histogram, edited_histogram,