Improvements

This commit is contained in:
Andras Schmelczer 2024-04-12 20:11:56 +01:00
parent 44e0c129ec
commit 38b21135e2
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
7 changed files with 12038 additions and 16054 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -14,8 +14,8 @@ def plot_histograms(hists, histogram_per_row: int = 3):
fig.add_trace(_get_3d_scatter_plot_from_histogram(hist), row=1, col=i) fig.add_trace(_get_3d_scatter_plot_from_histogram(hist), row=1, col=i)
fig.update_layout( fig.update_layout(
width=1200, showlegend=False,
height=600, autosize=True,
scene1=dict(xaxis_title="R", yaxis_title="G", zaxis_title="B"), scene1=dict(xaxis_title="R", yaxis_title="G", zaxis_title="B"),
scene2=dict(xaxis_title="R", yaxis_title="G", zaxis_title="B"), scene2=dict(xaxis_title="R", yaxis_title="G", zaxis_title="B"),
) )

View file

@ -1,52 +1,63 @@
from torch.utils.data import Dataset from torch.utils.data import Dataset
from typing import Generator, Tuple, List from typing import List
from editor.utils import compute_histogram from editor.utils import compute_histogram
from .random_edit import random_edit
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from pathlib import Path from pathlib import Path
import PIL.Image
PIL.Image.MAX_IMAGE_PIXELS = None
class HistogramDataset(Dataset): class HistogramDataset(Dataset):
def __init__( def __init__(
self, paths: List[Path], expected_edit_count: int = 5, bin_count: int = 32 self,
paths: List[Path],
edit_count: int = 5,
bin_count: int = 32,
target_size=(480, 480),
delete_corrupt_images: bool = False,
): ):
self._paths = paths self._paths = sorted(paths)
self._expected_edit_count = expected_edit_count self._edit_count = edit_count
self._bin_count = bin_count self._bin_count = bin_count
self._pairs = list(self._get_pairs()) self._target_size = target_size
def _get_pairs(self) -> Generator[Tuple[Path, Path], None, None]: if delete_corrupt_images:
self._delete_corrupt_images()
def _delete_corrupt_images(self) -> None:
deleted_count = 0
for path in tqdm(self._paths): for path in tqdm(self._paths):
if len(list(path.glob("*.jpg"))) != self._expected_edit_count + 1:
continue
original_path = path / "original.jpg"
try: try:
Image.open(original_path) Image.open(path)
except: except:
print(f"Failed to open {original_path}") print(f"Failed to open {path}, deleting...")
continue deleted_count += 1
yield original_path, original_path # The model should leave the original image unchanged path.unlink()
for i in range(self._expected_edit_count): print(f"Deleted {deleted_count} corrupt images")
try:
Image.open(path / f"{i}.jpg")
except:
print(f'Failed to open {path / f"{i}.jpg"}')
break
yield original_path, path / f"{i}.jpg"
def __len__(self): def __len__(self):
return len(self._pairs) return len(self._paths) * self._edit_count
def __getitem__(self, idx): def __getitem__(self, idx):
original, edited = self._pairs[idx] original_idx = idx // self._edit_count
original_path = self._paths[original_idx]
original = Image.open(original_path)
original.thumbnail(self._target_size, Image.Resampling.LANCZOS)
edited = random_edit(original, seed=idx)
original_histogram = compute_histogram( original_histogram = compute_histogram(
original, bins=self._bin_count, normalize=True original, bins=self._bin_count, normalize=True
) )
edited_histogram = compute_histogram( edited_histogram = compute_histogram(
edited, bins=self._bin_count, normalize=True edited, bins=self._bin_count, normalize=True
) )
return ( return (
torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0), torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0),
torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0), torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0),

View file

@ -0,0 +1,19 @@
from PIL import Image, ImageEnhance
from ..utils import random, get_colour_lut, apply_pixel_shader
from ..operations import add_noise, add_random_colour_spill
import numpy as np
def random_edit(img: Image, seed: int = 42) -> Image:
np.random.seed(seed)
img = add_noise(img, random(0, 0.2))
img = ImageEnhance.Contrast(img).enhance(random(0.5, 2))
img = add_random_colour_spill(img, 1.3)
img = img.convert("HSV")
saturation_lut = get_colour_lut(variance=0.3, count=5, type="linear")
brightness_lut = get_colour_lut(variance=0.3, count=5, type="cubic")
img = apply_pixel_shader(
img, lambda h, s, v: (h, saturation_lut[s], brightness_lut[v])
)
img = img.convert("RGB")
return img

View file

@ -3,14 +3,15 @@ import numpy as np
def compute_histogram( def compute_histogram(
image_path, bins: int, value_range=(0, 256), normalize: bool = True image: Image, bins: int, value_range=(0, 256), normalize: bool = True
): ):
image = Image.open(image_path)
image = np.array(image) image = np.array(image)
histogram, _ = np.histogramdd( histogram, _ = np.histogramdd(
image.reshape(-1, 3), bins=bins, range=[value_range, value_range, value_range] image.reshape(-1, 3), bins=bins, range=[value_range, value_range, value_range]
).astype(np.float64) )
histogram = histogram.astype(np.float32)
if normalize: if normalize:
histogram = histogram / np.sum(histogram) histogram = histogram / np.sum(histogram)

File diff suppressed because one or more lines are too long