Minor improvements

This commit is contained in:
Andras Schmelczer 2024-06-06 08:21:00 +01:00
parent ea7931fc57
commit 9aac55f62d
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
6 changed files with 10 additions and 8 deletions

File diff suppressed because one or more lines are too long

View file

@ -1,7 +1,7 @@
from pathlib import Path
DATA = sorted(Path("/mnt/wsl/PHYSICALDRIVE1/data/unsplash").glob("*.jpg"))
TRAIN_SIZE = 0.8
TRAIN_SIZE = 0.9
CACHE_PATH = Path("/mnt/wsl/PHYSICALDRIVE1/data/cache")
CACHE_PATH.mkdir(exist_ok=True, parents=True)

File diff suppressed because one or more lines are too long

View file

@ -14,6 +14,7 @@ def pdf_transfer_3d(
bin_count: int = 1000,
iterations: int = 25,
smoothness: float = 1,
should_regrain: bool = True,
):
[h, w, c] = source.shape
source_flattened = source.reshape(-1, c).transpose()
@ -43,4 +44,4 @@ def pdf_transfer_3d(
source_flattened.clip(0, 255, out=source_flattened)
result = source_flattened.astype(np.uint8).transpose().reshape(h, w, c)
return regrain(source, result, smoothness=smoothness)
return regrain(source, result, smoothness=smoothness) if should_regrain else result

View file

@ -9,6 +9,7 @@ def random_edit(img: Image, seed: int = 42) -> Image:
img = add_noise(img, random(0, 0.2))
img = ImageEnhance.Contrast(img).enhance(random(0.5, 2))
img = add_random_colour_spill(img, 1.3)
img = img.convert("HSV")
saturation_lut = get_colour_lut(variance=0.3, count=5, type="linear")
brightness_lut = get_colour_lut(variance=0.3, count=5, type="cubic")

View file

@ -8,7 +8,7 @@ def display_images(images: Dict[str, Image], images_per_row: int = 3):
fig, axes = plt.subplots(
nrows=ceil(len(images) / images_per_row),
ncols=min(images_per_row, len(images)),
figsize=(12, 8),
figsize=(24, 16),
)
axes = axes.flatten()