move files
This commit is contained in:
parent
1a41fd6829
commit
231e22cac8
36 changed files with 15580 additions and 79653 deletions
File diff suppressed because one or more lines are too long
25748
inference.ipynb
25748
inference.ipynb
File diff suppressed because one or more lines are too long
66
src/colour_lut.ipynb
Normal file
66
src/colour_lut.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
|
@ -3,7 +3,9 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"metadata": {}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
|
|
@ -19,20 +21,15 @@
|
|||
"source": [
|
||||
"from editor.training import random_edit\n",
|
||||
"from editor.visualisation import display_images\n",
|
||||
"from editor.training import HistogramDataset\n",
|
||||
"from config import DATA\n",
|
||||
"from editor.training import HistogramDataset\n",
|
||||
"from config import DATA\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"dataset = HistogramDataset(DATA)\n",
|
||||
"img = dataset.get_original_image(0)\n",
|
||||
"\n",
|
||||
"edits = {\n",
|
||||
" 'Original': img\n",
|
||||
"}\n",
|
||||
"edits.update({\n",
|
||||
" f'Edit {i}': random_edit(img.copy(), seed=i)\n",
|
||||
" for i in range(1, 9)\n",
|
||||
"})\n",
|
||||
"edits = {\"Original\": img}\n",
|
||||
"edits.update({f\"Edit {i}\": random_edit(img.copy(), seed=i) for i in range(1, 9)})\n",
|
||||
"\n",
|
||||
"display_images(edits)"
|
||||
]
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -18,20 +18,16 @@
|
|||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from pathlib import Path\n",
|
||||
"from config import DATA, MODELS_PATH, CACHE_PATH\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"DATA = Path('/mnt/wsl/PHYSICALDRIVE1/data/unsplash').glob('*.jpg')\n",
|
||||
"CACHE_PATH = Path('/mnt/wsl/PHYSICALDRIVE1/data/cache2')\n",
|
||||
"CACHE_PATH.mkdir(exist_ok=True, parents=True)\n",
|
||||
"BINS = 32\n",
|
||||
"NUM_EPOCHS = 20\n",
|
||||
"BATCH_SIZE = 64\n",
|
||||
"LEARNING_RATE = 0.005\n",
|
||||
"SCHEDULER_GAMMA = 0.7\n",
|
||||
"EDIT_COUNT = 25\n",
|
||||
"LOSS_DAMPING = 2\n",
|
||||
"MODELS_PATH = Path('models')\n",
|
||||
"MODELS_PATH.mkdir(exist_ok=True, parents=True)\n",
|
||||
"BINS = 32\n",
|
||||
"\n",
|
||||
"device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
|
||||
"f'Using device {device}'"
|
||||
|
|
@ -146,43 +142,6 @@
|
|||
"edited, og = next(iter(train_dataloader))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def plot_histograms(original_histogram, edited_histogram, predicted_histogram):\n",
|
||||
" fig = plt.figure(figsize=(15, 5))\n",
|
||||
" tensors = [original_histogram.numpy().squeeze(), edited_histogram.numpy().squeeze(), predicted_histogram.numpy().squeeze()]\n",
|
||||
"\n",
|
||||
" for i, tensor in enumerate(tensors, 1):\n",
|
||||
" ax = fig.add_subplot(1, 3, i, projection='3d')\n",
|
||||
"\n",
|
||||
" x, y, z = np.indices(tensor.shape)\n",
|
||||
" x = x.flatten()\n",
|
||||
" y = y.flatten()\n",
|
||||
" z = z.flatten()\n",
|
||||
" values = tensor.flatten()\n",
|
||||
"\n",
|
||||
" sizes = values * 5000 \n",
|
||||
"\n",
|
||||
" colors = np.vstack((x, y, z)).T / 31\n",
|
||||
"\n",
|
||||
" sc = ax.scatter(x, y, z, c=colors, s=sizes, marker='o', alpha=0.5)\n",
|
||||
"\n",
|
||||
" ax.set_xlim([0, 31])\n",
|
||||
" ax.set_ylim([0, 31])\n",
|
||||
" ax.set_zlim([0, 31])\n",
|
||||
"\n",
|
||||
" ax.set_title(f'Tensor {i}')\n",
|
||||
" return fig"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
|
|
@ -485,6 +444,7 @@
|
|||
"from tqdm.notebook import tqdm\n",
|
||||
"from torch.nn.utils import clip_grad_norm_\n",
|
||||
"from editor.training import ProgressivePoolingLoss\n",
|
||||
"from editor.visualisation import plot_histograms_in_2d\n",
|
||||
"# from geomloss import SamplesLoss \n",
|
||||
"# import numpy as np\n",
|
||||
"\n",
|
||||
|
|
@ -541,8 +501,12 @@
|
|||
" edited_histogram = edited_histogram.to(device)\n",
|
||||
" original_histogram = original_histogram.to(device)\n",
|
||||
" predicted_original = model(edited_histogram)\n",
|
||||
" writer.add_figure(\"Histograms/train/original\", plot_histograms(\n",
|
||||
" original_histogram.cpu()[0], edited_histogram.cpu()[0], predicted_original.cpu()[0]\n",
|
||||
" writer.add_figure(\"Histograms/train/original\", plot_histograms_in_2d(\n",
|
||||
" {\n",
|
||||
" 'original': original_histogram.cpu()[0].numpy().squeeze(),\n",
|
||||
" 'edited': edited_histogram.cpu()[0].numpy().squeeze(),\n",
|
||||
" 'predicted': predicted_original.cpu()[0].numpy().squeeze()\n",
|
||||
" }\n",
|
||||
" ), epoch)\n",
|
||||
" model.train()\n",
|
||||
" last_model_path = MODELS_PATH / f'model-{epoch}.pth'\n",
|
||||
|
|
@ -92170,7 +92134,7 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"from editor.ploting import plot_histograms\n",
|
||||
"from editor.visualisation import plot_histograms\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"edited_histogram, original_histogram = next(loader)\n",
|
||||
Loading…
Add table
Add a link
Reference in a new issue