Improvements
This commit is contained in:
parent
44e0c129ec
commit
38b21135e2
7 changed files with 12038 additions and 16054 deletions
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -14,8 +14,8 @@ def plot_histograms(hists, histogram_per_row: int = 3):
|
||||||
fig.add_trace(_get_3d_scatter_plot_from_histogram(hist), row=1, col=i)
|
fig.add_trace(_get_3d_scatter_plot_from_histogram(hist), row=1, col=i)
|
||||||
|
|
||||||
fig.update_layout(
|
fig.update_layout(
|
||||||
width=1200,
|
showlegend=False,
|
||||||
height=600,
|
autosize=True,
|
||||||
scene1=dict(xaxis_title="R", yaxis_title="G", zaxis_title="B"),
|
scene1=dict(xaxis_title="R", yaxis_title="G", zaxis_title="B"),
|
||||||
scene2=dict(xaxis_title="R", yaxis_title="G", zaxis_title="B"),
|
scene2=dict(xaxis_title="R", yaxis_title="G", zaxis_title="B"),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,52 +1,63 @@
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from typing import Generator, Tuple, List
|
from typing import List
|
||||||
from editor.utils import compute_histogram
|
from editor.utils import compute_histogram
|
||||||
|
from .random_edit import random_edit
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
|
|
||||||
|
PIL.Image.MAX_IMAGE_PIXELS = None
|
||||||
|
|
||||||
|
|
||||||
class HistogramDataset(Dataset):
|
class HistogramDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, paths: List[Path], expected_edit_count: int = 5, bin_count: int = 32
|
self,
|
||||||
|
paths: List[Path],
|
||||||
|
edit_count: int = 5,
|
||||||
|
bin_count: int = 32,
|
||||||
|
target_size=(480, 480),
|
||||||
|
delete_corrupt_images: bool = False,
|
||||||
):
|
):
|
||||||
self._paths = paths
|
self._paths = sorted(paths)
|
||||||
self._expected_edit_count = expected_edit_count
|
self._edit_count = edit_count
|
||||||
self._bin_count = bin_count
|
self._bin_count = bin_count
|
||||||
self._pairs = list(self._get_pairs())
|
self._target_size = target_size
|
||||||
|
|
||||||
def _get_pairs(self) -> Generator[Tuple[Path, Path], None, None]:
|
if delete_corrupt_images:
|
||||||
|
self._delete_corrupt_images()
|
||||||
|
|
||||||
|
def _delete_corrupt_images(self) -> None:
|
||||||
|
deleted_count = 0
|
||||||
for path in tqdm(self._paths):
|
for path in tqdm(self._paths):
|
||||||
if len(list(path.glob("*.jpg"))) != self._expected_edit_count + 1:
|
|
||||||
continue
|
|
||||||
|
|
||||||
original_path = path / "original.jpg"
|
|
||||||
try:
|
try:
|
||||||
Image.open(original_path)
|
Image.open(path)
|
||||||
except:
|
except:
|
||||||
print(f"Failed to open {original_path}")
|
print(f"Failed to open {path}, deleting...")
|
||||||
continue
|
deleted_count += 1
|
||||||
yield original_path, original_path # The model should leave the original image unchanged
|
path.unlink()
|
||||||
for i in range(self._expected_edit_count):
|
print(f"Deleted {deleted_count} corrupt images")
|
||||||
try:
|
|
||||||
Image.open(path / f"{i}.jpg")
|
|
||||||
except:
|
|
||||||
print(f'Failed to open {path / f"{i}.jpg"}')
|
|
||||||
break
|
|
||||||
yield original_path, path / f"{i}.jpg"
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._pairs)
|
return len(self._paths) * self._edit_count
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
original, edited = self._pairs[idx]
|
original_idx = idx // self._edit_count
|
||||||
|
original_path = self._paths[original_idx]
|
||||||
|
original = Image.open(original_path)
|
||||||
|
original.thumbnail(self._target_size, Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
|
edited = random_edit(original, seed=idx)
|
||||||
|
|
||||||
original_histogram = compute_histogram(
|
original_histogram = compute_histogram(
|
||||||
original, bins=self._bin_count, normalize=True
|
original, bins=self._bin_count, normalize=True
|
||||||
)
|
)
|
||||||
edited_histogram = compute_histogram(
|
edited_histogram = compute_histogram(
|
||||||
edited, bins=self._bin_count, normalize=True
|
edited, bins=self._bin_count, normalize=True
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0),
|
torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0),
|
||||||
torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0),
|
torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0),
|
||||||
|
|
|
||||||
19
editor/training/random_edit.py
Normal file
19
editor/training/random_edit.py
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
from PIL import Image, ImageEnhance
|
||||||
|
from ..utils import random, get_colour_lut, apply_pixel_shader
|
||||||
|
from ..operations import add_noise, add_random_colour_spill
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def random_edit(img: Image, seed: int = 42) -> Image:
|
||||||
|
np.random.seed(seed)
|
||||||
|
img = add_noise(img, random(0, 0.2))
|
||||||
|
img = ImageEnhance.Contrast(img).enhance(random(0.5, 2))
|
||||||
|
img = add_random_colour_spill(img, 1.3)
|
||||||
|
img = img.convert("HSV")
|
||||||
|
saturation_lut = get_colour_lut(variance=0.3, count=5, type="linear")
|
||||||
|
brightness_lut = get_colour_lut(variance=0.3, count=5, type="cubic")
|
||||||
|
img = apply_pixel_shader(
|
||||||
|
img, lambda h, s, v: (h, saturation_lut[s], brightness_lut[v])
|
||||||
|
)
|
||||||
|
img = img.convert("RGB")
|
||||||
|
return img
|
||||||
|
|
@ -3,14 +3,15 @@ import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def compute_histogram(
|
def compute_histogram(
|
||||||
image_path, bins: int, value_range=(0, 256), normalize: bool = True
|
image: Image, bins: int, value_range=(0, 256), normalize: bool = True
|
||||||
):
|
):
|
||||||
image = Image.open(image_path)
|
|
||||||
image = np.array(image)
|
image = np.array(image)
|
||||||
|
|
||||||
histogram, _ = np.histogramdd(
|
histogram, _ = np.histogramdd(
|
||||||
image.reshape(-1, 3), bins=bins, range=[value_range, value_range, value_range]
|
image.reshape(-1, 3), bins=bins, range=[value_range, value_range, value_range]
|
||||||
).astype(np.float64)
|
)
|
||||||
|
|
||||||
|
histogram = histogram.astype(np.float32)
|
||||||
|
|
||||||
if normalize:
|
if normalize:
|
||||||
histogram = histogram / np.sum(histogram)
|
histogram = histogram / np.sum(histogram)
|
||||||
|
|
|
||||||
27567
show_histograms.ipynb
27567
show_histograms.ipynb
File diff suppressed because one or more lines are too long
Loading…
Add table
Add a link
Reference in a new issue