Integrate ray tune
This commit is contained in:
parent
49d9ece2ec
commit
fb07fc674e
2 changed files with 379 additions and 168 deletions
342
src/train.ipynb
342
src/train.ipynb
|
|
@ -21,66 +21,98 @@
|
|||
"source": [
|
||||
"import os\n",
|
||||
"from utils import set_up_logging, get_device\n",
|
||||
"from training import train, random_hparam_search\n",
|
||||
"from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA, MODELS_PATH\n",
|
||||
"from training import train_with_ray_factory\n",
|
||||
"from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA\n",
|
||||
"from ray import tune\n",
|
||||
"from ray.tune.schedulers import ASHAScheduler\n",
|
||||
"import os\n",
|
||||
"from ray.air import RunConfig\n",
|
||||
"\n",
|
||||
"set_up_logging(LOGS_PATH)\n",
|
||||
"\n",
|
||||
"TRIAL_COUNT = 100\n",
|
||||
"CHUNK_COUNT = 40\n",
|
||||
"EPOCH_COUNT = 2\n",
|
||||
"\n",
|
||||
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n",
|
||||
"device = get_device()\n",
|
||||
"f\"Using device {device}\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# hparams = {\n",
|
||||
"# \"batch_size\": 64,\n",
|
||||
"# \"edit_count\": 12,\n",
|
||||
"# \"bin_count\": 16,\n",
|
||||
"# \"learning_rate\": 0.001,\n",
|
||||
"# \"scheduler_gamma\": 0.9,\n",
|
||||
"# \"num_epochs\": 12,\n",
|
||||
"# \"model_type\": \"SimpleCNN\",\n",
|
||||
"# }\n",
|
||||
"\n",
|
||||
"# train(\n",
|
||||
"# hparams,\n",
|
||||
"# train_data_paths=TRAIN_DATA,\n",
|
||||
"# test_data_paths=TEST_DATA,\n",
|
||||
"# log_dir=RUNS_PATH,\n",
|
||||
"# max_duration=None,\n",
|
||||
"# use_tqdm=True,\n",
|
||||
"# device=device,\n",
|
||||
"# **hparams\n",
|
||||
"# )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-29 10:45:18,539 - INFO - Loaded 22479 original images\n",
|
||||
"2024-06-29 10:45:18,544 - INFO - Loaded 2498 original images\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "61d4288a87a24255824e1430d8e69804",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/html": [
|
||||
"<div class=\"tuneStatus\">\n",
|
||||
" <div style=\"display: flex;flex-direction: row\">\n",
|
||||
" <div style=\"display: flex;flex-direction: column;\">\n",
|
||||
" <h3>Tune Status</h3>\n",
|
||||
" <table>\n",
|
||||
"<tbody>\n",
|
||||
"<tr><td>Current time:</td><td>2024-09-01 22:06:10</td></tr>\n",
|
||||
"<tr><td>Running for: </td><td>00:02:06.64 </td></tr>\n",
|
||||
"<tr><td>Memory: </td><td>22.4/47.0 GiB </td></tr>\n",
|
||||
"</tbody>\n",
|
||||
"</table>\n",
|
||||
" </div>\n",
|
||||
" <div class=\"vDivider\"></div>\n",
|
||||
" <div class=\"systemInfo\">\n",
|
||||
" <h3>System Info</h3>\n",
|
||||
" Using AsyncHyperBand: num_stopped=1<br>Bracket: Iter 2.000: -5.516964137554169<br>Logical resource usage: 32.0/32 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)\n",
|
||||
" </div>\n",
|
||||
" \n",
|
||||
" </div>\n",
|
||||
" <div class=\"hDivider\"></div>\n",
|
||||
" <div class=\"trialStatus\">\n",
|
||||
" <h3>Trial Status</h3>\n",
|
||||
" <table>\n",
|
||||
"<thead>\n",
|
||||
"<tr><th>Trial name </th><th>status </th><th>loc </th><th style=\"text-align: right;\"> batch_size</th><th style=\"text-align: right;\"> dropout_prob</th><th style=\"text-align: right;\"> elu_alpha</th><th>features </th><th style=\"text-align: right;\"> kernel_size</th><th style=\"text-align: right;\"> leaky_relu_alpha</th><th style=\"text-align: right;\"> leaky_relu_slope</th><th style=\"text-align: right;\"> learning_rate</th><th>model_type </th><th style=\"text-align: right;\"> scheduler_gamma</th><th>use_elu </th><th>use_residual </th><th style=\"text-align: right;\"> iter</th><th style=\"text-align: right;\"> total time (s)</th><th style=\"text-align: right;\"> chunk_test_loss</th><th style=\"text-align: right;\"> chunk_training_loss</th></tr>\n",
|
||||
"</thead>\n",
|
||||
"<tbody>\n",
|
||||
"<tr><td>train_with_ray_b7d3c_00000</td><td>TERMINATED</td><td>172.29.235.222:1134109</td><td style=\"text-align: right;\"> 32</td><td style=\"text-align: right;\"> 0.00395892</td><td style=\"text-align: right;\"> 1.35107</td><td>[8, 16, 32] </td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.0448715 </td><td style=\"text-align: right;\"> 0.00664526</td><td style=\"text-align: right;\"> 0.0029226 </td><td>HistogramNet</td><td style=\"text-align: right;\"> 0.816752</td><td>True </td><td>True </td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 70.652 </td><td style=\"text-align: right;\"> 5.22149</td><td style=\"text-align: right;\"> 56.5472</td></tr>\n",
|
||||
"<tr><td>train_with_ray_b7d3c_00001</td><td>TERMINATED</td><td>172.29.235.222:1140440</td><td style=\"text-align: right;\"> 32</td><td style=\"text-align: right;\"> 0.0439061 </td><td style=\"text-align: right;\"> 1.74642</td><td>[8, 8, 8, 8, 8,_c140</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.00579656</td><td style=\"text-align: right;\"> 0.0125715 </td><td style=\"text-align: right;\"> 0.00155182</td><td>HistogramNet</td><td style=\"text-align: right;\"> 0.898059</td><td>True </td><td>False </td><td style=\"text-align: right;\"> 2</td><td style=\"text-align: right;\"> 45.4228</td><td style=\"text-align: right;\"> 5.79311</td><td style=\"text-align: right;\"> 64.0481</td></tr>\n",
|
||||
"</tbody>\n",
|
||||
"</table>\n",
|
||||
" </div>\n",
|
||||
"</div>\n",
|
||||
"<style>\n",
|
||||
".tuneStatus {\n",
|
||||
" color: var(--jp-ui-font-color1);\n",
|
||||
"}\n",
|
||||
".tuneStatus .systemInfo {\n",
|
||||
" display: flex;\n",
|
||||
" flex-direction: column;\n",
|
||||
"}\n",
|
||||
".tuneStatus td {\n",
|
||||
" white-space: nowrap;\n",
|
||||
"}\n",
|
||||
".tuneStatus .trialStatus {\n",
|
||||
" display: flex;\n",
|
||||
" flex-direction: column;\n",
|
||||
"}\n",
|
||||
".tuneStatus h3 {\n",
|
||||
" font-weight: bold;\n",
|
||||
"}\n",
|
||||
".tuneStatus .hDivider {\n",
|
||||
" border-bottom-width: var(--jp-border-width);\n",
|
||||
" border-bottom-color: var(--jp-border-color0);\n",
|
||||
" border-bottom-style: solid;\n",
|
||||
"}\n",
|
||||
".tuneStatus .vDivider {\n",
|
||||
" border-left-width: var(--jp-border-width);\n",
|
||||
" border-left-color: var(--jp-border-color0);\n",
|
||||
" border-left-style: solid;\n",
|
||||
" margin: 0.5em 1em 0.5em 1em;\n",
|
||||
"}\n",
|
||||
"</style>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"Epoch 0: 0%| | 0/4215 [00:00<?, ?batch/s]"
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
|
|
@ -90,81 +122,113 @@
|
|||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-29 10:50:49,231 - INFO - Epoch 0 train loss: 5050.157035887241\n",
|
||||
"2024-06-29 10:51:01,594 - INFO - Epoch 0 test loss: 460.928723692894\n"
|
||||
"\u001b[36m(train_with_ray pid=1134109)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00000_0_batch_size=32,dropout_prob=0.0040,elu_alpha=1.3511,features=8_16_32,kernel_size=5,leaky_relu_alpha=0._2024-09-01_22-04-03/checkpoint_000000)\n",
|
||||
"\u001b[36m(train_with_ray pid=1134109)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00000_0_batch_size=32,dropout_prob=0.0040,elu_alpha=1.3511,features=8_16_32,kernel_size=5,leaky_relu_alpha=0._2024-09-01_22-04-03/checkpoint_000001)\n",
|
||||
"\u001b[36m(train_with_ray pid=1134109)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00000_0_batch_size=32,dropout_prob=0.0040,elu_alpha=1.3511,features=8_16_32,kernel_size=5,leaky_relu_alpha=0._2024-09-01_22-04-03/checkpoint_000002)\n",
|
||||
"\u001b[36m(train_with_ray pid=1140440)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00001_1_batch_size=32,dropout_prob=0.0439,elu_alpha=1.7464,features=8_8_8_8_8_8_8,kernel_size=3,leaky_relu_al_2024-09-01_22-04-03/checkpoint_000000)\n",
|
||||
"2024-09-01 22:06:10,132\tINFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/andras/projects/bipolaroid/runs5/tune' in 0.0020s.\n",
|
||||
"2024-09-01 22:06:10,135\tINFO tune.py:1041 -- Total run time: 126.67 seconds (126.64 seconds for the tuning loop).\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "55aa7c6adc41427081af06330d541c2d",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 1: 0%| | 0/4215 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-29 10:56:26,321 - INFO - Epoch 1 train loss: 4134.876303792\n",
|
||||
"2024-06-29 10:56:39,064 - INFO - Epoch 1 test loss: 432.3766641020775\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "2b7d28fc6f62489ebea844441a57178f",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 2: 0%| | 0/4215 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-29 11:02:03,858 - INFO - Epoch 2 train loss: 3871.2984607219696\n",
|
||||
"2024-06-29 11:02:16,040 - INFO - Epoch 2 test loss: 408.74887996912\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "1395274c7a38465099c93c2d241ea1da",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 3: 0%| | 0/4215 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"ename": "KeyboardInterrupt",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[2], line 20\u001b[0m\n\u001b[1;32m 1\u001b[0m hparams \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 2\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m64\u001b[39m,\n\u001b[1;32m 3\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124medit_count\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m12\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mleaky_relu_alpha\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m0.03745605986732464\u001b[39m,\n\u001b[1;32m 18\u001b[0m }\n\u001b[0;32m---> 20\u001b[0m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[43m \u001b[49m\u001b[43mhparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 22\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_data_paths\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mTRAIN_DATA\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 23\u001b[0m \u001b[43m \u001b[49m\u001b[43mtest_data_paths\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mTEST_DATA\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[43m \u001b[49m\u001b[43mlog_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mRUNS_PATH\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 25\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_duration\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 26\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_tqdm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 27\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 28\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mhparams\u001b[49m\n\u001b[1;32m 29\u001b[0m \u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m/mnt/wsl/PHYSICALDRIVE1/projects/bipolaroid/src/training/train.py:55\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(hyperparameters, train_data_paths, test_data_paths, log_dir, use_tqdm, device, model_type, learning_rate, scheduler_gamma, num_epochs, **_)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch_id, (edited_histogram, original_histogram) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\n\u001b[1;32m 50\u001b[0m tqdm(train_data_loader, desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, unit\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_tqdm\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m train_data_loader\n\u001b[1;32m 53\u001b[0m ):\n\u001b[1;32m 54\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 55\u001b[0m predicted_original \u001b[38;5;241m=\u001b[39m model(\u001b[43medited_histogram\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 56\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss_function(\n\u001b[1;32m 57\u001b[0m torch\u001b[38;5;241m.\u001b[39mlog(predicted_original \u001b[38;5;241m+\u001b[39m EPSILON),\n\u001b[1;32m 58\u001b[0m original_histogram\u001b[38;5;241m.\u001b[39mto(device),\n\u001b[1;32m 59\u001b[0m )\n\u001b[1;32m 61\u001b[0m epoch_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mitem()\n",
|
||||
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
||||
"\u001b[36m(train_with_ray pid=1140440)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00001_1_batch_size=32,dropout_prob=0.0439,elu_alpha=1.7464,features=8_8_8_8_8_8_8,kernel_size=3,leaky_relu_al_2024-09-01_22-04-03/checkpoint_000001)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"config = {\n",
|
||||
" \"batch_size\": 32,\n",
|
||||
" \"edit_count\": EPOCH_COUNT,\n",
|
||||
" \"bin_count\": 32,\n",
|
||||
" \"learning_rate\": tune.loguniform(5e-4, 5e-3),\n",
|
||||
" \"scheduler_gamma\": tune.uniform(0.8, 0.95),\n",
|
||||
" \"elu_alpha\": tune.uniform(0.5, 2),\n",
|
||||
" \"leaky_relu_slope\": tune.uniform(0, 0.03),\n",
|
||||
" \"dropout_prob\": tune.uniform(0, 0.1),\n",
|
||||
" \"chunk_count\": CHUNK_COUNT,\n",
|
||||
" \"features\": tune.choice(\n",
|
||||
" [\n",
|
||||
" [16, 32, 64],\n",
|
||||
" [16, 32, 64, 128],\n",
|
||||
" [32, 64],\n",
|
||||
" [32, 128],\n",
|
||||
" [8, 16, 32],\n",
|
||||
" [8, 8, 8, 8, 8],\n",
|
||||
" [8, 8, 8, 8, 8, 8, 8],\n",
|
||||
" [16, 16, 16],\n",
|
||||
" [16, 16, 16, 16, 16],\n",
|
||||
" [32, 32, 32],\n",
|
||||
" [32, 32, 32, 32],\n",
|
||||
" [64, 64],\n",
|
||||
" [64, 64, 64],\n",
|
||||
" ]\n",
|
||||
" ),\n",
|
||||
" \"use_residual\": tune.choice([True, False]),\n",
|
||||
" \"kernel_size\": tune.choice([3, 5]),\n",
|
||||
" \"model_type\": tune.choice([\"HistogramNet\"]),\n",
|
||||
" \"use_instance_norm\": True,\n",
|
||||
" \"use_elu\": tune.choice([True, False]),\n",
|
||||
" \"leaky_relu_alpha\": tune.uniform(0, 0.05),\n",
|
||||
"}\n",
|
||||
"scheduler = ASHAScheduler(max_t=config[\"chunk_count\"], grace_period=2)\n",
|
||||
"\n",
|
||||
"tuner = tune.Tuner(\n",
|
||||
" tune.with_resources(\n",
|
||||
" tune.with_parameters(\n",
|
||||
" train_with_ray_factory(\n",
|
||||
" train_data_paths=TRAIN_DATA,\n",
|
||||
" test_data_paths=TEST_DATA,\n",
|
||||
" device=device,\n",
|
||||
" log_dir=RUNS_PATH / \"custom\",\n",
|
||||
" )\n",
|
||||
" ),\n",
|
||||
" resources={\"cpu\": 32, \"gpu\": 1},\n",
|
||||
" ),\n",
|
||||
" run_config=RunConfig(storage_path=RUNS_PATH, name=\"tune\"),\n",
|
||||
" tune_config=tune.TuneConfig(\n",
|
||||
" metric=\"chunk_test_loss\",\n",
|
||||
" mode=\"min\",\n",
|
||||
" scheduler=scheduler,\n",
|
||||
" num_samples=TRIAL_COUNT,\n",
|
||||
" ),\n",
|
||||
" param_space=config,\n",
|
||||
")\n",
|
||||
"results = tuner.fit()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Best trial config: {'batch_size': 32, 'edit_count': 3, 'bin_count': 32, 'learning_rate': 0.0029226033808016005, 'scheduler_gamma': 0.8167516482513361, 'elu_alpha': 1.3510723758569865, 'leaky_relu_slope': 0.0066452562138349025, 'dropout_prob': 0.0039589213934103865, 'chunk_count': 4, 'features': [8, 16, 32], 'use_residual': True, 'kernel_size': 5, 'model_type': 'HistogramNet', 'use_instance_norm': True, 'use_elu': True, 'leaky_relu_alpha': 0.04487148648446764}\n",
|
||||
"Best trial final validation loss: 5.2214884757995605\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"best_result = results.get_best_result(\"chunk_test_loss\", \"min\")\n",
|
||||
"\n",
|
||||
"print(\"Best trial config: {}\".format(best_result.config))\n",
|
||||
"print(\n",
|
||||
" \"Best trial final validation loss: {}\".format(\n",
|
||||
" best_result.metrics[\"chunk_test_loss\"]\n",
|
||||
" )\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# hparams = {\n",
|
||||
"# \"batch_size\": 64,\n",
|
||||
|
|
@ -196,64 +260,6 @@
|
|||
"# **hparams\n",
|
||||
"# )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from scipy.stats import loguniform, uniform, randint\n",
|
||||
"from models import MODELS, test_models\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"hyperparameters = [\n",
|
||||
" {\n",
|
||||
" \"batch_size\": [64, 128, 256],\n",
|
||||
" \"edit_count\": [12],\n",
|
||||
" \"bin_count\": [16],\n",
|
||||
" \"learning_rate\": loguniform(5e-4, 5e-3),\n",
|
||||
" \"scheduler_gamma\": uniform(loc=0.8, scale=0.15),\n",
|
||||
" \"num_epochs\": [10],\n",
|
||||
" \"elu_alpha\": uniform(0.5, 1.5),\n",
|
||||
" \"leaky_relu_slope\": uniform(0, 0.03),\n",
|
||||
" \"dropout_prob\": uniform(0, 0.1),\n",
|
||||
" \"features\": [\n",
|
||||
" [16, 32, 64],\n",
|
||||
" [16, 32, 64, 128],\n",
|
||||
" [32, 64],\n",
|
||||
" [32, 64, 128],\n",
|
||||
" [8, 16, 32],\n",
|
||||
" [8, 8, 8, 8, 8],\n",
|
||||
" [8, 8, 8, 8, 8, 8, 8],\n",
|
||||
" [16, 16, 16],\n",
|
||||
" [16, 16, 16, 16, 16],\n",
|
||||
" [32, 32, 32],\n",
|
||||
" [32, 32, 32, 32],\n",
|
||||
" [64, 64],\n",
|
||||
" [64, 64, 64],\n",
|
||||
" ],\n",
|
||||
" \"use_residual\": [True, False],\n",
|
||||
" \"kernel_size\": [3, 5],\n",
|
||||
" \"model_type\": [\"HistogramNet\"],\n",
|
||||
" \"use_instance_norm\": [True],\n",
|
||||
" \"use_elu\": [True, False],\n",
|
||||
" \"leaky_relu_alpha\": uniform(0, 0.05),\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"test_models()\n",
|
||||
"\n",
|
||||
"random_hparam_search(\n",
|
||||
" hyperparameters=hyperparameters,\n",
|
||||
" train_data_paths=TRAIN_DATA,\n",
|
||||
" test_data_paths=TEST_DATA,\n",
|
||||
" models_path=MODELS_PATH,\n",
|
||||
" tensorboard_path=RUNS_PATH,\n",
|
||||
" timeout_hours=4,\n",
|
||||
" device=device,\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
|
|||
205
src/training/train_with_ray.py
Normal file
205
src/training/train_with_ray.py
Normal file
|
|
@ -0,0 +1,205 @@
|
|||
from typing import Any, Dict, List, Tuple
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from pathlib import Path
|
||||
from torch.optim import Adam
|
||||
from .get_next_run_name import get_next_run_name
|
||||
from utils import serialise_hparams
|
||||
from visualisation import plot_histograms_in_2d
|
||||
from models import create_model, load_model, save_model
|
||||
import torch
|
||||
from .get_data_loader import get_data_loader
|
||||
from ray import train
|
||||
from ray.train import Checkpoint
|
||||
import os
|
||||
import tempfile
|
||||
from more_itertools import distribute
|
||||
|
||||
|
||||
EPSILON = 1e-5
|
||||
|
||||
|
||||
def train_with_ray_factory(
|
||||
train_data_paths: List[Path],
|
||||
test_data_paths: List[Path],
|
||||
device: torch.device,
|
||||
log_dir: Path,
|
||||
):
|
||||
def train_with_ray(hyperparameters: Dict[str, Any]):
|
||||
def inner(
|
||||
hyperparameters: Dict[str, Any],
|
||||
chunk_count: int,
|
||||
**_,
|
||||
) -> torch.nn.Module:
|
||||
train_data_loader = get_data_loader(train_data_paths, **hyperparameters)
|
||||
test_data_loader = get_data_loader(
|
||||
test_data_paths, **{**hyperparameters, "edit_count": 1}
|
||||
)
|
||||
examples = next(iter(test_data_loader))
|
||||
|
||||
model, optimizer, scheduler, start_chunk_id, run_name = (
|
||||
load_or_create_state(
|
||||
device=device,
|
||||
log_dir=log_dir,
|
||||
**hyperparameters,
|
||||
)
|
||||
)
|
||||
loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device)
|
||||
|
||||
with SummaryWriter(log_dir=log_dir / run_name) as writer:
|
||||
writer.add_graph(model, examples[0].to(device))
|
||||
for chunk_id, chunk in enumerate(
|
||||
distribute(chunk_count, train_data_loader)[start_chunk_id:-1],
|
||||
start=start_chunk_id,
|
||||
):
|
||||
chunk_training_loss = 0
|
||||
writer.add_scalar(
|
||||
"Actual learning rate",
|
||||
scheduler.get_last_lr()[0],
|
||||
chunk_id,
|
||||
)
|
||||
for batch_id, (edited_histogram, original_histogram) in enumerate(
|
||||
chunk
|
||||
):
|
||||
global_step = (
|
||||
chunk_id * (len(train_data_loader) // chunk_count)
|
||||
+ batch_id
|
||||
)
|
||||
optimizer.zero_grad()
|
||||
predicted_original = model(edited_histogram.to(device))
|
||||
loss = loss_function(
|
||||
torch.log(predicted_original + EPSILON),
|
||||
original_histogram.to(device),
|
||||
)
|
||||
|
||||
chunk_training_loss += loss.item()
|
||||
writer.add_scalar(
|
||||
"Loss/train/batch", loss, global_step=global_step
|
||||
)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
write_histograms(
|
||||
model=model,
|
||||
examples=examples,
|
||||
writer=writer,
|
||||
device=device,
|
||||
global_step=global_step,
|
||||
)
|
||||
chunk_test_loss = 0
|
||||
for (
|
||||
edited_histogram,
|
||||
original_histogram,
|
||||
) in test_data_loader:
|
||||
predicted_original = model(edited_histogram.to(device))
|
||||
chunk_test_loss += loss_function(
|
||||
torch.log(predicted_original + EPSILON),
|
||||
original_histogram.to(device),
|
||||
).item()
|
||||
model.train()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
|
||||
temp_checkpoint_dir = Path(temp_checkpoint_dir)
|
||||
checkpoint_path = temp_checkpoint_dir / "checkpoint.pt"
|
||||
torch.save(
|
||||
(
|
||||
optimizer.state_dict(),
|
||||
scheduler.state_dict(),
|
||||
chunk_id,
|
||||
run_name,
|
||||
),
|
||||
checkpoint_path,
|
||||
)
|
||||
save_model(
|
||||
model, hyperparameters, temp_checkpoint_dir / "model"
|
||||
)
|
||||
writer.add_hparams(
|
||||
serialise_hparams(hyperparameters),
|
||||
{
|
||||
"Loss/test/epoch": chunk_test_loss,
|
||||
"Loss/train/epoch": chunk_training_loss,
|
||||
},
|
||||
global_step=global_step,
|
||||
run_name=(log_dir / run_name).absolute(),
|
||||
)
|
||||
train.report(
|
||||
{
|
||||
"chunk_test_loss": chunk_test_loss,
|
||||
"chunk_training_loss": chunk_training_loss,
|
||||
},
|
||||
checkpoint=Checkpoint.from_directory(temp_checkpoint_dir),
|
||||
)
|
||||
|
||||
scheduler.step()
|
||||
|
||||
return inner(hyperparameters=hyperparameters, **hyperparameters)
|
||||
|
||||
return train_with_ray
|
||||
|
||||
|
||||
def load_or_create_state(
|
||||
device, log_dir, model_type, learning_rate, scheduler_gamma, **hyperparameters
|
||||
) -> Tuple[
|
||||
torch.nn.Module,
|
||||
torch.optim.Optimizer,
|
||||
torch.optim.lr_scheduler.LRScheduler,
|
||||
int,
|
||||
str,
|
||||
]:
|
||||
loaded_checkpoint = train.get_checkpoint()
|
||||
if loaded_checkpoint:
|
||||
with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
|
||||
loaded_checkpoint_dir = Path(loaded_checkpoint_dir)
|
||||
model, hyperparameters = load_model(
|
||||
loaded_checkpoint_dir / "model", device=device
|
||||
)
|
||||
optimizer = Adam(model.parameters(), lr=learning_rate)
|
||||
|
||||
optimizer_state, scheduler_state, start_chunk_id, run_name = torch.load(
|
||||
loaded_checkpoint_dir / "checkpoint.pt"
|
||||
)
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(
|
||||
optimizer, step_size=1, gamma=scheduler_gamma
|
||||
)
|
||||
scheduler.load_state_dict(scheduler_state)
|
||||
else:
|
||||
run_name = get_next_run_name(log_dir)
|
||||
model = create_model(
|
||||
type=model_type,
|
||||
hyperparameters=hyperparameters,
|
||||
device=device,
|
||||
).train()
|
||||
optimizer = Adam(model.parameters(), lr=learning_rate)
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(
|
||||
optimizer, step_size=1, gamma=scheduler_gamma
|
||||
)
|
||||
start_chunk_id = 0
|
||||
|
||||
return model, optimizer, scheduler, start_chunk_id, run_name
|
||||
|
||||
|
||||
def write_histograms(
|
||||
model: torch.nn.Module,
|
||||
examples: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
writer: SummaryWriter,
|
||||
device: torch.device,
|
||||
global_step: int,
|
||||
):
|
||||
edited_histograms, original_histograms = examples
|
||||
predicted_originals = model(edited_histograms.to(device))
|
||||
for i, (original, edited, predicted) in enumerate(
|
||||
zip(original_histograms, edited_histograms, predicted_originals)
|
||||
):
|
||||
writer.add_figure(
|
||||
f"histogram_{i}",
|
||||
plot_histograms_in_2d(
|
||||
{
|
||||
"original": original[0].numpy().squeeze(),
|
||||
"edited": edited.cpu()[0].numpy().squeeze(),
|
||||
"predicted": predicted.cpu()[0].numpy().squeeze(),
|
||||
}
|
||||
),
|
||||
global_step=global_step,
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue