diff --git a/src/train.ipynb b/src/train.ipynb index 04c57df..62e8bb6 100644 --- a/src/train.ipynb +++ b/src/train.ipynb @@ -20,9 +20,8 @@ ], "source": [ "import torch\n", - "import logging\n", "import os\n", - "from editor.utils import set_up_logging\n", + "from utils import set_up_logging\n", "from config import LOGS_PATH\n", "\n", "set_up_logging(LOGS_PATH)\n", @@ -41,42 +40,42 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-06-22 15:59:06,999 - INFO - Testing model Dummy\n", - "2024-06-22 15:59:07,004 - INFO - Test passed! Output shape matches input shape.\n", - "2024-06-22 15:59:07,004 - INFO - Testing model SimpleCNN\n", - "2024-06-22 15:59:07,478 - INFO - Test passed! Output shape matches input shape.\n", - "2024-06-22 15:59:07,481 - INFO - Testing model Residual\n", - "2024-06-22 15:59:08,560 - INFO - Test passed! Output shape matches input shape.\n", - "2024-06-22 15:59:08,566 - INFO - Testing model Residual2\n", - "2024-06-22 15:59:09,671 - INFO - Test passed! Output shape matches input shape.\n", - "2024-06-22 15:59:09,676 - INFO - Testing model Residual3\n", - "2024-06-22 15:59:11,272 - INFO - Test passed! Output shape matches input shape.\n" + "2024-06-25 08:59:52,244 - INFO - Testing model Dummy\n", + "2024-06-25 08:59:52,249 - INFO - Test passed! Output shape matches input shape.\n", + "2024-06-25 08:59:52,249 - INFO - Testing model SimpleCNN\n", + "2024-06-25 08:59:52,746 - INFO - Test passed! Output shape matches input shape.\n", + "2024-06-25 08:59:52,752 - INFO - Testing model Residual\n", + "2024-06-25 08:59:53,853 - INFO - Test passed! Output shape matches input shape.\n", + "2024-06-25 08:59:53,917 - INFO - Testing model Residual3\n", + "2024-06-25 08:59:55,590 - INFO - Test passed! Output shape matches input shape.\n" ] } ], "source": [ "from scipy.stats import loguniform, uniform, randint\n", - "from editor.models import MODELS, test_models\n", + "from models import MODELS, test_models\n", "\n", "\n", - "hyperparameters = [{\n", - " \"batch_size\": [32, 64, 128],\n", - " \"edit_count\": [12],\n", - " \"bin_count\": [16],\n", - " \"learning_rate\": loguniform(5e-4, 5e-3),\n", - " \"scheduler_gamma\": uniform(loc=0.8, scale=0.15),\n", - " \"num_epochs\": [12],\n", - " \"elu_alpha\": uniform(0.5, 1.5),\n", - " \"leaky_relu_slope\": uniform(0, 0.03),\n", - " \"dropout_prob\": uniform(0, 0.1),\n", - " \"features\": [[16, 32, 64], [32, 64, 128], [8, 16, 32], [8, 8, 8], [16, 16, 16]],\n", - " \"kernel_sizes\": [[3, 3, 3]],\n", - " \"model_type\": [\"Residual3\"], # list(MODELS.keys()),\n", - " \"clip_gradients\": [True, False],\n", - " \"use_instance_norm\": [True, False],\n", - " \"use_elu\": [True, False],\n", - " \"leaky_relu_alpha\": uniform(0, 0.05),\n", - "}]\n", + "hyperparameters = [\n", + " {\n", + " \"batch_size\": [32, 64, 128],\n", + " \"edit_count\": [12],\n", + " \"bin_count\": [16],\n", + " \"learning_rate\": loguniform(5e-4, 5e-3),\n", + " \"scheduler_gamma\": uniform(loc=0.8, scale=0.15),\n", + " \"num_epochs\": [12],\n", + " \"elu_alpha\": uniform(0.5, 1.5),\n", + " \"leaky_relu_slope\": uniform(0, 0.03),\n", + " \"dropout_prob\": uniform(0, 0.1),\n", + " \"features\": [[16, 32, 64], [32, 64, 128], [8, 16, 32], [8, 8, 8], [16, 16, 16]],\n", + " \"kernel_sizes\": [[3, 3, 3]],\n", + " \"model_type\": [\"Residual3\"], # list(MODELS.keys()),\n", + " \"clip_gradients\": [True, False],\n", + " \"use_instance_norm\": [True, False],\n", + " \"use_elu\": [True, False],\n", + " \"leaky_relu_alpha\": uniform(0, 0.05),\n", + " }\n", + "]\n", "\n", "test_models()" ] @@ -86,451 +85,6 @@ "execution_count": 3, "metadata": {}, "outputs": [], - "source": [ - "from pathlib import Path\n", - "from typing import List, Any, Dict\n", - "from torch.utils.data import DataLoader\n", - "from config import CACHE_PATH\n", - "from editor.training import HistogramDataset\n", - "\n", - "\n", - "def get_data_loader(data: List[Path], hyperparameters: Dict[str, Any]) -> DataLoader:\n", - " return DataLoader(\n", - " dataset=HistogramDataset(\n", - " paths=data,\n", - " edit_count=hyperparameters[\"edit_count\"],\n", - " bin_count=hyperparameters[\"bin_count\"],\n", - " delete_corrupt_images=False,\n", - " cache_path=CACHE_PATH,\n", - " ),\n", - " batch_size=hyperparameters[\"batch_size\"],\n", - " shuffle=True,\n", - " num_workers=os.cpu_count(),\n", - " )\n", - "\n", - "\n", - "def serialise_hparams(hyperparameters: Dict[str, Any]) -> Dict[str, Any]:\n", - " return {k: str(v) if isinstance(v, list) else v for k, v in hyperparameters.items()}" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Optional\n", - "from torch.utils.tensorboard import SummaryWriter\n", - "from pathlib import Path\n", - "from torch.optim import Adam\n", - "from tqdm.notebook import tqdm\n", - "from editor.utils import get_next_run_name\n", - "from editor.visualisation import plot_histograms_in_2d\n", - "from editor.models import create_model, save_model\n", - "from datetime import timedelta, datetime\n", - "from config import MODELS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA\n", - "\n", - "\n", - "def train(\n", - " hyperparameters: Dict[str, Any],\n", - " max_duration: Optional[timedelta] = None,\n", - " use_tqdm: bool = True,\n", - ") -> Path:\n", - " start_time = datetime.now()\n", - "\n", - " log_dir = RUNS_PATH / get_next_run_name(RUNS_PATH)\n", - " with SummaryWriter(log_dir) as writer:\n", - " train_data_loader = get_data_loader(TRAIN_DATA, hyperparameters)\n", - " test_data_loader = get_data_loader(TEST_DATA, hyperparameters)\n", - "\n", - " model = create_model(\n", - " type=hyperparameters[\"model_type\"],\n", - " bin_count=hyperparameters[\"bin_count\"],\n", - " device=device,\n", - " ).train()\n", - " writer.add_graph(model, next(iter(train_data_loader))[0].to(device))\n", - "\n", - " optimizer = Adam(model.parameters(), lr=hyperparameters[\"learning_rate\"])\n", - " scheduler = torch.optim.lr_scheduler.StepLR(\n", - " optimizer, step_size=1, gamma=hyperparameters[\"scheduler_gamma\"]\n", - " )\n", - " loss_function = torch.nn.KLDivLoss(reduction=\"batchmean\").to(device)\n", - "\n", - " try:\n", - " for epoch in range(hyperparameters[\"num_epochs\"]):\n", - " epoch_loss = 0\n", - " writer.add_scalar(\n", - " \"Actual learning rate\", scheduler.get_last_lr()[0], epoch\n", - " )\n", - " for batch_id, (edited_histogram, original_histogram) in enumerate(\n", - " tqdm(train_data_loader, desc=f\"Epoch {epoch}\", unit=\"batch\")\n", - " if use_tqdm\n", - " else train_data_loader\n", - " ):\n", - " current_time = datetime.now()\n", - " if (\n", - " max_duration is not None\n", - " and current_time - start_time > max_duration\n", - " ):\n", - " raise TimeoutError(f\"Time limit {max_duration} exceeded\")\n", - "\n", - " optimizer.zero_grad()\n", - " predicted_original = model(edited_histogram.to(device))\n", - " loss = loss_function(\n", - " torch.log(torch.clamp(predicted_original, 1e-10, 1)),\n", - " original_histogram.to(device),\n", - " )\n", - "\n", - " epoch_loss += loss.item()\n", - " writer.add_scalar(\n", - " \"Loss/train/batch\",\n", - " loss,\n", - " global_step=epoch * len(train_data_loader) + batch_id,\n", - " )\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " logging.info(f\"Epoch {epoch} train loss: {epoch_loss}\")\n", - " with torch.no_grad():\n", - " model.eval()\n", - " loader = iter(test_data_loader)\n", - " edited_histogram, original_histogram = next(loader)\n", - " predicted_original = model(edited_histogram.to(device))\n", - " writer.add_figure(\n", - " \"histogram\",\n", - " plot_histograms_in_2d(\n", - " {\n", - " \"original\": original_histogram[0].numpy().squeeze(),\n", - " \"edited\": edited_histogram.cpu()[0].numpy().squeeze(),\n", - " \"predicted\": predicted_original.cpu()[0]\n", - " .numpy()\n", - " .squeeze(),\n", - " }\n", - " ),\n", - " epoch,\n", - " )\n", - "\n", - " epoch_test_loss = 0\n", - " for batch_id, (edited_histogram, original_histogram) in enumerate(\n", - " test_data_loader\n", - " ):\n", - " predicted_original = model(edited_histogram.to(device))\n", - " epoch_test_loss += loss_function(\n", - " torch.log(torch.clamp(predicted_original, 1e-10, 1)),\n", - " original_histogram.to(device),\n", - " ).item()\n", - " writer.add_hparams(\n", - " serialise_hparams(hyperparameters),\n", - " {\n", - " \"Loss/test/epoch\": epoch_test_loss,\n", - " \"Loss/train/epoch\": epoch_loss,\n", - " },\n", - " global_step=epoch,\n", - " run_name=log_dir.absolute(),\n", - " )\n", - " logging.info(f\"Epoch {epoch} test loss: {epoch_test_loss}\")\n", - "\n", - " model.train()\n", - " scheduler.step()\n", - " finally:\n", - " model_path = MODELS_PATH / get_next_run_name(MODELS_PATH)\n", - " save_model(model, hyperparameters, model_path)\n", - " del model\n", - " return model_path" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-06-22 16:57:28,986 - INFO - Loaded 22479 original images\n", - "2024-06-22 16:57:28,991 - INFO - Loaded 2498 original images\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "63e068d857484e59afb823ca4b5c3a58", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Epoch 0: 0%| | 0/2108 [00:00 None: - for _ in range(1): + for _ in count(): current_hyperparameters = { k: v.rvs() if hasattr(v, "rvs") else choice(v) for k, v in choice(hyperparameters).items() diff --git a/src/training/train.py b/src/training/train.py index e11ad14..7de5b33 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -64,7 +64,7 @@ def train( optimizer.zero_grad() predicted_original = model(edited_histogram.to(device)) loss = loss_function( - torch.log(torch.clamp(predicted_original, EPSILON, 1)), + torch.log(predicted_original + EPSILON), original_histogram.to(device), ) @@ -101,7 +101,7 @@ def train( ): predicted_original = model(edited_histogram.to(device)) epoch_test_loss += loss_function( - torch.log(torch.clamp(predicted_original, EPSILON, 1)), + torch.log(predicted_original + EPSILON), original_histogram.to(device), ).item() writer.add_hparams(