move files

This commit is contained in:
Andras Schmelczer 2024-05-09 21:22:28 +01:00
parent 1a41fd6829
commit 231e22cac8
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
36 changed files with 15580 additions and 79653 deletions

0
src/editor/__init__.py Normal file
View file

View file

@ -0,0 +1,3 @@
from .regrain import regrain
from .pdf_transfer_1d import pdf_transfer_1d
from .pdf_transfer_3d import pdf_transfer_3d

View file

@ -0,0 +1,13 @@
import numpy as np
def pdf_transfer_1d(pX: np.ndarray, pY: np.ndarray) -> np.ndarray:
PX = np.cumsum(pX + np.finfo(float).eps)
PX /= PX[-1]
PY = np.cumsum(pY + np.finfo(float).eps)
PY /= PY[-1]
f = np.interp(PX, PY, np.arange(len(pX)))
return f

View file

@ -0,0 +1,46 @@
import numpy as np
from editor.utils import generate_rotation_matrices
from editor.histogram_transfer import pdf_transfer_1d
from editor.histogram_transfer import regrain
EPSILON = 1e-6
def pdf_transfer_3d(
source: np.ndarray,
target_flattened: np.ndarray,
relaxation: float = 1,
bin_count: int = 1000,
iterations: int = 25,
smoothness: float = 1,
):
[h, w, c] = source.shape
source_flattened = source.reshape(-1, c).transpose()
rotation_matrices = generate_rotation_matrices(iterations)
for i, rotation in enumerate(rotation_matrices, start=1):
D0R = rotation @ source_flattened
D1R = rotation @ target_flattened
D0R_ = np.zeros_like(source_flattened)
for i in range(rotation.shape[0]):
datamin = min(np.min(D0R[i, :]), np.min(D1R[i, :])) - EPSILON
datamax = max(np.max(D0R[i, :]), np.max(D1R[i, :])) + EPSILON
u = np.linspace(datamin, datamax, bin_count)
p0R, _ = np.histogram(D0R[i, :], bins=u, density=True)
p1R, _ = np.histogram(D1R[i, :], bins=u, density=True)
f = pdf_transfer_1d(p0R, p1R)
mapped_values = (
np.interp(D0R[i, :], u[:-1], f) * (datamax - datamin) / (bin_count - 1)
+ datamin
)
D0R_[i, :] = mapped_values
source_flattened = source_flattened + relaxation * (rotation.T @ (D0R_ - D0R))
source_flattened.clip(0, 255, out=source_flattened)
result = source_flattened.astype(np.uint8).transpose().reshape(h, w, c)
return regrain(source, result, smoothness=smoothness)

View file

@ -0,0 +1,87 @@
from scipy.ndimage import zoom
import numpy as np
NBITS = [4, 16, 32, 64, 64, 64]
def regrain(img_arr_in, img_arr_col, smoothness: float = 1):
"""keep gradient of img_arr_in and color of img_arr_col."""
img_arr_in = img_arr_in / 255.0
img_arr_col = img_arr_col / 255.0
img_arr_out = np.array(img_arr_in)
img_arr_out = _regrain_rec(
img_arr_out, img_arr_in, img_arr_col, NBITS, 0, smoothness
)
img_arr_out[img_arr_out < 0] = 0
img_arr_out[img_arr_out > 1] = 1
img_arr_out = (255.0 * img_arr_out).astype("uint8")
return img_arr_out
def _regrain_rec(img_arr_out, img_arr_in, img_arr_col, nbits, level, smoothness):
[h, w, _] = img_arr_in.shape
h2 = (h + 1) // 2
w2 = (w + 1) // 2
if len(nbits) > 1 and h2 > 20 and w2 > 20:
resize_arr_in = _resize_image(img_arr_in, w2, h2)
resize_arr_col = _resize_image(img_arr_col, w2, h2)
resize_arr_out = _resize_image(img_arr_out, w2, h2)
resize_arr_out = _regrain_rec(
resize_arr_out,
resize_arr_in,
resize_arr_col,
nbits[1:],
level + 1,
smoothness,
)
img_arr_out = _resize_image(resize_arr_out, w, h)
img_arr_out = _solve(
img_arr_out, img_arr_in, img_arr_col, nbits[0], level, smoothness
)
return img_arr_out
def _solve(img_arr_out, img_arr_in, img_arr_col, nbit, level, smoothness):
[width, height, c] = img_arr_in.shape
first_pad_0 = lambda arr: np.concatenate((arr[:1, :], arr[:-1, :]), axis=0)
first_pad_1 = lambda arr: np.concatenate((arr[:, :1], arr[:, :-1]), axis=1)
last_pad_0 = lambda arr: np.concatenate((arr[1:, :], arr[-1:, :]), axis=0)
last_pad_1 = lambda arr: np.concatenate((arr[:, 1:], arr[:, -1:]), axis=1)
delta_x = last_pad_1(img_arr_in) - first_pad_1(img_arr_in)
delta_y = last_pad_0(img_arr_in) - first_pad_0(img_arr_in)
delta = np.sqrt((delta_x**2 + delta_y**2).sum(axis=2, keepdims=True))
psi = 256 * delta / 5
psi[psi > 1] = 1
phi = 30 * 2 ** (-level) / (1 + 10 * delta / smoothness)
phi1 = (last_pad_1(phi) + phi) / 2
phi2 = (last_pad_0(phi) + phi) / 2
phi3 = (first_pad_1(phi) + phi) / 2
phi4 = (first_pad_0(phi) + phi) / 2
rho = 1 / 5.0
for i in range(nbit):
den = psi + phi1 + phi2 + phi3 + phi4
num = (
np.tile(psi, [1, 1, c]) * img_arr_col
+ np.tile(phi1, [1, 1, c])
* (last_pad_1(img_arr_out) - last_pad_1(img_arr_in) + img_arr_in)
+ np.tile(phi2, [1, 1, c])
* (last_pad_0(img_arr_out) - last_pad_0(img_arr_in) + img_arr_in)
+ np.tile(phi3, [1, 1, c])
* (first_pad_1(img_arr_out) - first_pad_1(img_arr_in) + img_arr_in)
+ np.tile(phi4, [1, 1, c])
* (first_pad_0(img_arr_out) - first_pad_0(img_arr_in) + img_arr_in)
)
img_arr_out = (
num / np.tile(den + 1e-6, [1, 1, c]) * (1 - rho) + rho * img_arr_out
)
return img_arr_out
def _resize_image(data, target_width, target_height):
return zoom(data, (target_height / data.shape[0], target_width / data.shape[1], 1))

View file

View file

@ -0,0 +1,3 @@
from .add_noise import add_noise
from .change_temperature import change_temperature
from .add_random_colour_spill import add_random_colour_spill

View file

@ -0,0 +1,11 @@
import numpy as np
from PIL import Image
def add_noise(img: Image, alpha: float) -> Image:
img = img.convert("RGB")
width, height = img.size
random_colors = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
random_img = Image.fromarray(random_colors, mode="RGB")
result = Image.blend(img, random_img, alpha)
return result

View file

@ -0,0 +1,20 @@
from PIL import Image
from ..utils import random
def add_random_colour_spill(image: Image, range: float) -> Image:
matrix = (
random(1 / range, range),
0.0,
0.0,
0.0,
0.0,
random(1 / range, range),
0.0,
0.0,
0.0,
0.0,
random(1 / range, range),
0.0,
)
return image.convert("RGB", matrix)

View file

@ -0,0 +1,42 @@
from PIL import Image
kelvin_table = {
1000: (255, 56, 0),
1500: (255, 109, 0),
2000: (255, 137, 18),
2500: (255, 161, 72),
3000: (255, 180, 107),
3500: (255, 196, 137),
4000: (255, 209, 163),
4500: (255, 219, 186),
5000: (255, 228, 206),
5500: (255, 236, 224),
6000: (255, 243, 239),
6500: (255, 249, 253),
7000: (245, 243, 255),
7500: (235, 238, 255),
8000: (227, 233, 255),
8500: (220, 229, 255),
9000: (214, 225, 255),
9500: (208, 222, 255),
10000: (204, 219, 255),
}
def change_temperature(image: Image, temperature: float) -> Image:
r, g, b = kelvin_table[temperature]
matrix = (
r / 255.0,
0.0,
0.0,
0.0,
0.0,
g / 255.0,
0.0,
0.0,
0.0,
0.0,
b / 255.0,
0.0,
)
return image.convert("RGB", matrix)

View file

@ -0,0 +1,3 @@
from .histogram_dataset import HistogramDataset
from .random_edit import random_edit
from .progressive_pooling_loss import ProgressivePoolingLoss

View file

@ -0,0 +1,89 @@
from torch.utils.data import Dataset
from typing import List, Optional, Tuple
from editor.utils import compute_histogram
from .random_edit import random_edit
from PIL import Image
from tqdm import tqdm
import torch
from pathlib import Path
import PIL.Image
PIL.Image.MAX_IMAGE_PIXELS = None
class HistogramDataset(Dataset):
def __init__(
self,
paths: List[Path],
edit_count: int = 5,
bin_count: int = 32,
target_size=(480, 480),
delete_corrupt_images: bool = False,
cache_path: Optional[Path] = None,
):
self._paths = sorted(paths)
self._edit_count = edit_count
self._bin_count = bin_count
self._target_size = target_size
self._cache_path = cache_path
if delete_corrupt_images:
self._delete_corrupt_images()
def _delete_corrupt_images(self) -> None:
deleted_count = 0
for path in tqdm(self._paths):
try:
Image.open(path)
except:
print(f"Failed to open {path}, deleting...")
deleted_count += 1
path.unlink()
print(f"Deleted {deleted_count} corrupt images")
def __len__(self):
return len(self._paths) * self._edit_count
def get_original_image(self, original_idx: int) -> Image.Image:
original_path = self._paths[original_idx]
original = Image.open(original_path)
original.thumbnail(
self._target_size, Image.Resampling.LANCZOS
) # size will be at most target_size, the aspect ratio is preserved
return original
def get_edited_image(self, original_idx: int, edit_idx: int) -> Image.Image:
original_image = self.get_original_image(original_idx)
return random_edit(original_image, seed=edit_idx)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
if self._cache_path is not None:
self._cached_data_path = self._cache_path / f"{idx}.pt"
if self._cached_data_path.exists():
try:
return torch.load(self._cached_data_path)
except:
print(f"Failed to load {self._cached_data_path}, regenerating...")
original_idx = idx // self._edit_count
original = self.get_original_image(original_idx)
edited = random_edit(original, seed=idx)
edited_histogram = compute_histogram(
edited, bins=self._bin_count, normalize=True
)
original_histogram = compute_histogram(
original, bins=self._bin_count, normalize=True
)
result = (
torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0),
torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0),
)
if self._cache_path is not None:
torch.save(result, self._cached_data_path)
return result

View file

@ -0,0 +1,38 @@
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
class ProgressivePoolingLoss(nn.Module):
def __init__(self, target_sizes: List[int], damping: float):
super(ProgressivePoolingLoss, self).__init__()
self._target_sizes = target_sizes
self._damping = damping
def forward(self, tensor_a, tensor_b):
assert (
tensor_a.size() == tensor_b.size()
), f"Input tensors must have the same size, got {tensor_a.size()} and {tensor_b.size()}"
assert (
len(tensor_a.size()) == 5
), f"Input tensors must have 5 dimensions, got {tensor_a.size()}"
_minibatch_size, _channels, depth, height, width = tensor_a.size()
assert depth == height == width, "Input tensors must be cubes."
loss = 0.0
weight = 1
for target_size in self._target_sizes:
pool_size = depth // target_size
pooled_a = F.avg_pool3d(tensor_a, pool_size) * (pool_size**3)
pooled_b = F.avg_pool3d(tensor_b, pool_size) * (pool_size**3)
diff = torch.abs(pooled_a - pooled_b)
loss += diff.mean() * weight
weight *= self._damping
return loss

View file

@ -0,0 +1,19 @@
from PIL import Image, ImageEnhance
from ..utils import random, get_colour_lut, apply_pixel_shader
from ..operations import add_noise, add_random_colour_spill
import numpy as np
def random_edit(img: Image, seed: int = 42) -> Image:
np.random.seed(seed)
img = add_noise(img, random(0, 0.2))
img = ImageEnhance.Contrast(img).enhance(random(0.5, 2))
img = add_random_colour_spill(img, 1.3)
img = img.convert("HSV")
saturation_lut = get_colour_lut(variance=0.3, count=5, type="linear")
brightness_lut = get_colour_lut(variance=0.3, count=5, type="cubic")
img = apply_pixel_shader(
img, lambda h, s, v: (h, saturation_lut[s], brightness_lut[v])
)
img = img.convert("RGB")
return img

View file

@ -0,0 +1,7 @@
from .interpolate import interpolate
from .random import random
from .apply_pixel_shader import apply_pixel_shader
from .get_colour_lut import get_colour_lut
from .compute_histogram import compute_histogram
from .kldiv import kldiv
from .generate_rotation_matrices import generate_rotation_matrices

View file

@ -0,0 +1,14 @@
from typing import Callable, Tuple
from PIL import Image
def apply_pixel_shader(
img: Image, callback: Callable[[int, int, int], Tuple[int, int, int]]
):
width, height = img.size
pixels = img.load()
for x in range(width):
for y in range(height):
r, g, b = pixels[x, y]
pixels[x, y] = callback(r, g, b)
return img

View file

@ -0,0 +1,22 @@
from PIL import Image
import numpy as np
def compute_histogram(
image: Image.Image | np.ndarray,
bins: int,
value_range=(0, 256),
normalize: bool = True,
) -> np.ndarray:
image = np.array(image) if isinstance(image, Image.Image) else image
histogram, _ = np.histogramdd(
image.reshape(-1, 3), bins=bins, range=[value_range, value_range, value_range]
)
histogram = histogram.astype(np.float32)
if normalize:
histogram = histogram / np.sum(histogram)
return histogram

View file

@ -0,0 +1,66 @@
from random import shuffle
from typing import List, Tuple
import numpy as np
from functools import lru_cache
from numpy.typing import NDArray
@lru_cache
def generate_rotation_matrices(count: int) -> List[NDArray[np.float64]]:
axes = fibonacci_sphere(count)
shuffle(axes)
angles = np.linspace(0, 2 * np.pi, count, endpoint=False)
matrices = [_rotation_matrix(axis, angle) for axis, angle in zip(axes, angles)]
for matrix in matrices:
_check_rotation_matrix(matrix)
return matrices
def fibonacci_sphere(samples: int) -> List[Tuple[float, float, float]]:
points = []
phi = np.pi * (3.0 - np.sqrt(5.0)) # Golden angle in radians
for i in range(samples):
y = 1 - (i / float(samples - 1)) * 2 # y goes from 1 to -1
radius = np.sqrt(1 - y * y) # radius at y
theta = phi * i # golden angle increment
x = np.cos(theta) * radius
z = np.sin(theta) * radius
points.append([x, y, z])
return points
def _rotation_matrix(
axis: Tuple[float, float, float], theta: float
) -> NDArray[np.float64]:
axis = np.asarray(axis)
axis = axis / np.sqrt(np.dot(axis, axis))
a = np.cos(theta / 2.0)
b, c, d = -axis * np.sin(theta / 2.0)
aa, bb, cc, dd = a * a, b * b, c * c, d * d
bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
return np.array(
[
[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
[2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
[2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc],
]
)
def _check_rotation_matrix(R: NDArray[np.float64]):
# Check if the matrix is square
if R.shape != (3, 3):
raise ValueError("Matrix must be 3x3.")
# Check orthogonality: R.T * R should be close to the identity matrix
I = np.eye(3)
if not np.allclose(np.dot(R.T, R), I):
raise ValueError("allclose")
# Check determinant: Should be +1
if not np.isclose(np.linalg.det(R), 1.0):
raise ValueError(f"det {np.linalg.det(R)}")

View file

@ -0,0 +1,21 @@
import numpy as np
from typing import List
from .random import random
from .interpolate import interpolate, INTERPOLATION_TYPE
def get_edit_points(variance: float, count: int) -> List[float]:
return [
random(i / (count - 1) - variance, i / (count - 1) + variance)
for i in range(count)
]
def get_colour_lut(
variance=0.1, count=5, type: INTERPOLATION_TYPE = "cubic"
) -> List[int]:
edit_points = get_edit_points(variance=variance, count=count)
return [
round(interpolate(edit_points, i / 255, type=type) * 255)
for i in np.linspace(0, 255, 256)
]

View file

@ -0,0 +1,35 @@
import numpy as np
from scipy.interpolate import CubicSpline
from typing import List, Literal
INTERPOLATION_TYPE = Literal["cubic", "linear"]
def interpolate(
control_points: List[float], t: float, type: INTERPOLATION_TYPE
) -> float:
control_points = sorted(control_points)
if type == "cubic":
x = np.linspace(0, 1, len(control_points))
cs = CubicSpline(x, control_points)
return cs(t)
if type == "linear":
n = len(control_points) - 1
segment_indices = np.linspace(0, 1, n + 1)
index = np.searchsorted(segment_indices, t, side="right") - 1
if t == 1:
return control_points[-1]
else:
t_normalized = (t - segment_indices[index]) / (
segment_indices[index + 1] - segment_indices[index]
)
return control_points[index] + t_normalized * (
control_points[index + 1] - control_points[index]
)
raise ValueError("Invalid type")

11
src/editor/utils/kldiv.py Normal file
View file

@ -0,0 +1,11 @@
import numpy as np
def kldiv(P: np.ndarray, Q: np.ndarray) -> float:
P /= P.sum()
Q /= Q.sum()
P_safe = np.maximum(P, np.finfo(float).eps)
Q_safe = np.maximum(Q, np.finfo(float).eps)
return np.sum(P_safe * np.log(P_safe / Q_safe))

View file

@ -0,0 +1,10 @@
import numpy as np
def random(min: float = 0, max: float = 1):
mu = (max + min) / 2 # Mean of the distribution
sigma = (
max - min
) / 6 # Standard deviation, chosen so that ~99.7% fall within [min_val, max_val]
sample = np.random.normal(mu, sigma)
return np.clip(sample, min, max)

View file

@ -0,0 +1,3 @@
from .display_images import display_images
from .plot_histograms_in_3d import plot_histograms_in_3d
from .plot_histograms_in_2d import plot_histograms_in_2d

View file

@ -0,0 +1,25 @@
import matplotlib.pyplot as plt
from typing import Dict
from PIL.Image import Image
from math import ceil
def display_images(images: Dict[str, Image], images_per_row: int = 3):
fig, axes = plt.subplots(
nrows=ceil(len(images) / images_per_row),
ncols=min(images_per_row, len(images)),
figsize=(12, 8),
)
axes = axes.flatten()
for i, (title, image) in enumerate(images.items()):
axes[i].imshow(image)
axes[i].axis("off")
axes[i].set_title(title)
for i in range(len(images), len(axes)):
axes[i].axis("off")
plt.tight_layout()
plt.show()

View file

@ -0,0 +1,32 @@
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict
def plot_histograms_in_2d(histograms: Dict[str, np.ndarray]):
fig = plt.figure(figsize=(15, 5))
for i, (title, histogram) in enumerate(histograms.items(), 1):
ax = fig.add_subplot(1, 3, i, projection="3d")
size = histogram.shape[0]
x, y, z = np.indices(histogram.shape)
x = x.flatten()
y = y.flatten()
z = z.flatten()
values = histogram.flatten()
sizes = values * 5000
colors = np.vstack((x, y, z)).T / (size - 1)
sc = ax.scatter(x, y, z, c=colors, s=sizes, marker="o", alpha=0.5)
ax.set_xlim([0, (size - 1)])
ax.set_ylim([0, (size - 1)])
ax.set_zlim([0, (size - 1)])
ax.set_title(title)
return fig

View file

@ -0,0 +1,62 @@
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from math import ceil
from typing import Dict
import numpy as np
def plot_histograms_in_3d(
histograms: Dict[str, np.ndarray], histogram_per_row: int = 3
):
cols = min(histogram_per_row, len(histograms))
rows = ceil(len(histograms) / histogram_per_row)
fig = make_subplots(
rows=rows,
cols=cols,
specs=[[{"type": "scatter3d"} for _ in range(cols)] for _ in range(rows)],
)
for i, (title, histogram) in enumerate(histograms.items()):
fig.add_trace(
_get_3d_scatter_plot_from_histogram(title, histogram),
row=(i // (histogram_per_row + 1)) + 1,
col=(i % histogram_per_row) + 1,
)
scenes = {
f"scene{i}": dict(camera=dict(eye=dict(x=0.1, y=0, z=2)))
for i in range(1, len(histograms) + 1)
}
fig.update_layout(**scenes)
fig.show()
def _get_3d_scatter_plot_from_histogram(title, histogram):
x, y, z, marker_size = [], [], [], []
bins = len(histogram)
for i, row in enumerate(histogram):
for j, col in enumerate(row):
for k, value in enumerate(col):
if value > 0:
x.append(i)
y.append(j)
z.append(k)
marker_size.append(value)
return go.Scatter3d(
x=x,
y=y,
z=z,
mode="markers",
marker=dict(
size=[min(20, ms * 10000) for ms in marker_size],
color=[
f"rgb({xi*256/bins},{yi*256/bins},{zi*256/bins})"
for xi, yi, zi in zip(x, y, z)
],
opacity=1,
line=dict(width=0),
),
name=title,
)