move files

This commit is contained in:
Andras Schmelczer 2024-05-09 21:22:28 +01:00
parent 1a41fd6829
commit 231e22cac8
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
36 changed files with 15580 additions and 79653 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

66
src/colour_lut.ipynb Normal file

File diff suppressed because one or more lines are too long

View file

@ -3,7 +3,9 @@
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"metadata": {
"metadata": {}
},
"outputs": [
{
"data": {
@ -19,20 +21,15 @@
"source": [
"from editor.training import random_edit\n",
"from editor.visualisation import display_images\n",
"from editor.training import HistogramDataset\n",
"from config import DATA\n",
"from editor.training import HistogramDataset\n",
"from config import DATA\n",
"\n",
"\n",
"dataset = HistogramDataset(DATA)\n",
"img = dataset.get_original_image(0)\n",
"\n",
"edits = {\n",
" 'Original': img\n",
"}\n",
"edits.update({\n",
" f'Edit {i}': random_edit(img.copy(), seed=i)\n",
" for i in range(1, 9)\n",
"})\n",
"edits = {\"Original\": img}\n",
"edits.update({f\"Edit {i}\": random_edit(img.copy(), seed=i) for i in range(1, 9)})\n",
"\n",
"display_images(edits)"
]

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -18,20 +18,16 @@
],
"source": [
"import torch\n",
"from pathlib import Path\n",
"from config import DATA, MODELS_PATH, CACHE_PATH\n",
"\n",
"\n",
"DATA = Path('/mnt/wsl/PHYSICALDRIVE1/data/unsplash').glob('*.jpg')\n",
"CACHE_PATH = Path('/mnt/wsl/PHYSICALDRIVE1/data/cache2')\n",
"CACHE_PATH.mkdir(exist_ok=True, parents=True)\n",
"BINS = 32\n",
"NUM_EPOCHS = 20\n",
"BATCH_SIZE = 64\n",
"LEARNING_RATE = 0.005\n",
"SCHEDULER_GAMMA = 0.7\n",
"EDIT_COUNT = 25\n",
"LOSS_DAMPING = 2\n",
"MODELS_PATH = Path('models')\n",
"MODELS_PATH.mkdir(exist_ok=True, parents=True)\n",
"BINS = 32\n",
"\n",
"device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
"f'Using device {device}'"
@ -146,43 +142,6 @@
"edited, og = next(iter(train_dataloader))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"def plot_histograms(original_histogram, edited_histogram, predicted_histogram):\n",
" fig = plt.figure(figsize=(15, 5))\n",
" tensors = [original_histogram.numpy().squeeze(), edited_histogram.numpy().squeeze(), predicted_histogram.numpy().squeeze()]\n",
"\n",
" for i, tensor in enumerate(tensors, 1):\n",
" ax = fig.add_subplot(1, 3, i, projection='3d')\n",
"\n",
" x, y, z = np.indices(tensor.shape)\n",
" x = x.flatten()\n",
" y = y.flatten()\n",
" z = z.flatten()\n",
" values = tensor.flatten()\n",
"\n",
" sizes = values * 5000 \n",
"\n",
" colors = np.vstack((x, y, z)).T / 31\n",
"\n",
" sc = ax.scatter(x, y, z, c=colors, s=sizes, marker='o', alpha=0.5)\n",
"\n",
" ax.set_xlim([0, 31])\n",
" ax.set_ylim([0, 31])\n",
" ax.set_zlim([0, 31])\n",
"\n",
" ax.set_title(f'Tensor {i}')\n",
" return fig"
]
},
{
"cell_type": "code",
"execution_count": 18,
@ -485,6 +444,7 @@
"from tqdm.notebook import tqdm\n",
"from torch.nn.utils import clip_grad_norm_\n",
"from editor.training import ProgressivePoolingLoss\n",
"from editor.visualisation import plot_histograms_in_2d\n",
"# from geomloss import SamplesLoss \n",
"# import numpy as np\n",
"\n",
@ -541,8 +501,12 @@
" edited_histogram = edited_histogram.to(device)\n",
" original_histogram = original_histogram.to(device)\n",
" predicted_original = model(edited_histogram)\n",
" writer.add_figure(\"Histograms/train/original\", plot_histograms(\n",
" original_histogram.cpu()[0], edited_histogram.cpu()[0], predicted_original.cpu()[0]\n",
" writer.add_figure(\"Histograms/train/original\", plot_histograms_in_2d(\n",
" {\n",
" 'original': original_histogram.cpu()[0].numpy().squeeze(),\n",
" 'edited': edited_histogram.cpu()[0].numpy().squeeze(),\n",
" 'predicted': predicted_original.cpu()[0].numpy().squeeze()\n",
" }\n",
" ), epoch)\n",
" model.train()\n",
" last_model_path = MODELS_PATH / f'model-{epoch}.pth'\n",
@ -92170,7 +92134,7 @@
}
],
"source": [
"from editor.ploting import plot_histograms\n",
"from editor.visualisation import plot_histograms\n",
"\n",
"\n",
"edited_histogram, original_histogram = next(loader)\n",