Fix train/test split
This commit is contained in:
parent
af56ec3fec
commit
edeac12e37
7 changed files with 9738 additions and 26131 deletions
|
|
@ -1,6 +1,7 @@
|
|||
from pathlib import Path
|
||||
|
||||
DATA = sorted(Path("/mnt/wsl/PHYSICALDRIVE1/data/unsplash").glob("*.jpg"))
|
||||
TRAIN_SIZE = 0.8
|
||||
|
||||
CACHE_PATH = Path("/mnt/wsl/PHYSICALDRIVE1/data/cache")
|
||||
CACHE_PATH.mkdir(exist_ok=True, parents=True)
|
||||
|
|
|
|||
14
src/data.py
Normal file
14
src/data.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
import random
|
||||
from config import DATA, TRAIN_SIZE
|
||||
|
||||
random.seed(42)
|
||||
length = len(DATA)
|
||||
indices = list(range(length))
|
||||
|
||||
random.shuffle(indices)
|
||||
|
||||
train_indices = indices[: int(length * TRAIN_SIZE)]
|
||||
test_indices = indices[int(length * TRAIN_SIZE) :]
|
||||
|
||||
TRAIN_DATA = [DATA[i] for i in train_indices]
|
||||
TEST_DATA = [DATA[i] for i in test_indices]
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
from .histogram_dataset import HistogramDataset
|
||||
from .random_edit import random_edit
|
||||
from .progressive_pooling_loss import ProgressivePoolingLoss
|
||||
from .create_data_loaders import create_data_loaders
|
||||
|
|
|
|||
|
|
@ -1,46 +0,0 @@
|
|||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from editor.training import HistogramDataset
|
||||
import logging
|
||||
import torch
|
||||
from config import CACHE_PATH
|
||||
import os
|
||||
|
||||
|
||||
def create_data_loaders(
|
||||
data: List[Path],
|
||||
edit_count: int,
|
||||
bin_count: int,
|
||||
training_batch_size: int,
|
||||
train_size=0.9,
|
||||
delete_corrupt_images: bool = False,
|
||||
) -> Tuple[DataLoader, DataLoader]:
|
||||
dataset = HistogramDataset(
|
||||
data,
|
||||
edit_count=edit_count,
|
||||
bin_count=bin_count,
|
||||
delete_corrupt_images=delete_corrupt_images,
|
||||
cache_path=CACHE_PATH,
|
||||
)
|
||||
total_size = len(dataset)
|
||||
train_size = int(train_size * total_size)
|
||||
test_size = total_size - train_size
|
||||
train_dataset, test_dataset = random_split(
|
||||
dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42)
|
||||
)
|
||||
|
||||
train_data_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=training_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=os.cpu_count(),
|
||||
)
|
||||
test_data_loader = DataLoader(
|
||||
test_dataset, batch_size=1, shuffle=False, num_workers=os.cpu_count()
|
||||
)
|
||||
logging.info(
|
||||
f"Loaded {len(train_dataset)} training images and {len(test_dataset)} test images"
|
||||
)
|
||||
|
||||
return train_data_loader, test_data_loader
|
||||
|
|
@ -24,6 +24,8 @@ class HistogramDataset(Dataset):
|
|||
cache_path: Optional[Path] = None,
|
||||
):
|
||||
self._paths = sorted(paths)
|
||||
logging.info(f"Loaded {len(self._paths)} original images")
|
||||
|
||||
self._edit_count = edit_count
|
||||
self._bin_count = bin_count
|
||||
self._target_size = target_size
|
||||
|
|
|
|||
35592
src/inference.ipynb
35592
src/inference.ipynb
File diff suppressed because one or more lines are too long
213
src/train.ipynb
213
src/train.ipynb
|
|
@ -54,7 +54,8 @@
|
|||
" \"clip_gradients\": [True, False],\n",
|
||||
" \"learning_rate\": loguniform(0.00001, 0.01),\n",
|
||||
" \"scheduler_gamma\": uniform(0, 1),\n",
|
||||
" \"num_epochs\": randint(5, 10),\n",
|
||||
" \"num_epochs\": [10],\n",
|
||||
" # \"num_epochs\": randint(5, 10),\n",
|
||||
" \"model_type\": list(MODELS.keys()),\n",
|
||||
"}\n",
|
||||
"hyperparameters = [\n",
|
||||
|
|
@ -77,7 +78,38 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Any, Dict\n",
|
||||
"from pathlib import Path\n",
|
||||
"from typing import List, Any, Dict\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"from config import CACHE_PATH\n",
|
||||
"from editor.training import HistogramDataset\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_data_loader(data: List[Path], hyperparameters: Dict[str, Any]) -> DataLoader:\n",
|
||||
" return DataLoader(\n",
|
||||
" dataset=HistogramDataset(\n",
|
||||
" paths=data,\n",
|
||||
" edit_count=hyperparameters[\"edit_count\"],\n",
|
||||
" bin_count=hyperparameters[\"bin_count\"],\n",
|
||||
" delete_corrupt_images=False,\n",
|
||||
" cache_path=CACHE_PATH,\n",
|
||||
" ),\n",
|
||||
" batch_size=hyperparameters[\"batch_size\"],\n",
|
||||
" shuffle=True,\n",
|
||||
" num_workers=os.cpu_count(),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def serialise_hparams(hyperparameters: Dict[str, Any]) -> Dict[str, Any]:\n",
|
||||
" return {k: str(v) if isinstance(v, list) else v for k, v in hyperparameters.items()}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch.utils.tensorboard import SummaryWriter\n",
|
||||
"from pathlib import Path\n",
|
||||
"from torch.optim import Adam\n",
|
||||
|
|
@ -86,14 +118,14 @@
|
|||
"from editor.training import ProgressivePoolingLoss\n",
|
||||
"from editor.utils import get_next_run_name\n",
|
||||
"from editor.visualisation import plot_histograms_in_2d\n",
|
||||
"from editor.training import create_data_loaders\n",
|
||||
"from editor.models import create_model, test_models\n",
|
||||
"from config import DATA, MODELS_PATH\n",
|
||||
"from data import TRAIN_DATA, TEST_DATA\n",
|
||||
"from datetime import timedelta, datetime\n",
|
||||
"import json\n",
|
||||
"from config import MODELS_PATH\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# test_models()\n",
|
||||
"test_models()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def train(\n",
|
||||
|
|
@ -107,12 +139,8 @@
|
|||
"\n",
|
||||
" log_dir = Path(\"runs\") / get_next_run_name(Path(\"runs\"))\n",
|
||||
" with SummaryWriter(log_dir) as writer:\n",
|
||||
" train_data_loader, test_data_loader = create_data_loaders(\n",
|
||||
" data=DATA,\n",
|
||||
" edit_count=hyperparameters[\"edit_count\"],\n",
|
||||
" bin_count=hyperparameters[\"bin_count\"],\n",
|
||||
" training_batch_size=hyperparameters[\"batch_size\"],\n",
|
||||
" )\n",
|
||||
" train_data_loader = get_data_loader(TRAIN_DATA, hyperparameters)\n",
|
||||
" test_data_loader = get_data_loader(TEST_DATA, hyperparameters)\n",
|
||||
"\n",
|
||||
" model = (\n",
|
||||
" create_model(\n",
|
||||
|
|
@ -178,7 +206,7 @@
|
|||
" writer.add_scalar(\n",
|
||||
" \"Loss/train/batch\",\n",
|
||||
" loss,\n",
|
||||
" epoch * len(train_data_loader) + batch_id,\n",
|
||||
" global_step=epoch * len(train_data_loader) + batch_id,\n",
|
||||
" )\n",
|
||||
" loss.backward()\n",
|
||||
"\n",
|
||||
|
|
@ -186,18 +214,7 @@
|
|||
" clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
" writer.add_hparams(\n",
|
||||
" {\n",
|
||||
" k: str(v) if isinstance(v, list) else v\n",
|
||||
" for k, v in hyperparameters.items()\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"Loss/train/epoch\": epoch_loss,\n",
|
||||
" },\n",
|
||||
" global_step=epoch,\n",
|
||||
" run_name=log_dir.absolute(),\n",
|
||||
" )\n",
|
||||
" logging.info(f\"Epoch {epoch} loss: {epoch_loss}\")\n",
|
||||
" logging.info(f\"Epoch {epoch} train loss: {epoch_loss}\")\n",
|
||||
" with torch.no_grad():\n",
|
||||
" model.eval()\n",
|
||||
" loader = iter(test_data_loader)\n",
|
||||
|
|
@ -222,9 +239,48 @@
|
|||
" ),\n",
|
||||
" epoch,\n",
|
||||
" )\n",
|
||||
" model.train()\n",
|
||||
"\n",
|
||||
" epoch_test_loss = 0\n",
|
||||
" for batch_id, (edited_histogram, original_histogram) in enumerate(\n",
|
||||
" test_data_loader\n",
|
||||
" ):\n",
|
||||
" edited_histogram = edited_histogram.to(device)\n",
|
||||
" original_histogram = original_histogram.to(device)\n",
|
||||
"\n",
|
||||
" predicted_original = model(edited_histogram)\n",
|
||||
" sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)\n",
|
||||
" predicted_original = predicted_original / sum\n",
|
||||
"\n",
|
||||
" if hyperparameters[\"loss\"] == \"kl\":\n",
|
||||
" predicted_original = torch.clamp(\n",
|
||||
" predicted_original, 0.0000000000000000000001, 1\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" loss = {\n",
|
||||
" \"kl\": lambda: loss_function(\n",
|
||||
" torch.log(predicted_original),\n",
|
||||
" original_histogram,\n",
|
||||
" ),\n",
|
||||
" \"progressive\": lambda: loss_function(\n",
|
||||
" predicted_original, original_histogram\n",
|
||||
" ),\n",
|
||||
" }[hyperparameters[\"loss\"]]()\n",
|
||||
"\n",
|
||||
" epoch_test_loss += loss.item()\n",
|
||||
" writer.add_hparams(\n",
|
||||
" serialise_hparams(hyperparameters),\n",
|
||||
" {\n",
|
||||
" \"Loss/test/epoch\": epoch_test_loss,\n",
|
||||
" \"Loss/train/epoch\": epoch_loss,\n",
|
||||
" },\n",
|
||||
" global_step=epoch,\n",
|
||||
" run_name=log_dir.absolute(),\n",
|
||||
" )\n",
|
||||
" logging.info(f\"Epoch {epoch} test loss: {epoch_test_loss}\")\n",
|
||||
"\n",
|
||||
" model.train()\n",
|
||||
" scheduler.step()\n",
|
||||
" except Exception as e:\n",
|
||||
" except Exception:\n",
|
||||
" raise\n",
|
||||
" finally:\n",
|
||||
" logging.info(f\"Saving model to {model_path}\")\n",
|
||||
|
|
@ -236,7 +292,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
|
@ -259,95 +315,40 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-03 22:46:07,734 - INFO - Starting run_96 with hparams {\n",
|
||||
" \"batch_size\": 16,\n",
|
||||
" \"bin_count\": 32,\n",
|
||||
" \"clip_gradients\": false,\n",
|
||||
" \"edit_count\": 16,\n",
|
||||
" \"learning_rate\": 3.291322467520231e-05,\n",
|
||||
" \"loss\": \"progressive\",\n",
|
||||
" \"loss_damping\": 1.1967321790868395,\n",
|
||||
" \"loss_sizes\": [\n",
|
||||
" 8,\n",
|
||||
" 32\n",
|
||||
" ],\n",
|
||||
" \"model_type\": \"attention2\",\n",
|
||||
" \"num_epochs\": 5,\n",
|
||||
" \"scheduler_gamma\": 0.4573812925553331\n",
|
||||
"}\n",
|
||||
"2024-06-03 22:46:07,762 - INFO - Loaded 359668 training images and 39964 test images\n",
|
||||
"2024-06-04 00:08:08,259 - INFO - Epoch 0 loss: 38.42179161275271\n",
|
||||
"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/matplotlib/collections.py:996: RuntimeWarning: invalid value encountered in sqrt\n",
|
||||
" scale = np.sqrt(self._sizes) * dpi / 72.0 * self._factor\n",
|
||||
"2024-06-04 01:30:02,938 - INFO - Epoch 1 loss: 34.078268383513205\n",
|
||||
"2024-06-04 02:46:08,066 - INFO - Saving model to /home/andras/projects/bipolaroid/models/run_96.pth\n",
|
||||
"2024-06-04 02:46:08,182 - WARNING - Timeout, aborting experiment\n",
|
||||
"2024-06-04 02:46:08,479 - INFO - Starting run_97 with hparams {\n",
|
||||
" \"batch_size\": 64,\n",
|
||||
" \"bin_count\": 32,\n",
|
||||
" \"clip_gradients\": true,\n",
|
||||
" \"edit_count\": 8,\n",
|
||||
" \"learning_rate\": 5.96886240713341e-05,\n",
|
||||
" \"loss\": \"progressive\",\n",
|
||||
" \"loss_damping\": 2.8893045711729517,\n",
|
||||
" \"loss_sizes\": [\n",
|
||||
" 4,\n",
|
||||
" 8,\n",
|
||||
" 16,\n",
|
||||
" 32\n",
|
||||
" ],\n",
|
||||
" \"model_type\": \"attention2\",\n",
|
||||
" \"num_epochs\": 10,\n",
|
||||
" \"scheduler_gamma\": 0.9315193474157711\n",
|
||||
"}\n",
|
||||
"2024-06-04 02:46:08,500 - INFO - Loaded 179834 training images and 19982 test images\n",
|
||||
"2024-06-04 06:46:16,877 - INFO - Saving model to /home/andras/projects/bipolaroid/models/run_97.pth\n",
|
||||
"2024-06-04 06:46:28,422 - WARNING - Timeout, aborting experiment\n",
|
||||
"2024-06-04 06:46:28,437 - INFO - Starting run_98 with hparams {\n",
|
||||
" \"batch_size\": 64,\n",
|
||||
" \"bin_count\": 16,\n",
|
||||
" \"clip_gradients\": false,\n",
|
||||
" \"edit_count\": 16,\n",
|
||||
" \"learning_rate\": 0.0019552772361485543,\n",
|
||||
" \"loss\": \"kl\",\n",
|
||||
" \"model_type\": \"SimpleCNN\",\n",
|
||||
" \"num_epochs\": 5,\n",
|
||||
" \"scheduler_gamma\": 0.022346077394851838\n",
|
||||
"}\n",
|
||||
"2024-06-04 06:46:28,475 - INFO - Loaded 359668 training images and 39964 test images\n",
|
||||
"2024-06-04 06:48:11,407 - INFO - Epoch 0 loss: 20430.093976140022\n",
|
||||
"2024-06-04 06:49:52,288 - INFO - Epoch 1 loss: 14717.722860097885\n",
|
||||
"2024-06-04 06:51:32,993 - INFO - Epoch 2 loss: 13855.800803661346\n",
|
||||
"2024-06-04 06:53:13,588 - INFO - Epoch 3 loss: 13853.357687234879\n",
|
||||
"2024-06-04 06:54:54,389 - INFO - Epoch 4 loss: 13853.240978479385\n",
|
||||
"2024-06-04 06:54:56,519 - INFO - Saving model to /home/andras/projects/bipolaroid/models/run_98.pth\n",
|
||||
"2024-06-04 06:54:57,057 - INFO - Starting run_99 with hparams {\n",
|
||||
"2024-06-04 22:42:58,542 - INFO - Starting run_6 with hparams {\n",
|
||||
" \"batch_size\": 32,\n",
|
||||
" \"bin_count\": 32,\n",
|
||||
" \"clip_gradients\": false,\n",
|
||||
" \"bin_count\": 24,\n",
|
||||
" \"clip_gradients\": true,\n",
|
||||
" \"edit_count\": 16,\n",
|
||||
" \"learning_rate\": 0.00041782149104212284,\n",
|
||||
" \"loss\": \"progressive\",\n",
|
||||
" \"loss_damping\": 2.393572363792762,\n",
|
||||
" \"loss_sizes\": [\n",
|
||||
" 8,\n",
|
||||
" 16,\n",
|
||||
" 32\n",
|
||||
" ],\n",
|
||||
" \"model_type\": \"attention2\",\n",
|
||||
" \"learning_rate\": 0.00043772328325342977,\n",
|
||||
" \"loss\": \"kl\",\n",
|
||||
" \"model_type\": \"Residual\",\n",
|
||||
" \"num_epochs\": 10,\n",
|
||||
" \"scheduler_gamma\": 0.3478968531660309\n",
|
||||
" \"scheduler_gamma\": 0.7839153264198727\n",
|
||||
"}\n",
|
||||
"2024-06-04 06:54:57,082 - INFO - Loaded 359668 training images and 39964 test images\n",
|
||||
"2024-06-04 07:28:59,180 - INFO - Saving model to /home/andras/projects/bipolaroid/models/run_99.pth\n",
|
||||
"2024-06-04 07:28:59,341 - INFO - Interrupted, stopping\n"
|
||||
"2024-06-04 22:42:58,546 - INFO - Loaded 800 original images\n",
|
||||
"2024-06-04 22:42:58,548 - INFO - Loaded 200 original images\n",
|
||||
"2024-06-04 22:43:28,238 - INFO - Epoch 0 train loss: 1092.8763027191162\n",
|
||||
"2024-06-04 22:43:32,819 - INFO - Epoch 0 test loss: 378.1922540664673\n",
|
||||
"2024-06-04 22:43:58,916 - INFO - Epoch 1 train loss: 833.7781882286072\n",
|
||||
"2024-06-04 22:44:03,764 - INFO - Epoch 1 test loss: 415.40036368370056\n",
|
||||
"2024-06-04 22:44:29,857 - INFO - Epoch 2 train loss: 767.9900119304657\n",
|
||||
"2024-06-04 22:44:34,696 - INFO - Epoch 2 test loss: 297.7264175415039\n",
|
||||
"2024-06-04 22:45:00,728 - INFO - Epoch 3 train loss: 726.7450095415115\n",
|
||||
"2024-06-04 22:45:05,531 - INFO - Epoch 3 test loss: 265.4522042274475\n",
|
||||
"2024-06-04 22:45:31,520 - INFO - Epoch 4 train loss: 700.53100669384\n",
|
||||
"2024-06-04 22:45:36,503 - INFO - Epoch 4 test loss: 285.5881419181824\n",
|
||||
"2024-06-04 22:46:02,685 - INFO - Epoch 5 train loss: 675.3199667930603\n",
|
||||
"2024-06-04 22:46:07,717 - INFO - Epoch 5 test loss: 246.0367443561554\n",
|
||||
"2024-06-04 22:46:33,992 - INFO - Epoch 6 train loss: 656.3262705802917\n",
|
||||
"2024-06-04 22:46:38,921 - INFO - Epoch 6 test loss: 242.63626527786255\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue