Update helpers

This commit is contained in:
Andras Schmelczer 2024-04-08 08:02:31 +01:00
parent c0b0dacd99
commit f5c03db198
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
6 changed files with 140 additions and 4 deletions

View file

@ -0,0 +1,2 @@
from .display_images import display_images
from .plot_histograms import plot_histograms

View file

@ -0,0 +1,25 @@
import matplotlib.pyplot as plt
from typing import List
from PIL.Image import Image
from math import ceil
def display_images(images: List[Image], titles: List[str], images_per_row: int = 3):
fig, axes = plt.subplots(
nrows=ceil(len(images) / images_per_row),
ncols=min(images_per_row, len(images)),
figsize=(12, 8),
)
axes = axes.flatten()
for i, (title, image) in enumerate(zip(titles, images)):
axes[i].imshow(image)
axes[i].axis("off")
axes[i].set_title(title)
for i in range(len(images), len(axes)):
axes[i].axis("off")
plt.tight_layout()
plt.show()

View file

@ -0,0 +1,51 @@
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from math import ceil
def plot_histograms(hists, histogram_per_row: int = 3):
cols = min(histogram_per_row, len(hists))
fig = make_subplots(
rows=ceil(len(hists) / histogram_per_row),
cols=cols,
specs=[[{"type": "scatter3d"} for _ in range(cols)] for _ in range(1)],
)
for i, hist in enumerate(hists, start=1):
fig.add_trace(_get_3d_scatter_plot_from_histogram(hist), row=1, col=i)
fig.update_layout(
width=1200,
height=600,
scene1=dict(xaxis_title="R", yaxis_title="G", zaxis_title="B"),
scene2=dict(xaxis_title="R", yaxis_title="G", zaxis_title="B"),
)
fig.show()
def _get_3d_scatter_plot_from_histogram(hist):
x, y, z, marker_size = [], [], [], []
bins = len(hist)
for i, row in enumerate(hist):
for j, col in enumerate(row):
for k, value in enumerate(col):
if value > 0:
x.append(i)
y.append(j)
z.append(k)
marker_size.append(value)
return go.Scatter3d(
x=x,
y=y,
z=z,
mode="markers",
marker=dict(
size=[min(20, ms * 10000) for ms in marker_size],
color=[
f"rgb({xi*256/bins},{yi*256/bins},{zi*256/bins})"
for xi, yi, zi in zip(x, y, z)
],
opacity=0.8,
),
)

View file

@ -0,0 +1 @@
from .histogram_dataset import HistogramDataset

View file

@ -0,0 +1,53 @@
from torch.utils.data import Dataset
from typing import Generator, Tuple, List
from editor.utils import compute_histogram
from PIL import Image
from tqdm import tqdm
import torch
from pathlib import Path
class HistogramDataset(Dataset):
def __init__(
self, paths: List[Path], expected_edit_count: int = 5, bin_count: int = 32
):
self._paths = paths
self._expected_edit_count = expected_edit_count
self._bin_count = bin_count
self._pairs = list(self._get_pairs())
def _get_pairs(self) -> Generator[Tuple[Path, Path], None, None]:
for path in tqdm(self._paths):
if len(list(path.glob("*.jpg"))) != self._expected_edit_count + 1:
continue
original_path = path / "original.jpg"
try:
Image.open(original_path)
except:
print(f"Failed to open {original_path}")
continue
yield original_path, original_path # The model should leave the original image unchanged
for i in range(self._expected_edit_count):
try:
Image.open(path / f"{i}.jpg")
except:
print(f'Failed to open {path / f"{i}.jpg"}')
break
yield original_path, path / f"{i}.jpg"
def __len__(self):
return len(self._pairs)
def __getitem__(self, idx):
original, edited = self._pairs[idx]
original_histogram = compute_histogram(
original, bins=self._bin_count, normalize=True
)
edited_histogram = compute_histogram(
edited, bins=self._bin_count, normalize=True
)
return (
torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0),
torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0),
)

View file

@ -2,13 +2,17 @@ from PIL import Image
import numpy as np
def compute_histogram(image_path, bins: int, value_range=(0, 256)):
def compute_histogram(
image_path, bins: int, value_range=(0, 256), normalize: bool = True
):
image = Image.open(image_path)
image = np.array(image)
histogram, _ = np.histogramdd(
image.reshape(-1, 3), bins=bins, range=[value_range, value_range, value_range]
)
histogram = histogram / np.sum(histogram)
).astype(np.float64)
return histogram.astype(np.float32)
if normalize:
histogram = histogram / np.sum(histogram)
return histogram