{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "metadata": {} }, "outputs": [ { "data": { "text/plain": [ "'Using device cuda:0'" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import logging\n", "import os\n", "from datetime import datetime\n", "\n", "logging.basicConfig(\n", " level=logging.INFO,\n", " format=\"%(asctime)s - %(levelname)s - %(message)s\",\n", " handlers=[\n", " logging.StreamHandler(),\n", " logging.FileHandler(\n", " f\"train-{datetime.now().isoformat(timespec='minutes')}.log\"\n", " ),\n", " ],\n", ")\n", "\n", "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n", "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "f\"Using device {device}\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from scipy.stats import loguniform, uniform, randint\n", "from editor.models import MODELS\n", "\n", "common_hyperparameters = {\n", " \"batch_size\": [64],\n", " \"edit_count\": [8],\n", " \"bin_count\": [16, 24, 32],\n", " \"clip_gradients\": [False],\n", " \"learning_rate\": loguniform(3e-4, 3e-3),\n", " \"scheduler_gamma\": uniform(0.5, 1),\n", " \"num_epochs\": [16],\n", " # \"num_epochs\": randint(5, 10),\n", " \"model_type\": list(MODELS.keys()),\n", "}\n", "hyperparameters = [\n", " # {\n", " # **common_hyperparameters,\n", " # \"loss\": [\"progressive\"],\n", " # \"loss_sizes\": [[4, 8, 16, 32], [8, 16, 32], [16, 32], [8, 32]],\n", " # \"loss_damping\": uniform(0.2, 5),\n", " # },\n", " {\n", " **common_hyperparameters,\n", " \"loss\": [\"kl\"],\n", " },\n", "]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "from typing import List, Any, Dict\n", "from torch.utils.data import DataLoader\n", "from config import CACHE_PATH\n", "from editor.training import HistogramDataset\n", "\n", "\n", "def get_data_loader(data: List[Path], hyperparameters: Dict[str, Any]) -> DataLoader:\n", " return DataLoader(\n", " dataset=HistogramDataset(\n", " paths=data,\n", " edit_count=hyperparameters[\"edit_count\"],\n", " bin_count=hyperparameters[\"bin_count\"],\n", " delete_corrupt_images=False,\n", " cache_path=CACHE_PATH,\n", " ),\n", " batch_size=hyperparameters[\"batch_size\"],\n", " shuffle=True,\n", " num_workers=os.cpu_count(),\n", " )\n", "\n", "\n", "def serialise_hparams(hyperparameters: Dict[str, Any]) -> Dict[str, Any]:\n", " return {k: str(v) if isinstance(v, list) else v for k, v in hyperparameters.items()}" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing model SimpleCNN\n", "Test passed! Output shape matches input shape.\n", "Testing model Residual\n", "Test passed! Output shape matches input shape.\n" ] } ], "source": [ "from torch.utils.tensorboard import SummaryWriter\n", "from pathlib import Path\n", "from torch.optim import Adam\n", "from tqdm.notebook import tqdm\n", "from torch.nn.utils import clip_grad_norm_\n", "from editor.training import ProgressivePoolingLoss\n", "from editor.utils import get_next_run_name\n", "from editor.visualisation import plot_histograms_in_2d\n", "from editor.models import create_model, test_models\n", "from data import TRAIN_DATA, TEST_DATA\n", "from datetime import timedelta, datetime\n", "import json\n", "from config import MODELS_PATH\n", "\n", "\n", "test_models()\n", "\n", "\n", "def train(\n", " hyperparameters: Dict[str, Any], max_duration: timedelta, use_tqdm: bool\n", ") -> Path:\n", " start_time = datetime.now()\n", "\n", " log_dir = Path(\"runs\") / get_next_run_name(Path(\"runs\"))\n", " with SummaryWriter(log_dir) as writer:\n", " train_data_loader = get_data_loader(TRAIN_DATA, hyperparameters)\n", " test_data_loader = get_data_loader(TEST_DATA, hyperparameters)\n", "\n", " model = (\n", " create_model(\n", " type=hyperparameters[\"model_type\"],\n", " bin_count=hyperparameters[\"bin_count\"],\n", " )\n", " .train()\n", " .to(device)\n", " )\n", " writer.add_graph(model, next(iter(train_data_loader))[0].to(device))\n", "\n", " optimizer = Adam(model.parameters(), lr=hyperparameters[\"learning_rate\"])\n", " scheduler = torch.optim.lr_scheduler.StepLR(\n", " optimizer, step_size=1, gamma=hyperparameters[\"scheduler_gamma\"]\n", " )\n", "\n", " loss_function = {\n", " \"progressive\": lambda: ProgressivePoolingLoss(\n", " target_sizes=hyperparameters[\"loss_sizes\"],\n", " damping=hyperparameters[\"loss_damping\"],\n", " ),\n", " \"kl\": lambda: torch.nn.KLDivLoss(reduction=\"batchmean\"),\n", " }[hyperparameters[\"loss\"]]().to(device)\n", "\n", " try:\n", " for epoch in range(hyperparameters[\"num_epochs\"]):\n", " epoch_loss = 0\n", " writer.add_scalar(\n", " \"Actual learning rate\", scheduler.get_last_lr()[0], epoch\n", " )\n", " for batch_id, (edited_histogram, original_histogram) in enumerate(\n", " tqdm(train_data_loader, desc=f\"Epoch {epoch}\", unit=\"batch\")\n", " if use_tqdm\n", " else train_data_loader\n", " ):\n", " current_time = datetime.now()\n", " if current_time - start_time > max_duration:\n", " raise TimeoutError(f\"Time limit {max_duration} exceeded\")\n", " edited_histogram = edited_histogram.to(device)\n", " original_histogram = original_histogram.to(device)\n", "\n", " optimizer.zero_grad()\n", " predicted_original = model(edited_histogram)\n", " sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)\n", " predicted_original = predicted_original / sum\n", "\n", " if hyperparameters[\"loss\"] == \"kl\":\n", " predicted_original = torch.clamp(\n", " predicted_original, 0.0000000000000000000001, 1\n", " )\n", "\n", " loss = {\n", " \"kl\": lambda: loss_function(\n", " torch.log(predicted_original),\n", " original_histogram,\n", " ),\n", " \"progressive\": lambda: loss_function(\n", " predicted_original, original_histogram\n", " ),\n", " }[hyperparameters[\"loss\"]]()\n", "\n", " epoch_loss += loss.item()\n", " writer.add_scalar(\n", " \"Loss/train/batch\",\n", " loss,\n", " global_step=epoch * len(train_data_loader) + batch_id,\n", " )\n", " loss.backward()\n", "\n", " if hyperparameters[\"clip_gradients\"]:\n", " clip_grad_norm_(model.parameters(), max_norm=1.0)\n", " optimizer.step()\n", "\n", " logging.info(f\"Epoch {epoch} train loss: {epoch_loss}\")\n", " with torch.no_grad():\n", " model.eval()\n", " loader = iter(test_data_loader)\n", " edited_histogram, original_histogram = next(loader)\n", " edited_histogram = edited_histogram.to(device)\n", " original_histogram = original_histogram.to(device)\n", " predicted_original = model(edited_histogram)\n", " sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)\n", " predicted_original = predicted_original / sum\n", " writer.add_figure(\n", " \"histogram\",\n", " plot_histograms_in_2d(\n", " {\n", " \"original\": original_histogram.cpu()[0]\n", " .numpy()\n", " .squeeze(),\n", " \"edited\": edited_histogram.cpu()[0].numpy().squeeze(),\n", " \"predicted\": predicted_original.cpu()[0]\n", " .numpy()\n", " .squeeze(),\n", " }\n", " ),\n", " epoch,\n", " )\n", "\n", " epoch_test_loss = 0\n", " for batch_id, (edited_histogram, original_histogram) in enumerate(\n", " test_data_loader\n", " ):\n", " edited_histogram = edited_histogram.to(device)\n", " original_histogram = original_histogram.to(device)\n", "\n", " predicted_original = model(edited_histogram)\n", " sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)\n", " predicted_original = predicted_original / sum\n", "\n", " if hyperparameters[\"loss\"] == \"kl\":\n", " predicted_original = torch.clamp(\n", " predicted_original, 0.0000000000000000000001, 1\n", " )\n", "\n", " loss = {\n", " \"kl\": lambda: loss_function(\n", " torch.log(predicted_original),\n", " original_histogram,\n", " ),\n", " \"progressive\": lambda: loss_function(\n", " predicted_original, original_histogram\n", " ),\n", " }[hyperparameters[\"loss\"]]()\n", "\n", " epoch_test_loss += loss.item()\n", " writer.add_hparams(\n", " serialise_hparams(hyperparameters),\n", " {\n", " \"Loss/test/epoch\": epoch_test_loss,\n", " \"Loss/train/epoch\": epoch_loss,\n", " },\n", " global_step=epoch,\n", " run_name=log_dir.absolute(),\n", " )\n", " logging.info(f\"Epoch {epoch} test loss: {epoch_test_loss}\")\n", "\n", " model.train()\n", " scheduler.step()\n", " except Exception:\n", " raise\n", " finally:\n", " run_name = get_next_run_name(MODELS_PATH)\n", " model_path = (MODELS_PATH / run_name).with_suffix(\".pth\")\n", " params_path = (MODELS_PATH / run_name).with_suffix(\".json\")\n", "\n", " logging.info(f\"Saving model to {model_path}\")\n", " with open(model_path, \"wb\") as f:\n", " torch.save(model.state_dict(), f)\n", " with open(params_path, \"w\") as f:\n", " json.dump(hyperparameters, f, indent=2)\n", " del model\n", " torch.cuda.empty_cache()\n", " return model_path" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# train(\n", "# {\n", "# \"batch_size\": 64,\n", "# \"edit_count\": 25,\n", "# \"bin_count\": 32,\n", "# \"clip_gradients\": True,\n", "# \"learning_rate\": 0.005,\n", "# \"scheduler_gamma\": 0.7,\n", "# \"num_epochs\": 20,\n", "# \"model_type\": \"NormalisedCNN\",\n", "# \"loss\": \"progressive\",\n", "# \"loss_sizes\": [16, 32],\n", "# \"loss_damping\": 2,\n", "# }\n", "# )" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-06-06 08:19:45,651 - INFO - Starting run_26 with hparams {\n", " \"batch_size\": 64,\n", " \"bin_count\": 32,\n", " \"clip_gradients\": false,\n", " \"edit_count\": 8,\n", " \"learning_rate\": 0.0020871393198725404,\n", " \"loss\": \"kl\",\n", " \"model_type\": \"Residual\",\n", " \"num_epochs\": 16,\n", " \"scheduler_gamma\": 1.2086440138363033\n", "}\n", "2024-06-06 08:19:45,768 - INFO - Loaded 22479 original images\n", "2024-06-06 08:19:45,778 - INFO - Loaded 2498 original images\n", "2024-06-06 08:20:02,104 - INFO - Saving model to /home/andras/projects/bipolaroid/models/run_26.pth\n", "2024-06-06 08:20:02,248 - INFO - Interrupted, stopping\n" ] } ], "source": [ "from random import choice\n", "from itertools import count\n", "import json\n", "\n", "\n", "for _ in count():\n", " current_hyperparameters = {\n", " k: v.rvs() if hasattr(v, \"rvs\") else choice(v)\n", " for k, v in choice(hyperparameters).items()\n", " }\n", " key = json.dumps(current_hyperparameters, indent=2, sort_keys=True)\n", " logging.info(\n", " f\"Starting {get_next_run_name(Path(\"runs\"))} with hparams {key}\"\n", " )\n", " try:\n", " train(current_hyperparameters, max_duration=timedelta(hours=8), use_tqdm=False)\n", " except KeyboardInterrupt as e:\n", " logging.info(\"Interrupted, stopping\")\n", " break\n", " except TimeoutError as e:\n", " logging.warning(f\"Timeout, aborting experiment\")\n", " except Exception as e:\n", " logging.error(\n", " f\"Error with hparams {current_hyperparameters}:\\n\\t{e}\", stack_info=True\n", " )" ] } ], "metadata": { "kernelspec": { "display_name": "bipolaroid", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 2 }