Fix loading bug

This commit is contained in:
Andras Schmelczer 2024-09-08 10:32:36 +01:00
parent aecd7ec9cb
commit 6b1fafd586
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C

View file

@ -53,23 +53,23 @@ class HistogramDataset(Dataset):
return random_edit(original_image, seed=original_idx * 7919 + edit_idx) return random_edit(original_image, seed=original_idx * 7919 + edit_idx)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
assert self._cache_path is not None, "Cache path is not set"
original_idx = idx // self._edit_count original_idx = idx // self._edit_count
edit_idx = idx % self._edit_count edit_idx = idx % self._edit_count
cached_data_path = None cached_data_path = self._cache_path / str(original_idx) / f"{edit_idx}.bin"
if self._cache_path is not None: cached_data_path.parent.mkdir(parents=True, exist_ok=True)
cached_data_path = self._cache_path / str(original_idx) / f"{edit_idx}.bin"
cached_data_path.parent.mkdir(parents=True, exist_ok=True)
if cached_data_path and cached_data_path.exists(): if cached_data_path.exists():
try: try:
edited_histogram, original_histogram = self.read_2_histograms( edited_histogram, original_histogram = self.read_2_histograms(
cached_data_path, self._bin_count cached_data_path, self._bin_count
) )
logging.debug(f"Loaded {cached_data_path} from cache") logging.debug(f"Loaded {cached_data_path} from cache")
except: except:
cached_data_path.unlink(missing_ok=True)
logging.warning(f"Failed to load {cached_data_path}, regenerating...") logging.warning(f"Failed to load {cached_data_path}, regenerating...")
else: elif not cached_data_path.exists():
edited = self.get_edited_image(original_idx, edit_idx) edited = self.get_edited_image(original_idx, edit_idx)
edited_histogram = compute_histogram( edited_histogram = compute_histogram(
edited, bins=self._bin_count, normalize=True edited, bins=self._bin_count, normalize=True
@ -80,16 +80,15 @@ class HistogramDataset(Dataset):
original, bins=self._bin_count, normalize=True original, bins=self._bin_count, normalize=True
) )
if cached_data_path: try:
try: self.save_2_histograms(
self.save_2_histograms( edited_histogram,
edited_histogram, original_histogram,
original_histogram, cached_data_path,
cached_data_path, )
) logging.debug(f"Saved {cached_data_path} to cache")
logging.debug(f"Saved {cached_data_path} to cache") except:
except: logging.warning(f"Failed to save {cached_data_path}")
logging.warning(f"Failed to save {cached_data_path}")
return ( return (
torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0), torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0),