diff --git a/.gitignore b/.gitignore index ada1975..249f81b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ __pycache__ runs* *.log -saved_models/* +saved_models* train.py diff --git a/src/train.ipynb b/src/train.ipynb index 46d07a7..84be519 100644 --- a/src/train.ipynb +++ b/src/train.ipynb @@ -2,22 +2,32 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "metadata": {} }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'Using device cuda:0'" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "import torch\n", "import os\n", - "from utils import set_up_logging\n", + "from utils import set_up_logging, get_device\n", "from training import train, random_hparam_search\n", "from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA, MODELS_PATH\n", "\n", "set_up_logging(LOGS_PATH)\n", "\n", "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n", - "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "device = get_device()\n", "f\"Using device {device}\"" ] }, @@ -49,6 +59,144 @@ "# )" ] }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-29 10:45:18,539 - INFO - Loaded 22479 original images\n", + "2024-06-29 10:45:18,544 - INFO - Loaded 2498 original images\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "61d4288a87a24255824e1430d8e69804", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Epoch 0: 0%| | 0/4215 [00:00 20\u001b[0m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[43m \u001b[49m\u001b[43mhparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 22\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_data_paths\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mTRAIN_DATA\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 23\u001b[0m \u001b[43m \u001b[49m\u001b[43mtest_data_paths\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mTEST_DATA\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[43m \u001b[49m\u001b[43mlog_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mRUNS_PATH\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 25\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_duration\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 26\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_tqdm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 27\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 28\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mhparams\u001b[49m\n\u001b[1;32m 29\u001b[0m \u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/mnt/wsl/PHYSICALDRIVE1/projects/bipolaroid/src/training/train.py:55\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(hyperparameters, train_data_paths, test_data_paths, log_dir, use_tqdm, device, model_type, learning_rate, scheduler_gamma, num_epochs, **_)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch_id, (edited_histogram, original_histogram) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\n\u001b[1;32m 50\u001b[0m tqdm(train_data_loader, desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, unit\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_tqdm\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m train_data_loader\n\u001b[1;32m 53\u001b[0m ):\n\u001b[1;32m 54\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 55\u001b[0m predicted_original \u001b[38;5;241m=\u001b[39m model(\u001b[43medited_histogram\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 56\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss_function(\n\u001b[1;32m 57\u001b[0m torch\u001b[38;5;241m.\u001b[39mlog(predicted_original \u001b[38;5;241m+\u001b[39m EPSILON),\n\u001b[1;32m 58\u001b[0m original_histogram\u001b[38;5;241m.\u001b[39mto(device),\n\u001b[1;32m 59\u001b[0m )\n\u001b[1;32m 61\u001b[0m epoch_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mitem()\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "# hparams = {\n", + "# \"batch_size\": 64,\n", + "# \"edit_count\": 12,\n", + "# \"bin_count\": 16,\n", + "# \"learning_rate\": 0.0006126108207352808,\n", + "# \"scheduler_gamma\": 0.9382286228762693,\n", + "# \"num_epochs\": 10,\n", + "# \"elu_alpha\": 1.3092260477215776,\n", + "# \"leaky_relu_slope\": 0.029438156325552762,\n", + "# \"dropout_prob\": 0.06261255195786307,\n", + "# \"features\": [8, 16, 32],\n", + "# \"use_residual\": True,\n", + "# \"kernel_size\": 5,\n", + "# \"model_type\": \"HistogramNet\",\n", + "# \"use_instance_norm\": True,\n", + "# \"use_elu\": False,\n", + "# \"leaky_relu_alpha\": 0.03745605986732464,\n", + "# }\n", + "\n", + "# train(\n", + "# hparams,\n", + "# train_data_paths=TRAIN_DATA,\n", + "# test_data_paths=TEST_DATA,\n", + "# log_dir=RUNS_PATH,\n", + "# max_duration=None,\n", + "# use_tqdm=True,\n", + "# device=device,\n", + "# **hparams\n", + "# )" + ] + }, { "cell_type": "code", "execution_count": null, @@ -61,37 +209,34 @@ "\n", "hyperparameters = [\n", " {\n", - " \"batch_size\": [32, 64, 128],\n", + " \"batch_size\": [64, 128, 256],\n", " \"edit_count\": [12],\n", " \"bin_count\": [16],\n", " \"learning_rate\": loguniform(5e-4, 5e-3),\n", " \"scheduler_gamma\": uniform(loc=0.8, scale=0.15),\n", - " \"num_epochs\": [12],\n", + " \"num_epochs\": [10],\n", " \"elu_alpha\": uniform(0.5, 1.5),\n", " \"leaky_relu_slope\": uniform(0, 0.03),\n", " \"dropout_prob\": uniform(0, 0.1),\n", " \"features\": [\n", - " [16, 32],\n", " [16, 32, 64],\n", " [16, 32, 64, 128],\n", " [32, 64],\n", " [32, 64, 128],\n", " [8, 16, 32],\n", - " [8, 8, 8],\n", " [8, 8, 8, 8, 8],\n", " [8, 8, 8, 8, 8, 8, 8],\n", - " [16, 16, 16, 16, 16],\n", " [16, 16, 16],\n", - " [32, 32], \n", + " [16, 16, 16, 16, 16],\n", " [32, 32, 32],\n", " [32, 32, 32, 32],\n", " [64, 64],\n", - " [64, 64, 64]\n", + " [64, 64, 64],\n", " ],\n", " \"use_residual\": [True, False],\n", " \"kernel_size\": [3, 5],\n", " \"model_type\": [\"HistogramNet\"],\n", - " \"use_instance_norm\": [True, False],\n", + " \"use_instance_norm\": [True],\n", " \"use_elu\": [True, False],\n", " \"leaky_relu_alpha\": uniform(0, 0.05),\n", " }\n", @@ -127,7 +272,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.1.-1" + "version": "3.12.2" } }, "nbformat": 4, diff --git a/src/training/histogram_dataset.py b/src/training/histogram_dataset.py index 7aa1b56..306112f 100644 --- a/src/training/histogram_dataset.py +++ b/src/training/histogram_dataset.py @@ -18,8 +18,8 @@ class HistogramDataset(Dataset): self, /, paths: List[Path], - bin_count: int, - edit_count: int = 5, + bin_count: int = 16, + edit_count: int = 12, target_size=(240, 240), delete_corrupt_images: bool = False, cache_path: Optional[Path] = None, diff --git a/src/utils/compute_histogram.py b/src/utils/compute_histogram.py index 37abbc1..abcbcd1 100644 --- a/src/utils/compute_histogram.py +++ b/src/utils/compute_histogram.py @@ -4,7 +4,7 @@ import numpy as np def compute_histogram( image: Image.Image | np.ndarray, - bins: int, + bins: int = 16, value_range=(0, 256), normalize: bool = True, ) -> np.ndarray: