Fix train/test split

This commit is contained in:
Andras Schmelczer 2024-06-04 22:48:07 +01:00
parent af56ec3fec
commit edeac12e37
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
7 changed files with 9738 additions and 26131 deletions

View file

@ -1,6 +1,7 @@
from pathlib import Path
DATA = sorted(Path("/mnt/wsl/PHYSICALDRIVE1/data/unsplash").glob("*.jpg"))
TRAIN_SIZE = 0.8
CACHE_PATH = Path("/mnt/wsl/PHYSICALDRIVE1/data/cache")
CACHE_PATH.mkdir(exist_ok=True, parents=True)

14
src/data.py Normal file
View file

@ -0,0 +1,14 @@
import random
from config import DATA, TRAIN_SIZE
random.seed(42)
length = len(DATA)
indices = list(range(length))
random.shuffle(indices)
train_indices = indices[: int(length * TRAIN_SIZE)]
test_indices = indices[int(length * TRAIN_SIZE) :]
TRAIN_DATA = [DATA[i] for i in train_indices]
TEST_DATA = [DATA[i] for i in test_indices]

View file

@ -1,4 +1,3 @@
from .histogram_dataset import HistogramDataset
from .random_edit import random_edit
from .progressive_pooling_loss import ProgressivePoolingLoss
from .create_data_loaders import create_data_loaders

View file

@ -1,46 +0,0 @@
from pathlib import Path
from typing import List, Tuple
from torch.utils.data import DataLoader, random_split
from editor.training import HistogramDataset
import logging
import torch
from config import CACHE_PATH
import os
def create_data_loaders(
data: List[Path],
edit_count: int,
bin_count: int,
training_batch_size: int,
train_size=0.9,
delete_corrupt_images: bool = False,
) -> Tuple[DataLoader, DataLoader]:
dataset = HistogramDataset(
data,
edit_count=edit_count,
bin_count=bin_count,
delete_corrupt_images=delete_corrupt_images,
cache_path=CACHE_PATH,
)
total_size = len(dataset)
train_size = int(train_size * total_size)
test_size = total_size - train_size
train_dataset, test_dataset = random_split(
dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42)
)
train_data_loader = DataLoader(
train_dataset,
batch_size=training_batch_size,
shuffle=True,
num_workers=os.cpu_count(),
)
test_data_loader = DataLoader(
test_dataset, batch_size=1, shuffle=False, num_workers=os.cpu_count()
)
logging.info(
f"Loaded {len(train_dataset)} training images and {len(test_dataset)} test images"
)
return train_data_loader, test_data_loader

View file

@ -24,6 +24,8 @@ class HistogramDataset(Dataset):
cache_path: Optional[Path] = None,
):
self._paths = sorted(paths)
logging.info(f"Loaded {len(self._paths)} original images")
self._edit_count = edit_count
self._bin_count = bin_count
self._target_size = target_size

File diff suppressed because one or more lines are too long

View file

@ -54,7 +54,8 @@
" \"clip_gradients\": [True, False],\n",
" \"learning_rate\": loguniform(0.00001, 0.01),\n",
" \"scheduler_gamma\": uniform(0, 1),\n",
" \"num_epochs\": randint(5, 10),\n",
" \"num_epochs\": [10],\n",
" # \"num_epochs\": randint(5, 10),\n",
" \"model_type\": list(MODELS.keys()),\n",
"}\n",
"hyperparameters = [\n",
@ -77,7 +78,38 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import Any, Dict\n",
"from pathlib import Path\n",
"from typing import List, Any, Dict\n",
"from torch.utils.data import DataLoader\n",
"from config import CACHE_PATH\n",
"from editor.training import HistogramDataset\n",
"\n",
"\n",
"def get_data_loader(data: List[Path], hyperparameters: Dict[str, Any]) -> DataLoader:\n",
" return DataLoader(\n",
" dataset=HistogramDataset(\n",
" paths=data,\n",
" edit_count=hyperparameters[\"edit_count\"],\n",
" bin_count=hyperparameters[\"bin_count\"],\n",
" delete_corrupt_images=False,\n",
" cache_path=CACHE_PATH,\n",
" ),\n",
" batch_size=hyperparameters[\"batch_size\"],\n",
" shuffle=True,\n",
" num_workers=os.cpu_count(),\n",
" )\n",
"\n",
"\n",
"def serialise_hparams(hyperparameters: Dict[str, Any]) -> Dict[str, Any]:\n",
" return {k: str(v) if isinstance(v, list) else v for k, v in hyperparameters.items()}"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.tensorboard import SummaryWriter\n",
"from pathlib import Path\n",
"from torch.optim import Adam\n",
@ -86,14 +118,14 @@
"from editor.training import ProgressivePoolingLoss\n",
"from editor.utils import get_next_run_name\n",
"from editor.visualisation import plot_histograms_in_2d\n",
"from editor.training import create_data_loaders\n",
"from editor.models import create_model, test_models\n",
"from config import DATA, MODELS_PATH\n",
"from data import TRAIN_DATA, TEST_DATA\n",
"from datetime import timedelta, datetime\n",
"import json\n",
"from config import MODELS_PATH\n",
"\n",
"\n",
"# test_models()\n",
"test_models()\n",
"\n",
"\n",
"def train(\n",
@ -107,12 +139,8 @@
"\n",
" log_dir = Path(\"runs\") / get_next_run_name(Path(\"runs\"))\n",
" with SummaryWriter(log_dir) as writer:\n",
" train_data_loader, test_data_loader = create_data_loaders(\n",
" data=DATA,\n",
" edit_count=hyperparameters[\"edit_count\"],\n",
" bin_count=hyperparameters[\"bin_count\"],\n",
" training_batch_size=hyperparameters[\"batch_size\"],\n",
" )\n",
" train_data_loader = get_data_loader(TRAIN_DATA, hyperparameters)\n",
" test_data_loader = get_data_loader(TEST_DATA, hyperparameters)\n",
"\n",
" model = (\n",
" create_model(\n",
@ -178,7 +206,7 @@
" writer.add_scalar(\n",
" \"Loss/train/batch\",\n",
" loss,\n",
" epoch * len(train_data_loader) + batch_id,\n",
" global_step=epoch * len(train_data_loader) + batch_id,\n",
" )\n",
" loss.backward()\n",
"\n",
@ -186,18 +214,7 @@
" clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
" optimizer.step()\n",
"\n",
" writer.add_hparams(\n",
" {\n",
" k: str(v) if isinstance(v, list) else v\n",
" for k, v in hyperparameters.items()\n",
" },\n",
" {\n",
" \"Loss/train/epoch\": epoch_loss,\n",
" },\n",
" global_step=epoch,\n",
" run_name=log_dir.absolute(),\n",
" )\n",
" logging.info(f\"Epoch {epoch} loss: {epoch_loss}\")\n",
" logging.info(f\"Epoch {epoch} train loss: {epoch_loss}\")\n",
" with torch.no_grad():\n",
" model.eval()\n",
" loader = iter(test_data_loader)\n",
@ -222,9 +239,48 @@
" ),\n",
" epoch,\n",
" )\n",
"\n",
" epoch_test_loss = 0\n",
" for batch_id, (edited_histogram, original_histogram) in enumerate(\n",
" test_data_loader\n",
" ):\n",
" edited_histogram = edited_histogram.to(device)\n",
" original_histogram = original_histogram.to(device)\n",
"\n",
" predicted_original = model(edited_histogram)\n",
" sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)\n",
" predicted_original = predicted_original / sum\n",
"\n",
" if hyperparameters[\"loss\"] == \"kl\":\n",
" predicted_original = torch.clamp(\n",
" predicted_original, 0.0000000000000000000001, 1\n",
" )\n",
"\n",
" loss = {\n",
" \"kl\": lambda: loss_function(\n",
" torch.log(predicted_original),\n",
" original_histogram,\n",
" ),\n",
" \"progressive\": lambda: loss_function(\n",
" predicted_original, original_histogram\n",
" ),\n",
" }[hyperparameters[\"loss\"]]()\n",
"\n",
" epoch_test_loss += loss.item()\n",
" writer.add_hparams(\n",
" serialise_hparams(hyperparameters),\n",
" {\n",
" \"Loss/test/epoch\": epoch_test_loss,\n",
" \"Loss/train/epoch\": epoch_loss,\n",
" },\n",
" global_step=epoch,\n",
" run_name=log_dir.absolute(),\n",
" )\n",
" logging.info(f\"Epoch {epoch} test loss: {epoch_test_loss}\")\n",
"\n",
" model.train()\n",
" scheduler.step()\n",
" except Exception as e:\n",
" except Exception:\n",
" raise\n",
" finally:\n",
" logging.info(f\"Saving model to {model_path}\")\n",
@ -236,7 +292,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@ -259,95 +315,40 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-03 22:46:07,734 - INFO - Starting run_96 with hparams {\n",
" \"batch_size\": 16,\n",
" \"bin_count\": 32,\n",
" \"clip_gradients\": false,\n",
" \"edit_count\": 16,\n",
" \"learning_rate\": 3.291322467520231e-05,\n",
" \"loss\": \"progressive\",\n",
" \"loss_damping\": 1.1967321790868395,\n",
" \"loss_sizes\": [\n",
" 8,\n",
" 32\n",
" ],\n",
" \"model_type\": \"attention2\",\n",
" \"num_epochs\": 5,\n",
" \"scheduler_gamma\": 0.4573812925553331\n",
"}\n",
"2024-06-03 22:46:07,762 - INFO - Loaded 359668 training images and 39964 test images\n",
"2024-06-04 00:08:08,259 - INFO - Epoch 0 loss: 38.42179161275271\n",
"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/matplotlib/collections.py:996: RuntimeWarning: invalid value encountered in sqrt\n",
" scale = np.sqrt(self._sizes) * dpi / 72.0 * self._factor\n",
"2024-06-04 01:30:02,938 - INFO - Epoch 1 loss: 34.078268383513205\n",
"2024-06-04 02:46:08,066 - INFO - Saving model to /home/andras/projects/bipolaroid/models/run_96.pth\n",
"2024-06-04 02:46:08,182 - WARNING - Timeout, aborting experiment\n",
"2024-06-04 02:46:08,479 - INFO - Starting run_97 with hparams {\n",
" \"batch_size\": 64,\n",
" \"bin_count\": 32,\n",
" \"clip_gradients\": true,\n",
" \"edit_count\": 8,\n",
" \"learning_rate\": 5.96886240713341e-05,\n",
" \"loss\": \"progressive\",\n",
" \"loss_damping\": 2.8893045711729517,\n",
" \"loss_sizes\": [\n",
" 4,\n",
" 8,\n",
" 16,\n",
" 32\n",
" ],\n",
" \"model_type\": \"attention2\",\n",
" \"num_epochs\": 10,\n",
" \"scheduler_gamma\": 0.9315193474157711\n",
"}\n",
"2024-06-04 02:46:08,500 - INFO - Loaded 179834 training images and 19982 test images\n",
"2024-06-04 06:46:16,877 - INFO - Saving model to /home/andras/projects/bipolaroid/models/run_97.pth\n",
"2024-06-04 06:46:28,422 - WARNING - Timeout, aborting experiment\n",
"2024-06-04 06:46:28,437 - INFO - Starting run_98 with hparams {\n",
" \"batch_size\": 64,\n",
" \"bin_count\": 16,\n",
" \"clip_gradients\": false,\n",
" \"edit_count\": 16,\n",
" \"learning_rate\": 0.0019552772361485543,\n",
" \"loss\": \"kl\",\n",
" \"model_type\": \"SimpleCNN\",\n",
" \"num_epochs\": 5,\n",
" \"scheduler_gamma\": 0.022346077394851838\n",
"}\n",
"2024-06-04 06:46:28,475 - INFO - Loaded 359668 training images and 39964 test images\n",
"2024-06-04 06:48:11,407 - INFO - Epoch 0 loss: 20430.093976140022\n",
"2024-06-04 06:49:52,288 - INFO - Epoch 1 loss: 14717.722860097885\n",
"2024-06-04 06:51:32,993 - INFO - Epoch 2 loss: 13855.800803661346\n",
"2024-06-04 06:53:13,588 - INFO - Epoch 3 loss: 13853.357687234879\n",
"2024-06-04 06:54:54,389 - INFO - Epoch 4 loss: 13853.240978479385\n",
"2024-06-04 06:54:56,519 - INFO - Saving model to /home/andras/projects/bipolaroid/models/run_98.pth\n",
"2024-06-04 06:54:57,057 - INFO - Starting run_99 with hparams {\n",
"2024-06-04 22:42:58,542 - INFO - Starting run_6 with hparams {\n",
" \"batch_size\": 32,\n",
" \"bin_count\": 32,\n",
" \"clip_gradients\": false,\n",
" \"bin_count\": 24,\n",
" \"clip_gradients\": true,\n",
" \"edit_count\": 16,\n",
" \"learning_rate\": 0.00041782149104212284,\n",
" \"loss\": \"progressive\",\n",
" \"loss_damping\": 2.393572363792762,\n",
" \"loss_sizes\": [\n",
" 8,\n",
" 16,\n",
" 32\n",
" ],\n",
" \"model_type\": \"attention2\",\n",
" \"learning_rate\": 0.00043772328325342977,\n",
" \"loss\": \"kl\",\n",
" \"model_type\": \"Residual\",\n",
" \"num_epochs\": 10,\n",
" \"scheduler_gamma\": 0.3478968531660309\n",
" \"scheduler_gamma\": 0.7839153264198727\n",
"}\n",
"2024-06-04 06:54:57,082 - INFO - Loaded 359668 training images and 39964 test images\n",
"2024-06-04 07:28:59,180 - INFO - Saving model to /home/andras/projects/bipolaroid/models/run_99.pth\n",
"2024-06-04 07:28:59,341 - INFO - Interrupted, stopping\n"
"2024-06-04 22:42:58,546 - INFO - Loaded 800 original images\n",
"2024-06-04 22:42:58,548 - INFO - Loaded 200 original images\n",
"2024-06-04 22:43:28,238 - INFO - Epoch 0 train loss: 1092.8763027191162\n",
"2024-06-04 22:43:32,819 - INFO - Epoch 0 test loss: 378.1922540664673\n",
"2024-06-04 22:43:58,916 - INFO - Epoch 1 train loss: 833.7781882286072\n",
"2024-06-04 22:44:03,764 - INFO - Epoch 1 test loss: 415.40036368370056\n",
"2024-06-04 22:44:29,857 - INFO - Epoch 2 train loss: 767.9900119304657\n",
"2024-06-04 22:44:34,696 - INFO - Epoch 2 test loss: 297.7264175415039\n",
"2024-06-04 22:45:00,728 - INFO - Epoch 3 train loss: 726.7450095415115\n",
"2024-06-04 22:45:05,531 - INFO - Epoch 3 test loss: 265.4522042274475\n",
"2024-06-04 22:45:31,520 - INFO - Epoch 4 train loss: 700.53100669384\n",
"2024-06-04 22:45:36,503 - INFO - Epoch 4 test loss: 285.5881419181824\n",
"2024-06-04 22:46:02,685 - INFO - Epoch 5 train loss: 675.3199667930603\n",
"2024-06-04 22:46:07,717 - INFO - Epoch 5 test loss: 246.0367443561554\n",
"2024-06-04 22:46:33,992 - INFO - Epoch 6 train loss: 656.3262705802917\n",
"2024-06-04 22:46:38,921 - INFO - Epoch 6 test loss: 242.63626527786255\n"
]
}
],