move files
This commit is contained in:
parent
1a41fd6829
commit
231e22cac8
36 changed files with 15580 additions and 79653 deletions
0
src/editor/__init__.py
Normal file
0
src/editor/__init__.py
Normal file
3
src/editor/histogram_transfer/__init__.py
Normal file
3
src/editor/histogram_transfer/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .regrain import regrain
|
||||
from .pdf_transfer_1d import pdf_transfer_1d
|
||||
from .pdf_transfer_3d import pdf_transfer_3d
|
||||
13
src/editor/histogram_transfer/pdf_transfer_1d.py
Normal file
13
src/editor/histogram_transfer/pdf_transfer_1d.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
def pdf_transfer_1d(pX: np.ndarray, pY: np.ndarray) -> np.ndarray:
|
||||
PX = np.cumsum(pX + np.finfo(float).eps)
|
||||
PX /= PX[-1]
|
||||
|
||||
PY = np.cumsum(pY + np.finfo(float).eps)
|
||||
PY /= PY[-1]
|
||||
|
||||
f = np.interp(PX, PY, np.arange(len(pX)))
|
||||
|
||||
return f
|
||||
46
src/editor/histogram_transfer/pdf_transfer_3d.py
Normal file
46
src/editor/histogram_transfer/pdf_transfer_3d.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
import numpy as np
|
||||
from editor.utils import generate_rotation_matrices
|
||||
from editor.histogram_transfer import pdf_transfer_1d
|
||||
from editor.histogram_transfer import regrain
|
||||
|
||||
|
||||
EPSILON = 1e-6
|
||||
|
||||
|
||||
def pdf_transfer_3d(
|
||||
source: np.ndarray,
|
||||
target_flattened: np.ndarray,
|
||||
relaxation: float = 1,
|
||||
bin_count: int = 1000,
|
||||
iterations: int = 25,
|
||||
smoothness: float = 1,
|
||||
):
|
||||
[h, w, c] = source.shape
|
||||
source_flattened = source.reshape(-1, c).transpose()
|
||||
|
||||
rotation_matrices = generate_rotation_matrices(iterations)
|
||||
for i, rotation in enumerate(rotation_matrices, start=1):
|
||||
D0R = rotation @ source_flattened
|
||||
D1R = rotation @ target_flattened
|
||||
D0R_ = np.zeros_like(source_flattened)
|
||||
|
||||
for i in range(rotation.shape[0]):
|
||||
datamin = min(np.min(D0R[i, :]), np.min(D1R[i, :])) - EPSILON
|
||||
datamax = max(np.max(D0R[i, :]), np.max(D1R[i, :])) + EPSILON
|
||||
u = np.linspace(datamin, datamax, bin_count)
|
||||
|
||||
p0R, _ = np.histogram(D0R[i, :], bins=u, density=True)
|
||||
p1R, _ = np.histogram(D1R[i, :], bins=u, density=True)
|
||||
|
||||
f = pdf_transfer_1d(p0R, p1R)
|
||||
mapped_values = (
|
||||
np.interp(D0R[i, :], u[:-1], f) * (datamax - datamin) / (bin_count - 1)
|
||||
+ datamin
|
||||
)
|
||||
D0R_[i, :] = mapped_values
|
||||
|
||||
source_flattened = source_flattened + relaxation * (rotation.T @ (D0R_ - D0R))
|
||||
source_flattened.clip(0, 255, out=source_flattened)
|
||||
|
||||
result = source_flattened.astype(np.uint8).transpose().reshape(h, w, c)
|
||||
return regrain(source, result, smoothness=smoothness)
|
||||
87
src/editor/histogram_transfer/regrain.py
Normal file
87
src/editor/histogram_transfer/regrain.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
from scipy.ndimage import zoom
|
||||
import numpy as np
|
||||
|
||||
|
||||
NBITS = [4, 16, 32, 64, 64, 64]
|
||||
|
||||
|
||||
def regrain(img_arr_in, img_arr_col, smoothness: float = 1):
|
||||
"""keep gradient of img_arr_in and color of img_arr_col."""
|
||||
|
||||
img_arr_in = img_arr_in / 255.0
|
||||
img_arr_col = img_arr_col / 255.0
|
||||
img_arr_out = np.array(img_arr_in)
|
||||
img_arr_out = _regrain_rec(
|
||||
img_arr_out, img_arr_in, img_arr_col, NBITS, 0, smoothness
|
||||
)
|
||||
img_arr_out[img_arr_out < 0] = 0
|
||||
img_arr_out[img_arr_out > 1] = 1
|
||||
img_arr_out = (255.0 * img_arr_out).astype("uint8")
|
||||
return img_arr_out
|
||||
|
||||
|
||||
def _regrain_rec(img_arr_out, img_arr_in, img_arr_col, nbits, level, smoothness):
|
||||
[h, w, _] = img_arr_in.shape
|
||||
h2 = (h + 1) // 2
|
||||
w2 = (w + 1) // 2
|
||||
if len(nbits) > 1 and h2 > 20 and w2 > 20:
|
||||
resize_arr_in = _resize_image(img_arr_in, w2, h2)
|
||||
resize_arr_col = _resize_image(img_arr_col, w2, h2)
|
||||
resize_arr_out = _resize_image(img_arr_out, w2, h2)
|
||||
resize_arr_out = _regrain_rec(
|
||||
resize_arr_out,
|
||||
resize_arr_in,
|
||||
resize_arr_col,
|
||||
nbits[1:],
|
||||
level + 1,
|
||||
smoothness,
|
||||
)
|
||||
img_arr_out = _resize_image(resize_arr_out, w, h)
|
||||
img_arr_out = _solve(
|
||||
img_arr_out, img_arr_in, img_arr_col, nbits[0], level, smoothness
|
||||
)
|
||||
return img_arr_out
|
||||
|
||||
|
||||
def _solve(img_arr_out, img_arr_in, img_arr_col, nbit, level, smoothness):
|
||||
[width, height, c] = img_arr_in.shape
|
||||
first_pad_0 = lambda arr: np.concatenate((arr[:1, :], arr[:-1, :]), axis=0)
|
||||
first_pad_1 = lambda arr: np.concatenate((arr[:, :1], arr[:, :-1]), axis=1)
|
||||
last_pad_0 = lambda arr: np.concatenate((arr[1:, :], arr[-1:, :]), axis=0)
|
||||
last_pad_1 = lambda arr: np.concatenate((arr[:, 1:], arr[:, -1:]), axis=1)
|
||||
|
||||
delta_x = last_pad_1(img_arr_in) - first_pad_1(img_arr_in)
|
||||
delta_y = last_pad_0(img_arr_in) - first_pad_0(img_arr_in)
|
||||
delta = np.sqrt((delta_x**2 + delta_y**2).sum(axis=2, keepdims=True))
|
||||
|
||||
psi = 256 * delta / 5
|
||||
psi[psi > 1] = 1
|
||||
phi = 30 * 2 ** (-level) / (1 + 10 * delta / smoothness)
|
||||
|
||||
phi1 = (last_pad_1(phi) + phi) / 2
|
||||
phi2 = (last_pad_0(phi) + phi) / 2
|
||||
phi3 = (first_pad_1(phi) + phi) / 2
|
||||
phi4 = (first_pad_0(phi) + phi) / 2
|
||||
|
||||
rho = 1 / 5.0
|
||||
for i in range(nbit):
|
||||
den = psi + phi1 + phi2 + phi3 + phi4
|
||||
num = (
|
||||
np.tile(psi, [1, 1, c]) * img_arr_col
|
||||
+ np.tile(phi1, [1, 1, c])
|
||||
* (last_pad_1(img_arr_out) - last_pad_1(img_arr_in) + img_arr_in)
|
||||
+ np.tile(phi2, [1, 1, c])
|
||||
* (last_pad_0(img_arr_out) - last_pad_0(img_arr_in) + img_arr_in)
|
||||
+ np.tile(phi3, [1, 1, c])
|
||||
* (first_pad_1(img_arr_out) - first_pad_1(img_arr_in) + img_arr_in)
|
||||
+ np.tile(phi4, [1, 1, c])
|
||||
* (first_pad_0(img_arr_out) - first_pad_0(img_arr_in) + img_arr_in)
|
||||
)
|
||||
img_arr_out = (
|
||||
num / np.tile(den + 1e-6, [1, 1, c]) * (1 - rho) + rho * img_arr_out
|
||||
)
|
||||
return img_arr_out
|
||||
|
||||
|
||||
def _resize_image(data, target_width, target_height):
|
||||
return zoom(data, (target_height / data.shape[0], target_width / data.shape[1], 1))
|
||||
0
src/editor/image_editor.py
Normal file
0
src/editor/image_editor.py
Normal file
3
src/editor/operations/__init__.py
Normal file
3
src/editor/operations/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .add_noise import add_noise
|
||||
from .change_temperature import change_temperature
|
||||
from .add_random_colour_spill import add_random_colour_spill
|
||||
11
src/editor/operations/add_noise.py
Normal file
11
src/editor/operations/add_noise.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def add_noise(img: Image, alpha: float) -> Image:
|
||||
img = img.convert("RGB")
|
||||
width, height = img.size
|
||||
random_colors = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
||||
random_img = Image.fromarray(random_colors, mode="RGB")
|
||||
result = Image.blend(img, random_img, alpha)
|
||||
return result
|
||||
20
src/editor/operations/add_random_colour_spill.py
Normal file
20
src/editor/operations/add_random_colour_spill.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
from PIL import Image
|
||||
from ..utils import random
|
||||
|
||||
|
||||
def add_random_colour_spill(image: Image, range: float) -> Image:
|
||||
matrix = (
|
||||
random(1 / range, range),
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
random(1 / range, range),
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
random(1 / range, range),
|
||||
0.0,
|
||||
)
|
||||
return image.convert("RGB", matrix)
|
||||
42
src/editor/operations/change_temperature.py
Normal file
42
src/editor/operations/change_temperature.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
from PIL import Image
|
||||
|
||||
kelvin_table = {
|
||||
1000: (255, 56, 0),
|
||||
1500: (255, 109, 0),
|
||||
2000: (255, 137, 18),
|
||||
2500: (255, 161, 72),
|
||||
3000: (255, 180, 107),
|
||||
3500: (255, 196, 137),
|
||||
4000: (255, 209, 163),
|
||||
4500: (255, 219, 186),
|
||||
5000: (255, 228, 206),
|
||||
5500: (255, 236, 224),
|
||||
6000: (255, 243, 239),
|
||||
6500: (255, 249, 253),
|
||||
7000: (245, 243, 255),
|
||||
7500: (235, 238, 255),
|
||||
8000: (227, 233, 255),
|
||||
8500: (220, 229, 255),
|
||||
9000: (214, 225, 255),
|
||||
9500: (208, 222, 255),
|
||||
10000: (204, 219, 255),
|
||||
}
|
||||
|
||||
|
||||
def change_temperature(image: Image, temperature: float) -> Image:
|
||||
r, g, b = kelvin_table[temperature]
|
||||
matrix = (
|
||||
r / 255.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
g / 255.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
b / 255.0,
|
||||
0.0,
|
||||
)
|
||||
return image.convert("RGB", matrix)
|
||||
3
src/editor/training/__init__.py
Normal file
3
src/editor/training/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .histogram_dataset import HistogramDataset
|
||||
from .random_edit import random_edit
|
||||
from .progressive_pooling_loss import ProgressivePoolingLoss
|
||||
89
src/editor/training/histogram_dataset.py
Normal file
89
src/editor/training/histogram_dataset.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
from torch.utils.data import Dataset
|
||||
from typing import List, Optional, Tuple
|
||||
from editor.utils import compute_histogram
|
||||
from .random_edit import random_edit
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
import PIL.Image
|
||||
|
||||
PIL.Image.MAX_IMAGE_PIXELS = None
|
||||
|
||||
|
||||
class HistogramDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
paths: List[Path],
|
||||
edit_count: int = 5,
|
||||
bin_count: int = 32,
|
||||
target_size=(480, 480),
|
||||
delete_corrupt_images: bool = False,
|
||||
cache_path: Optional[Path] = None,
|
||||
):
|
||||
self._paths = sorted(paths)
|
||||
self._edit_count = edit_count
|
||||
self._bin_count = bin_count
|
||||
self._target_size = target_size
|
||||
self._cache_path = cache_path
|
||||
|
||||
if delete_corrupt_images:
|
||||
self._delete_corrupt_images()
|
||||
|
||||
def _delete_corrupt_images(self) -> None:
|
||||
deleted_count = 0
|
||||
for path in tqdm(self._paths):
|
||||
try:
|
||||
Image.open(path)
|
||||
except:
|
||||
print(f"Failed to open {path}, deleting...")
|
||||
deleted_count += 1
|
||||
path.unlink()
|
||||
print(f"Deleted {deleted_count} corrupt images")
|
||||
|
||||
def __len__(self):
|
||||
return len(self._paths) * self._edit_count
|
||||
|
||||
def get_original_image(self, original_idx: int) -> Image.Image:
|
||||
original_path = self._paths[original_idx]
|
||||
original = Image.open(original_path)
|
||||
original.thumbnail(
|
||||
self._target_size, Image.Resampling.LANCZOS
|
||||
) # size will be at most target_size, the aspect ratio is preserved
|
||||
return original
|
||||
|
||||
def get_edited_image(self, original_idx: int, edit_idx: int) -> Image.Image:
|
||||
original_image = self.get_original_image(original_idx)
|
||||
return random_edit(original_image, seed=edit_idx)
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self._cache_path is not None:
|
||||
self._cached_data_path = self._cache_path / f"{idx}.pt"
|
||||
if self._cached_data_path.exists():
|
||||
try:
|
||||
return torch.load(self._cached_data_path)
|
||||
except:
|
||||
print(f"Failed to load {self._cached_data_path}, regenerating...")
|
||||
|
||||
original_idx = idx // self._edit_count
|
||||
original = self.get_original_image(original_idx)
|
||||
edited = random_edit(original, seed=idx)
|
||||
|
||||
edited_histogram = compute_histogram(
|
||||
edited, bins=self._bin_count, normalize=True
|
||||
)
|
||||
|
||||
original_histogram = compute_histogram(
|
||||
original, bins=self._bin_count, normalize=True
|
||||
)
|
||||
|
||||
result = (
|
||||
torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0),
|
||||
torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0),
|
||||
)
|
||||
|
||||
if self._cache_path is not None:
|
||||
torch.save(result, self._cached_data_path)
|
||||
|
||||
return result
|
||||
38
src/editor/training/progressive_pooling_loss.py
Normal file
38
src/editor/training/progressive_pooling_loss.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
from typing import List
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ProgressivePoolingLoss(nn.Module):
|
||||
def __init__(self, target_sizes: List[int], damping: float):
|
||||
super(ProgressivePoolingLoss, self).__init__()
|
||||
self._target_sizes = target_sizes
|
||||
self._damping = damping
|
||||
|
||||
def forward(self, tensor_a, tensor_b):
|
||||
assert (
|
||||
tensor_a.size() == tensor_b.size()
|
||||
), f"Input tensors must have the same size, got {tensor_a.size()} and {tensor_b.size()}"
|
||||
|
||||
assert (
|
||||
len(tensor_a.size()) == 5
|
||||
), f"Input tensors must have 5 dimensions, got {tensor_a.size()}"
|
||||
|
||||
_minibatch_size, _channels, depth, height, width = tensor_a.size()
|
||||
assert depth == height == width, "Input tensors must be cubes."
|
||||
|
||||
loss = 0.0
|
||||
weight = 1
|
||||
|
||||
for target_size in self._target_sizes:
|
||||
pool_size = depth // target_size
|
||||
pooled_a = F.avg_pool3d(tensor_a, pool_size) * (pool_size**3)
|
||||
pooled_b = F.avg_pool3d(tensor_b, pool_size) * (pool_size**3)
|
||||
|
||||
diff = torch.abs(pooled_a - pooled_b)
|
||||
|
||||
loss += diff.mean() * weight
|
||||
weight *= self._damping
|
||||
|
||||
return loss
|
||||
19
src/editor/training/random_edit.py
Normal file
19
src/editor/training/random_edit.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from PIL import Image, ImageEnhance
|
||||
from ..utils import random, get_colour_lut, apply_pixel_shader
|
||||
from ..operations import add_noise, add_random_colour_spill
|
||||
import numpy as np
|
||||
|
||||
|
||||
def random_edit(img: Image, seed: int = 42) -> Image:
|
||||
np.random.seed(seed)
|
||||
img = add_noise(img, random(0, 0.2))
|
||||
img = ImageEnhance.Contrast(img).enhance(random(0.5, 2))
|
||||
img = add_random_colour_spill(img, 1.3)
|
||||
img = img.convert("HSV")
|
||||
saturation_lut = get_colour_lut(variance=0.3, count=5, type="linear")
|
||||
brightness_lut = get_colour_lut(variance=0.3, count=5, type="cubic")
|
||||
img = apply_pixel_shader(
|
||||
img, lambda h, s, v: (h, saturation_lut[s], brightness_lut[v])
|
||||
)
|
||||
img = img.convert("RGB")
|
||||
return img
|
||||
7
src/editor/utils/__init__.py
Normal file
7
src/editor/utils/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
from .interpolate import interpolate
|
||||
from .random import random
|
||||
from .apply_pixel_shader import apply_pixel_shader
|
||||
from .get_colour_lut import get_colour_lut
|
||||
from .compute_histogram import compute_histogram
|
||||
from .kldiv import kldiv
|
||||
from .generate_rotation_matrices import generate_rotation_matrices
|
||||
14
src/editor/utils/apply_pixel_shader.py
Normal file
14
src/editor/utils/apply_pixel_shader.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from typing import Callable, Tuple
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def apply_pixel_shader(
|
||||
img: Image, callback: Callable[[int, int, int], Tuple[int, int, int]]
|
||||
):
|
||||
width, height = img.size
|
||||
pixels = img.load()
|
||||
for x in range(width):
|
||||
for y in range(height):
|
||||
r, g, b = pixels[x, y]
|
||||
pixels[x, y] = callback(r, g, b)
|
||||
return img
|
||||
22
src/editor/utils/compute_histogram.py
Normal file
22
src/editor/utils/compute_histogram.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
def compute_histogram(
|
||||
image: Image.Image | np.ndarray,
|
||||
bins: int,
|
||||
value_range=(0, 256),
|
||||
normalize: bool = True,
|
||||
) -> np.ndarray:
|
||||
image = np.array(image) if isinstance(image, Image.Image) else image
|
||||
|
||||
histogram, _ = np.histogramdd(
|
||||
image.reshape(-1, 3), bins=bins, range=[value_range, value_range, value_range]
|
||||
)
|
||||
|
||||
histogram = histogram.astype(np.float32)
|
||||
|
||||
if normalize:
|
||||
histogram = histogram / np.sum(histogram)
|
||||
|
||||
return histogram
|
||||
66
src/editor/utils/generate_rotation_matrices.py
Normal file
66
src/editor/utils/generate_rotation_matrices.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
from random import shuffle
|
||||
from typing import List, Tuple
|
||||
import numpy as np
|
||||
from functools import lru_cache
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
@lru_cache
|
||||
def generate_rotation_matrices(count: int) -> List[NDArray[np.float64]]:
|
||||
axes = fibonacci_sphere(count)
|
||||
shuffle(axes)
|
||||
angles = np.linspace(0, 2 * np.pi, count, endpoint=False)
|
||||
matrices = [_rotation_matrix(axis, angle) for axis, angle in zip(axes, angles)]
|
||||
for matrix in matrices:
|
||||
_check_rotation_matrix(matrix)
|
||||
return matrices
|
||||
|
||||
|
||||
def fibonacci_sphere(samples: int) -> List[Tuple[float, float, float]]:
|
||||
points = []
|
||||
phi = np.pi * (3.0 - np.sqrt(5.0)) # Golden angle in radians
|
||||
for i in range(samples):
|
||||
y = 1 - (i / float(samples - 1)) * 2 # y goes from 1 to -1
|
||||
radius = np.sqrt(1 - y * y) # radius at y
|
||||
|
||||
theta = phi * i # golden angle increment
|
||||
|
||||
x = np.cos(theta) * radius
|
||||
z = np.sin(theta) * radius
|
||||
|
||||
points.append([x, y, z])
|
||||
return points
|
||||
|
||||
|
||||
def _rotation_matrix(
|
||||
axis: Tuple[float, float, float], theta: float
|
||||
) -> NDArray[np.float64]:
|
||||
axis = np.asarray(axis)
|
||||
axis = axis / np.sqrt(np.dot(axis, axis))
|
||||
a = np.cos(theta / 2.0)
|
||||
b, c, d = -axis * np.sin(theta / 2.0)
|
||||
|
||||
aa, bb, cc, dd = a * a, b * b, c * c, d * d
|
||||
bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
|
||||
return np.array(
|
||||
[
|
||||
[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
|
||||
[2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
|
||||
[2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _check_rotation_matrix(R: NDArray[np.float64]):
|
||||
# Check if the matrix is square
|
||||
if R.shape != (3, 3):
|
||||
raise ValueError("Matrix must be 3x3.")
|
||||
|
||||
# Check orthogonality: R.T * R should be close to the identity matrix
|
||||
I = np.eye(3)
|
||||
if not np.allclose(np.dot(R.T, R), I):
|
||||
raise ValueError("allclose")
|
||||
|
||||
# Check determinant: Should be +1
|
||||
if not np.isclose(np.linalg.det(R), 1.0):
|
||||
raise ValueError(f"det {np.linalg.det(R)}")
|
||||
21
src/editor/utils/get_colour_lut.py
Normal file
21
src/editor/utils/get_colour_lut.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
import numpy as np
|
||||
from typing import List
|
||||
from .random import random
|
||||
from .interpolate import interpolate, INTERPOLATION_TYPE
|
||||
|
||||
|
||||
def get_edit_points(variance: float, count: int) -> List[float]:
|
||||
return [
|
||||
random(i / (count - 1) - variance, i / (count - 1) + variance)
|
||||
for i in range(count)
|
||||
]
|
||||
|
||||
|
||||
def get_colour_lut(
|
||||
variance=0.1, count=5, type: INTERPOLATION_TYPE = "cubic"
|
||||
) -> List[int]:
|
||||
edit_points = get_edit_points(variance=variance, count=count)
|
||||
return [
|
||||
round(interpolate(edit_points, i / 255, type=type) * 255)
|
||||
for i in np.linspace(0, 255, 256)
|
||||
]
|
||||
35
src/editor/utils/interpolate.py
Normal file
35
src/editor/utils/interpolate.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
import numpy as np
|
||||
from scipy.interpolate import CubicSpline
|
||||
from typing import List, Literal
|
||||
|
||||
|
||||
INTERPOLATION_TYPE = Literal["cubic", "linear"]
|
||||
|
||||
|
||||
def interpolate(
|
||||
control_points: List[float], t: float, type: INTERPOLATION_TYPE
|
||||
) -> float:
|
||||
control_points = sorted(control_points)
|
||||
|
||||
if type == "cubic":
|
||||
x = np.linspace(0, 1, len(control_points))
|
||||
cs = CubicSpline(x, control_points)
|
||||
return cs(t)
|
||||
|
||||
if type == "linear":
|
||||
n = len(control_points) - 1
|
||||
segment_indices = np.linspace(0, 1, n + 1)
|
||||
|
||||
index = np.searchsorted(segment_indices, t, side="right") - 1
|
||||
|
||||
if t == 1:
|
||||
return control_points[-1]
|
||||
else:
|
||||
t_normalized = (t - segment_indices[index]) / (
|
||||
segment_indices[index + 1] - segment_indices[index]
|
||||
)
|
||||
return control_points[index] + t_normalized * (
|
||||
control_points[index + 1] - control_points[index]
|
||||
)
|
||||
|
||||
raise ValueError("Invalid type")
|
||||
11
src/editor/utils/kldiv.py
Normal file
11
src/editor/utils/kldiv.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
def kldiv(P: np.ndarray, Q: np.ndarray) -> float:
|
||||
P /= P.sum()
|
||||
Q /= Q.sum()
|
||||
|
||||
P_safe = np.maximum(P, np.finfo(float).eps)
|
||||
Q_safe = np.maximum(Q, np.finfo(float).eps)
|
||||
|
||||
return np.sum(P_safe * np.log(P_safe / Q_safe))
|
||||
10
src/editor/utils/random.py
Normal file
10
src/editor/utils/random.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
def random(min: float = 0, max: float = 1):
|
||||
mu = (max + min) / 2 # Mean of the distribution
|
||||
sigma = (
|
||||
max - min
|
||||
) / 6 # Standard deviation, chosen so that ~99.7% fall within [min_val, max_val]
|
||||
sample = np.random.normal(mu, sigma)
|
||||
return np.clip(sample, min, max)
|
||||
3
src/editor/visualisation/__init__.py
Normal file
3
src/editor/visualisation/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .display_images import display_images
|
||||
from .plot_histograms_in_3d import plot_histograms_in_3d
|
||||
from .plot_histograms_in_2d import plot_histograms_in_2d
|
||||
25
src/editor/visualisation/display_images.py
Normal file
25
src/editor/visualisation/display_images.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
import matplotlib.pyplot as plt
|
||||
from typing import Dict
|
||||
from PIL.Image import Image
|
||||
from math import ceil
|
||||
|
||||
|
||||
def display_images(images: Dict[str, Image], images_per_row: int = 3):
|
||||
fig, axes = plt.subplots(
|
||||
nrows=ceil(len(images) / images_per_row),
|
||||
ncols=min(images_per_row, len(images)),
|
||||
figsize=(12, 8),
|
||||
)
|
||||
|
||||
axes = axes.flatten()
|
||||
|
||||
for i, (title, image) in enumerate(images.items()):
|
||||
axes[i].imshow(image)
|
||||
axes[i].axis("off")
|
||||
axes[i].set_title(title)
|
||||
|
||||
for i in range(len(images), len(axes)):
|
||||
axes[i].axis("off")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
32
src/editor/visualisation/plot_histograms_in_2d.py
Normal file
32
src/editor/visualisation/plot_histograms_in_2d.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def plot_histograms_in_2d(histograms: Dict[str, np.ndarray]):
|
||||
fig = plt.figure(figsize=(15, 5))
|
||||
|
||||
for i, (title, histogram) in enumerate(histograms.items(), 1):
|
||||
ax = fig.add_subplot(1, 3, i, projection="3d")
|
||||
|
||||
size = histogram.shape[0]
|
||||
|
||||
x, y, z = np.indices(histogram.shape)
|
||||
x = x.flatten()
|
||||
y = y.flatten()
|
||||
z = z.flatten()
|
||||
values = histogram.flatten()
|
||||
|
||||
sizes = values * 5000
|
||||
|
||||
colors = np.vstack((x, y, z)).T / (size - 1)
|
||||
|
||||
sc = ax.scatter(x, y, z, c=colors, s=sizes, marker="o", alpha=0.5)
|
||||
|
||||
ax.set_xlim([0, (size - 1)])
|
||||
ax.set_ylim([0, (size - 1)])
|
||||
ax.set_zlim([0, (size - 1)])
|
||||
ax.set_title(title)
|
||||
|
||||
return fig
|
||||
62
src/editor/visualisation/plot_histograms_in_3d.py
Normal file
62
src/editor/visualisation/plot_histograms_in_3d.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
from plotly.subplots import make_subplots
|
||||
import plotly.graph_objects as go
|
||||
from math import ceil
|
||||
from typing import Dict
|
||||
import numpy as np
|
||||
|
||||
|
||||
def plot_histograms_in_3d(
|
||||
histograms: Dict[str, np.ndarray], histogram_per_row: int = 3
|
||||
):
|
||||
cols = min(histogram_per_row, len(histograms))
|
||||
rows = ceil(len(histograms) / histogram_per_row)
|
||||
fig = make_subplots(
|
||||
rows=rows,
|
||||
cols=cols,
|
||||
specs=[[{"type": "scatter3d"} for _ in range(cols)] for _ in range(rows)],
|
||||
)
|
||||
|
||||
for i, (title, histogram) in enumerate(histograms.items()):
|
||||
fig.add_trace(
|
||||
_get_3d_scatter_plot_from_histogram(title, histogram),
|
||||
row=(i // (histogram_per_row + 1)) + 1,
|
||||
col=(i % histogram_per_row) + 1,
|
||||
)
|
||||
|
||||
scenes = {
|
||||
f"scene{i}": dict(camera=dict(eye=dict(x=0.1, y=0, z=2)))
|
||||
for i in range(1, len(histograms) + 1)
|
||||
}
|
||||
fig.update_layout(**scenes)
|
||||
fig.show()
|
||||
|
||||
|
||||
def _get_3d_scatter_plot_from_histogram(title, histogram):
|
||||
x, y, z, marker_size = [], [], [], []
|
||||
bins = len(histogram)
|
||||
|
||||
for i, row in enumerate(histogram):
|
||||
for j, col in enumerate(row):
|
||||
for k, value in enumerate(col):
|
||||
if value > 0:
|
||||
x.append(i)
|
||||
y.append(j)
|
||||
z.append(k)
|
||||
marker_size.append(value)
|
||||
|
||||
return go.Scatter3d(
|
||||
x=x,
|
||||
y=y,
|
||||
z=z,
|
||||
mode="markers",
|
||||
marker=dict(
|
||||
size=[min(20, ms * 10000) for ms in marker_size],
|
||||
color=[
|
||||
f"rgb({xi*256/bins},{yi*256/bins},{zi*256/bins})"
|
||||
for xi, yi, zi in zip(x, y, z)
|
||||
],
|
||||
opacity=1,
|
||||
line=dict(width=0),
|
||||
),
|
||||
name=title,
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue