diff --git a/src/train.ipynb b/src/train.ipynb
index 84be519..c424838 100644
--- a/src/train.ipynb
+++ b/src/train.ipynb
@@ -21,66 +21,98 @@
"source": [
"import os\n",
"from utils import set_up_logging, get_device\n",
- "from training import train, random_hparam_search\n",
- "from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA, MODELS_PATH\n",
+ "from training import train_with_ray_factory\n",
+ "from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA\n",
+ "from ray import tune\n",
+ "from ray.tune.schedulers import ASHAScheduler\n",
+ "import os\n",
+ "from ray.air import RunConfig\n",
"\n",
"set_up_logging(LOGS_PATH)\n",
"\n",
+ "TRIAL_COUNT = 100\n",
+ "CHUNK_COUNT = 40\n",
+ "EPOCH_COUNT = 2\n",
+ "\n",
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n",
"device = get_device()\n",
"f\"Using device {device}\""
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# hparams = {\n",
- "# \"batch_size\": 64,\n",
- "# \"edit_count\": 12,\n",
- "# \"bin_count\": 16,\n",
- "# \"learning_rate\": 0.001,\n",
- "# \"scheduler_gamma\": 0.9,\n",
- "# \"num_epochs\": 12,\n",
- "# \"model_type\": \"SimpleCNN\",\n",
- "# }\n",
- "\n",
- "# train(\n",
- "# hparams,\n",
- "# train_data_paths=TRAIN_DATA,\n",
- "# test_data_paths=TEST_DATA,\n",
- "# log_dir=RUNS_PATH,\n",
- "# max_duration=None,\n",
- "# use_tqdm=True,\n",
- "# device=device,\n",
- "# **hparams\n",
- "# )"
- ]
- },
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "2024-06-29 10:45:18,539 - INFO - Loaded 22479 original images\n",
- "2024-06-29 10:45:18,544 - INFO - Loaded 2498 original images\n"
- ]
- },
{
"data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "61d4288a87a24255824e1430d8e69804",
- "version_major": 2,
- "version_minor": 0
- },
+ "text/html": [
+ "
\n",
+ "
\n",
+ "
\n",
+ "
Tune Status
\n",
+ "
\n",
+ "\n",
+ "| Current time: | 2024-09-01 22:06:10 |
\n",
+ "| Running for: | 00:02:06.64 |
\n",
+ "| Memory: | 22.4/47.0 GiB |
\n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
System Info
\n",
+ " Using AsyncHyperBand: num_stopped=1
Bracket: Iter 2.000: -5.516964137554169
Logical resource usage: 32.0/32 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)\n",
+ " \n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
Trial Status
\n",
+ "
\n",
+ "\n",
+ "| Trial name | status | loc | batch_size | dropout_prob | elu_alpha | features | kernel_size | leaky_relu_alpha | leaky_relu_slope | learning_rate | model_type | scheduler_gamma | use_elu | use_residual | iter | total time (s) | chunk_test_loss | chunk_training_loss |
\n",
+ "\n",
+ "\n",
+ "| train_with_ray_b7d3c_00000 | TERMINATED | 172.29.235.222:1134109 | 32 | 0.00395892 | 1.35107 | [8, 16, 32] | 5 | 0.0448715 | 0.00664526 | 0.0029226 | HistogramNet | 0.816752 | True | True | 3 | 70.652 | 5.22149 | 56.5472 |
\n",
+ "| train_with_ray_b7d3c_00001 | TERMINATED | 172.29.235.222:1140440 | 32 | 0.0439061 | 1.74642 | [8, 8, 8, 8, 8,_c140 | 3 | 0.00579656 | 0.0125715 | 0.00155182 | HistogramNet | 0.898059 | True | False | 2 | 45.4228 | 5.79311 | 64.0481 |
\n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n"
+ ],
"text/plain": [
- "Epoch 0: 0%| | 0/4215 [00:00, ?batch/s]"
+ ""
]
},
"metadata": {},
@@ -90,81 +122,113 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2024-06-29 10:50:49,231 - INFO - Epoch 0 train loss: 5050.157035887241\n",
- "2024-06-29 10:51:01,594 - INFO - Epoch 0 test loss: 460.928723692894\n"
+ "\u001b[36m(train_with_ray pid=1134109)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00000_0_batch_size=32,dropout_prob=0.0040,elu_alpha=1.3511,features=8_16_32,kernel_size=5,leaky_relu_alpha=0._2024-09-01_22-04-03/checkpoint_000000)\n",
+ "\u001b[36m(train_with_ray pid=1134109)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00000_0_batch_size=32,dropout_prob=0.0040,elu_alpha=1.3511,features=8_16_32,kernel_size=5,leaky_relu_alpha=0._2024-09-01_22-04-03/checkpoint_000001)\n",
+ "\u001b[36m(train_with_ray pid=1134109)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00000_0_batch_size=32,dropout_prob=0.0040,elu_alpha=1.3511,features=8_16_32,kernel_size=5,leaky_relu_alpha=0._2024-09-01_22-04-03/checkpoint_000002)\n",
+ "\u001b[36m(train_with_ray pid=1140440)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00001_1_batch_size=32,dropout_prob=0.0439,elu_alpha=1.7464,features=8_8_8_8_8_8_8,kernel_size=3,leaky_relu_al_2024-09-01_22-04-03/checkpoint_000000)\n",
+ "2024-09-01 22:06:10,132\tINFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/andras/projects/bipolaroid/runs5/tune' in 0.0020s.\n",
+ "2024-09-01 22:06:10,135\tINFO tune.py:1041 -- Total run time: 126.67 seconds (126.64 seconds for the tuning loop).\n"
]
},
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "55aa7c6adc41427081af06330d541c2d",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Epoch 1: 0%| | 0/4215 [00:00, ?batch/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
{
"name": "stderr",
"output_type": "stream",
"text": [
- "2024-06-29 10:56:26,321 - INFO - Epoch 1 train loss: 4134.876303792\n",
- "2024-06-29 10:56:39,064 - INFO - Epoch 1 test loss: 432.3766641020775\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "2b7d28fc6f62489ebea844441a57178f",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Epoch 2: 0%| | 0/4215 [00:00, ?batch/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "2024-06-29 11:02:03,858 - INFO - Epoch 2 train loss: 3871.2984607219696\n",
- "2024-06-29 11:02:16,040 - INFO - Epoch 2 test loss: 408.74887996912\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "1395274c7a38465099c93c2d241ea1da",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Epoch 3: 0%| | 0/4215 [00:00, ?batch/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "ename": "KeyboardInterrupt",
- "evalue": "",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
- "Cell \u001b[0;32mIn[2], line 20\u001b[0m\n\u001b[1;32m 1\u001b[0m hparams \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 2\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m64\u001b[39m,\n\u001b[1;32m 3\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124medit_count\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m12\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mleaky_relu_alpha\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m0.03745605986732464\u001b[39m,\n\u001b[1;32m 18\u001b[0m }\n\u001b[0;32m---> 20\u001b[0m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[43m \u001b[49m\u001b[43mhparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 22\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_data_paths\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mTRAIN_DATA\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 23\u001b[0m \u001b[43m \u001b[49m\u001b[43mtest_data_paths\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mTEST_DATA\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[43m \u001b[49m\u001b[43mlog_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mRUNS_PATH\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 25\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_duration\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 26\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_tqdm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 27\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 28\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mhparams\u001b[49m\n\u001b[1;32m 29\u001b[0m \u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m/mnt/wsl/PHYSICALDRIVE1/projects/bipolaroid/src/training/train.py:55\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(hyperparameters, train_data_paths, test_data_paths, log_dir, use_tqdm, device, model_type, learning_rate, scheduler_gamma, num_epochs, **_)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch_id, (edited_histogram, original_histogram) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\n\u001b[1;32m 50\u001b[0m tqdm(train_data_loader, desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, unit\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_tqdm\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m train_data_loader\n\u001b[1;32m 53\u001b[0m ):\n\u001b[1;32m 54\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 55\u001b[0m predicted_original \u001b[38;5;241m=\u001b[39m model(\u001b[43medited_histogram\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 56\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss_function(\n\u001b[1;32m 57\u001b[0m torch\u001b[38;5;241m.\u001b[39mlog(predicted_original \u001b[38;5;241m+\u001b[39m EPSILON),\n\u001b[1;32m 58\u001b[0m original_histogram\u001b[38;5;241m.\u001b[39mto(device),\n\u001b[1;32m 59\u001b[0m )\n\u001b[1;32m 61\u001b[0m epoch_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mitem()\n",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
+ "\u001b[36m(train_with_ray pid=1140440)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00001_1_batch_size=32,dropout_prob=0.0439,elu_alpha=1.7464,features=8_8_8_8_8_8_8,kernel_size=3,leaky_relu_al_2024-09-01_22-04-03/checkpoint_000001)\n"
]
}
],
+ "source": [
+ "config = {\n",
+ " \"batch_size\": 32,\n",
+ " \"edit_count\": EPOCH_COUNT,\n",
+ " \"bin_count\": 32,\n",
+ " \"learning_rate\": tune.loguniform(5e-4, 5e-3),\n",
+ " \"scheduler_gamma\": tune.uniform(0.8, 0.95),\n",
+ " \"elu_alpha\": tune.uniform(0.5, 2),\n",
+ " \"leaky_relu_slope\": tune.uniform(0, 0.03),\n",
+ " \"dropout_prob\": tune.uniform(0, 0.1),\n",
+ " \"chunk_count\": CHUNK_COUNT,\n",
+ " \"features\": tune.choice(\n",
+ " [\n",
+ " [16, 32, 64],\n",
+ " [16, 32, 64, 128],\n",
+ " [32, 64],\n",
+ " [32, 128],\n",
+ " [8, 16, 32],\n",
+ " [8, 8, 8, 8, 8],\n",
+ " [8, 8, 8, 8, 8, 8, 8],\n",
+ " [16, 16, 16],\n",
+ " [16, 16, 16, 16, 16],\n",
+ " [32, 32, 32],\n",
+ " [32, 32, 32, 32],\n",
+ " [64, 64],\n",
+ " [64, 64, 64],\n",
+ " ]\n",
+ " ),\n",
+ " \"use_residual\": tune.choice([True, False]),\n",
+ " \"kernel_size\": tune.choice([3, 5]),\n",
+ " \"model_type\": tune.choice([\"HistogramNet\"]),\n",
+ " \"use_instance_norm\": True,\n",
+ " \"use_elu\": tune.choice([True, False]),\n",
+ " \"leaky_relu_alpha\": tune.uniform(0, 0.05),\n",
+ "}\n",
+ "scheduler = ASHAScheduler(max_t=config[\"chunk_count\"], grace_period=2)\n",
+ "\n",
+ "tuner = tune.Tuner(\n",
+ " tune.with_resources(\n",
+ " tune.with_parameters(\n",
+ " train_with_ray_factory(\n",
+ " train_data_paths=TRAIN_DATA,\n",
+ " test_data_paths=TEST_DATA,\n",
+ " device=device,\n",
+ " log_dir=RUNS_PATH / \"custom\",\n",
+ " )\n",
+ " ),\n",
+ " resources={\"cpu\": 32, \"gpu\": 1},\n",
+ " ),\n",
+ " run_config=RunConfig(storage_path=RUNS_PATH, name=\"tune\"),\n",
+ " tune_config=tune.TuneConfig(\n",
+ " metric=\"chunk_test_loss\",\n",
+ " mode=\"min\",\n",
+ " scheduler=scheduler,\n",
+ " num_samples=TRIAL_COUNT,\n",
+ " ),\n",
+ " param_space=config,\n",
+ ")\n",
+ "results = tuner.fit()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Best trial config: {'batch_size': 32, 'edit_count': 3, 'bin_count': 32, 'learning_rate': 0.0029226033808016005, 'scheduler_gamma': 0.8167516482513361, 'elu_alpha': 1.3510723758569865, 'leaky_relu_slope': 0.0066452562138349025, 'dropout_prob': 0.0039589213934103865, 'chunk_count': 4, 'features': [8, 16, 32], 'use_residual': True, 'kernel_size': 5, 'model_type': 'HistogramNet', 'use_instance_norm': True, 'use_elu': True, 'leaky_relu_alpha': 0.04487148648446764}\n",
+ "Best trial final validation loss: 5.2214884757995605\n"
+ ]
+ }
+ ],
+ "source": [
+ "best_result = results.get_best_result(\"chunk_test_loss\", \"min\")\n",
+ "\n",
+ "print(\"Best trial config: {}\".format(best_result.config))\n",
+ "print(\n",
+ " \"Best trial final validation loss: {}\".format(\n",
+ " best_result.metrics[\"chunk_test_loss\"]\n",
+ " )\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
"source": [
"# hparams = {\n",
"# \"batch_size\": 64,\n",
@@ -196,64 +260,6 @@
"# **hparams\n",
"# )"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from scipy.stats import loguniform, uniform, randint\n",
- "from models import MODELS, test_models\n",
- "\n",
- "\n",
- "hyperparameters = [\n",
- " {\n",
- " \"batch_size\": [64, 128, 256],\n",
- " \"edit_count\": [12],\n",
- " \"bin_count\": [16],\n",
- " \"learning_rate\": loguniform(5e-4, 5e-3),\n",
- " \"scheduler_gamma\": uniform(loc=0.8, scale=0.15),\n",
- " \"num_epochs\": [10],\n",
- " \"elu_alpha\": uniform(0.5, 1.5),\n",
- " \"leaky_relu_slope\": uniform(0, 0.03),\n",
- " \"dropout_prob\": uniform(0, 0.1),\n",
- " \"features\": [\n",
- " [16, 32, 64],\n",
- " [16, 32, 64, 128],\n",
- " [32, 64],\n",
- " [32, 64, 128],\n",
- " [8, 16, 32],\n",
- " [8, 8, 8, 8, 8],\n",
- " [8, 8, 8, 8, 8, 8, 8],\n",
- " [16, 16, 16],\n",
- " [16, 16, 16, 16, 16],\n",
- " [32, 32, 32],\n",
- " [32, 32, 32, 32],\n",
- " [64, 64],\n",
- " [64, 64, 64],\n",
- " ],\n",
- " \"use_residual\": [True, False],\n",
- " \"kernel_size\": [3, 5],\n",
- " \"model_type\": [\"HistogramNet\"],\n",
- " \"use_instance_norm\": [True],\n",
- " \"use_elu\": [True, False],\n",
- " \"leaky_relu_alpha\": uniform(0, 0.05),\n",
- " }\n",
- "]\n",
- "\n",
- "test_models()\n",
- "\n",
- "random_hparam_search(\n",
- " hyperparameters=hyperparameters,\n",
- " train_data_paths=TRAIN_DATA,\n",
- " test_data_paths=TEST_DATA,\n",
- " models_path=MODELS_PATH,\n",
- " tensorboard_path=RUNS_PATH,\n",
- " timeout_hours=4,\n",
- " device=device,\n",
- ")"
- ]
}
],
"metadata": {
diff --git a/src/training/train_with_ray.py b/src/training/train_with_ray.py
new file mode 100644
index 0000000..5ddcc1d
--- /dev/null
+++ b/src/training/train_with_ray.py
@@ -0,0 +1,205 @@
+from typing import Any, Dict, List, Tuple
+from torch.utils.tensorboard import SummaryWriter
+from pathlib import Path
+from torch.optim import Adam
+from .get_next_run_name import get_next_run_name
+from utils import serialise_hparams
+from visualisation import plot_histograms_in_2d
+from models import create_model, load_model, save_model
+import torch
+from .get_data_loader import get_data_loader
+from ray import train
+from ray.train import Checkpoint
+import os
+import tempfile
+from more_itertools import distribute
+
+
+EPSILON = 1e-5
+
+
+def train_with_ray_factory(
+ train_data_paths: List[Path],
+ test_data_paths: List[Path],
+ device: torch.device,
+ log_dir: Path,
+):
+ def train_with_ray(hyperparameters: Dict[str, Any]):
+ def inner(
+ hyperparameters: Dict[str, Any],
+ chunk_count: int,
+ **_,
+ ) -> torch.nn.Module:
+ train_data_loader = get_data_loader(train_data_paths, **hyperparameters)
+ test_data_loader = get_data_loader(
+ test_data_paths, **{**hyperparameters, "edit_count": 1}
+ )
+ examples = next(iter(test_data_loader))
+
+ model, optimizer, scheduler, start_chunk_id, run_name = (
+ load_or_create_state(
+ device=device,
+ log_dir=log_dir,
+ **hyperparameters,
+ )
+ )
+ loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device)
+
+ with SummaryWriter(log_dir=log_dir / run_name) as writer:
+ writer.add_graph(model, examples[0].to(device))
+ for chunk_id, chunk in enumerate(
+ distribute(chunk_count, train_data_loader)[start_chunk_id:-1],
+ start=start_chunk_id,
+ ):
+ chunk_training_loss = 0
+ writer.add_scalar(
+ "Actual learning rate",
+ scheduler.get_last_lr()[0],
+ chunk_id,
+ )
+ for batch_id, (edited_histogram, original_histogram) in enumerate(
+ chunk
+ ):
+ global_step = (
+ chunk_id * (len(train_data_loader) // chunk_count)
+ + batch_id
+ )
+ optimizer.zero_grad()
+ predicted_original = model(edited_histogram.to(device))
+ loss = loss_function(
+ torch.log(predicted_original + EPSILON),
+ original_histogram.to(device),
+ )
+
+ chunk_training_loss += loss.item()
+ writer.add_scalar(
+ "Loss/train/batch", loss, global_step=global_step
+ )
+ loss.backward()
+ optimizer.step()
+
+ with torch.no_grad():
+ model.eval()
+ write_histograms(
+ model=model,
+ examples=examples,
+ writer=writer,
+ device=device,
+ global_step=global_step,
+ )
+ chunk_test_loss = 0
+ for (
+ edited_histogram,
+ original_histogram,
+ ) in test_data_loader:
+ predicted_original = model(edited_histogram.to(device))
+ chunk_test_loss += loss_function(
+ torch.log(predicted_original + EPSILON),
+ original_histogram.to(device),
+ ).item()
+ model.train()
+
+ with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
+ temp_checkpoint_dir = Path(temp_checkpoint_dir)
+ checkpoint_path = temp_checkpoint_dir / "checkpoint.pt"
+ torch.save(
+ (
+ optimizer.state_dict(),
+ scheduler.state_dict(),
+ chunk_id,
+ run_name,
+ ),
+ checkpoint_path,
+ )
+ save_model(
+ model, hyperparameters, temp_checkpoint_dir / "model"
+ )
+ writer.add_hparams(
+ serialise_hparams(hyperparameters),
+ {
+ "Loss/test/epoch": chunk_test_loss,
+ "Loss/train/epoch": chunk_training_loss,
+ },
+ global_step=global_step,
+ run_name=(log_dir / run_name).absolute(),
+ )
+ train.report(
+ {
+ "chunk_test_loss": chunk_test_loss,
+ "chunk_training_loss": chunk_training_loss,
+ },
+ checkpoint=Checkpoint.from_directory(temp_checkpoint_dir),
+ )
+
+ scheduler.step()
+
+ return inner(hyperparameters=hyperparameters, **hyperparameters)
+
+ return train_with_ray
+
+
+def load_or_create_state(
+ device, log_dir, model_type, learning_rate, scheduler_gamma, **hyperparameters
+) -> Tuple[
+ torch.nn.Module,
+ torch.optim.Optimizer,
+ torch.optim.lr_scheduler.LRScheduler,
+ int,
+ str,
+]:
+ loaded_checkpoint = train.get_checkpoint()
+ if loaded_checkpoint:
+ with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
+ loaded_checkpoint_dir = Path(loaded_checkpoint_dir)
+ model, hyperparameters = load_model(
+ loaded_checkpoint_dir / "model", device=device
+ )
+ optimizer = Adam(model.parameters(), lr=learning_rate)
+
+ optimizer_state, scheduler_state, start_chunk_id, run_name = torch.load(
+ loaded_checkpoint_dir / "checkpoint.pt"
+ )
+ optimizer.load_state_dict(optimizer_state)
+ scheduler = torch.optim.lr_scheduler.StepLR(
+ optimizer, step_size=1, gamma=scheduler_gamma
+ )
+ scheduler.load_state_dict(scheduler_state)
+ else:
+ run_name = get_next_run_name(log_dir)
+ model = create_model(
+ type=model_type,
+ hyperparameters=hyperparameters,
+ device=device,
+ ).train()
+ optimizer = Adam(model.parameters(), lr=learning_rate)
+ scheduler = torch.optim.lr_scheduler.StepLR(
+ optimizer, step_size=1, gamma=scheduler_gamma
+ )
+ start_chunk_id = 0
+
+ return model, optimizer, scheduler, start_chunk_id, run_name
+
+
+def write_histograms(
+ model: torch.nn.Module,
+ examples: List[Tuple[torch.Tensor, torch.Tensor]],
+ writer: SummaryWriter,
+ device: torch.device,
+ global_step: int,
+):
+ edited_histograms, original_histograms = examples
+ predicted_originals = model(edited_histograms.to(device))
+ for i, (original, edited, predicted) in enumerate(
+ zip(original_histograms, edited_histograms, predicted_originals)
+ ):
+ writer.add_figure(
+ f"histogram_{i}",
+ plot_histograms_in_2d(
+ {
+ "original": original[0].numpy().squeeze(),
+ "edited": edited.cpu()[0].numpy().squeeze(),
+ "predicted": predicted.cpu()[0].numpy().squeeze(),
+ }
+ ),
+ global_step=global_step,
+ )