From 6b1fafd5861d4fb8d633ef93e6b823b3bd37c8f2 Mon Sep 17 00:00:00 2001 From: Andras Schmelczer Date: Sun, 8 Sep 2024 10:32:36 +0100 Subject: [PATCH] Fix loading bug --- src/training/histogram_dataset.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/src/training/histogram_dataset.py b/src/training/histogram_dataset.py index b9d526a..b260121 100644 --- a/src/training/histogram_dataset.py +++ b/src/training/histogram_dataset.py @@ -53,23 +53,23 @@ class HistogramDataset(Dataset): return random_edit(original_image, seed=original_idx * 7919 + edit_idx) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + assert self._cache_path is not None, "Cache path is not set" original_idx = idx // self._edit_count edit_idx = idx % self._edit_count - cached_data_path = None - if self._cache_path is not None: - cached_data_path = self._cache_path / str(original_idx) / f"{edit_idx}.bin" - cached_data_path.parent.mkdir(parents=True, exist_ok=True) + cached_data_path = self._cache_path / str(original_idx) / f"{edit_idx}.bin" + cached_data_path.parent.mkdir(parents=True, exist_ok=True) - if cached_data_path and cached_data_path.exists(): + if cached_data_path.exists(): try: edited_histogram, original_histogram = self.read_2_histograms( cached_data_path, self._bin_count ) logging.debug(f"Loaded {cached_data_path} from cache") except: + cached_data_path.unlink(missing_ok=True) logging.warning(f"Failed to load {cached_data_path}, regenerating...") - else: + elif not cached_data_path.exists(): edited = self.get_edited_image(original_idx, edit_idx) edited_histogram = compute_histogram( edited, bins=self._bin_count, normalize=True @@ -80,16 +80,15 @@ class HistogramDataset(Dataset): original, bins=self._bin_count, normalize=True ) - if cached_data_path: - try: - self.save_2_histograms( - edited_histogram, - original_histogram, - cached_data_path, - ) - logging.debug(f"Saved {cached_data_path} to cache") - except: - logging.warning(f"Failed to save {cached_data_path}") + try: + self.save_2_histograms( + edited_histogram, + original_histogram, + cached_data_path, + ) + logging.debug(f"Saved {cached_data_path} to cache") + except: + logging.warning(f"Failed to save {cached_data_path}") return ( torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0),