Improvements

This commit is contained in:
Andras Schmelczer 2024-04-12 20:11:56 +01:00
parent 44e0c129ec
commit 38b21135e2
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
7 changed files with 12038 additions and 16054 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -14,8 +14,8 @@ def plot_histograms(hists, histogram_per_row: int = 3):
fig.add_trace(_get_3d_scatter_plot_from_histogram(hist), row=1, col=i)
fig.update_layout(
width=1200,
height=600,
showlegend=False,
autosize=True,
scene1=dict(xaxis_title="R", yaxis_title="G", zaxis_title="B"),
scene2=dict(xaxis_title="R", yaxis_title="G", zaxis_title="B"),
)

View file

@ -1,52 +1,63 @@
from torch.utils.data import Dataset
from typing import Generator, Tuple, List
from typing import List
from editor.utils import compute_histogram
from .random_edit import random_edit
from PIL import Image
from tqdm import tqdm
import torch
from pathlib import Path
import PIL.Image
PIL.Image.MAX_IMAGE_PIXELS = None
class HistogramDataset(Dataset):
def __init__(
self, paths: List[Path], expected_edit_count: int = 5, bin_count: int = 32
self,
paths: List[Path],
edit_count: int = 5,
bin_count: int = 32,
target_size=(480, 480),
delete_corrupt_images: bool = False,
):
self._paths = paths
self._expected_edit_count = expected_edit_count
self._paths = sorted(paths)
self._edit_count = edit_count
self._bin_count = bin_count
self._pairs = list(self._get_pairs())
self._target_size = target_size
def _get_pairs(self) -> Generator[Tuple[Path, Path], None, None]:
if delete_corrupt_images:
self._delete_corrupt_images()
def _delete_corrupt_images(self) -> None:
deleted_count = 0
for path in tqdm(self._paths):
if len(list(path.glob("*.jpg"))) != self._expected_edit_count + 1:
continue
original_path = path / "original.jpg"
try:
Image.open(original_path)
Image.open(path)
except:
print(f"Failed to open {original_path}")
continue
yield original_path, original_path # The model should leave the original image unchanged
for i in range(self._expected_edit_count):
try:
Image.open(path / f"{i}.jpg")
except:
print(f'Failed to open {path / f"{i}.jpg"}')
break
yield original_path, path / f"{i}.jpg"
print(f"Failed to open {path}, deleting...")
deleted_count += 1
path.unlink()
print(f"Deleted {deleted_count} corrupt images")
def __len__(self):
return len(self._pairs)
return len(self._paths) * self._edit_count
def __getitem__(self, idx):
original, edited = self._pairs[idx]
original_idx = idx // self._edit_count
original_path = self._paths[original_idx]
original = Image.open(original_path)
original.thumbnail(self._target_size, Image.Resampling.LANCZOS)
edited = random_edit(original, seed=idx)
original_histogram = compute_histogram(
original, bins=self._bin_count, normalize=True
)
edited_histogram = compute_histogram(
edited, bins=self._bin_count, normalize=True
)
return (
torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0),
torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0),

View file

@ -0,0 +1,19 @@
from PIL import Image, ImageEnhance
from ..utils import random, get_colour_lut, apply_pixel_shader
from ..operations import add_noise, add_random_colour_spill
import numpy as np
def random_edit(img: Image, seed: int = 42) -> Image:
np.random.seed(seed)
img = add_noise(img, random(0, 0.2))
img = ImageEnhance.Contrast(img).enhance(random(0.5, 2))
img = add_random_colour_spill(img, 1.3)
img = img.convert("HSV")
saturation_lut = get_colour_lut(variance=0.3, count=5, type="linear")
brightness_lut = get_colour_lut(variance=0.3, count=5, type="cubic")
img = apply_pixel_shader(
img, lambda h, s, v: (h, saturation_lut[s], brightness_lut[v])
)
img = img.convert("RGB")
return img

View file

@ -3,14 +3,15 @@ import numpy as np
def compute_histogram(
image_path, bins: int, value_range=(0, 256), normalize: bool = True
image: Image, bins: int, value_range=(0, 256), normalize: bool = True
):
image = Image.open(image_path)
image = np.array(image)
histogram, _ = np.histogramdd(
image.reshape(-1, 3), bins=bins, range=[value_range, value_range, value_range]
).astype(np.float64)
)
histogram = histogram.astype(np.float32)
if normalize:
histogram = histogram / np.sum(histogram)

File diff suppressed because one or more lines are too long