minor improvements
This commit is contained in:
parent
6353ae7f78
commit
703019c4cd
8 changed files with 123 additions and 214 deletions
|
|
@ -1 +1,2 @@
|
|||
from .color_transfer import ColorTransfer
|
||||
from .regrain import regrain
|
||||
from .pdf_transfer_1d import pdf_transfer_1d
|
||||
|
|
|
|||
|
|
@ -1,209 +0,0 @@
|
|||
from scipy.ndimage import zoom
|
||||
import numpy as np
|
||||
from scipy.stats import special_ortho_group
|
||||
|
||||
|
||||
class ColorTransfer:
|
||||
def __init__(
|
||||
self,
|
||||
iteration_count: int = 10,
|
||||
histogram_dimensions: int = 3,
|
||||
eps=1e-6,
|
||||
):
|
||||
self.eps = eps
|
||||
self.rotation_matrices = [
|
||||
special_ortho_group.rvs(dim=histogram_dimensions, random_state=i * 67)
|
||||
for i in range(iteration_count)
|
||||
]
|
||||
self.RG = Regrain()
|
||||
|
||||
def __call__(self, img_arr_in, img_arr_ref, regrain=False):
|
||||
"""Apply probability density function transfer.
|
||||
|
||||
img_o = t(img_i) so that f_{t(img_i)}(r, g, b) = f_{img_r}(r, g, b),
|
||||
where f_{img}(r, g, b) is the probability density function of img's rgb values.
|
||||
Args:
|
||||
img_arr_in: bgr numpy array of input image.
|
||||
img_arr_ref: bgr numpy array of reference image.
|
||||
Returns:
|
||||
img_arr_out: transfered bgr numpy array of input image.
|
||||
"""
|
||||
|
||||
# reshape (h, w, c) to normalized (c, h*w)
|
||||
[h, w, c] = img_arr_in.shape
|
||||
reshape_arr_in = img_arr_in.reshape(-1, c).transpose() / 255.0
|
||||
reshape_arr_ref = img_arr_ref.reshape(-1, c).transpose() / 255.0
|
||||
|
||||
# pdf transfer
|
||||
reshape_arr_out = self.pdf_transfer_nd(
|
||||
arr_in=reshape_arr_in, arr_ref=reshape_arr_ref, step_size=0.2
|
||||
)
|
||||
|
||||
# reshape (c, h*w) to (h, w, c)
|
||||
reshape_arr_out[reshape_arr_out < 0] = 0
|
||||
reshape_arr_out[reshape_arr_out > 1] = 1
|
||||
reshape_arr_out = (255.0 * reshape_arr_out).astype("uint8")
|
||||
img_arr_out = reshape_arr_out.transpose().reshape(h, w, c)
|
||||
|
||||
if regrain:
|
||||
img_arr_out = self.RG.regrain(
|
||||
img_arr_in=img_arr_in, img_arr_col=img_arr_out
|
||||
)
|
||||
|
||||
return img_arr_out
|
||||
|
||||
def pdf_transfer_nd(self, arr_in=None, arr_ref=None, step_size=1):
|
||||
"""Apply n-dim probability density function transfer.
|
||||
|
||||
Args:
|
||||
arr_in: shape=(n, x).
|
||||
arr_ref: shape=(n, x).
|
||||
step_size: arr = arr + step_size * delta_arr.
|
||||
Returns:
|
||||
arr_out: shape=(n, x).
|
||||
"""
|
||||
# n times of 1d-pdf-transfer
|
||||
arr_out = np.array(arr_in)
|
||||
for rotation_matrix in self.rotation_matrices:
|
||||
rot_arr_in = np.matmul(rotation_matrix, arr_out)
|
||||
rot_arr_ref = np.matmul(rotation_matrix, arr_ref)
|
||||
rot_arr_out = np.zeros(rot_arr_in.shape)
|
||||
for i in range(rot_arr_out.shape[0]):
|
||||
rot_arr_out[i] = self._pdf_transfer_1d(rot_arr_in[i], rot_arr_ref[i])
|
||||
rot_delta_arr = rot_arr_out - rot_arr_in
|
||||
delta_arr = np.matmul(
|
||||
rotation_matrix.transpose(), rot_delta_arr
|
||||
) # np.linalg.solve(rotation_matrix, rot_delta_arr)
|
||||
arr_out = step_size * delta_arr + arr_out
|
||||
return arr_out
|
||||
|
||||
# def _pdf_transfer_1d(self, arr_in: np.ndarray, arr_ref: np.ndarray):
|
||||
# nbins = max(arr_in.shape)
|
||||
# eps = 1e-6 # small damping term that facilitates the inversion
|
||||
|
||||
# PX = np.cumsum(arr_in + eps)
|
||||
# PX /= PX[-1]
|
||||
|
||||
# PY = np.cumsum(arr_ref + eps)
|
||||
# PY /= PY[-1]
|
||||
|
||||
# f = np.interp(PX, PY, np.arange(nbins, ))
|
||||
|
||||
# # f[PX <= PY[0]] = 0
|
||||
# # f[PX >= PY[-1]] = nbins - 1
|
||||
# return f
|
||||
|
||||
def _pdf_transfer_1d(self, arr_in=None, arr_ref=None, n=300):
|
||||
"""Apply 1-dim probability density function transfer.
|
||||
|
||||
Args:
|
||||
arr_in: 1d numpy input array.
|
||||
arr_ref: 1d numpy reference array.
|
||||
n: discretization num of distribution of image's pixels.
|
||||
Returns:
|
||||
arr_out: transfered input array.
|
||||
"""
|
||||
|
||||
arr = np.concatenate((arr_in, arr_ref))
|
||||
# discretization as histogram
|
||||
min_v = arr.min() - self.eps
|
||||
max_v = arr.max() + self.eps
|
||||
xs = np.array([min_v + (max_v - min_v) * i / n for i in range(n + 1)])
|
||||
hist_in, _ = np.histogram(arr_in, xs)
|
||||
hist_ref, _ = np.histogram(arr_ref, xs)
|
||||
xs = xs[:-1]
|
||||
# compute probability distribution
|
||||
cum_in = np.cumsum(hist_in)
|
||||
cum_ref = np.cumsum(hist_ref)
|
||||
d_in = cum_in / cum_in[-1]
|
||||
d_ref = cum_ref / cum_ref[-1]
|
||||
# transfer
|
||||
t_d_in = np.interp(d_in, d_ref, xs)
|
||||
t_d_in[d_in <= d_ref[0]] = min_v
|
||||
t_d_in[d_in >= d_ref[-1]] = max_v
|
||||
arr_out = np.interp(arr_in, xs, t_d_in)
|
||||
return arr_out
|
||||
|
||||
|
||||
class Regrain:
|
||||
def __init__(self, smoothness=1):
|
||||
"""To understand the meaning of these params, refer to paper07."""
|
||||
self.nbits = [4, 16, 32, 64, 64, 64]
|
||||
self.smoothness = smoothness
|
||||
self.level = 0
|
||||
|
||||
def regrain(self, img_arr_in=None, img_arr_col=None):
|
||||
"""keep gradient of img_arr_in and color of img_arr_col."""
|
||||
|
||||
img_arr_in = img_arr_in / 255.0
|
||||
img_arr_col = img_arr_col / 255.0
|
||||
img_arr_out = np.array(img_arr_in)
|
||||
img_arr_out = self.regrain_rec(
|
||||
img_arr_out, img_arr_in, img_arr_col, self.nbits, self.level
|
||||
)
|
||||
img_arr_out[img_arr_out < 0] = 0
|
||||
img_arr_out[img_arr_out > 1] = 1
|
||||
img_arr_out = (255.0 * img_arr_out).astype("uint8")
|
||||
return img_arr_out
|
||||
|
||||
def regrain_rec(self, img_arr_out, img_arr_in, img_arr_col, nbits, level):
|
||||
"""direct translation of matlab code."""
|
||||
|
||||
[h, w, _] = img_arr_in.shape
|
||||
h2 = (h + 1) // 2
|
||||
w2 = (w + 1) // 2
|
||||
if len(nbits) > 1 and h2 > 20 and w2 > 20:
|
||||
resize_arr_in = resize_image(img_arr_in, w2, h2)
|
||||
resize_arr_col = resize_image(img_arr_col, w2, h2)
|
||||
resize_arr_out = resize_image(img_arr_out, w2, h2)
|
||||
resize_arr_out = self.regrain_rec(
|
||||
resize_arr_out, resize_arr_in, resize_arr_col, nbits[1:], level + 1
|
||||
)
|
||||
img_arr_out = resize_image(resize_arr_out, w, h)
|
||||
img_arr_out = self.solve(img_arr_out, img_arr_in, img_arr_col, nbits[0], level)
|
||||
return img_arr_out
|
||||
|
||||
def solve(self, img_arr_out, img_arr_in, img_arr_col, nbit, level, eps=1e-6):
|
||||
"""direct translation of matlab code."""
|
||||
|
||||
[width, height, c] = img_arr_in.shape
|
||||
first_pad_0 = lambda arr: np.concatenate((arr[:1, :], arr[:-1, :]), axis=0)
|
||||
first_pad_1 = lambda arr: np.concatenate((arr[:, :1], arr[:, :-1]), axis=1)
|
||||
last_pad_0 = lambda arr: np.concatenate((arr[1:, :], arr[-1:, :]), axis=0)
|
||||
last_pad_1 = lambda arr: np.concatenate((arr[:, 1:], arr[:, -1:]), axis=1)
|
||||
|
||||
delta_x = last_pad_1(img_arr_in) - first_pad_1(img_arr_in)
|
||||
delta_y = last_pad_0(img_arr_in) - first_pad_0(img_arr_in)
|
||||
delta = np.sqrt((delta_x**2 + delta_y**2).sum(axis=2, keepdims=True))
|
||||
|
||||
psi = 256 * delta / 5
|
||||
psi[psi > 1] = 1
|
||||
phi = 30 * 2 ** (-level) / (1 + 10 * delta / self.smoothness)
|
||||
|
||||
phi1 = (last_pad_1(phi) + phi) / 2
|
||||
phi2 = (last_pad_0(phi) + phi) / 2
|
||||
phi3 = (first_pad_1(phi) + phi) / 2
|
||||
phi4 = (first_pad_0(phi) + phi) / 2
|
||||
|
||||
rho = 1 / 5.0
|
||||
for i in range(nbit):
|
||||
den = psi + phi1 + phi2 + phi3 + phi4
|
||||
num = (
|
||||
np.tile(psi, [1, 1, c]) * img_arr_col
|
||||
+ np.tile(phi1, [1, 1, c])
|
||||
* (last_pad_1(img_arr_out) - last_pad_1(img_arr_in) + img_arr_in)
|
||||
+ np.tile(phi2, [1, 1, c])
|
||||
* (last_pad_0(img_arr_out) - last_pad_0(img_arr_in) + img_arr_in)
|
||||
+ np.tile(phi3, [1, 1, c])
|
||||
* (first_pad_1(img_arr_out) - first_pad_1(img_arr_in) + img_arr_in)
|
||||
+ np.tile(phi4, [1, 1, c])
|
||||
* (first_pad_0(img_arr_out) - first_pad_0(img_arr_in) + img_arr_in)
|
||||
)
|
||||
img_arr_out = (
|
||||
num / np.tile(den + eps, [1, 1, c]) * (1 - rho) + rho * img_arr_out
|
||||
)
|
||||
return img_arr_out
|
||||
|
||||
|
||||
def resize_image(data, target_width, target_height):
|
||||
return zoom(data, (target_height / data.shape[0], target_width / data.shape[1], 1))
|
||||
13
editor/histogram_transfer/pdf_transfer_1d.py
Normal file
13
editor/histogram_transfer/pdf_transfer_1d.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
def pdf_transfer_1d(pX: np.ndarray, pY: np.ndarray) -> np.ndarray:
|
||||
PX = np.cumsum(pX + np.finfo(float).eps)
|
||||
PX /= PX[-1]
|
||||
|
||||
PY = np.cumsum(pY + np.finfo(float).eps)
|
||||
PY /= PY[-1]
|
||||
|
||||
f = np.interp(PX, PY, np.arange(len(pX)))
|
||||
|
||||
return f
|
||||
81
editor/histogram_transfer/regrain.py
Normal file
81
editor/histogram_transfer/regrain.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
from scipy.ndimage import zoom
|
||||
import numpy as np
|
||||
|
||||
|
||||
SMOOTHNESS = 1
|
||||
NBITS = [4, 16, 32, 64, 64, 64]
|
||||
|
||||
|
||||
def regrain(img_arr_in, img_arr_col):
|
||||
"""keep gradient of img_arr_in and color of img_arr_col."""
|
||||
|
||||
img_arr_in = img_arr_in / 255.0
|
||||
img_arr_col = img_arr_col / 255.0
|
||||
img_arr_out = np.array(img_arr_in)
|
||||
img_arr_out = _regrain_rec(img_arr_out, img_arr_in, img_arr_col, 0)
|
||||
img_arr_out[img_arr_out < 0] = 0
|
||||
img_arr_out[img_arr_out > 1] = 1
|
||||
img_arr_out = (255.0 * img_arr_out).astype("uint8")
|
||||
return img_arr_out
|
||||
|
||||
|
||||
def _regrain_rec(img_arr_out, img_arr_in, img_arr_col, level):
|
||||
"""direct translation of matlab code."""
|
||||
|
||||
[h, w, _] = img_arr_in.shape
|
||||
h2 = (h + 1) // 2
|
||||
w2 = (w + 1) // 2
|
||||
if len(NBITS) > 1 and h2 > 20 and w2 > 20:
|
||||
resize_arr_in = _resize_image(img_arr_in, w2, h2)
|
||||
resize_arr_col = _resize_image(img_arr_col, w2, h2)
|
||||
resize_arr_out = _resize_image(img_arr_out, w2, h2)
|
||||
resize_arr_out = _regrain_rec(
|
||||
resize_arr_out, resize_arr_in, resize_arr_col, NBITS[1:], level + 1
|
||||
)
|
||||
img_arr_out = _resize_image(resize_arr_out, w, h)
|
||||
img_arr_out = _solve(img_arr_out, img_arr_in, img_arr_col, NBITS[0], level)
|
||||
return img_arr_out
|
||||
|
||||
|
||||
def _solve(img_arr_out, img_arr_in, img_arr_col, nbit, level, eps=1e-6):
|
||||
[width, height, c] = img_arr_in.shape
|
||||
first_pad_0 = lambda arr: np.concatenate((arr[:1, :], arr[:-1, :]), axis=0)
|
||||
first_pad_1 = lambda arr: np.concatenate((arr[:, :1], arr[:, :-1]), axis=1)
|
||||
last_pad_0 = lambda arr: np.concatenate((arr[1:, :], arr[-1:, :]), axis=0)
|
||||
last_pad_1 = lambda arr: np.concatenate((arr[:, 1:], arr[:, -1:]), axis=1)
|
||||
|
||||
delta_x = last_pad_1(img_arr_in) - first_pad_1(img_arr_in)
|
||||
delta_y = last_pad_0(img_arr_in) - first_pad_0(img_arr_in)
|
||||
delta = np.sqrt((delta_x**2 + delta_y**2).sum(axis=2, keepdims=True))
|
||||
|
||||
psi = 256 * delta / 5
|
||||
psi[psi > 1] = 1
|
||||
phi = 30 * 2 ** (-level) / (1 + 10 * delta / SMOOTHNESS)
|
||||
|
||||
phi1 = (last_pad_1(phi) + phi) / 2
|
||||
phi2 = (last_pad_0(phi) + phi) / 2
|
||||
phi3 = (first_pad_1(phi) + phi) / 2
|
||||
phi4 = (first_pad_0(phi) + phi) / 2
|
||||
|
||||
rho = 1 / 5.0
|
||||
for i in range(nbit):
|
||||
den = psi + phi1 + phi2 + phi3 + phi4
|
||||
num = (
|
||||
np.tile(psi, [1, 1, c]) * img_arr_col
|
||||
+ np.tile(phi1, [1, 1, c])
|
||||
* (last_pad_1(img_arr_out) - last_pad_1(img_arr_in) + img_arr_in)
|
||||
+ np.tile(phi2, [1, 1, c])
|
||||
* (last_pad_0(img_arr_out) - last_pad_0(img_arr_in) + img_arr_in)
|
||||
+ np.tile(phi3, [1, 1, c])
|
||||
* (first_pad_1(img_arr_out) - first_pad_1(img_arr_in) + img_arr_in)
|
||||
+ np.tile(phi4, [1, 1, c])
|
||||
* (first_pad_0(img_arr_out) - first_pad_0(img_arr_in) + img_arr_in)
|
||||
)
|
||||
img_arr_out = (
|
||||
num / np.tile(den + eps, [1, 1, c]) * (1 - rho) + rho * img_arr_out
|
||||
)
|
||||
return img_arr_out
|
||||
|
||||
|
||||
def _resize_image(data, target_width, target_height):
|
||||
return zoom(data, (target_height / data.shape[0], target_width / data.shape[1], 1))
|
||||
|
|
@ -3,3 +3,4 @@ from .random import random
|
|||
from .apply_pixel_shader import apply_pixel_shader
|
||||
from .get_colour_lut import get_colour_lut
|
||||
from .compute_histogram import compute_histogram
|
||||
from .kldiv import kldiv
|
||||
|
|
|
|||
|
|
@ -3,9 +3,12 @@ import numpy as np
|
|||
|
||||
|
||||
def compute_histogram(
|
||||
image: Image, bins: int, value_range=(0, 256), normalize: bool = True
|
||||
):
|
||||
image = np.array(image)
|
||||
image: Image.Image | np.ndarray,
|
||||
bins: int,
|
||||
value_range=(0, 256),
|
||||
normalize: bool = True,
|
||||
) -> np.ndarray:
|
||||
image = np.array(image) if isinstance(image, Image.Image) else image
|
||||
|
||||
histogram, _ = np.histogramdd(
|
||||
image.reshape(-1, 3), bins=bins, range=[value_range, value_range, value_range]
|
||||
|
|
|
|||
11
editor/utils/kldiv.py
Normal file
11
editor/utils/kldiv.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
def kldiv(P: np.ndarray, Q: np.ndarray) -> float:
|
||||
P /= P.sum()
|
||||
Q /= Q.sum()
|
||||
|
||||
P_safe = np.maximum(P, np.finfo(float).eps)
|
||||
Q_safe = np.maximum(Q, np.finfo(float).eps)
|
||||
|
||||
return np.sum(P_safe * np.log(P_safe / Q_safe))
|
||||
|
|
@ -15,12 +15,19 @@ def plot_histograms_in_3d(
|
|||
cols=cols,
|
||||
specs=[[{"type": "scatter3d"} for _ in range(cols)] for _ in range(rows)],
|
||||
)
|
||||
|
||||
for i, (title, histogram) in enumerate(histograms.items()):
|
||||
fig.add_trace(
|
||||
_get_3d_scatter_plot_from_histogram(title, histogram),
|
||||
row=(i // (histogram_per_row + 1)) + 1,
|
||||
col=(i % histogram_per_row) + 1,
|
||||
)
|
||||
|
||||
scenes = {
|
||||
f"scene{i}": dict(camera=dict(eye=dict(x=0.1, y=0, z=2)))
|
||||
for i in range(1, len(histograms) + 1)
|
||||
}
|
||||
fig.update_layout(**scenes)
|
||||
fig.show()
|
||||
|
||||
|
||||
|
|
@ -48,7 +55,8 @@ def _get_3d_scatter_plot_from_histogram(title, histogram):
|
|||
f"rgb({xi*256/bins},{yi*256/bins},{zi*256/bins})"
|
||||
for xi, yi, zi in zip(x, y, z)
|
||||
],
|
||||
opacity=0.8,
|
||||
opacity=1,
|
||||
line=dict(width=0),
|
||||
),
|
||||
name=title,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue