Refactor
This commit is contained in:
parent
eec9ee0275
commit
294f2fab12
9 changed files with 62140 additions and 11540 deletions
|
|
@ -1,2 +0,0 @@
|
|||
from .display_images import display_images
|
||||
from .plot_histograms import plot_histograms
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
from torch.utils.data import Dataset
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
from editor.utils import compute_histogram
|
||||
from .random_edit import random_edit
|
||||
from PIL import Image
|
||||
|
|
@ -45,7 +45,19 @@ class HistogramDataset(Dataset):
|
|||
def __len__(self):
|
||||
return len(self._paths) * self._edit_count
|
||||
|
||||
def __getitem__(self, idx):
|
||||
def get_original_image(self, original_idx: int) -> Image.Image:
|
||||
original_path = self._paths[original_idx]
|
||||
original = Image.open(original_path)
|
||||
original.thumbnail(
|
||||
self._target_size, Image.Resampling.LANCZOS
|
||||
) # size will be at most target_size, the aspect ratio is preserved
|
||||
return original
|
||||
|
||||
def get_edited_image(self, original_idx: int, edit_idx: int) -> Image.Image:
|
||||
original_image = self.get_original_image(original_idx)
|
||||
return random_edit(original_image, seed=edit_idx)
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self._cache_path is not None:
|
||||
self._cached_data_path = self._cache_path / f"{idx}.pt"
|
||||
if self._cached_data_path.exists():
|
||||
|
|
@ -55,10 +67,7 @@ class HistogramDataset(Dataset):
|
|||
print(f"Failed to load {self._cached_data_path}, regenerating...")
|
||||
|
||||
original_idx = idx // self._edit_count
|
||||
original_path = self._paths[original_idx]
|
||||
original = Image.open(original_path)
|
||||
original.thumbnail(self._target_size, Image.Resampling.LANCZOS)
|
||||
|
||||
original = self.get_original_image(original_idx)
|
||||
edited = random_edit(original, seed=idx)
|
||||
|
||||
edited_histogram = compute_histogram(
|
||||
|
|
|
|||
3
editor/visualisation/__init__.py
Normal file
3
editor/visualisation/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .display_images import display_images
|
||||
from .plot_histograms_in_3d import plot_histograms_in_3d
|
||||
from .plot_histograms_in_2d import plot_histograms_in_2d
|
||||
|
|
@ -1,10 +1,10 @@
|
|||
import matplotlib.pyplot as plt
|
||||
from typing import List
|
||||
from typing import Dict
|
||||
from PIL.Image import Image
|
||||
from math import ceil
|
||||
|
||||
|
||||
def display_images(images: List[Image], titles: List[str], images_per_row: int = 3):
|
||||
def display_images(images: Dict[str, Image], images_per_row: int = 3):
|
||||
fig, axes = plt.subplots(
|
||||
nrows=ceil(len(images) / images_per_row),
|
||||
ncols=min(images_per_row, len(images)),
|
||||
|
|
@ -13,7 +13,7 @@ def display_images(images: List[Image], titles: List[str], images_per_row: int =
|
|||
|
||||
axes = axes.flatten()
|
||||
|
||||
for i, (title, image) in enumerate(zip(titles, images)):
|
||||
for i, (title, image) in enumerate(images.items()):
|
||||
axes[i].imshow(image)
|
||||
axes[i].axis("off")
|
||||
axes[i].set_title(title)
|
||||
32
editor/visualisation/plot_histograms_in_2d.py
Normal file
32
editor/visualisation/plot_histograms_in_2d.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def plot_histograms_in_2d(histograms: Dict[str, np.ndarray]):
|
||||
fig = plt.figure(figsize=(15, 5))
|
||||
|
||||
for i, (title, histogram) in enumerate(histograms.items(), 1):
|
||||
ax = fig.add_subplot(1, 3, i, projection="3d")
|
||||
|
||||
size = histogram.shape[0]
|
||||
|
||||
x, y, z = np.indices(histogram.shape)
|
||||
x = x.flatten()
|
||||
y = y.flatten()
|
||||
z = z.flatten()
|
||||
values = histogram.flatten()
|
||||
|
||||
sizes = values * 5000
|
||||
|
||||
colors = np.vstack((x, y, z)).T / (size - 1)
|
||||
|
||||
sc = ax.scatter(x, y, z, c=colors, s=sizes, marker="o", alpha=0.5)
|
||||
|
||||
ax.set_xlim([0, (size - 1)])
|
||||
ax.set_ylim([0, (size - 1)])
|
||||
ax.set_zlim([0, (size - 1)])
|
||||
ax.set_title(title)
|
||||
|
||||
return fig
|
||||
|
|
@ -1,32 +1,34 @@
|
|||
from plotly.subplots import make_subplots
|
||||
import plotly.graph_objects as go
|
||||
from math import ceil
|
||||
from typing import Dict
|
||||
import numpy as np
|
||||
|
||||
|
||||
def plot_histograms(hists, histogram_per_row: int = 3):
|
||||
cols = min(histogram_per_row, len(hists))
|
||||
def plot_histograms_in_3d(
|
||||
histograms: Dict[str, np.ndarray], histogram_per_row: int = 3
|
||||
):
|
||||
cols = min(histogram_per_row, len(histograms))
|
||||
rows = ceil(len(histograms) / histogram_per_row)
|
||||
fig = make_subplots(
|
||||
rows=ceil(len(hists) / histogram_per_row),
|
||||
rows=rows,
|
||||
cols=cols,
|
||||
specs=[[{"type": "scatter3d"} for _ in range(cols)] for _ in range(1)],
|
||||
)
|
||||
for i, hist in enumerate(hists, start=1):
|
||||
fig.add_trace(_get_3d_scatter_plot_from_histogram(hist), row=1, col=i)
|
||||
|
||||
fig.update_layout(
|
||||
showlegend=False,
|
||||
autosize=True,
|
||||
scene1=dict(xaxis_title="R", yaxis_title="G", zaxis_title="B"),
|
||||
scene2=dict(xaxis_title="R", yaxis_title="G", zaxis_title="B"),
|
||||
specs=[[{"type": "scatter3d"} for _ in range(cols)] for _ in range(rows)],
|
||||
)
|
||||
for i, (title, histogram) in enumerate(histograms.items()):
|
||||
fig.add_trace(
|
||||
_get_3d_scatter_plot_from_histogram(title, histogram),
|
||||
row=(i // (histogram_per_row + 1)) + 1,
|
||||
col=(i % histogram_per_row) + 1,
|
||||
)
|
||||
fig.show()
|
||||
|
||||
|
||||
def _get_3d_scatter_plot_from_histogram(hist):
|
||||
def _get_3d_scatter_plot_from_histogram(title, histogram):
|
||||
x, y, z, marker_size = [], [], [], []
|
||||
bins = len(hist)
|
||||
bins = len(histogram)
|
||||
|
||||
for i, row in enumerate(hist):
|
||||
for i, row in enumerate(histogram):
|
||||
for j, col in enumerate(row):
|
||||
for k, value in enumerate(col):
|
||||
if value > 0:
|
||||
|
|
@ -48,4 +50,5 @@ def _get_3d_scatter_plot_from_histogram(hist):
|
|||
],
|
||||
opacity=0.8,
|
||||
),
|
||||
name=title,
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue