From f5c03db198702b8528195d04a3ebc62af17d2233 Mon Sep 17 00:00:00 2001 From: Andras Schmelczer Date: Mon, 8 Apr 2024 08:02:31 +0100 Subject: [PATCH] Update helpers --- editor/ploting/__init__.py | 2 ++ editor/ploting/display_images.py | 25 +++++++++++++ editor/ploting/plot_histograms.py | 51 ++++++++++++++++++++++++++ editor/training/__init__.py | 1 + editor/training/histogram_dataset.py | 53 ++++++++++++++++++++++++++++ editor/utils/compute_histogram.py | 12 ++++--- 6 files changed, 140 insertions(+), 4 deletions(-) create mode 100644 editor/ploting/__init__.py create mode 100644 editor/ploting/display_images.py create mode 100644 editor/ploting/plot_histograms.py create mode 100644 editor/training/__init__.py create mode 100644 editor/training/histogram_dataset.py diff --git a/editor/ploting/__init__.py b/editor/ploting/__init__.py new file mode 100644 index 0000000..00c2417 --- /dev/null +++ b/editor/ploting/__init__.py @@ -0,0 +1,2 @@ +from .display_images import display_images +from .plot_histograms import plot_histograms diff --git a/editor/ploting/display_images.py b/editor/ploting/display_images.py new file mode 100644 index 0000000..d05d036 --- /dev/null +++ b/editor/ploting/display_images.py @@ -0,0 +1,25 @@ +import matplotlib.pyplot as plt +from typing import List +from PIL.Image import Image +from math import ceil + + +def display_images(images: List[Image], titles: List[str], images_per_row: int = 3): + fig, axes = plt.subplots( + nrows=ceil(len(images) / images_per_row), + ncols=min(images_per_row, len(images)), + figsize=(12, 8), + ) + + axes = axes.flatten() + + for i, (title, image) in enumerate(zip(titles, images)): + axes[i].imshow(image) + axes[i].axis("off") + axes[i].set_title(title) + + for i in range(len(images), len(axes)): + axes[i].axis("off") + + plt.tight_layout() + plt.show() diff --git a/editor/ploting/plot_histograms.py b/editor/ploting/plot_histograms.py new file mode 100644 index 0000000..1159193 --- /dev/null +++ b/editor/ploting/plot_histograms.py @@ -0,0 +1,51 @@ +from plotly.subplots import make_subplots +import plotly.graph_objects as go +from math import ceil + + +def plot_histograms(hists, histogram_per_row: int = 3): + cols = min(histogram_per_row, len(hists)) + fig = make_subplots( + rows=ceil(len(hists) / histogram_per_row), + cols=cols, + specs=[[{"type": "scatter3d"} for _ in range(cols)] for _ in range(1)], + ) + for i, hist in enumerate(hists, start=1): + fig.add_trace(_get_3d_scatter_plot_from_histogram(hist), row=1, col=i) + + fig.update_layout( + width=1200, + height=600, + scene1=dict(xaxis_title="R", yaxis_title="G", zaxis_title="B"), + scene2=dict(xaxis_title="R", yaxis_title="G", zaxis_title="B"), + ) + fig.show() + + +def _get_3d_scatter_plot_from_histogram(hist): + x, y, z, marker_size = [], [], [], [] + bins = len(hist) + + for i, row in enumerate(hist): + for j, col in enumerate(row): + for k, value in enumerate(col): + if value > 0: + x.append(i) + y.append(j) + z.append(k) + marker_size.append(value) + + return go.Scatter3d( + x=x, + y=y, + z=z, + mode="markers", + marker=dict( + size=[min(20, ms * 10000) for ms in marker_size], + color=[ + f"rgb({xi*256/bins},{yi*256/bins},{zi*256/bins})" + for xi, yi, zi in zip(x, y, z) + ], + opacity=0.8, + ), + ) diff --git a/editor/training/__init__.py b/editor/training/__init__.py new file mode 100644 index 0000000..2054ca4 --- /dev/null +++ b/editor/training/__init__.py @@ -0,0 +1 @@ +from .histogram_dataset import HistogramDataset diff --git a/editor/training/histogram_dataset.py b/editor/training/histogram_dataset.py new file mode 100644 index 0000000..9746b0f --- /dev/null +++ b/editor/training/histogram_dataset.py @@ -0,0 +1,53 @@ +from torch.utils.data import Dataset +from typing import Generator, Tuple, List +from editor.utils import compute_histogram +from PIL import Image +from tqdm import tqdm +import torch +from pathlib import Path + + +class HistogramDataset(Dataset): + def __init__( + self, paths: List[Path], expected_edit_count: int = 5, bin_count: int = 32 + ): + self._paths = paths + self._expected_edit_count = expected_edit_count + self._bin_count = bin_count + self._pairs = list(self._get_pairs()) + + def _get_pairs(self) -> Generator[Tuple[Path, Path], None, None]: + for path in tqdm(self._paths): + if len(list(path.glob("*.jpg"))) != self._expected_edit_count + 1: + continue + + original_path = path / "original.jpg" + try: + Image.open(original_path) + except: + print(f"Failed to open {original_path}") + continue + yield original_path, original_path # The model should leave the original image unchanged + for i in range(self._expected_edit_count): + try: + Image.open(path / f"{i}.jpg") + except: + print(f'Failed to open {path / f"{i}.jpg"}') + break + yield original_path, path / f"{i}.jpg" + + def __len__(self): + return len(self._pairs) + + def __getitem__(self, idx): + original, edited = self._pairs[idx] + original_histogram = compute_histogram( + original, bins=self._bin_count, normalize=True + ) + edited_histogram = compute_histogram( + edited, bins=self._bin_count, normalize=True + ) + return ( + torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0), + torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0), + ) diff --git a/editor/utils/compute_histogram.py b/editor/utils/compute_histogram.py index 530d3b5..1af64b8 100644 --- a/editor/utils/compute_histogram.py +++ b/editor/utils/compute_histogram.py @@ -2,13 +2,17 @@ from PIL import Image import numpy as np -def compute_histogram(image_path, bins: int, value_range=(0, 256)): +def compute_histogram( + image_path, bins: int, value_range=(0, 256), normalize: bool = True +): image = Image.open(image_path) image = np.array(image) histogram, _ = np.histogramdd( image.reshape(-1, 3), bins=bins, range=[value_range, value_range, value_range] - ) - histogram = histogram / np.sum(histogram) + ).astype(np.float64) - return histogram.astype(np.float32) + if normalize: + histogram = histogram / np.sum(histogram) + + return histogram