Update helpers
This commit is contained in:
parent
c0b0dacd99
commit
f5c03db198
6 changed files with 140 additions and 4 deletions
2
editor/ploting/__init__.py
Normal file
2
editor/ploting/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
from .display_images import display_images
|
||||
from .plot_histograms import plot_histograms
|
||||
25
editor/ploting/display_images.py
Normal file
25
editor/ploting/display_images.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
import matplotlib.pyplot as plt
|
||||
from typing import List
|
||||
from PIL.Image import Image
|
||||
from math import ceil
|
||||
|
||||
|
||||
def display_images(images: List[Image], titles: List[str], images_per_row: int = 3):
|
||||
fig, axes = plt.subplots(
|
||||
nrows=ceil(len(images) / images_per_row),
|
||||
ncols=min(images_per_row, len(images)),
|
||||
figsize=(12, 8),
|
||||
)
|
||||
|
||||
axes = axes.flatten()
|
||||
|
||||
for i, (title, image) in enumerate(zip(titles, images)):
|
||||
axes[i].imshow(image)
|
||||
axes[i].axis("off")
|
||||
axes[i].set_title(title)
|
||||
|
||||
for i in range(len(images), len(axes)):
|
||||
axes[i].axis("off")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
51
editor/ploting/plot_histograms.py
Normal file
51
editor/ploting/plot_histograms.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
from plotly.subplots import make_subplots
|
||||
import plotly.graph_objects as go
|
||||
from math import ceil
|
||||
|
||||
|
||||
def plot_histograms(hists, histogram_per_row: int = 3):
|
||||
cols = min(histogram_per_row, len(hists))
|
||||
fig = make_subplots(
|
||||
rows=ceil(len(hists) / histogram_per_row),
|
||||
cols=cols,
|
||||
specs=[[{"type": "scatter3d"} for _ in range(cols)] for _ in range(1)],
|
||||
)
|
||||
for i, hist in enumerate(hists, start=1):
|
||||
fig.add_trace(_get_3d_scatter_plot_from_histogram(hist), row=1, col=i)
|
||||
|
||||
fig.update_layout(
|
||||
width=1200,
|
||||
height=600,
|
||||
scene1=dict(xaxis_title="R", yaxis_title="G", zaxis_title="B"),
|
||||
scene2=dict(xaxis_title="R", yaxis_title="G", zaxis_title="B"),
|
||||
)
|
||||
fig.show()
|
||||
|
||||
|
||||
def _get_3d_scatter_plot_from_histogram(hist):
|
||||
x, y, z, marker_size = [], [], [], []
|
||||
bins = len(hist)
|
||||
|
||||
for i, row in enumerate(hist):
|
||||
for j, col in enumerate(row):
|
||||
for k, value in enumerate(col):
|
||||
if value > 0:
|
||||
x.append(i)
|
||||
y.append(j)
|
||||
z.append(k)
|
||||
marker_size.append(value)
|
||||
|
||||
return go.Scatter3d(
|
||||
x=x,
|
||||
y=y,
|
||||
z=z,
|
||||
mode="markers",
|
||||
marker=dict(
|
||||
size=[min(20, ms * 10000) for ms in marker_size],
|
||||
color=[
|
||||
f"rgb({xi*256/bins},{yi*256/bins},{zi*256/bins})"
|
||||
for xi, yi, zi in zip(x, y, z)
|
||||
],
|
||||
opacity=0.8,
|
||||
),
|
||||
)
|
||||
1
editor/training/__init__.py
Normal file
1
editor/training/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .histogram_dataset import HistogramDataset
|
||||
53
editor/training/histogram_dataset.py
Normal file
53
editor/training/histogram_dataset.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
from torch.utils.data import Dataset
|
||||
from typing import Generator, Tuple, List
|
||||
from editor.utils import compute_histogram
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class HistogramDataset(Dataset):
|
||||
def __init__(
|
||||
self, paths: List[Path], expected_edit_count: int = 5, bin_count: int = 32
|
||||
):
|
||||
self._paths = paths
|
||||
self._expected_edit_count = expected_edit_count
|
||||
self._bin_count = bin_count
|
||||
self._pairs = list(self._get_pairs())
|
||||
|
||||
def _get_pairs(self) -> Generator[Tuple[Path, Path], None, None]:
|
||||
for path in tqdm(self._paths):
|
||||
if len(list(path.glob("*.jpg"))) != self._expected_edit_count + 1:
|
||||
continue
|
||||
|
||||
original_path = path / "original.jpg"
|
||||
try:
|
||||
Image.open(original_path)
|
||||
except:
|
||||
print(f"Failed to open {original_path}")
|
||||
continue
|
||||
yield original_path, original_path # The model should leave the original image unchanged
|
||||
for i in range(self._expected_edit_count):
|
||||
try:
|
||||
Image.open(path / f"{i}.jpg")
|
||||
except:
|
||||
print(f'Failed to open {path / f"{i}.jpg"}')
|
||||
break
|
||||
yield original_path, path / f"{i}.jpg"
|
||||
|
||||
def __len__(self):
|
||||
return len(self._pairs)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
original, edited = self._pairs[idx]
|
||||
original_histogram = compute_histogram(
|
||||
original, bins=self._bin_count, normalize=True
|
||||
)
|
||||
edited_histogram = compute_histogram(
|
||||
edited, bins=self._bin_count, normalize=True
|
||||
)
|
||||
return (
|
||||
torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0),
|
||||
torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0),
|
||||
)
|
||||
|
|
@ -2,13 +2,17 @@ from PIL import Image
|
|||
import numpy as np
|
||||
|
||||
|
||||
def compute_histogram(image_path, bins: int, value_range=(0, 256)):
|
||||
def compute_histogram(
|
||||
image_path, bins: int, value_range=(0, 256), normalize: bool = True
|
||||
):
|
||||
image = Image.open(image_path)
|
||||
image = np.array(image)
|
||||
|
||||
histogram, _ = np.histogramdd(
|
||||
image.reshape(-1, 3), bins=bins, range=[value_range, value_range, value_range]
|
||||
)
|
||||
histogram = histogram / np.sum(histogram)
|
||||
).astype(np.float64)
|
||||
|
||||
return histogram.astype(np.float32)
|
||||
if normalize:
|
||||
histogram = histogram / np.sum(histogram)
|
||||
|
||||
return histogram
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue