Optimise pdf transfer

This commit is contained in:
Andras Schmelczer 2024-08-25 22:13:41 +01:00
parent 28425d53af
commit 131cc6132b
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
5 changed files with 198 additions and 18366 deletions

View file

@ -1,3 +1,2 @@
from .regrain import regrain from .regrain import regrain
from .pdf_transfer_3d import pdf_transfer_3d
from .apply_histogram import apply_histogram from .apply_histogram import apply_histogram

View file

@ -1,40 +1,112 @@
from histogram_transfer import pdf_transfer_3d from typing import Optional
import numpy as np import numpy as np
from scipy.ndimage import zoom from utils import generate_rotation_matrices, compute_histogram
from histogram_transfer.pdf_transfer_1d import pdf_transfer_1d
from histogram_transfer.regrain import regrain
def apply_histogram(original_image, target_histogram, bin_count: int): def apply_histogram(
actual_predicted_histogram = target_histogram.cpu().detach().numpy().squeeze() source_img: np.ndarray,
target_histogram: np.ndarray,
*,
iterations: int = 25,
source_histogram: Optional[np.ndarray] = None,
should_regrain: bool = True,
):
if not isinstance(source_img, np.ndarray):
source_img = np.array(source_img)
scale = 64 / bin_count assert (
scaled_predicted_histogram = zoom(actual_predicted_histogram, scale, order=3) target_histogram.shape[0]
scaled_predicted_histogram = ( == target_histogram.shape[1]
scaled_predicted_histogram / scaled_predicted_histogram.sum() == target_histogram.shape[2]
), "Histograms must be 3D"
bins = target_histogram.shape[0]
assert 256 % bins == 0, "Bin size must be a factor of 256"
if source_histogram is None:
source_histogram = compute_histogram(source_img, bins=bins)
else:
assert (
source_histogram.shape == target_histogram.shape
), "Source and target histograms must be the same shape"
source_colours = np.array(
[
index for index, value in np.ndenumerate(source_histogram) if value > 0
] # we ignore colours without occurences
).T # unique rgb colours, 3xN
source_colours_original = source_colours.copy()
source_counts = np.array(
[value for value in np.nditer(source_histogram) if value > 0]
) # cardinality of each rgb colour in source_colours, N
target_histogram_colours = np.array(
[index for index, value in np.ndenumerate(target_histogram) if value > 0]
).T
target_histogram_counts = np.array(
[value for value in np.nditer(target_histogram) if value > 0]
) )
[h, w, _] = np.array(original_image).shape for rotation in generate_rotation_matrices(iterations):
source_colour_adjustment = np.zeros_like(source_colours)
rotated_source_colours = rotation @ source_colours
rotated_target_histogram_colours = rotation @ target_histogram_colours
histogram = np.round(scaled_predicted_histogram * h * w).astype(int) rotated_source_colours.round(out=rotated_source_colours)
rotated_target_histogram_colours.round(out=rotated_target_histogram_colours)
rgb_vectors = [] assert rotation.shape[0] == 3
for i in range(rotation.shape[0]): # for each axis (rgb)
for r in range(histogram.shape[0]): sorted_source_colours = sorted(
for g in range(histogram.shape[1]): enumerate(source_counts),
for b in range(histogram.shape[2]): key=lambda x: rotated_source_colours[i, x[0]],
# Append the RGB value 'count' times to the list
for _ in range(histogram[r, g, b]):
rgb_vectors.append([r, g, b])
rgb_vectors = np.array(rgb_vectors)
np.random.shuffle(rgb_vectors)
rgb_vectors = rgb_vectors * 256 / 64
return pdf_transfer_3d(
source=np.array(original_image),
target_flattened=rgb_vectors.transpose(),
relaxation=0.9,
bin_count=3500,
iterations=50,
smoothness=1,
should_regrain=True,
) )
sorted_source_counts = [v for _, v in sorted_source_colours]
sorted_source_indices = [i for i, _ in sorted_source_colours]
sorted_target_histogram_counts = [
v
for _, v in sorted(
enumerate(target_histogram_counts),
key=lambda x: rotated_target_histogram_colours[i, x[0]],
)
]
sorted_target_histogram_colours = [
v for v in sorted(rotated_target_histogram_colours[i, :])
]
sorted_new_source_colours = pdf_transfer_1d(
sorted_source_counts,
sorted_target_histogram_counts,
sorted_target_histogram_colours,
)
source_colour_adjustment[i, :] = [
v
for _, v in sorted(
enumerate(sorted_new_source_colours),
key=lambda x: sorted_source_indices[x[0]],
)
]
source_colours = source_colours + (
rotation.T @ (source_colour_adjustment - rotated_source_colours)
)
source_colours.clip(0, bins, out=source_colours)
lut = np.zeros((256, 256, 256, 3), dtype=np.uint8)
scale = 256 // bins
source_colours = source_colours.T * scale
source_colours.clip(0, 256, out=source_colours)
source_colours.round(out=source_colours)
for old, new in zip(source_colours_original.T, source_colours):
for x in range(scale):
for y in range(scale):
for z in range(scale):
lut[old[0] * scale + x, old[1] * scale + y, old[2] * scale + z] = (
new
)
result = lut[source_img[:, :, 0], source_img[:, :, 1], source_img[:, :, 2]]
return regrain(source_img, result) if should_regrain else result

File diff suppressed because one or more lines are too long

View file

@ -1,47 +0,0 @@
import numpy as np
from utils import generate_rotation_matrices
from histogram_transfer import pdf_transfer_1d
from histogram_transfer import regrain
EPSILON = 1e-6
def pdf_transfer_3d(
source: np.ndarray,
target_flattened: np.ndarray,
relaxation: float = 1,
bin_count: int = 1000,
iterations: int = 25,
smoothness: float = 1,
should_regrain: bool = True,
):
[h, w, c] = source.shape
source_flattened = source.reshape(-1, c).transpose()
rotation_matrices = generate_rotation_matrices(iterations)
for i, rotation in enumerate(rotation_matrices, start=1):
D0R = rotation @ source_flattened
D1R = rotation @ target_flattened
D0R_ = np.zeros_like(source_flattened)
for i in range(rotation.shape[0]):
datamin = min(np.min(D0R[i, :]), np.min(D1R[i, :])) - EPSILON
datamax = max(np.max(D0R[i, :]), np.max(D1R[i, :])) + EPSILON
u = np.linspace(datamin, datamax, bin_count)
p0R, _ = np.histogram(D0R[i, :], bins=u, density=True)
p1R, _ = np.histogram(D1R[i, :], bins=u, density=True)
f = pdf_transfer_1d(p0R, p1R)
mapped_values = (
np.interp(D0R[i, :], u[:-1], f) * (datamax - datamin) / (bin_count - 1)
+ datamin
)
D0R_[i, :] = mapped_values
source_flattened = source_flattened + relaxation * (rotation.T @ (D0R_ - D0R))
source_flattened.clip(0, 255, out=source_flattened)
result = source_flattened.astype(np.uint8).transpose().reshape(h, w, c)
return regrain(source, result, smoothness=smoothness) if should_regrain else result

View file

@ -32699,7 +32699,6 @@
" apply_histogram,\n", " apply_histogram,\n",
" edits.values(),\n", " edits.values(),\n",
" outputs,\n", " outputs,\n",
" [hyperparameters[\"bin_count\"] for _ in range(len(edits))],\n",
" )\n", " )\n",
" )\n", " )\n",
"\n", "\n",
@ -32755,7 +32754,6 @@
" apply_histogram,\n", " apply_histogram,\n",
" input_photos,\n", " input_photos,\n",
" output_histograms,\n", " output_histograms,\n",
" [hyperparameters[\"bin_count\"] for _ in range(len(edits))],\n",
" )\n", " )\n",
" )\n", " )\n",
"\n", "\n",