Remove editor module
This commit is contained in:
parent
e5959268c1
commit
c966866abc
37 changed files with 7752 additions and 7345 deletions
3
src/training/__init__.py
Normal file
3
src/training/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .histogram_dataset import HistogramDataset
|
||||
from .random_edit import random_edit
|
||||
from .progressive_pooling_loss import ProgressivePoolingLoss
|
||||
102
src/training/histogram_dataset.py
Normal file
102
src/training/histogram_dataset.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
from torch.utils.data import Dataset
|
||||
from typing import List, Optional, Tuple
|
||||
from utils import compute_histogram
|
||||
from .random_edit import random_edit
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
import PIL.Image
|
||||
|
||||
PIL.Image.MAX_IMAGE_PIXELS = None
|
||||
|
||||
|
||||
class HistogramDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
/,
|
||||
paths: List[Path],
|
||||
bin_count: int,
|
||||
edit_count: int = 5,
|
||||
target_size=(240, 240),
|
||||
delete_corrupt_images: bool = False,
|
||||
cache_path: Optional[Path] = None,
|
||||
):
|
||||
self._paths = sorted(paths)
|
||||
logging.info(f"Loaded {len(self._paths)} original images")
|
||||
|
||||
self._edit_count = edit_count
|
||||
self._bin_count = bin_count
|
||||
self._target_size = target_size
|
||||
self._cache_path = cache_path
|
||||
if self._cache_path:
|
||||
self._cache_path = (
|
||||
self._cache_path
|
||||
/ f"{self._bin_count}bins_{self._target_size[0]}x{self._target_size[1]}px"
|
||||
)
|
||||
|
||||
if delete_corrupt_images:
|
||||
self._delete_corrupt_images()
|
||||
|
||||
def _delete_corrupt_images(self) -> None:
|
||||
deleted_count = 0
|
||||
for path in tqdm(self._paths):
|
||||
try:
|
||||
Image.open(path)
|
||||
except:
|
||||
logging.warning(f"Failed to open {path}, deleting...")
|
||||
deleted_count += 1
|
||||
path.unlink()
|
||||
logging.info(f"Deleted {deleted_count} corrupt images")
|
||||
|
||||
def __len__(self):
|
||||
return len(self._paths) * self._edit_count
|
||||
|
||||
def get_original_image(self, original_idx: int) -> Image.Image:
|
||||
original_path = self._paths[original_idx]
|
||||
original = Image.open(original_path)
|
||||
original.thumbnail(
|
||||
self._target_size, Image.Resampling.LANCZOS
|
||||
) # size will be at most target_size, the aspect ratio is preserved
|
||||
return original
|
||||
|
||||
def get_edited_image(self, original_idx: int, edit_idx: int) -> Image.Image:
|
||||
original_image = self.get_original_image(original_idx)
|
||||
return random_edit(original_image, seed=original_idx * 7919 + edit_idx)
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
original_idx = idx // self._edit_count
|
||||
edit_idx = idx % self._edit_count
|
||||
|
||||
if self._cache_path is not None:
|
||||
_cached_data_path = self._cache_path / str(original_idx) / f"{edit_idx}.pt"
|
||||
_cached_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if _cached_data_path.exists():
|
||||
try:
|
||||
return torch.load(_cached_data_path)
|
||||
except:
|
||||
logging.warning(
|
||||
f"Failed to load {_cached_data_path}, regenerating..."
|
||||
)
|
||||
|
||||
edited = self.get_edited_image(original_idx, edit_idx)
|
||||
edited_histogram = compute_histogram(
|
||||
edited, bins=self._bin_count, normalize=True
|
||||
)
|
||||
|
||||
original = self.get_original_image(original_idx)
|
||||
original_histogram = compute_histogram(
|
||||
original, bins=self._bin_count, normalize=True
|
||||
)
|
||||
|
||||
result = (
|
||||
torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0),
|
||||
torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0),
|
||||
)
|
||||
|
||||
if self._cache_path is not None:
|
||||
torch.save(result, _cached_data_path)
|
||||
|
||||
return result
|
||||
38
src/training/progressive_pooling_loss.py
Normal file
38
src/training/progressive_pooling_loss.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
from typing import List
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ProgressivePoolingLoss(nn.Module):
|
||||
def __init__(self, target_sizes: List[int], damping: float):
|
||||
super(ProgressivePoolingLoss, self).__init__()
|
||||
self._target_sizes = target_sizes
|
||||
self._damping = damping
|
||||
|
||||
def forward(self, tensor_a, tensor_b):
|
||||
assert (
|
||||
tensor_a.size() == tensor_b.size()
|
||||
), f"Input tensors must have the same size, got {tensor_a.size()} and {tensor_b.size()}"
|
||||
|
||||
assert (
|
||||
len(tensor_a.size()) == 5
|
||||
), f"Input tensors must have 5 dimensions, got {tensor_a.size()}"
|
||||
|
||||
_minibatch_size, _channels, depth, height, width = tensor_a.size()
|
||||
assert depth == height == width, "Input tensors must be cubes."
|
||||
|
||||
loss = 0.0
|
||||
weight = 1
|
||||
|
||||
for target_size in self._target_sizes:
|
||||
pool_size = depth // target_size
|
||||
pooled_a = F.avg_pool3d(tensor_a, pool_size) * (pool_size**3)
|
||||
pooled_b = F.avg_pool3d(tensor_b, pool_size) * (pool_size**3)
|
||||
|
||||
diff = torch.abs(pooled_a - pooled_b)
|
||||
|
||||
loss += diff.mean() * weight
|
||||
weight *= self._damping
|
||||
|
||||
return loss
|
||||
33
src/training/random_edit.py
Normal file
33
src/training/random_edit.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
from PIL import Image, ImageEnhance
|
||||
from utils import random
|
||||
from operations import (
|
||||
add_noise,
|
||||
add_random_colour_spill,
|
||||
get_random_gamma,
|
||||
adjust_gamma,
|
||||
apply_pixel_shader,
|
||||
get_random_brightness_lut,
|
||||
get_random_saturation_per_hue_lut,
|
||||
)
|
||||
import numpy as np
|
||||
|
||||
|
||||
def random_edit(img: Image, seed: int = 42) -> Image:
|
||||
np.random.seed(seed)
|
||||
img = img.convert("RGB")
|
||||
|
||||
img = ImageEnhance.Contrast(img).enhance(random(0.5, 1.5))
|
||||
img = adjust_gamma(img, get_random_gamma())
|
||||
|
||||
img = img.convert("HSV")
|
||||
saturation_lut = get_random_saturation_per_hue_lut()
|
||||
brightness_lut = get_random_brightness_lut()
|
||||
img = apply_pixel_shader(
|
||||
img, lambda h, s, v: (h, round(s * saturation_lut[h]), brightness_lut[v])
|
||||
)
|
||||
img = img.convert("RGB")
|
||||
|
||||
img = add_random_colour_spill(img, 0.2)
|
||||
img = add_noise(img, random(0, 0.1))
|
||||
|
||||
return img
|
||||
Loading…
Add table
Add a link
Reference in a new issue