Fix algorithm
This commit is contained in:
parent
f821e162f6
commit
cb585538a3
3 changed files with 54 additions and 55 deletions
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, Any
|
||||
import numpy as np
|
||||
from utils import generate_rotation_matrices, compute_histogram
|
||||
from histogram_transfer.pdf_transfer_1d import pdf_transfer_1d
|
||||
|
|
@ -6,10 +6,11 @@ from histogram_transfer.regrain import regrain
|
|||
|
||||
|
||||
def apply_histogram(
|
||||
source_img: np.ndarray,
|
||||
source_img: np.ndarray | Any,
|
||||
target_histogram: np.ndarray,
|
||||
*,
|
||||
iterations: int = 25,
|
||||
alpha: float = 0.3,
|
||||
source_histogram: Optional[np.ndarray] = None,
|
||||
should_regrain: bool = True,
|
||||
):
|
||||
|
|
@ -38,6 +39,7 @@ def apply_histogram(
|
|||
] # we ignore colours without occurences
|
||||
).T # unique rgb colours, 3xN
|
||||
source_colours_original = source_colours.copy()
|
||||
|
||||
source_counts = np.array(
|
||||
[value for value in np.nditer(source_histogram) if value > 0]
|
||||
) # cardinality of each rgb colour in source_colours, N
|
||||
|
|
@ -49,8 +51,8 @@ def apply_histogram(
|
|||
[value for value in np.nditer(target_histogram) if value > 0]
|
||||
)
|
||||
|
||||
for rotation in generate_rotation_matrices(iterations):
|
||||
source_colour_adjustment = np.zeros_like(source_colours)
|
||||
source_colour_adjustment = np.zeros_like(source_colours)
|
||||
for iteration, rotation in enumerate(generate_rotation_matrices(iterations)):
|
||||
rotated_source_colours = rotation @ source_colours
|
||||
rotated_target_histogram_colours = rotation @ target_histogram_colours
|
||||
|
||||
|
|
@ -58,23 +60,23 @@ def apply_histogram(
|
|||
rotated_target_histogram_colours.round(out=rotated_target_histogram_colours)
|
||||
|
||||
assert rotation.shape[0] == 3
|
||||
for i in range(rotation.shape[0]): # for each axis (rgb)
|
||||
for axis in range(rotation.shape[0]): # for each axis (rgb)
|
||||
sorted_source_colours = sorted(
|
||||
enumerate(source_counts),
|
||||
key=lambda x: rotated_source_colours[i, x[0]],
|
||||
key=lambda x: rotated_source_colours[axis, x[0]],
|
||||
)
|
||||
sorted_source_counts = [v for _, v in sorted_source_colours]
|
||||
sorted_source_indices = [i for i, _ in sorted_source_colours]
|
||||
sorted_source_counts = [v for _, v in sorted_source_colours]
|
||||
|
||||
sorted_target_histogram_counts = [
|
||||
v
|
||||
for _, v in sorted(
|
||||
enumerate(target_histogram_counts),
|
||||
key=lambda x: rotated_target_histogram_colours[i, x[0]],
|
||||
key=lambda x: rotated_target_histogram_colours[axis, x[0]],
|
||||
)
|
||||
]
|
||||
sorted_target_histogram_colours = [
|
||||
v for v in sorted(rotated_target_histogram_colours[i, :])
|
||||
v for v in sorted(rotated_target_histogram_colours[axis, :])
|
||||
]
|
||||
|
||||
sorted_new_source_colours = pdf_transfer_1d(
|
||||
|
|
@ -82,24 +84,22 @@ def apply_histogram(
|
|||
sorted_target_histogram_counts,
|
||||
sorted_target_histogram_colours,
|
||||
)
|
||||
source_colour_adjustment[i, :] = [
|
||||
source_colour_adjustment[axis, :] = [
|
||||
v
|
||||
for _, v in sorted(
|
||||
enumerate(sorted_new_source_colours),
|
||||
key=lambda x: sorted_source_indices[x[0]],
|
||||
)
|
||||
]
|
||||
|
||||
source_colours = source_colours + (
|
||||
source_colours = source_colours + (1 - (iteration / iterations) ** alpha) * (
|
||||
rotation.T @ (source_colour_adjustment - rotated_source_colours)
|
||||
)
|
||||
source_colours.clip(0, bins, out=source_colours)
|
||||
|
||||
lut = np.zeros((256, 256, 256, 3), dtype=np.uint8)
|
||||
scale = 256 // bins
|
||||
source_colours = source_colours.T * scale
|
||||
source_colours.clip(0, 256, out=source_colours)
|
||||
source_colours.round(out=source_colours)
|
||||
source_colours.clip(0, 255, out=source_colours)
|
||||
lut = np.zeros((256, 256, 256, 3), dtype=np.uint8)
|
||||
|
||||
for old, new in zip(source_colours_original.T, source_colours):
|
||||
for x in range(scale):
|
||||
for y in range(scale):
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -53,14 +53,10 @@ def _rotation_matrix(
|
|||
|
||||
def _check_rotation_matrix(R: NDArray[np.float64]):
|
||||
# Check if the matrix is square
|
||||
if R.shape != (3, 3):
|
||||
raise ValueError("Matrix must be 3x3.")
|
||||
assert R.shape == (3, 3), "Matrix must be 3x3"
|
||||
|
||||
# Check orthogonality: R.T * R should be close to the identity matrix
|
||||
I = np.eye(3)
|
||||
if not np.allclose(np.dot(R.T, R), I):
|
||||
raise ValueError("allclose")
|
||||
assert np.allclose(np.dot(R.T, R), I)
|
||||
|
||||
# Check determinant: Should be +1
|
||||
if not np.isclose(np.linalg.det(R), 1.0):
|
||||
raise ValueError(f"det {np.linalg.det(R)}")
|
||||
assert np.isclose(np.linalg.det(R), 1.0), f"det {np.linalg.det(R)}"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue