Integrate ray tune

This commit is contained in:
Andras Schmelczer 2024-09-01 22:11:09 +01:00
parent 49d9ece2ec
commit fb07fc674e
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
2 changed files with 379 additions and 168 deletions

View file

@ -21,66 +21,98 @@
"source": [
"import os\n",
"from utils import set_up_logging, get_device\n",
"from training import train, random_hparam_search\n",
"from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA, MODELS_PATH\n",
"from training import train_with_ray_factory\n",
"from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA\n",
"from ray import tune\n",
"from ray.tune.schedulers import ASHAScheduler\n",
"import os\n",
"from ray.air import RunConfig\n",
"\n",
"set_up_logging(LOGS_PATH)\n",
"\n",
"TRIAL_COUNT = 100\n",
"CHUNK_COUNT = 40\n",
"EPOCH_COUNT = 2\n",
"\n",
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n",
"device = get_device()\n",
"f\"Using device {device}\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# hparams = {\n",
"# \"batch_size\": 64,\n",
"# \"edit_count\": 12,\n",
"# \"bin_count\": 16,\n",
"# \"learning_rate\": 0.001,\n",
"# \"scheduler_gamma\": 0.9,\n",
"# \"num_epochs\": 12,\n",
"# \"model_type\": \"SimpleCNN\",\n",
"# }\n",
"\n",
"# train(\n",
"# hparams,\n",
"# train_data_paths=TRAIN_DATA,\n",
"# test_data_paths=TEST_DATA,\n",
"# log_dir=RUNS_PATH,\n",
"# max_duration=None,\n",
"# use_tqdm=True,\n",
"# device=device,\n",
"# **hparams\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-29 10:45:18,539 - INFO - Loaded 22479 original images\n",
"2024-06-29 10:45:18,544 - INFO - Loaded 2498 original images\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "61d4288a87a24255824e1430d8e69804",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"<div class=\"tuneStatus\">\n",
" <div style=\"display: flex;flex-direction: row\">\n",
" <div style=\"display: flex;flex-direction: column;\">\n",
" <h3>Tune Status</h3>\n",
" <table>\n",
"<tbody>\n",
"<tr><td>Current time:</td><td>2024-09-01 22:06:10</td></tr>\n",
"<tr><td>Running for: </td><td>00:02:06.64 </td></tr>\n",
"<tr><td>Memory: </td><td>22.4/47.0 GiB </td></tr>\n",
"</tbody>\n",
"</table>\n",
" </div>\n",
" <div class=\"vDivider\"></div>\n",
" <div class=\"systemInfo\">\n",
" <h3>System Info</h3>\n",
" Using AsyncHyperBand: num_stopped=1<br>Bracket: Iter 2.000: -5.516964137554169<br>Logical resource usage: 32.0/32 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)\n",
" </div>\n",
" \n",
" </div>\n",
" <div class=\"hDivider\"></div>\n",
" <div class=\"trialStatus\">\n",
" <h3>Trial Status</h3>\n",
" <table>\n",
"<thead>\n",
"<tr><th>Trial name </th><th>status </th><th>loc </th><th style=\"text-align: right;\"> batch_size</th><th style=\"text-align: right;\"> dropout_prob</th><th style=\"text-align: right;\"> elu_alpha</th><th>features </th><th style=\"text-align: right;\"> kernel_size</th><th style=\"text-align: right;\"> leaky_relu_alpha</th><th style=\"text-align: right;\"> leaky_relu_slope</th><th style=\"text-align: right;\"> learning_rate</th><th>model_type </th><th style=\"text-align: right;\"> scheduler_gamma</th><th>use_elu </th><th>use_residual </th><th style=\"text-align: right;\"> iter</th><th style=\"text-align: right;\"> total time (s)</th><th style=\"text-align: right;\"> chunk_test_loss</th><th style=\"text-align: right;\"> chunk_training_loss</th></tr>\n",
"</thead>\n",
"<tbody>\n",
"<tr><td>train_with_ray_b7d3c_00000</td><td>TERMINATED</td><td>172.29.235.222:1134109</td><td style=\"text-align: right;\"> 32</td><td style=\"text-align: right;\"> 0.00395892</td><td style=\"text-align: right;\"> 1.35107</td><td>[8, 16, 32] </td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.0448715 </td><td style=\"text-align: right;\"> 0.00664526</td><td style=\"text-align: right;\"> 0.0029226 </td><td>HistogramNet</td><td style=\"text-align: right;\"> 0.816752</td><td>True </td><td>True </td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 70.652 </td><td style=\"text-align: right;\"> 5.22149</td><td style=\"text-align: right;\"> 56.5472</td></tr>\n",
"<tr><td>train_with_ray_b7d3c_00001</td><td>TERMINATED</td><td>172.29.235.222:1140440</td><td style=\"text-align: right;\"> 32</td><td style=\"text-align: right;\"> 0.0439061 </td><td style=\"text-align: right;\"> 1.74642</td><td>[8, 8, 8, 8, 8,_c140</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.00579656</td><td style=\"text-align: right;\"> 0.0125715 </td><td style=\"text-align: right;\"> 0.00155182</td><td>HistogramNet</td><td style=\"text-align: right;\"> 0.898059</td><td>True </td><td>False </td><td style=\"text-align: right;\"> 2</td><td style=\"text-align: right;\"> 45.4228</td><td style=\"text-align: right;\"> 5.79311</td><td style=\"text-align: right;\"> 64.0481</td></tr>\n",
"</tbody>\n",
"</table>\n",
" </div>\n",
"</div>\n",
"<style>\n",
".tuneStatus {\n",
" color: var(--jp-ui-font-color1);\n",
"}\n",
".tuneStatus .systemInfo {\n",
" display: flex;\n",
" flex-direction: column;\n",
"}\n",
".tuneStatus td {\n",
" white-space: nowrap;\n",
"}\n",
".tuneStatus .trialStatus {\n",
" display: flex;\n",
" flex-direction: column;\n",
"}\n",
".tuneStatus h3 {\n",
" font-weight: bold;\n",
"}\n",
".tuneStatus .hDivider {\n",
" border-bottom-width: var(--jp-border-width);\n",
" border-bottom-color: var(--jp-border-color0);\n",
" border-bottom-style: solid;\n",
"}\n",
".tuneStatus .vDivider {\n",
" border-left-width: var(--jp-border-width);\n",
" border-left-color: var(--jp-border-color0);\n",
" border-left-style: solid;\n",
" margin: 0.5em 1em 0.5em 1em;\n",
"}\n",
"</style>\n"
],
"text/plain": [
"Epoch 0: 0%| | 0/4215 [00:00<?, ?batch/s]"
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
@ -90,81 +122,113 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-29 10:50:49,231 - INFO - Epoch 0 train loss: 5050.157035887241\n",
"2024-06-29 10:51:01,594 - INFO - Epoch 0 test loss: 460.928723692894\n"
"\u001b[36m(train_with_ray pid=1134109)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00000_0_batch_size=32,dropout_prob=0.0040,elu_alpha=1.3511,features=8_16_32,kernel_size=5,leaky_relu_alpha=0._2024-09-01_22-04-03/checkpoint_000000)\n",
"\u001b[36m(train_with_ray pid=1134109)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00000_0_batch_size=32,dropout_prob=0.0040,elu_alpha=1.3511,features=8_16_32,kernel_size=5,leaky_relu_alpha=0._2024-09-01_22-04-03/checkpoint_000001)\n",
"\u001b[36m(train_with_ray pid=1134109)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00000_0_batch_size=32,dropout_prob=0.0040,elu_alpha=1.3511,features=8_16_32,kernel_size=5,leaky_relu_alpha=0._2024-09-01_22-04-03/checkpoint_000002)\n",
"\u001b[36m(train_with_ray pid=1140440)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00001_1_batch_size=32,dropout_prob=0.0439,elu_alpha=1.7464,features=8_8_8_8_8_8_8,kernel_size=3,leaky_relu_al_2024-09-01_22-04-03/checkpoint_000000)\n",
"2024-09-01 22:06:10,132\tINFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/andras/projects/bipolaroid/runs5/tune' in 0.0020s.\n",
"2024-09-01 22:06:10,135\tINFO tune.py:1041 -- Total run time: 126.67 seconds (126.64 seconds for the tuning loop).\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "55aa7c6adc41427081af06330d541c2d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 1: 0%| | 0/4215 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-29 10:56:26,321 - INFO - Epoch 1 train loss: 4134.876303792\n",
"2024-06-29 10:56:39,064 - INFO - Epoch 1 test loss: 432.3766641020775\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2b7d28fc6f62489ebea844441a57178f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 2: 0%| | 0/4215 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-29 11:02:03,858 - INFO - Epoch 2 train loss: 3871.2984607219696\n",
"2024-06-29 11:02:16,040 - INFO - Epoch 2 test loss: 408.74887996912\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1395274c7a38465099c93c2d241ea1da",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 3: 0%| | 0/4215 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[2], line 20\u001b[0m\n\u001b[1;32m 1\u001b[0m hparams \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 2\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m64\u001b[39m,\n\u001b[1;32m 3\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124medit_count\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m12\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mleaky_relu_alpha\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m0.03745605986732464\u001b[39m,\n\u001b[1;32m 18\u001b[0m }\n\u001b[0;32m---> 20\u001b[0m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[43m \u001b[49m\u001b[43mhparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 22\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_data_paths\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mTRAIN_DATA\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 23\u001b[0m \u001b[43m \u001b[49m\u001b[43mtest_data_paths\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mTEST_DATA\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[43m \u001b[49m\u001b[43mlog_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mRUNS_PATH\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 25\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_duration\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 26\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_tqdm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 27\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 28\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mhparams\u001b[49m\n\u001b[1;32m 29\u001b[0m \u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/mnt/wsl/PHYSICALDRIVE1/projects/bipolaroid/src/training/train.py:55\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(hyperparameters, train_data_paths, test_data_paths, log_dir, use_tqdm, device, model_type, learning_rate, scheduler_gamma, num_epochs, **_)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch_id, (edited_histogram, original_histogram) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\n\u001b[1;32m 50\u001b[0m tqdm(train_data_loader, desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, unit\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_tqdm\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m train_data_loader\n\u001b[1;32m 53\u001b[0m ):\n\u001b[1;32m 54\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 55\u001b[0m predicted_original \u001b[38;5;241m=\u001b[39m model(\u001b[43medited_histogram\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 56\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss_function(\n\u001b[1;32m 57\u001b[0m torch\u001b[38;5;241m.\u001b[39mlog(predicted_original \u001b[38;5;241m+\u001b[39m EPSILON),\n\u001b[1;32m 58\u001b[0m original_histogram\u001b[38;5;241m.\u001b[39mto(device),\n\u001b[1;32m 59\u001b[0m )\n\u001b[1;32m 61\u001b[0m epoch_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mitem()\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
"\u001b[36m(train_with_ray pid=1140440)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00001_1_batch_size=32,dropout_prob=0.0439,elu_alpha=1.7464,features=8_8_8_8_8_8_8,kernel_size=3,leaky_relu_al_2024-09-01_22-04-03/checkpoint_000001)\n"
]
}
],
"source": [
"config = {\n",
" \"batch_size\": 32,\n",
" \"edit_count\": EPOCH_COUNT,\n",
" \"bin_count\": 32,\n",
" \"learning_rate\": tune.loguniform(5e-4, 5e-3),\n",
" \"scheduler_gamma\": tune.uniform(0.8, 0.95),\n",
" \"elu_alpha\": tune.uniform(0.5, 2),\n",
" \"leaky_relu_slope\": tune.uniform(0, 0.03),\n",
" \"dropout_prob\": tune.uniform(0, 0.1),\n",
" \"chunk_count\": CHUNK_COUNT,\n",
" \"features\": tune.choice(\n",
" [\n",
" [16, 32, 64],\n",
" [16, 32, 64, 128],\n",
" [32, 64],\n",
" [32, 128],\n",
" [8, 16, 32],\n",
" [8, 8, 8, 8, 8],\n",
" [8, 8, 8, 8, 8, 8, 8],\n",
" [16, 16, 16],\n",
" [16, 16, 16, 16, 16],\n",
" [32, 32, 32],\n",
" [32, 32, 32, 32],\n",
" [64, 64],\n",
" [64, 64, 64],\n",
" ]\n",
" ),\n",
" \"use_residual\": tune.choice([True, False]),\n",
" \"kernel_size\": tune.choice([3, 5]),\n",
" \"model_type\": tune.choice([\"HistogramNet\"]),\n",
" \"use_instance_norm\": True,\n",
" \"use_elu\": tune.choice([True, False]),\n",
" \"leaky_relu_alpha\": tune.uniform(0, 0.05),\n",
"}\n",
"scheduler = ASHAScheduler(max_t=config[\"chunk_count\"], grace_period=2)\n",
"\n",
"tuner = tune.Tuner(\n",
" tune.with_resources(\n",
" tune.with_parameters(\n",
" train_with_ray_factory(\n",
" train_data_paths=TRAIN_DATA,\n",
" test_data_paths=TEST_DATA,\n",
" device=device,\n",
" log_dir=RUNS_PATH / \"custom\",\n",
" )\n",
" ),\n",
" resources={\"cpu\": 32, \"gpu\": 1},\n",
" ),\n",
" run_config=RunConfig(storage_path=RUNS_PATH, name=\"tune\"),\n",
" tune_config=tune.TuneConfig(\n",
" metric=\"chunk_test_loss\",\n",
" mode=\"min\",\n",
" scheduler=scheduler,\n",
" num_samples=TRIAL_COUNT,\n",
" ),\n",
" param_space=config,\n",
")\n",
"results = tuner.fit()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best trial config: {'batch_size': 32, 'edit_count': 3, 'bin_count': 32, 'learning_rate': 0.0029226033808016005, 'scheduler_gamma': 0.8167516482513361, 'elu_alpha': 1.3510723758569865, 'leaky_relu_slope': 0.0066452562138349025, 'dropout_prob': 0.0039589213934103865, 'chunk_count': 4, 'features': [8, 16, 32], 'use_residual': True, 'kernel_size': 5, 'model_type': 'HistogramNet', 'use_instance_norm': True, 'use_elu': True, 'leaky_relu_alpha': 0.04487148648446764}\n",
"Best trial final validation loss: 5.2214884757995605\n"
]
}
],
"source": [
"best_result = results.get_best_result(\"chunk_test_loss\", \"min\")\n",
"\n",
"print(\"Best trial config: {}\".format(best_result.config))\n",
"print(\n",
" \"Best trial final validation loss: {}\".format(\n",
" best_result.metrics[\"chunk_test_loss\"]\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# hparams = {\n",
"# \"batch_size\": 64,\n",
@ -196,64 +260,6 @@
"# **hparams\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from scipy.stats import loguniform, uniform, randint\n",
"from models import MODELS, test_models\n",
"\n",
"\n",
"hyperparameters = [\n",
" {\n",
" \"batch_size\": [64, 128, 256],\n",
" \"edit_count\": [12],\n",
" \"bin_count\": [16],\n",
" \"learning_rate\": loguniform(5e-4, 5e-3),\n",
" \"scheduler_gamma\": uniform(loc=0.8, scale=0.15),\n",
" \"num_epochs\": [10],\n",
" \"elu_alpha\": uniform(0.5, 1.5),\n",
" \"leaky_relu_slope\": uniform(0, 0.03),\n",
" \"dropout_prob\": uniform(0, 0.1),\n",
" \"features\": [\n",
" [16, 32, 64],\n",
" [16, 32, 64, 128],\n",
" [32, 64],\n",
" [32, 64, 128],\n",
" [8, 16, 32],\n",
" [8, 8, 8, 8, 8],\n",
" [8, 8, 8, 8, 8, 8, 8],\n",
" [16, 16, 16],\n",
" [16, 16, 16, 16, 16],\n",
" [32, 32, 32],\n",
" [32, 32, 32, 32],\n",
" [64, 64],\n",
" [64, 64, 64],\n",
" ],\n",
" \"use_residual\": [True, False],\n",
" \"kernel_size\": [3, 5],\n",
" \"model_type\": [\"HistogramNet\"],\n",
" \"use_instance_norm\": [True],\n",
" \"use_elu\": [True, False],\n",
" \"leaky_relu_alpha\": uniform(0, 0.05),\n",
" }\n",
"]\n",
"\n",
"test_models()\n",
"\n",
"random_hparam_search(\n",
" hyperparameters=hyperparameters,\n",
" train_data_paths=TRAIN_DATA,\n",
" test_data_paths=TEST_DATA,\n",
" models_path=MODELS_PATH,\n",
" tensorboard_path=RUNS_PATH,\n",
" timeout_hours=4,\n",
" device=device,\n",
")"
]
}
],
"metadata": {

View file

@ -0,0 +1,205 @@
from typing import Any, Dict, List, Tuple
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from torch.optim import Adam
from .get_next_run_name import get_next_run_name
from utils import serialise_hparams
from visualisation import plot_histograms_in_2d
from models import create_model, load_model, save_model
import torch
from .get_data_loader import get_data_loader
from ray import train
from ray.train import Checkpoint
import os
import tempfile
from more_itertools import distribute
EPSILON = 1e-5
def train_with_ray_factory(
train_data_paths: List[Path],
test_data_paths: List[Path],
device: torch.device,
log_dir: Path,
):
def train_with_ray(hyperparameters: Dict[str, Any]):
def inner(
hyperparameters: Dict[str, Any],
chunk_count: int,
**_,
) -> torch.nn.Module:
train_data_loader = get_data_loader(train_data_paths, **hyperparameters)
test_data_loader = get_data_loader(
test_data_paths, **{**hyperparameters, "edit_count": 1}
)
examples = next(iter(test_data_loader))
model, optimizer, scheduler, start_chunk_id, run_name = (
load_or_create_state(
device=device,
log_dir=log_dir,
**hyperparameters,
)
)
loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device)
with SummaryWriter(log_dir=log_dir / run_name) as writer:
writer.add_graph(model, examples[0].to(device))
for chunk_id, chunk in enumerate(
distribute(chunk_count, train_data_loader)[start_chunk_id:-1],
start=start_chunk_id,
):
chunk_training_loss = 0
writer.add_scalar(
"Actual learning rate",
scheduler.get_last_lr()[0],
chunk_id,
)
for batch_id, (edited_histogram, original_histogram) in enumerate(
chunk
):
global_step = (
chunk_id * (len(train_data_loader) // chunk_count)
+ batch_id
)
optimizer.zero_grad()
predicted_original = model(edited_histogram.to(device))
loss = loss_function(
torch.log(predicted_original + EPSILON),
original_histogram.to(device),
)
chunk_training_loss += loss.item()
writer.add_scalar(
"Loss/train/batch", loss, global_step=global_step
)
loss.backward()
optimizer.step()
with torch.no_grad():
model.eval()
write_histograms(
model=model,
examples=examples,
writer=writer,
device=device,
global_step=global_step,
)
chunk_test_loss = 0
for (
edited_histogram,
original_histogram,
) in test_data_loader:
predicted_original = model(edited_histogram.to(device))
chunk_test_loss += loss_function(
torch.log(predicted_original + EPSILON),
original_histogram.to(device),
).item()
model.train()
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
temp_checkpoint_dir = Path(temp_checkpoint_dir)
checkpoint_path = temp_checkpoint_dir / "checkpoint.pt"
torch.save(
(
optimizer.state_dict(),
scheduler.state_dict(),
chunk_id,
run_name,
),
checkpoint_path,
)
save_model(
model, hyperparameters, temp_checkpoint_dir / "model"
)
writer.add_hparams(
serialise_hparams(hyperparameters),
{
"Loss/test/epoch": chunk_test_loss,
"Loss/train/epoch": chunk_training_loss,
},
global_step=global_step,
run_name=(log_dir / run_name).absolute(),
)
train.report(
{
"chunk_test_loss": chunk_test_loss,
"chunk_training_loss": chunk_training_loss,
},
checkpoint=Checkpoint.from_directory(temp_checkpoint_dir),
)
scheduler.step()
return inner(hyperparameters=hyperparameters, **hyperparameters)
return train_with_ray
def load_or_create_state(
device, log_dir, model_type, learning_rate, scheduler_gamma, **hyperparameters
) -> Tuple[
torch.nn.Module,
torch.optim.Optimizer,
torch.optim.lr_scheduler.LRScheduler,
int,
str,
]:
loaded_checkpoint = train.get_checkpoint()
if loaded_checkpoint:
with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
loaded_checkpoint_dir = Path(loaded_checkpoint_dir)
model, hyperparameters = load_model(
loaded_checkpoint_dir / "model", device=device
)
optimizer = Adam(model.parameters(), lr=learning_rate)
optimizer_state, scheduler_state, start_chunk_id, run_name = torch.load(
loaded_checkpoint_dir / "checkpoint.pt"
)
optimizer.load_state_dict(optimizer_state)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=1, gamma=scheduler_gamma
)
scheduler.load_state_dict(scheduler_state)
else:
run_name = get_next_run_name(log_dir)
model = create_model(
type=model_type,
hyperparameters=hyperparameters,
device=device,
).train()
optimizer = Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=1, gamma=scheduler_gamma
)
start_chunk_id = 0
return model, optimizer, scheduler, start_chunk_id, run_name
def write_histograms(
model: torch.nn.Module,
examples: List[Tuple[torch.Tensor, torch.Tensor]],
writer: SummaryWriter,
device: torch.device,
global_step: int,
):
edited_histograms, original_histograms = examples
predicted_originals = model(edited_histograms.to(device))
for i, (original, edited, predicted) in enumerate(
zip(original_histograms, edited_histograms, predicted_originals)
):
writer.add_figure(
f"histogram_{i}",
plot_histograms_in_2d(
{
"original": original[0].numpy().squeeze(),
"edited": edited.cpu()[0].numpy().squeeze(),
"predicted": predicted.cpu()[0].numpy().squeeze(),
}
),
global_step=global_step,
)