diff --git a/editor/histogram_transfer/__init__.py b/editor/histogram_transfer/__init__.py index 8cfa959..649c62b 100644 --- a/editor/histogram_transfer/__init__.py +++ b/editor/histogram_transfer/__init__.py @@ -1 +1,2 @@ -from .color_transfer import ColorTransfer +from .regrain import regrain +from .pdf_transfer_1d import pdf_transfer_1d diff --git a/editor/histogram_transfer/color_transfer.py b/editor/histogram_transfer/color_transfer.py deleted file mode 100644 index a2e6432..0000000 --- a/editor/histogram_transfer/color_transfer.py +++ /dev/null @@ -1,209 +0,0 @@ -from scipy.ndimage import zoom -import numpy as np -from scipy.stats import special_ortho_group - - -class ColorTransfer: - def __init__( - self, - iteration_count: int = 10, - histogram_dimensions: int = 3, - eps=1e-6, - ): - self.eps = eps - self.rotation_matrices = [ - special_ortho_group.rvs(dim=histogram_dimensions, random_state=i * 67) - for i in range(iteration_count) - ] - self.RG = Regrain() - - def __call__(self, img_arr_in, img_arr_ref, regrain=False): - """Apply probability density function transfer. - - img_o = t(img_i) so that f_{t(img_i)}(r, g, b) = f_{img_r}(r, g, b), - where f_{img}(r, g, b) is the probability density function of img's rgb values. - Args: - img_arr_in: bgr numpy array of input image. - img_arr_ref: bgr numpy array of reference image. - Returns: - img_arr_out: transfered bgr numpy array of input image. - """ - - # reshape (h, w, c) to normalized (c, h*w) - [h, w, c] = img_arr_in.shape - reshape_arr_in = img_arr_in.reshape(-1, c).transpose() / 255.0 - reshape_arr_ref = img_arr_ref.reshape(-1, c).transpose() / 255.0 - - # pdf transfer - reshape_arr_out = self.pdf_transfer_nd( - arr_in=reshape_arr_in, arr_ref=reshape_arr_ref, step_size=0.2 - ) - - # reshape (c, h*w) to (h, w, c) - reshape_arr_out[reshape_arr_out < 0] = 0 - reshape_arr_out[reshape_arr_out > 1] = 1 - reshape_arr_out = (255.0 * reshape_arr_out).astype("uint8") - img_arr_out = reshape_arr_out.transpose().reshape(h, w, c) - - if regrain: - img_arr_out = self.RG.regrain( - img_arr_in=img_arr_in, img_arr_col=img_arr_out - ) - - return img_arr_out - - def pdf_transfer_nd(self, arr_in=None, arr_ref=None, step_size=1): - """Apply n-dim probability density function transfer. - - Args: - arr_in: shape=(n, x). - arr_ref: shape=(n, x). - step_size: arr = arr + step_size * delta_arr. - Returns: - arr_out: shape=(n, x). - """ - # n times of 1d-pdf-transfer - arr_out = np.array(arr_in) - for rotation_matrix in self.rotation_matrices: - rot_arr_in = np.matmul(rotation_matrix, arr_out) - rot_arr_ref = np.matmul(rotation_matrix, arr_ref) - rot_arr_out = np.zeros(rot_arr_in.shape) - for i in range(rot_arr_out.shape[0]): - rot_arr_out[i] = self._pdf_transfer_1d(rot_arr_in[i], rot_arr_ref[i]) - rot_delta_arr = rot_arr_out - rot_arr_in - delta_arr = np.matmul( - rotation_matrix.transpose(), rot_delta_arr - ) # np.linalg.solve(rotation_matrix, rot_delta_arr) - arr_out = step_size * delta_arr + arr_out - return arr_out - - # def _pdf_transfer_1d(self, arr_in: np.ndarray, arr_ref: np.ndarray): - # nbins = max(arr_in.shape) - # eps = 1e-6 # small damping term that facilitates the inversion - - # PX = np.cumsum(arr_in + eps) - # PX /= PX[-1] - - # PY = np.cumsum(arr_ref + eps) - # PY /= PY[-1] - - # f = np.interp(PX, PY, np.arange(nbins, )) - - # # f[PX <= PY[0]] = 0 - # # f[PX >= PY[-1]] = nbins - 1 - # return f - - def _pdf_transfer_1d(self, arr_in=None, arr_ref=None, n=300): - """Apply 1-dim probability density function transfer. - - Args: - arr_in: 1d numpy input array. - arr_ref: 1d numpy reference array. - n: discretization num of distribution of image's pixels. - Returns: - arr_out: transfered input array. - """ - - arr = np.concatenate((arr_in, arr_ref)) - # discretization as histogram - min_v = arr.min() - self.eps - max_v = arr.max() + self.eps - xs = np.array([min_v + (max_v - min_v) * i / n for i in range(n + 1)]) - hist_in, _ = np.histogram(arr_in, xs) - hist_ref, _ = np.histogram(arr_ref, xs) - xs = xs[:-1] - # compute probability distribution - cum_in = np.cumsum(hist_in) - cum_ref = np.cumsum(hist_ref) - d_in = cum_in / cum_in[-1] - d_ref = cum_ref / cum_ref[-1] - # transfer - t_d_in = np.interp(d_in, d_ref, xs) - t_d_in[d_in <= d_ref[0]] = min_v - t_d_in[d_in >= d_ref[-1]] = max_v - arr_out = np.interp(arr_in, xs, t_d_in) - return arr_out - - -class Regrain: - def __init__(self, smoothness=1): - """To understand the meaning of these params, refer to paper07.""" - self.nbits = [4, 16, 32, 64, 64, 64] - self.smoothness = smoothness - self.level = 0 - - def regrain(self, img_arr_in=None, img_arr_col=None): - """keep gradient of img_arr_in and color of img_arr_col.""" - - img_arr_in = img_arr_in / 255.0 - img_arr_col = img_arr_col / 255.0 - img_arr_out = np.array(img_arr_in) - img_arr_out = self.regrain_rec( - img_arr_out, img_arr_in, img_arr_col, self.nbits, self.level - ) - img_arr_out[img_arr_out < 0] = 0 - img_arr_out[img_arr_out > 1] = 1 - img_arr_out = (255.0 * img_arr_out).astype("uint8") - return img_arr_out - - def regrain_rec(self, img_arr_out, img_arr_in, img_arr_col, nbits, level): - """direct translation of matlab code.""" - - [h, w, _] = img_arr_in.shape - h2 = (h + 1) // 2 - w2 = (w + 1) // 2 - if len(nbits) > 1 and h2 > 20 and w2 > 20: - resize_arr_in = resize_image(img_arr_in, w2, h2) - resize_arr_col = resize_image(img_arr_col, w2, h2) - resize_arr_out = resize_image(img_arr_out, w2, h2) - resize_arr_out = self.regrain_rec( - resize_arr_out, resize_arr_in, resize_arr_col, nbits[1:], level + 1 - ) - img_arr_out = resize_image(resize_arr_out, w, h) - img_arr_out = self.solve(img_arr_out, img_arr_in, img_arr_col, nbits[0], level) - return img_arr_out - - def solve(self, img_arr_out, img_arr_in, img_arr_col, nbit, level, eps=1e-6): - """direct translation of matlab code.""" - - [width, height, c] = img_arr_in.shape - first_pad_0 = lambda arr: np.concatenate((arr[:1, :], arr[:-1, :]), axis=0) - first_pad_1 = lambda arr: np.concatenate((arr[:, :1], arr[:, :-1]), axis=1) - last_pad_0 = lambda arr: np.concatenate((arr[1:, :], arr[-1:, :]), axis=0) - last_pad_1 = lambda arr: np.concatenate((arr[:, 1:], arr[:, -1:]), axis=1) - - delta_x = last_pad_1(img_arr_in) - first_pad_1(img_arr_in) - delta_y = last_pad_0(img_arr_in) - first_pad_0(img_arr_in) - delta = np.sqrt((delta_x**2 + delta_y**2).sum(axis=2, keepdims=True)) - - psi = 256 * delta / 5 - psi[psi > 1] = 1 - phi = 30 * 2 ** (-level) / (1 + 10 * delta / self.smoothness) - - phi1 = (last_pad_1(phi) + phi) / 2 - phi2 = (last_pad_0(phi) + phi) / 2 - phi3 = (first_pad_1(phi) + phi) / 2 - phi4 = (first_pad_0(phi) + phi) / 2 - - rho = 1 / 5.0 - for i in range(nbit): - den = psi + phi1 + phi2 + phi3 + phi4 - num = ( - np.tile(psi, [1, 1, c]) * img_arr_col - + np.tile(phi1, [1, 1, c]) - * (last_pad_1(img_arr_out) - last_pad_1(img_arr_in) + img_arr_in) - + np.tile(phi2, [1, 1, c]) - * (last_pad_0(img_arr_out) - last_pad_0(img_arr_in) + img_arr_in) - + np.tile(phi3, [1, 1, c]) - * (first_pad_1(img_arr_out) - first_pad_1(img_arr_in) + img_arr_in) - + np.tile(phi4, [1, 1, c]) - * (first_pad_0(img_arr_out) - first_pad_0(img_arr_in) + img_arr_in) - ) - img_arr_out = ( - num / np.tile(den + eps, [1, 1, c]) * (1 - rho) + rho * img_arr_out - ) - return img_arr_out - - -def resize_image(data, target_width, target_height): - return zoom(data, (target_height / data.shape[0], target_width / data.shape[1], 1)) diff --git a/editor/histogram_transfer/pdf_transfer_1d.py b/editor/histogram_transfer/pdf_transfer_1d.py new file mode 100644 index 0000000..71e0925 --- /dev/null +++ b/editor/histogram_transfer/pdf_transfer_1d.py @@ -0,0 +1,13 @@ +import numpy as np + + +def pdf_transfer_1d(pX: np.ndarray, pY: np.ndarray) -> np.ndarray: + PX = np.cumsum(pX + np.finfo(float).eps) + PX /= PX[-1] + + PY = np.cumsum(pY + np.finfo(float).eps) + PY /= PY[-1] + + f = np.interp(PX, PY, np.arange(len(pX))) + + return f diff --git a/editor/histogram_transfer/regrain.py b/editor/histogram_transfer/regrain.py new file mode 100644 index 0000000..093b7d6 --- /dev/null +++ b/editor/histogram_transfer/regrain.py @@ -0,0 +1,81 @@ +from scipy.ndimage import zoom +import numpy as np + + +SMOOTHNESS = 1 +NBITS = [4, 16, 32, 64, 64, 64] + + +def regrain(img_arr_in, img_arr_col): + """keep gradient of img_arr_in and color of img_arr_col.""" + + img_arr_in = img_arr_in / 255.0 + img_arr_col = img_arr_col / 255.0 + img_arr_out = np.array(img_arr_in) + img_arr_out = _regrain_rec(img_arr_out, img_arr_in, img_arr_col, 0) + img_arr_out[img_arr_out < 0] = 0 + img_arr_out[img_arr_out > 1] = 1 + img_arr_out = (255.0 * img_arr_out).astype("uint8") + return img_arr_out + + +def _regrain_rec(img_arr_out, img_arr_in, img_arr_col, level): + """direct translation of matlab code.""" + + [h, w, _] = img_arr_in.shape + h2 = (h + 1) // 2 + w2 = (w + 1) // 2 + if len(NBITS) > 1 and h2 > 20 and w2 > 20: + resize_arr_in = _resize_image(img_arr_in, w2, h2) + resize_arr_col = _resize_image(img_arr_col, w2, h2) + resize_arr_out = _resize_image(img_arr_out, w2, h2) + resize_arr_out = _regrain_rec( + resize_arr_out, resize_arr_in, resize_arr_col, NBITS[1:], level + 1 + ) + img_arr_out = _resize_image(resize_arr_out, w, h) + img_arr_out = _solve(img_arr_out, img_arr_in, img_arr_col, NBITS[0], level) + return img_arr_out + + +def _solve(img_arr_out, img_arr_in, img_arr_col, nbit, level, eps=1e-6): + [width, height, c] = img_arr_in.shape + first_pad_0 = lambda arr: np.concatenate((arr[:1, :], arr[:-1, :]), axis=0) + first_pad_1 = lambda arr: np.concatenate((arr[:, :1], arr[:, :-1]), axis=1) + last_pad_0 = lambda arr: np.concatenate((arr[1:, :], arr[-1:, :]), axis=0) + last_pad_1 = lambda arr: np.concatenate((arr[:, 1:], arr[:, -1:]), axis=1) + + delta_x = last_pad_1(img_arr_in) - first_pad_1(img_arr_in) + delta_y = last_pad_0(img_arr_in) - first_pad_0(img_arr_in) + delta = np.sqrt((delta_x**2 + delta_y**2).sum(axis=2, keepdims=True)) + + psi = 256 * delta / 5 + psi[psi > 1] = 1 + phi = 30 * 2 ** (-level) / (1 + 10 * delta / SMOOTHNESS) + + phi1 = (last_pad_1(phi) + phi) / 2 + phi2 = (last_pad_0(phi) + phi) / 2 + phi3 = (first_pad_1(phi) + phi) / 2 + phi4 = (first_pad_0(phi) + phi) / 2 + + rho = 1 / 5.0 + for i in range(nbit): + den = psi + phi1 + phi2 + phi3 + phi4 + num = ( + np.tile(psi, [1, 1, c]) * img_arr_col + + np.tile(phi1, [1, 1, c]) + * (last_pad_1(img_arr_out) - last_pad_1(img_arr_in) + img_arr_in) + + np.tile(phi2, [1, 1, c]) + * (last_pad_0(img_arr_out) - last_pad_0(img_arr_in) + img_arr_in) + + np.tile(phi3, [1, 1, c]) + * (first_pad_1(img_arr_out) - first_pad_1(img_arr_in) + img_arr_in) + + np.tile(phi4, [1, 1, c]) + * (first_pad_0(img_arr_out) - first_pad_0(img_arr_in) + img_arr_in) + ) + img_arr_out = ( + num / np.tile(den + eps, [1, 1, c]) * (1 - rho) + rho * img_arr_out + ) + return img_arr_out + + +def _resize_image(data, target_width, target_height): + return zoom(data, (target_height / data.shape[0], target_width / data.shape[1], 1)) diff --git a/editor/utils/__init__.py b/editor/utils/__init__.py index 019dd0f..2306aae 100644 --- a/editor/utils/__init__.py +++ b/editor/utils/__init__.py @@ -3,3 +3,4 @@ from .random import random from .apply_pixel_shader import apply_pixel_shader from .get_colour_lut import get_colour_lut from .compute_histogram import compute_histogram +from .kldiv import kldiv diff --git a/editor/utils/compute_histogram.py b/editor/utils/compute_histogram.py index db3b5ae..37abbc1 100644 --- a/editor/utils/compute_histogram.py +++ b/editor/utils/compute_histogram.py @@ -3,9 +3,12 @@ import numpy as np def compute_histogram( - image: Image, bins: int, value_range=(0, 256), normalize: bool = True -): - image = np.array(image) + image: Image.Image | np.ndarray, + bins: int, + value_range=(0, 256), + normalize: bool = True, +) -> np.ndarray: + image = np.array(image) if isinstance(image, Image.Image) else image histogram, _ = np.histogramdd( image.reshape(-1, 3), bins=bins, range=[value_range, value_range, value_range] diff --git a/editor/utils/kldiv.py b/editor/utils/kldiv.py new file mode 100644 index 0000000..1f85d25 --- /dev/null +++ b/editor/utils/kldiv.py @@ -0,0 +1,11 @@ +import numpy as np + + +def kldiv(P: np.ndarray, Q: np.ndarray) -> float: + P /= P.sum() + Q /= Q.sum() + + P_safe = np.maximum(P, np.finfo(float).eps) + Q_safe = np.maximum(Q, np.finfo(float).eps) + + return np.sum(P_safe * np.log(P_safe / Q_safe)) diff --git a/editor/visualisation/plot_histograms_in_3d.py b/editor/visualisation/plot_histograms_in_3d.py index 4b16780..87e79fa 100644 --- a/editor/visualisation/plot_histograms_in_3d.py +++ b/editor/visualisation/plot_histograms_in_3d.py @@ -15,12 +15,19 @@ def plot_histograms_in_3d( cols=cols, specs=[[{"type": "scatter3d"} for _ in range(cols)] for _ in range(rows)], ) + for i, (title, histogram) in enumerate(histograms.items()): fig.add_trace( _get_3d_scatter_plot_from_histogram(title, histogram), row=(i // (histogram_per_row + 1)) + 1, col=(i % histogram_per_row) + 1, ) + + scenes = { + f"scene{i}": dict(camera=dict(eye=dict(x=0.1, y=0, z=2))) + for i in range(1, len(histograms) + 1) + } + fig.update_layout(**scenes) fig.show() @@ -48,7 +55,8 @@ def _get_3d_scatter_plot_from_histogram(title, histogram): f"rgb({xi*256/bins},{yi*256/bins},{zi*256/bins})" for xi, yi, zi in zip(x, y, z) ], - opacity=0.8, + opacity=1, + line=dict(width=0), ), name=title, )