diff --git a/src/train.ipynb b/src/train.ipynb index 9e06e73..04c57df 100644 --- a/src/train.ipynb +++ b/src/train.ipynb @@ -22,19 +22,10 @@ "import torch\n", "import logging\n", "import os\n", - "from datetime import datetime\n", + "from editor.utils import set_up_logging\n", "from config import LOGS_PATH\n", "\n", - "logging.basicConfig(\n", - " level=logging.INFO,\n", - " format=\"%(asctime)s - %(levelname)s - %(message)s\",\n", - " handlers=[\n", - " logging.StreamHandler(),\n", - " logging.FileHandler(\n", - " LOGS_PATH / f\"train-{datetime.now().isoformat(timespec='minutes')}.log\"\n", - " ),\n", - " ],\n", - ")\n", + "set_up_logging(LOGS_PATH)\n", "\n", "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n", "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", @@ -45,34 +36,49 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-22 15:59:06,999 - INFO - Testing model Dummy\n", + "2024-06-22 15:59:07,004 - INFO - Test passed! Output shape matches input shape.\n", + "2024-06-22 15:59:07,004 - INFO - Testing model SimpleCNN\n", + "2024-06-22 15:59:07,478 - INFO - Test passed! Output shape matches input shape.\n", + "2024-06-22 15:59:07,481 - INFO - Testing model Residual\n", + "2024-06-22 15:59:08,560 - INFO - Test passed! Output shape matches input shape.\n", + "2024-06-22 15:59:08,566 - INFO - Testing model Residual2\n", + "2024-06-22 15:59:09,671 - INFO - Test passed! Output shape matches input shape.\n", + "2024-06-22 15:59:09,676 - INFO - Testing model Residual3\n", + "2024-06-22 15:59:11,272 - INFO - Test passed! Output shape matches input shape.\n" + ] + } + ], "source": [ "from scipy.stats import loguniform, uniform, randint\n", - "from editor.models import MODELS\n", + "from editor.models import MODELS, test_models\n", "\n", - "common_hyperparameters = {\n", - " \"batch_size\": [64],\n", + "\n", + "hyperparameters = [{\n", + " \"batch_size\": [32, 64, 128],\n", " \"edit_count\": [12],\n", " \"bin_count\": [16],\n", - " \"clip_gradients\": [False],\n", - " \"learning_rate\": loguniform(1e-4, 5e-3),\n", - " \"scheduler_gamma\": uniform(loc=0.7, scale=0.3),\n", - " \"num_epochs\": [24],\n", - " # \"num_epochs\": randint(5, 10),\n", - " \"model_type\": list(MODELS.keys()),\n", - "}\n", - "hyperparameters = [\n", - " # {\n", - " # **common_hyperparameters,\n", - " # \"loss\": [\"progressive\"],\n", - " # \"loss_sizes\": [[4, 8, 16, 32], [8, 16, 32], [16, 32], [8, 32]],\n", - " # \"loss_damping\": uniform(0.2, 5),\n", - " # },\n", - " {\n", - " **common_hyperparameters,\n", - " \"loss\": [\"kl\"],\n", - " },\n", - "]" + " \"learning_rate\": loguniform(5e-4, 5e-3),\n", + " \"scheduler_gamma\": uniform(loc=0.8, scale=0.15),\n", + " \"num_epochs\": [12],\n", + " \"elu_alpha\": uniform(0.5, 1.5),\n", + " \"leaky_relu_slope\": uniform(0, 0.03),\n", + " \"dropout_prob\": uniform(0, 0.1),\n", + " \"features\": [[16, 32, 64], [32, 64, 128], [8, 16, 32], [8, 8, 8], [16, 16, 16]],\n", + " \"kernel_sizes\": [[3, 3, 3]],\n", + " \"model_type\": [\"Residual3\"], # list(MODELS.keys()),\n", + " \"clip_gradients\": [True, False],\n", + " \"use_instance_norm\": [True, False],\n", + " \"use_elu\": [True, False],\n", + " \"leaky_relu_alpha\": uniform(0, 0.05),\n", + "}]\n", + "\n", + "test_models()" ] }, { @@ -111,36 +117,24 @@ "cell_type": "code", "execution_count": 4, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Testing model Residual\n", - "Test passed! Output shape matches input shape.\n" - ] - } - ], + "outputs": [], "source": [ + "from typing import Optional\n", "from torch.utils.tensorboard import SummaryWriter\n", "from pathlib import Path\n", "from torch.optim import Adam\n", "from tqdm.notebook import tqdm\n", - "from torch.nn.utils import clip_grad_norm_\n", - "from editor.training import ProgressivePoolingLoss\n", "from editor.utils import get_next_run_name\n", "from editor.visualisation import plot_histograms_in_2d\n", - "from editor.models import create_model, test_models\n", + "from editor.models import create_model, save_model\n", "from datetime import timedelta, datetime\n", - "import json\n", "from config import MODELS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA\n", "\n", "\n", - "test_models()\n", - "\n", - "\n", "def train(\n", - " hyperparameters: Dict[str, Any], max_duration: timedelta, use_tqdm: bool\n", + " hyperparameters: Dict[str, Any],\n", + " max_duration: Optional[timedelta] = None,\n", + " use_tqdm: bool = True,\n", ") -> Path:\n", " start_time = datetime.now()\n", "\n", @@ -149,28 +143,18 @@ " train_data_loader = get_data_loader(TRAIN_DATA, hyperparameters)\n", " test_data_loader = get_data_loader(TEST_DATA, hyperparameters)\n", "\n", - " model = (\n", - " create_model(\n", - " type=hyperparameters[\"model_type\"],\n", - " bin_count=hyperparameters[\"bin_count\"],\n", - " )\n", - " .train()\n", - " .to(device)\n", - " )\n", + " model = create_model(\n", + " type=hyperparameters[\"model_type\"],\n", + " bin_count=hyperparameters[\"bin_count\"],\n", + " device=device,\n", + " ).train()\n", " writer.add_graph(model, next(iter(train_data_loader))[0].to(device))\n", "\n", " optimizer = Adam(model.parameters(), lr=hyperparameters[\"learning_rate\"])\n", " scheduler = torch.optim.lr_scheduler.StepLR(\n", " optimizer, step_size=1, gamma=hyperparameters[\"scheduler_gamma\"]\n", " )\n", - "\n", - " loss_function = {\n", - " \"progressive\": lambda: ProgressivePoolingLoss(\n", - " target_sizes=hyperparameters[\"loss_sizes\"],\n", - " damping=hyperparameters[\"loss_damping\"],\n", - " ),\n", - " \"kl\": lambda: torch.nn.KLDivLoss(reduction=\"batchmean\"),\n", - " }[hyperparameters[\"loss\"]]().to(device)\n", + " loss_function = torch.nn.KLDivLoss(reduction=\"batchmean\").to(device)\n", "\n", " try:\n", " for epoch in range(hyperparameters[\"num_epochs\"]):\n", @@ -184,30 +168,18 @@ " else train_data_loader\n", " ):\n", " current_time = datetime.now()\n", - " if current_time - start_time > max_duration:\n", + " if (\n", + " max_duration is not None\n", + " and current_time - start_time > max_duration\n", + " ):\n", " raise TimeoutError(f\"Time limit {max_duration} exceeded\")\n", - " edited_histogram = edited_histogram.to(device)\n", - " original_histogram = original_histogram.to(device)\n", "\n", " optimizer.zero_grad()\n", - " predicted_original = model(edited_histogram)\n", - " sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)\n", - " predicted_original = predicted_original / sum\n", - "\n", - " if hyperparameters[\"loss\"] == \"kl\":\n", - " predicted_original = torch.clamp(\n", - " predicted_original, 0.0000000000000000000001, 1\n", - " )\n", - "\n", - " loss = {\n", - " \"kl\": lambda: loss_function(\n", - " torch.log(predicted_original),\n", - " original_histogram,\n", - " ),\n", - " \"progressive\": lambda: loss_function(\n", - " predicted_original, original_histogram\n", - " ),\n", - " }[hyperparameters[\"loss\"]]()\n", + " predicted_original = model(edited_histogram.to(device))\n", + " loss = loss_function(\n", + " torch.log(torch.clamp(predicted_original, 1e-10, 1)),\n", + " original_histogram.to(device),\n", + " )\n", "\n", " epoch_loss += loss.item()\n", " writer.add_scalar(\n", @@ -216,9 +188,6 @@ " global_step=epoch * len(train_data_loader) + batch_id,\n", " )\n", " loss.backward()\n", - "\n", - " if hyperparameters[\"clip_gradients\"]:\n", - " clip_grad_norm_(model.parameters(), max_norm=1.0)\n", " optimizer.step()\n", "\n", " logging.info(f\"Epoch {epoch} train loss: {epoch_loss}\")\n", @@ -226,18 +195,12 @@ " model.eval()\n", " loader = iter(test_data_loader)\n", " edited_histogram, original_histogram = next(loader)\n", - " edited_histogram = edited_histogram.to(device)\n", - " original_histogram = original_histogram.to(device)\n", - " predicted_original = model(edited_histogram)\n", - " sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)\n", - " predicted_original = predicted_original / sum\n", + " predicted_original = model(edited_histogram.to(device))\n", " writer.add_figure(\n", " \"histogram\",\n", " plot_histograms_in_2d(\n", " {\n", - " \"original\": original_histogram.cpu()[0]\n", - " .numpy()\n", - " .squeeze(),\n", + " \"original\": original_histogram[0].numpy().squeeze(),\n", " \"edited\": edited_histogram.cpu()[0].numpy().squeeze(),\n", " \"predicted\": predicted_original.cpu()[0]\n", " .numpy()\n", @@ -251,29 +214,11 @@ " for batch_id, (edited_histogram, original_histogram) in enumerate(\n", " test_data_loader\n", " ):\n", - " edited_histogram = edited_histogram.to(device)\n", - " original_histogram = original_histogram.to(device)\n", - "\n", - " predicted_original = model(edited_histogram)\n", - " sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)\n", - " predicted_original = predicted_original / sum\n", - "\n", - " if hyperparameters[\"loss\"] == \"kl\":\n", - " predicted_original = torch.clamp(\n", - " predicted_original, 0.0000000000000000000001, 1\n", - " )\n", - "\n", - " loss = {\n", - " \"kl\": lambda: loss_function(\n", - " torch.log(predicted_original),\n", - " original_histogram,\n", - " ),\n", - " \"progressive\": lambda: loss_function(\n", - " predicted_original, original_histogram\n", - " ),\n", - " }[hyperparameters[\"loss\"]]()\n", - "\n", - " epoch_test_loss += loss.item()\n", + " predicted_original = model(edited_histogram.to(device))\n", + " epoch_test_loss += loss_function(\n", + " torch.log(torch.clamp(predicted_original, 1e-10, 1)),\n", + " original_histogram.to(device),\n", + " ).item()\n", " writer.add_hparams(\n", " serialise_hparams(hyperparameters),\n", " {\n", @@ -287,40 +232,323 @@ "\n", " model.train()\n", " scheduler.step()\n", - " except Exception:\n", - " raise\n", " finally:\n", - " run_name = get_next_run_name(MODELS_PATH)\n", - " model_path = (MODELS_PATH / run_name).with_suffix(\".pth\")\n", - " params_path = (MODELS_PATH / run_name).with_suffix(\".json\")\n", - "\n", - " logging.info(f\"Saving model to {model_path}\")\n", - " with open(model_path, \"wb\") as f:\n", - " torch.save(model.state_dict(), f)\n", - " with open(params_path, \"w\") as f:\n", - " json.dump(hyperparameters, f, indent=2)\n", + " model_path = MODELS_PATH / get_next_run_name(MODELS_PATH)\n", + " save_model(model, hyperparameters, model_path)\n", " del model\n", - " torch.cuda.empty_cache()\n", " return model_path" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-22 16:57:28,986 - INFO - Loaded 22479 original images\n", + "2024-06-22 16:57:28,991 - INFO - Loaded 2498 original images\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "63e068d857484e59afb823ca4b5c3a58", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Epoch 0: 0%| | 0/2108 [00:00