bipolaroidbipolaroid/editor/histogram_transfer/regrain.py

81 lines
3 KiB
Python

from scipy.ndimage import zoom
import numpy as np
SMOOTHNESS = 1
NBITS = [4, 16, 32, 64, 64, 64]
def regrain(img_arr_in, img_arr_col):
"""keep gradient of img_arr_in and color of img_arr_col."""
img_arr_in = img_arr_in / 255.0
img_arr_col = img_arr_col / 255.0
img_arr_out = np.array(img_arr_in)
img_arr_out = _regrain_rec(img_arr_out, img_arr_in, img_arr_col, 0)
img_arr_out[img_arr_out < 0] = 0
img_arr_out[img_arr_out > 1] = 1
img_arr_out = (255.0 * img_arr_out).astype("uint8")
return img_arr_out
def _regrain_rec(img_arr_out, img_arr_in, img_arr_col, level):
"""direct translation of matlab code."""
[h, w, _] = img_arr_in.shape
h2 = (h + 1) // 2
w2 = (w + 1) // 2
if len(NBITS) > 1 and h2 > 20 and w2 > 20:
resize_arr_in = _resize_image(img_arr_in, w2, h2)
resize_arr_col = _resize_image(img_arr_col, w2, h2)
resize_arr_out = _resize_image(img_arr_out, w2, h2)
resize_arr_out = _regrain_rec(
resize_arr_out, resize_arr_in, resize_arr_col, NBITS[1:], level + 1
)
img_arr_out = _resize_image(resize_arr_out, w, h)
img_arr_out = _solve(img_arr_out, img_arr_in, img_arr_col, NBITS[0], level)
return img_arr_out
def _solve(img_arr_out, img_arr_in, img_arr_col, nbit, level, eps=1e-6):
[width, height, c] = img_arr_in.shape
first_pad_0 = lambda arr: np.concatenate((arr[:1, :], arr[:-1, :]), axis=0)
first_pad_1 = lambda arr: np.concatenate((arr[:, :1], arr[:, :-1]), axis=1)
last_pad_0 = lambda arr: np.concatenate((arr[1:, :], arr[-1:, :]), axis=0)
last_pad_1 = lambda arr: np.concatenate((arr[:, 1:], arr[:, -1:]), axis=1)
delta_x = last_pad_1(img_arr_in) - first_pad_1(img_arr_in)
delta_y = last_pad_0(img_arr_in) - first_pad_0(img_arr_in)
delta = np.sqrt((delta_x**2 + delta_y**2).sum(axis=2, keepdims=True))
psi = 256 * delta / 5
psi[psi > 1] = 1
phi = 30 * 2 ** (-level) / (1 + 10 * delta / SMOOTHNESS)
phi1 = (last_pad_1(phi) + phi) / 2
phi2 = (last_pad_0(phi) + phi) / 2
phi3 = (first_pad_1(phi) + phi) / 2
phi4 = (first_pad_0(phi) + phi) / 2
rho = 1 / 5.0
for i in range(nbit):
den = psi + phi1 + phi2 + phi3 + phi4
num = (
np.tile(psi, [1, 1, c]) * img_arr_col
+ np.tile(phi1, [1, 1, c])
* (last_pad_1(img_arr_out) - last_pad_1(img_arr_in) + img_arr_in)
+ np.tile(phi2, [1, 1, c])
* (last_pad_0(img_arr_out) - last_pad_0(img_arr_in) + img_arr_in)
+ np.tile(phi3, [1, 1, c])
* (first_pad_1(img_arr_out) - first_pad_1(img_arr_in) + img_arr_in)
+ np.tile(phi4, [1, 1, c])
* (first_pad_0(img_arr_out) - first_pad_0(img_arr_in) + img_arr_in)
)
img_arr_out = (
num / np.tile(den + eps, [1, 1, c]) * (1 - rho) + rho * img_arr_out
)
return img_arr_out
def _resize_image(data, target_width, target_height):
return zoom(data, (target_height / data.shape[0], target_width / data.shape[1], 1))