Optimise pdf transfer
This commit is contained in:
parent
28425d53af
commit
131cc6132b
5 changed files with 198 additions and 18366 deletions
|
|
@ -1,3 +1,2 @@
|
||||||
from .regrain import regrain
|
from .regrain import regrain
|
||||||
from .pdf_transfer_3d import pdf_transfer_3d
|
|
||||||
from .apply_histogram import apply_histogram
|
from .apply_histogram import apply_histogram
|
||||||
|
|
|
||||||
|
|
@ -1,40 +1,112 @@
|
||||||
from histogram_transfer import pdf_transfer_3d
|
from typing import Optional
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.ndimage import zoom
|
from utils import generate_rotation_matrices, compute_histogram
|
||||||
|
from histogram_transfer.pdf_transfer_1d import pdf_transfer_1d
|
||||||
|
from histogram_transfer.regrain import regrain
|
||||||
|
|
||||||
|
|
||||||
def apply_histogram(original_image, target_histogram, bin_count: int):
|
def apply_histogram(
|
||||||
actual_predicted_histogram = target_histogram.cpu().detach().numpy().squeeze()
|
source_img: np.ndarray,
|
||||||
|
target_histogram: np.ndarray,
|
||||||
|
*,
|
||||||
|
iterations: int = 25,
|
||||||
|
source_histogram: Optional[np.ndarray] = None,
|
||||||
|
should_regrain: bool = True,
|
||||||
|
):
|
||||||
|
if not isinstance(source_img, np.ndarray):
|
||||||
|
source_img = np.array(source_img)
|
||||||
|
|
||||||
scale = 64 / bin_count
|
assert (
|
||||||
scaled_predicted_histogram = zoom(actual_predicted_histogram, scale, order=3)
|
target_histogram.shape[0]
|
||||||
scaled_predicted_histogram = (
|
== target_histogram.shape[1]
|
||||||
scaled_predicted_histogram / scaled_predicted_histogram.sum()
|
== target_histogram.shape[2]
|
||||||
|
), "Histograms must be 3D"
|
||||||
|
|
||||||
|
bins = target_histogram.shape[0]
|
||||||
|
assert 256 % bins == 0, "Bin size must be a factor of 256"
|
||||||
|
|
||||||
|
if source_histogram is None:
|
||||||
|
source_histogram = compute_histogram(source_img, bins=bins)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
source_histogram.shape == target_histogram.shape
|
||||||
|
), "Source and target histograms must be the same shape"
|
||||||
|
|
||||||
|
source_colours = np.array(
|
||||||
|
[
|
||||||
|
index for index, value in np.ndenumerate(source_histogram) if value > 0
|
||||||
|
] # we ignore colours without occurences
|
||||||
|
).T # unique rgb colours, 3xN
|
||||||
|
source_colours_original = source_colours.copy()
|
||||||
|
source_counts = np.array(
|
||||||
|
[value for value in np.nditer(source_histogram) if value > 0]
|
||||||
|
) # cardinality of each rgb colour in source_colours, N
|
||||||
|
|
||||||
|
target_histogram_colours = np.array(
|
||||||
|
[index for index, value in np.ndenumerate(target_histogram) if value > 0]
|
||||||
|
).T
|
||||||
|
target_histogram_counts = np.array(
|
||||||
|
[value for value in np.nditer(target_histogram) if value > 0]
|
||||||
)
|
)
|
||||||
|
|
||||||
[h, w, _] = np.array(original_image).shape
|
for rotation in generate_rotation_matrices(iterations):
|
||||||
|
source_colour_adjustment = np.zeros_like(source_colours)
|
||||||
|
rotated_source_colours = rotation @ source_colours
|
||||||
|
rotated_target_histogram_colours = rotation @ target_histogram_colours
|
||||||
|
|
||||||
histogram = np.round(scaled_predicted_histogram * h * w).astype(int)
|
rotated_source_colours.round(out=rotated_source_colours)
|
||||||
|
rotated_target_histogram_colours.round(out=rotated_target_histogram_colours)
|
||||||
|
|
||||||
rgb_vectors = []
|
assert rotation.shape[0] == 3
|
||||||
|
for i in range(rotation.shape[0]): # for each axis (rgb)
|
||||||
for r in range(histogram.shape[0]):
|
sorted_source_colours = sorted(
|
||||||
for g in range(histogram.shape[1]):
|
enumerate(source_counts),
|
||||||
for b in range(histogram.shape[2]):
|
key=lambda x: rotated_source_colours[i, x[0]],
|
||||||
# Append the RGB value 'count' times to the list
|
|
||||||
for _ in range(histogram[r, g, b]):
|
|
||||||
rgb_vectors.append([r, g, b])
|
|
||||||
|
|
||||||
rgb_vectors = np.array(rgb_vectors)
|
|
||||||
np.random.shuffle(rgb_vectors)
|
|
||||||
rgb_vectors = rgb_vectors * 256 / 64
|
|
||||||
|
|
||||||
return pdf_transfer_3d(
|
|
||||||
source=np.array(original_image),
|
|
||||||
target_flattened=rgb_vectors.transpose(),
|
|
||||||
relaxation=0.9,
|
|
||||||
bin_count=3500,
|
|
||||||
iterations=50,
|
|
||||||
smoothness=1,
|
|
||||||
should_regrain=True,
|
|
||||||
)
|
)
|
||||||
|
sorted_source_counts = [v for _, v in sorted_source_colours]
|
||||||
|
sorted_source_indices = [i for i, _ in sorted_source_colours]
|
||||||
|
|
||||||
|
sorted_target_histogram_counts = [
|
||||||
|
v
|
||||||
|
for _, v in sorted(
|
||||||
|
enumerate(target_histogram_counts),
|
||||||
|
key=lambda x: rotated_target_histogram_colours[i, x[0]],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
sorted_target_histogram_colours = [
|
||||||
|
v for v in sorted(rotated_target_histogram_colours[i, :])
|
||||||
|
]
|
||||||
|
|
||||||
|
sorted_new_source_colours = pdf_transfer_1d(
|
||||||
|
sorted_source_counts,
|
||||||
|
sorted_target_histogram_counts,
|
||||||
|
sorted_target_histogram_colours,
|
||||||
|
)
|
||||||
|
source_colour_adjustment[i, :] = [
|
||||||
|
v
|
||||||
|
for _, v in sorted(
|
||||||
|
enumerate(sorted_new_source_colours),
|
||||||
|
key=lambda x: sorted_source_indices[x[0]],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
source_colours = source_colours + (
|
||||||
|
rotation.T @ (source_colour_adjustment - rotated_source_colours)
|
||||||
|
)
|
||||||
|
source_colours.clip(0, bins, out=source_colours)
|
||||||
|
|
||||||
|
lut = np.zeros((256, 256, 256, 3), dtype=np.uint8)
|
||||||
|
scale = 256 // bins
|
||||||
|
source_colours = source_colours.T * scale
|
||||||
|
source_colours.clip(0, 256, out=source_colours)
|
||||||
|
source_colours.round(out=source_colours)
|
||||||
|
for old, new in zip(source_colours_original.T, source_colours):
|
||||||
|
for x in range(scale):
|
||||||
|
for y in range(scale):
|
||||||
|
for z in range(scale):
|
||||||
|
lut[old[0] * scale + x, old[1] * scale + y, old[2] * scale + z] = (
|
||||||
|
new
|
||||||
|
)
|
||||||
|
|
||||||
|
result = lut[source_img[:, :, 0], source_img[:, :, 1], source_img[:, :, 2]]
|
||||||
|
return regrain(source_img, result) if should_regrain else result
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
|
|
@ -1,47 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
from utils import generate_rotation_matrices
|
|
||||||
from histogram_transfer import pdf_transfer_1d
|
|
||||||
from histogram_transfer import regrain
|
|
||||||
|
|
||||||
|
|
||||||
EPSILON = 1e-6
|
|
||||||
|
|
||||||
|
|
||||||
def pdf_transfer_3d(
|
|
||||||
source: np.ndarray,
|
|
||||||
target_flattened: np.ndarray,
|
|
||||||
relaxation: float = 1,
|
|
||||||
bin_count: int = 1000,
|
|
||||||
iterations: int = 25,
|
|
||||||
smoothness: float = 1,
|
|
||||||
should_regrain: bool = True,
|
|
||||||
):
|
|
||||||
[h, w, c] = source.shape
|
|
||||||
source_flattened = source.reshape(-1, c).transpose()
|
|
||||||
|
|
||||||
rotation_matrices = generate_rotation_matrices(iterations)
|
|
||||||
for i, rotation in enumerate(rotation_matrices, start=1):
|
|
||||||
D0R = rotation @ source_flattened
|
|
||||||
D1R = rotation @ target_flattened
|
|
||||||
D0R_ = np.zeros_like(source_flattened)
|
|
||||||
|
|
||||||
for i in range(rotation.shape[0]):
|
|
||||||
datamin = min(np.min(D0R[i, :]), np.min(D1R[i, :])) - EPSILON
|
|
||||||
datamax = max(np.max(D0R[i, :]), np.max(D1R[i, :])) + EPSILON
|
|
||||||
u = np.linspace(datamin, datamax, bin_count)
|
|
||||||
|
|
||||||
p0R, _ = np.histogram(D0R[i, :], bins=u, density=True)
|
|
||||||
p1R, _ = np.histogram(D1R[i, :], bins=u, density=True)
|
|
||||||
|
|
||||||
f = pdf_transfer_1d(p0R, p1R)
|
|
||||||
mapped_values = (
|
|
||||||
np.interp(D0R[i, :], u[:-1], f) * (datamax - datamin) / (bin_count - 1)
|
|
||||||
+ datamin
|
|
||||||
)
|
|
||||||
D0R_[i, :] = mapped_values
|
|
||||||
|
|
||||||
source_flattened = source_flattened + relaxation * (rotation.T @ (D0R_ - D0R))
|
|
||||||
source_flattened.clip(0, 255, out=source_flattened)
|
|
||||||
|
|
||||||
result = source_flattened.astype(np.uint8).transpose().reshape(h, w, c)
|
|
||||||
return regrain(source, result, smoothness=smoothness) if should_regrain else result
|
|
||||||
|
|
@ -32699,7 +32699,6 @@
|
||||||
" apply_histogram,\n",
|
" apply_histogram,\n",
|
||||||
" edits.values(),\n",
|
" edits.values(),\n",
|
||||||
" outputs,\n",
|
" outputs,\n",
|
||||||
" [hyperparameters[\"bin_count\"] for _ in range(len(edits))],\n",
|
|
||||||
" )\n",
|
" )\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -32755,7 +32754,6 @@
|
||||||
" apply_histogram,\n",
|
" apply_histogram,\n",
|
||||||
" input_photos,\n",
|
" input_photos,\n",
|
||||||
" output_histograms,\n",
|
" output_histograms,\n",
|
||||||
" [hyperparameters[\"bin_count\"] for _ in range(len(edits))],\n",
|
|
||||||
" )\n",
|
" )\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue