Fix loading bug
This commit is contained in:
parent
aecd7ec9cb
commit
6b1fafd586
1 changed files with 15 additions and 16 deletions
|
|
@ -53,23 +53,23 @@ class HistogramDataset(Dataset):
|
||||||
return random_edit(original_image, seed=original_idx * 7919 + edit_idx)
|
return random_edit(original_image, seed=original_idx * 7919 + edit_idx)
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert self._cache_path is not None, "Cache path is not set"
|
||||||
original_idx = idx // self._edit_count
|
original_idx = idx // self._edit_count
|
||||||
edit_idx = idx % self._edit_count
|
edit_idx = idx % self._edit_count
|
||||||
|
|
||||||
cached_data_path = None
|
cached_data_path = self._cache_path / str(original_idx) / f"{edit_idx}.bin"
|
||||||
if self._cache_path is not None:
|
cached_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
cached_data_path = self._cache_path / str(original_idx) / f"{edit_idx}.bin"
|
|
||||||
cached_data_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
if cached_data_path and cached_data_path.exists():
|
if cached_data_path.exists():
|
||||||
try:
|
try:
|
||||||
edited_histogram, original_histogram = self.read_2_histograms(
|
edited_histogram, original_histogram = self.read_2_histograms(
|
||||||
cached_data_path, self._bin_count
|
cached_data_path, self._bin_count
|
||||||
)
|
)
|
||||||
logging.debug(f"Loaded {cached_data_path} from cache")
|
logging.debug(f"Loaded {cached_data_path} from cache")
|
||||||
except:
|
except:
|
||||||
|
cached_data_path.unlink(missing_ok=True)
|
||||||
logging.warning(f"Failed to load {cached_data_path}, regenerating...")
|
logging.warning(f"Failed to load {cached_data_path}, regenerating...")
|
||||||
else:
|
elif not cached_data_path.exists():
|
||||||
edited = self.get_edited_image(original_idx, edit_idx)
|
edited = self.get_edited_image(original_idx, edit_idx)
|
||||||
edited_histogram = compute_histogram(
|
edited_histogram = compute_histogram(
|
||||||
edited, bins=self._bin_count, normalize=True
|
edited, bins=self._bin_count, normalize=True
|
||||||
|
|
@ -80,16 +80,15 @@ class HistogramDataset(Dataset):
|
||||||
original, bins=self._bin_count, normalize=True
|
original, bins=self._bin_count, normalize=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if cached_data_path:
|
try:
|
||||||
try:
|
self.save_2_histograms(
|
||||||
self.save_2_histograms(
|
edited_histogram,
|
||||||
edited_histogram,
|
original_histogram,
|
||||||
original_histogram,
|
cached_data_path,
|
||||||
cached_data_path,
|
)
|
||||||
)
|
logging.debug(f"Saved {cached_data_path} to cache")
|
||||||
logging.debug(f"Saved {cached_data_path} to cache")
|
except:
|
||||||
except:
|
logging.warning(f"Failed to save {cached_data_path}")
|
||||||
logging.warning(f"Failed to save {cached_data_path}")
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0),
|
torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0),
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue