Improve training

This commit is contained in:
Andras Schmelczer 2024-06-03 07:49:09 +01:00
parent a6a15ec650
commit ae9b19d2db
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
2 changed files with 281 additions and 203 deletions

View file

@ -2,18 +2,25 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {
"metadata": {}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-03 07:46:08,999 - INFO - PyTorch version: 2.2.2\n"
]
},
{
"data": {
"text/plain": [
"'Using device cuda:0'"
]
},
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
@ -21,11 +28,19 @@
"source": [
"import torch\n",
"import logging\n",
"import os\n",
"from datetime import datetime\n",
"\n",
"logging.basicConfig(\n",
" level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\"\n",
" level=logging.INFO,\n",
" format=\"%(asctime)s - %(levelname)s - %(message)s\",\n",
" handlers=[\n",
" logging.StreamHandler(),\n",
" logging.FileHandler(f\"train-{datetime.now().date()}.log\"),\n",
" ],\n",
")\n",
"\n",
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n",
"device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
"f\"Using device {device}\""
]
@ -37,23 +52,17 @@
"outputs": [],
"source": [
"from scipy.stats import loguniform, uniform\n",
"\n",
"from editor.models import MODELS\n",
"\n",
"common_hyperparameters = {\n",
" \"batch_size\": [8, 16, 32, 64],\n",
" \"edit_count\": [8, 16, 32],\n",
" \"bin_count\": [16, 32, 64],\n",
" \"batch_size\": [16, 32, 64],\n",
" \"edit_count\": [8, 16],\n",
" \"bin_count\": [16, 32],\n",
" \"clip_gradients\": [True, False],\n",
" \"learning_rate\": loguniform(0.00001, 0.005),\n",
" \"learning_rate\": loguniform(0.0001, 0.005),\n",
" \"scheduler_gamma\": uniform(0.1, 0.9),\n",
" \"num_epochs\": [10],\n",
" \"model_type\": [\n",
" \"NormalisedCNN\",\n",
" \"SimpleCNN\",\n",
" \"Residual\",\n",
" \"SmartRes\",\n",
" \"Res2\",\n",
" ],\n",
" \"num_epochs\": [5],\n",
" \"model_type\": list(MODELS.keys()),\n",
"}\n",
"hyperparameters = [\n",
" # {\n",
@ -73,7 +82,30 @@
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Testing model SimpleCNN\n",
"Test passed! Output shape matches input shape.\n",
"Testing model Residual\n",
"Test passed! Output shape matches input shape.\n",
"Testing model NormalisedCNN\n",
"Test passed! Output shape matches input shape.\n",
"Testing model SmartRes\n",
"Test passed! Output shape matches input shape.\n",
"Testing model attention2\n",
"Test passed! Output shape matches input shape.\n",
"Testing model advanced_attention\n",
"Test passed! Output shape matches input shape.\n",
"Testing model Res2\n",
"Test passed! Output shape matches input shape.\n",
"Testing model attention1\n",
"Test passed! Output shape matches input shape.\n"
]
}
],
"source": [
"from typing import Any, Dict\n",
"from torch.utils.tensorboard import SummaryWriter\n",
@ -85,11 +117,17 @@
"from editor.utils import get_next_run_name\n",
"from editor.visualisation import plot_histograms_in_2d\n",
"from editor.training import create_data_loaders\n",
"from editor.models import create_model\n",
"from editor.models import create_model, test_models\n",
"from config import DATA, MODELS_PATH\n",
"from datetime import timedelta, datetime\n",
"\n",
"test_models()\n",
"\n",
"\n",
"def train(hyperparameters: Dict[str, Any]) -> Path:\n",
"def train(\n",
" hyperparameters: Dict[str, Any], max_duration: timedelta, use_tqdm: bool\n",
") -> Path:\n",
" start_time = datetime.now()\n",
" model_path = (MODELS_PATH / get_next_run_name(Path(\"runs\"))).with_suffix(\".pth\")\n",
"\n",
" log_dir = Path(\"runs\") / get_next_run_name(Path(\"runs\"))\n",
@ -132,7 +170,12 @@
" )\n",
" for batch_id, (edited_histogram, original_histogram) in enumerate(\n",
" tqdm(train_data_loader, desc=f\"Epoch {epoch}\", unit=\"batch\")\n",
" if use_tqdm\n",
" else train_data_loader\n",
" ):\n",
" current_time = datetime.now()\n",
" if current_time - start_time > max_duration:\n",
" raise TimeoutError(f\"Time limit {max_duration} exceeded\")\n",
" edited_histogram = edited_histogram.to(device)\n",
" original_histogram = original_histogram.to(device)\n",
"\n",
@ -246,189 +289,227 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2024-05-12 21:54:49,789 - INFO - Starting run_71 with hparams {'batch_size': 8, 'edit_count': 8, 'bin_count': 32, 'clip_gradients': True, 'learning_rate': 0.0001, 'scheduler_gamma': 0.5, 'num_epochs': 2, 'model_type': 'Res2', 'loss': 'kl'}\n",
"2024-05-12 21:54:49,792 - INFO - Loaded 72 training images and 8 test images\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "461659956f3944a085b5cc3a5af6ec31",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 0: 0%| | 0/9 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-02 21:42:49,762 - INFO - Starting run_51 with hparams {\n",
" \"batch_size\": 16,\n",
" \"bin_count\": 64,\n",
" \"clip_gradients\": true,\n",
" \"edit_count\": 16,\n",
" \"learning_rate\": 0.0019018860481580008,\n",
" \"loss\": \"kl\",\n",
" \"model_type\": \"Residual\",\n",
" \"num_epochs\": 10,\n",
" \"scheduler_gamma\": 0.5124233085818609\n",
"}\n",
"2024-06-02 21:42:49,787 - INFO - Loaded 359668 training images and 39964 test images\n",
"2024-06-02 23:43:03,467 - WARNING - Timeout, aborting experiment\n",
"2024-06-02 23:43:03,698 - INFO - Starting run_52 with hparams {\n",
" \"batch_size\": 16,\n",
" \"bin_count\": 16,\n",
" \"clip_gradients\": false,\n",
" \"edit_count\": 8,\n",
" \"learning_rate\": 2.9976475506468536e-05,\n",
" \"loss\": \"kl\",\n",
" \"model_type\": \"SmartRes\",\n",
" \"num_epochs\": 10,\n",
" \"scheduler_gamma\": 0.8138813825657673\n",
"}\n",
"2024-06-02 23:43:03,991 - INFO - Loaded 179834 training images and 19982 test images\n",
"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/matplotlib/collections.py:996: RuntimeWarning: invalid value encountered in sqrt\n",
" scale = np.sqrt(self._sizes) * dpi / 72.0 * self._factor\n"
" scale = np.sqrt(self._sizes) * dpi / 72.0 * self._factor\n",
"2024-06-02 23:52:17,393 - INFO - Starting run_53 with hparams {\n",
" \"batch_size\": 8,\n",
" \"bin_count\": 32,\n",
" \"clip_gradients\": false,\n",
" \"edit_count\": 8,\n",
" \"learning_rate\": 0.0002765101396434423,\n",
" \"loss\": \"kl\",\n",
" \"model_type\": \"SmartRes\",\n",
" \"num_epochs\": 10,\n",
" \"scheduler_gamma\": 0.8393595799921102\n",
"}\n",
"2024-06-02 23:52:17,413 - INFO - Loaded 179834 training images and 19982 test images\n",
"2024-06-03 00:48:49,485 - INFO - Starting run_54 with hparams {\n",
" \"batch_size\": 16,\n",
" \"bin_count\": 16,\n",
" \"clip_gradients\": false,\n",
" \"edit_count\": 16,\n",
" \"learning_rate\": 0.00040493280785202865,\n",
" \"loss\": \"kl\",\n",
" \"model_type\": \"SmartRes\",\n",
" \"num_epochs\": 10,\n",
" \"scheduler_gamma\": 0.6647838946959123\n",
"}\n",
"2024-06-03 00:48:49,509 - INFO - Loaded 359668 training images and 39964 test images\n",
"2024-06-03 01:10:40,678 - INFO - Starting run_55 with hparams {\n",
" \"batch_size\": 32,\n",
" \"bin_count\": 16,\n",
" \"clip_gradients\": true,\n",
" \"edit_count\": 16,\n",
" \"learning_rate\": 0.000989324245186775,\n",
" \"loss\": \"kl\",\n",
" \"model_type\": \"SmartRes\",\n",
" \"num_epochs\": 10,\n",
" \"scheduler_gamma\": 0.6779989111474544\n",
"}\n",
"2024-06-03 01:10:40,704 - INFO - Loaded 359668 training images and 39964 test images\n",
"2024-06-03 01:26:06,028 - INFO - Starting run_56 with hparams {\n",
" \"batch_size\": 8,\n",
" \"bin_count\": 16,\n",
" \"clip_gradients\": false,\n",
" \"edit_count\": 16,\n",
" \"learning_rate\": 1.0695951486573912e-05,\n",
" \"loss\": \"kl\",\n",
" \"model_type\": \"Residual\",\n",
" \"num_epochs\": 10,\n",
" \"scheduler_gamma\": 0.3619561054933521\n",
"}\n",
"2024-06-03 01:26:06,052 - INFO - Loaded 359668 training images and 39964 test images\n",
"2024-06-03 02:03:39,558 - INFO - Starting run_57 with hparams {\n",
" \"batch_size\": 32,\n",
" \"bin_count\": 64,\n",
" \"clip_gradients\": false,\n",
" \"edit_count\": 16,\n",
" \"learning_rate\": 0.00024721579172106914,\n",
" \"loss\": \"kl\",\n",
" \"model_type\": \"attention1\",\n",
" \"num_epochs\": 10,\n",
" \"scheduler_gamma\": 0.7999479970967494\n",
"}\n",
"2024-06-03 02:03:39,585 - INFO - Loaded 359668 training images and 39964 test images\n",
"2024-06-03 02:05:40,747 - ERROR - Error with hparams {'batch_size': 32, 'edit_count': 16, 'bin_count': 64, 'clip_gradients': False, 'learning_rate': 0.00024721579172106914, 'scheduler_gamma': 0.7999479970967494, 'num_epochs': 10, 'model_type': 'attention1', 'loss': 'kl'}:\n",
"\tCUDA out of memory. Tried to allocate 6.00 GiB. GPU 0 has a total capacity of 15.99 GiB of which 0 bytes is free. Including non-PyTorch memory, this process has 17179869184.00 GiB memory in use. Of the allocated memory 39.04 GiB is allocated by PyTorch, and 2.02 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)\n",
"Stack (most recent call last):\n",
" File \"<frozen runpy>\", line 198, in _run_module_as_main\n",
" File \"<frozen runpy>\", line 88, in _run_code\n",
" File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/ipykernel_launcher.py\", line 18, in <module>\n",
" app.launch_new_instance()\n",
" File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/traitlets/config/application.py\", line 1075, in launch_instance\n",
" app.start()\n",
" File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/ipykernel/kernelapp.py\", line 739, in start\n",
" self.io_loop.start()\n",
" File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/tornado/platform/asyncio.py\", line 195, in start\n",
" self.asyncio_loop.run_forever()\n",
" File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/asyncio/base_events.py\", line 639, in run_forever\n",
" self._run_once()\n",
" File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/asyncio/base_events.py\", line 1985, in _run_once\n",
" handle._run()\n",
" File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/asyncio/events.py\", line 88, in _run\n",
" self._context.run(self._callback, *self._args)\n",
" File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 545, in dispatch_queue\n",
" await self.process_one()\n",
" File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 534, in process_one\n",
" await dispatch(*args)\n",
" File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 437, in dispatch_shell\n",
" await result\n",
" File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/ipykernel/ipkernel.py\", line 359, in execute_request\n",
" await super().execute_request(stream, ident, parent)\n",
" File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 778, in execute_request\n",
" reply_content = await reply_content\n",
" File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/ipykernel/ipkernel.py\", line 446, in do_execute\n",
" res = shell.run_cell(\n",
" File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/ipykernel/zmqshell.py\", line 549, in run_cell\n",
" return super().run_cell(*args, **kwargs)\n",
" File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3075, in run_cell\n",
" result = self._run_cell(\n",
" File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3130, in _run_cell\n",
" result = runner(coro)\n",
" File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/IPython/core/async_helpers.py\", line 129, in _pseudo_sync_runner\n",
" coro.send(None)\n",
" File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3334, in run_cell_async\n",
" has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n",
" File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3517, in run_ast_nodes\n",
" if await self.run_code(code, result, async_=asy):\n",
" File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3577, in run_code\n",
" exec(code_obj, self.user_global_ns, self.user_ns)\n",
" File \"/tmp/ipykernel_141525/1542138470.py\", line 28, in <module>\n",
" logging.error(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1ea8edb16ad742d0ab6673560f231b20",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 1: 0%| | 0/9 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
"name": "stdout",
"output_type": "stream",
"text": [
"CUDA out of memory. Tried to allocate 6.00 GiB. GPU 0 has a total capacity of 15.99 GiB of which 0 bytes is free. Including non-PyTorch memory, this process has 17179869184.00 GiB memory in use. Of the allocated memory 39.04 GiB is allocated by PyTorch, and 2.02 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)\n",
"Error occurs, No graph saved\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-05-12 21:55:07,098 - INFO - Starting run_72 with hparams {'batch_size': 8, 'edit_count': 8, 'bin_count': 32, 'clip_gradients': True, 'learning_rate': 0.0001, 'scheduler_gamma': 0.5, 'num_epochs': 2, 'model_type': 'Residual', 'loss': 'kl'}\n",
"2024-05-12 21:55:07,100 - INFO - Loaded 72 training images and 8 test images\n"
"2024-06-03 02:05:41,071 - INFO - Starting run_58 with hparams {\n",
" \"batch_size\": 64,\n",
" \"bin_count\": 16,\n",
" \"clip_gradients\": false,\n",
" \"edit_count\": 16,\n",
" \"learning_rate\": 5.8262398455352215e-05,\n",
" \"loss\": \"kl\",\n",
" \"model_type\": \"attention2\",\n",
" \"num_epochs\": 10,\n",
" \"scheduler_gamma\": 0.17181073763193916\n",
"}\n",
"2024-06-03 02:05:41,262 - INFO - Loaded 359668 training images and 39964 test images\n",
"2024-06-03 03:49:02,268 - INFO - Starting run_59 with hparams {\n",
" \"batch_size\": 16,\n",
" \"bin_count\": 16,\n",
" \"clip_gradients\": false,\n",
" \"edit_count\": 32,\n",
" \"learning_rate\": 0.00017213076448986518,\n",
" \"loss\": \"kl\",\n",
" \"model_type\": \"NormalisedCNN\",\n",
" \"num_epochs\": 10,\n",
" \"scheduler_gamma\": 0.1302383221350669\n",
"}\n",
"2024-06-03 03:49:02,397 - INFO - Loaded 719337 training images and 79927 test images\n",
"2024-06-03 04:28:45,612 - INFO - Starting run_60 with hparams {\n",
" \"batch_size\": 16,\n",
" \"bin_count\": 16,\n",
" \"clip_gradients\": false,\n",
" \"edit_count\": 32,\n",
" \"learning_rate\": 0.00010975854085067054,\n",
" \"loss\": \"kl\",\n",
" \"model_type\": \"NormalisedCNN\",\n",
" \"num_epochs\": 10,\n",
" \"scheduler_gamma\": 0.5457536006732233\n",
"}\n",
"2024-06-03 04:28:45,645 - INFO - Loaded 719337 training images and 79927 test images\n",
"2024-06-03 05:07:36,501 - INFO - Starting run_61 with hparams {\n",
" \"batch_size\": 16,\n",
" \"bin_count\": 32,\n",
" \"clip_gradients\": false,\n",
" \"edit_count\": 16,\n",
" \"learning_rate\": 7.977966217588004e-05,\n",
" \"loss\": \"kl\",\n",
" \"model_type\": \"Res2\",\n",
" \"num_epochs\": 10,\n",
" \"scheduler_gamma\": 0.5539449021909474\n",
"}\n",
"2024-06-03 05:07:36,526 - INFO - Loaded 359668 training images and 39964 test images\n",
"2024-06-03 06:53:48,871 - INFO - Starting run_62 with hparams {\n",
" \"batch_size\": 64,\n",
" \"bin_count\": 16,\n",
" \"clip_gradients\": true,\n",
" \"edit_count\": 32,\n",
" \"learning_rate\": 0.0014725778411180288,\n",
" \"loss\": \"kl\",\n",
" \"model_type\": \"SimpleCNN\",\n",
" \"num_epochs\": 10,\n",
" \"scheduler_gamma\": 0.981077298963819\n",
"}\n",
"2024-06-03 06:53:49,078 - INFO - Loaded 719337 training images and 79927 test images\n",
"2024-06-03 07:26:40,577 - INFO - Starting run_63 with hparams {\n",
" \"batch_size\": 32,\n",
" \"bin_count\": 64,\n",
" \"clip_gradients\": false,\n",
" \"edit_count\": 16,\n",
" \"learning_rate\": 0.0002723042772767375,\n",
" \"loss\": \"kl\",\n",
" \"model_type\": \"attention2\",\n",
" \"num_epochs\": 10,\n",
" \"scheduler_gamma\": 0.9651950429647194\n",
"}\n",
"2024-06-03 07:26:40,602 - INFO - Loaded 359668 training images and 39964 test images\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ede7a736204f4198868a98381b8862a7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 0: 0%| | 0/9 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "707b28e45e3448fc9e449624c9a8467e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 1: 0%| | 0/9 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-05-12 21:55:21,888 - INFO - Starting run_73 with hparams {'batch_size': 8, 'edit_count': 8, 'bin_count': 32, 'clip_gradients': True, 'learning_rate': 0.0001, 'scheduler_gamma': 0.5, 'num_epochs': 2, 'model_type': 'SimpleCNN', 'loss': 'kl'}\n",
"2024-05-12 21:55:21,890 - INFO - Loaded 72 training images and 8 test images\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8f94f38fd9c245bd9cc99ffe9cb0c058",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 0: 0%| | 0/9 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9bdbfb83837640099d2c1b5e8ac0cfe2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 1: 0%| | 0/9 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-05-12 21:55:36,871 - INFO - Starting run_74 with hparams {'batch_size': 8, 'edit_count': 8, 'bin_count': 32, 'clip_gradients': True, 'learning_rate': 0.0001, 'scheduler_gamma': 0.5, 'num_epochs': 2, 'model_type': 'NormalisedCNN', 'loss': 'kl'}\n",
"2024-05-12 21:55:36,872 - INFO - Loaded 72 training images and 8 test images\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3b415e0ac9334aaf89b0fd8e234aa7c4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 0: 0%| | 0/9 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7e632fae2e4b4929b6e70dffcad6a341",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 1: 0%| | 0/9 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-05-12 21:55:52,083 - INFO - Starting run_75 with hparams {'batch_size': 8, 'edit_count': 8, 'bin_count': 32, 'clip_gradients': True, 'learning_rate': 0.0001, 'scheduler_gamma': 0.5, 'num_epochs': 2, 'model_type': 'SmartRes', 'loss': 'kl'}\n",
"2024-05-12 21:55:52,084 - INFO - Loaded 72 training images and 8 test images\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "46909535e98543ccb700caeb718dde48",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 0: 0%| | 0/9 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "36f757acc512484e9ac0f34efcd4c1c4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 1: 0%| | 0/9 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
@ -444,30 +525,25 @@
" k: v.rvs() if hasattr(v, \"rvs\") else choice(v)\n",
" for k, v in choice(hyperparameters).items()\n",
" }\n",
" key = json.dumps(current_hyperparameters)\n",
" key = json.dumps(current_hyperparameters, indent=2, sort_keys=True)\n",
" if key in tried:\n",
" continue\n",
" tried.add(key)\n",
" logging.info(\n",
" f\"Starting {get_next_run_name(Path(\"runs\"))} with hparams {current_hyperparameters}\"\n",
" f\"Starting {get_next_run_name(Path(\"runs\"))} with hparams {key}\"\n",
" )\n",
" try:\n",
" train(current_hyperparameters)\n",
" train(current_hyperparameters, max_duration=timedelta(hours=2), use_tqdm=False)\n",
" except KeyboardInterrupt as e:\n",
" logging.info(\"Interrupted, stopping\")\n",
" break\n",
" except TimeoutError as e:\n",
" logging.warning(f\"Timeout, aborting experiment\")\n",
" except Exception as e:\n",
" logging.error(\n",
" f\"Error with hparams {current_hyperparameters}:\\n\\t{e}\", stack_info=True\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"todo: try different colour spaces, see the results applied to images\n"
]
}
],
"metadata": {