Fix algorithm

This commit is contained in:
Andras Schmelczer 2024-08-26 21:38:51 +01:00
parent f821e162f6
commit cb585538a3
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
3 changed files with 54 additions and 55 deletions

View file

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Any
import numpy as np
from utils import generate_rotation_matrices, compute_histogram
from histogram_transfer.pdf_transfer_1d import pdf_transfer_1d
@ -6,10 +6,11 @@ from histogram_transfer.regrain import regrain
def apply_histogram(
source_img: np.ndarray,
source_img: np.ndarray | Any,
target_histogram: np.ndarray,
*,
iterations: int = 25,
alpha: float = 0.3,
source_histogram: Optional[np.ndarray] = None,
should_regrain: bool = True,
):
@ -38,6 +39,7 @@ def apply_histogram(
] # we ignore colours without occurences
).T # unique rgb colours, 3xN
source_colours_original = source_colours.copy()
source_counts = np.array(
[value for value in np.nditer(source_histogram) if value > 0]
) # cardinality of each rgb colour in source_colours, N
@ -49,8 +51,8 @@ def apply_histogram(
[value for value in np.nditer(target_histogram) if value > 0]
)
for rotation in generate_rotation_matrices(iterations):
source_colour_adjustment = np.zeros_like(source_colours)
source_colour_adjustment = np.zeros_like(source_colours)
for iteration, rotation in enumerate(generate_rotation_matrices(iterations)):
rotated_source_colours = rotation @ source_colours
rotated_target_histogram_colours = rotation @ target_histogram_colours
@ -58,23 +60,23 @@ def apply_histogram(
rotated_target_histogram_colours.round(out=rotated_target_histogram_colours)
assert rotation.shape[0] == 3
for i in range(rotation.shape[0]): # for each axis (rgb)
for axis in range(rotation.shape[0]): # for each axis (rgb)
sorted_source_colours = sorted(
enumerate(source_counts),
key=lambda x: rotated_source_colours[i, x[0]],
key=lambda x: rotated_source_colours[axis, x[0]],
)
sorted_source_counts = [v for _, v in sorted_source_colours]
sorted_source_indices = [i for i, _ in sorted_source_colours]
sorted_source_counts = [v for _, v in sorted_source_colours]
sorted_target_histogram_counts = [
v
for _, v in sorted(
enumerate(target_histogram_counts),
key=lambda x: rotated_target_histogram_colours[i, x[0]],
key=lambda x: rotated_target_histogram_colours[axis, x[0]],
)
]
sorted_target_histogram_colours = [
v for v in sorted(rotated_target_histogram_colours[i, :])
v for v in sorted(rotated_target_histogram_colours[axis, :])
]
sorted_new_source_colours = pdf_transfer_1d(
@ -82,24 +84,22 @@ def apply_histogram(
sorted_target_histogram_counts,
sorted_target_histogram_colours,
)
source_colour_adjustment[i, :] = [
source_colour_adjustment[axis, :] = [
v
for _, v in sorted(
enumerate(sorted_new_source_colours),
key=lambda x: sorted_source_indices[x[0]],
)
]
source_colours = source_colours + (
source_colours = source_colours + (1 - (iteration / iterations) ** alpha) * (
rotation.T @ (source_colour_adjustment - rotated_source_colours)
)
source_colours.clip(0, bins, out=source_colours)
lut = np.zeros((256, 256, 256, 3), dtype=np.uint8)
scale = 256 // bins
source_colours = source_colours.T * scale
source_colours.clip(0, 256, out=source_colours)
source_colours.round(out=source_colours)
source_colours.clip(0, 255, out=source_colours)
lut = np.zeros((256, 256, 256, 3), dtype=np.uint8)
for old, new in zip(source_colours_original.T, source_colours):
for x in range(scale):
for y in range(scale):

File diff suppressed because one or more lines are too long

View file

@ -53,14 +53,10 @@ def _rotation_matrix(
def _check_rotation_matrix(R: NDArray[np.float64]):
# Check if the matrix is square
if R.shape != (3, 3):
raise ValueError("Matrix must be 3x3.")
assert R.shape == (3, 3), "Matrix must be 3x3"
# Check orthogonality: R.T * R should be close to the identity matrix
I = np.eye(3)
if not np.allclose(np.dot(R.T, R), I):
raise ValueError("allclose")
assert np.allclose(np.dot(R.T, R), I)
# Check determinant: Should be +1
if not np.isclose(np.linalg.det(R), 1.0):
raise ValueError(f"det {np.linalg.det(R)}")
assert np.isclose(np.linalg.det(R), 1.0), f"det {np.linalg.det(R)}"