From ae9b19d2db5ab4f4533f1f1a96479e0e7f40e5ba Mon Sep 17 00:00:00 2001 From: Andras Schmelczer Date: Mon, 3 Jun 2024 07:49:09 +0100 Subject: [PATCH] Improve training --- .gitignore | 4 +- src/train.ipynb | 480 ++++++++++++++++++++++++++++-------------------- 2 files changed, 281 insertions(+), 203 deletions(-) diff --git a/.gitignore b/.gitignore index 2f57af6..afe7f26 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ __pycache__ -runs \ No newline at end of file +runs +*.log +models/*.pth diff --git a/src/train.ipynb b/src/train.ipynb index 924b80c..d10fc12 100644 --- a/src/train.ipynb +++ b/src/train.ipynb @@ -2,18 +2,25 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": { "metadata": {} }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-03 07:46:08,999 - INFO - PyTorch version: 2.2.2\n" + ] + }, { "data": { "text/plain": [ "'Using device cuda:0'" ] }, - "execution_count": 1, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -21,11 +28,19 @@ "source": [ "import torch\n", "import logging\n", + "import os\n", + "from datetime import datetime\n", "\n", "logging.basicConfig(\n", - " level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\"\n", + " level=logging.INFO,\n", + " format=\"%(asctime)s - %(levelname)s - %(message)s\",\n", + " handlers=[\n", + " logging.StreamHandler(),\n", + " logging.FileHandler(f\"train-{datetime.now().date()}.log\"),\n", + " ],\n", ")\n", "\n", + "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n", "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "f\"Using device {device}\"" ] @@ -37,23 +52,17 @@ "outputs": [], "source": [ "from scipy.stats import loguniform, uniform\n", - "\n", + "from editor.models import MODELS\n", "\n", "common_hyperparameters = {\n", - " \"batch_size\": [8, 16, 32, 64],\n", - " \"edit_count\": [8, 16, 32],\n", - " \"bin_count\": [16, 32, 64],\n", + " \"batch_size\": [16, 32, 64],\n", + " \"edit_count\": [8, 16],\n", + " \"bin_count\": [16, 32],\n", " \"clip_gradients\": [True, False],\n", - " \"learning_rate\": loguniform(0.00001, 0.005),\n", + " \"learning_rate\": loguniform(0.0001, 0.005),\n", " \"scheduler_gamma\": uniform(0.1, 0.9),\n", - " \"num_epochs\": [10],\n", - " \"model_type\": [\n", - " \"NormalisedCNN\",\n", - " \"SimpleCNN\",\n", - " \"Residual\",\n", - " \"SmartRes\",\n", - " \"Res2\",\n", - " ],\n", + " \"num_epochs\": [5],\n", + " \"model_type\": list(MODELS.keys()),\n", "}\n", "hyperparameters = [\n", " # {\n", @@ -73,7 +82,30 @@ "cell_type": "code", "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Testing model SimpleCNN\n", + "Test passed! Output shape matches input shape.\n", + "Testing model Residual\n", + "Test passed! Output shape matches input shape.\n", + "Testing model NormalisedCNN\n", + "Test passed! Output shape matches input shape.\n", + "Testing model SmartRes\n", + "Test passed! Output shape matches input shape.\n", + "Testing model attention2\n", + "Test passed! Output shape matches input shape.\n", + "Testing model advanced_attention\n", + "Test passed! Output shape matches input shape.\n", + "Testing model Res2\n", + "Test passed! Output shape matches input shape.\n", + "Testing model attention1\n", + "Test passed! Output shape matches input shape.\n" + ] + } + ], "source": [ "from typing import Any, Dict\n", "from torch.utils.tensorboard import SummaryWriter\n", @@ -85,11 +117,17 @@ "from editor.utils import get_next_run_name\n", "from editor.visualisation import plot_histograms_in_2d\n", "from editor.training import create_data_loaders\n", - "from editor.models import create_model\n", + "from editor.models import create_model, test_models\n", "from config import DATA, MODELS_PATH\n", + "from datetime import timedelta, datetime\n", + "\n", + "test_models()\n", "\n", "\n", - "def train(hyperparameters: Dict[str, Any]) -> Path:\n", + "def train(\n", + " hyperparameters: Dict[str, Any], max_duration: timedelta, use_tqdm: bool\n", + ") -> Path:\n", + " start_time = datetime.now()\n", " model_path = (MODELS_PATH / get_next_run_name(Path(\"runs\"))).with_suffix(\".pth\")\n", "\n", " log_dir = Path(\"runs\") / get_next_run_name(Path(\"runs\"))\n", @@ -132,7 +170,12 @@ " )\n", " for batch_id, (edited_histogram, original_histogram) in enumerate(\n", " tqdm(train_data_loader, desc=f\"Epoch {epoch}\", unit=\"batch\")\n", + " if use_tqdm\n", + " else train_data_loader\n", " ):\n", + " current_time = datetime.now()\n", + " if current_time - start_time > max_duration:\n", + " raise TimeoutError(f\"Time limit {max_duration} exceeded\")\n", " edited_histogram = edited_histogram.to(device)\n", " original_histogram = original_histogram.to(device)\n", "\n", @@ -246,189 +289,227 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-05-12 21:54:49,789 - INFO - Starting run_71 with hparams {'batch_size': 8, 'edit_count': 8, 'bin_count': 32, 'clip_gradients': True, 'learning_rate': 0.0001, 'scheduler_gamma': 0.5, 'num_epochs': 2, 'model_type': 'Res2', 'loss': 'kl'}\n", - "2024-05-12 21:54:49,792 - INFO - Loaded 72 training images and 8 test images\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "461659956f3944a085b5cc3a5af6ec31", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Epoch 0: 0%| | 0/9 [00:00\", line 198, in _run_module_as_main\n", + " File \"\", line 88, in _run_code\n", + " File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/ipykernel_launcher.py\", line 18, in \n", + " app.launch_new_instance()\n", + " File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/traitlets/config/application.py\", line 1075, in launch_instance\n", + " app.start()\n", + " File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/ipykernel/kernelapp.py\", line 739, in start\n", + " self.io_loop.start()\n", + " File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/tornado/platform/asyncio.py\", line 195, in start\n", + " self.asyncio_loop.run_forever()\n", + " File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/asyncio/base_events.py\", line 639, in run_forever\n", + " self._run_once()\n", + " File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/asyncio/base_events.py\", line 1985, in _run_once\n", + " handle._run()\n", + " File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/asyncio/events.py\", line 88, in _run\n", + " self._context.run(self._callback, *self._args)\n", + " File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 545, in dispatch_queue\n", + " await self.process_one()\n", + " File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 534, in process_one\n", + " await dispatch(*args)\n", + " File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 437, in dispatch_shell\n", + " await result\n", + " File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/ipykernel/ipkernel.py\", line 359, in execute_request\n", + " await super().execute_request(stream, ident, parent)\n", + " File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/ipykernel/kernelbase.py\", line 778, in execute_request\n", + " reply_content = await reply_content\n", + " File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/ipykernel/ipkernel.py\", line 446, in do_execute\n", + " res = shell.run_cell(\n", + " File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/ipykernel/zmqshell.py\", line 549, in run_cell\n", + " return super().run_cell(*args, **kwargs)\n", + " File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3075, in run_cell\n", + " result = self._run_cell(\n", + " File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3130, in _run_cell\n", + " result = runner(coro)\n", + " File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/IPython/core/async_helpers.py\", line 129, in _pseudo_sync_runner\n", + " coro.send(None)\n", + " File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3334, in run_cell_async\n", + " has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n", + " File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3517, in run_ast_nodes\n", + " if await self.run_code(code, result, async_=asy):\n", + " File \"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/IPython/core/interactiveshell.py\", line 3577, in run_code\n", + " exec(code_obj, self.user_global_ns, self.user_ns)\n", + " File \"/tmp/ipykernel_141525/1542138470.py\", line 28, in \n", + " logging.error(\n" ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "1ea8edb16ad742d0ab6673560f231b20", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Epoch 1: 0%| | 0/9 [00:00