Fine tune training

This commit is contained in:
Andras Schmelczer 2024-06-06 08:21:34 +01:00
parent 9aac55f62d
commit 7cfc201229
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
3 changed files with 21677 additions and 14350 deletions

View file

@ -13,7 +13,7 @@ import torch
MODELS = { MODELS = {
# "v1": v1, # "v1": v1,
"SimpleCNN": SimpleCNN, # "SimpleCNN": SimpleCNN,
"Residual": Residual, "Residual": Residual,
# "NormalisedCNN": NormalisedCNN, # "NormalisedCNN": NormalisedCNN,
# "SmartRes": SmartRes, # "SmartRes": SmartRes,

File diff suppressed because one or more lines are too long

View file

@ -2,11 +2,22 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 1,
"metadata": { "metadata": {
"metadata": {} "metadata": {}
}, },
"outputs": [], "outputs": [
{
"data": {
"text/plain": [
"'Using device cuda:0'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"import torch\n", "import torch\n",
"import logging\n", "import logging\n",
@ -31,7 +42,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -39,13 +50,13 @@
"from editor.models import MODELS\n", "from editor.models import MODELS\n",
"\n", "\n",
"common_hyperparameters = {\n", "common_hyperparameters = {\n",
" \"batch_size\": [16, 32, 64],\n", " \"batch_size\": [64],\n",
" \"edit_count\": [8, 16],\n", " \"edit_count\": [8],\n",
" \"bin_count\": [16, 24, 32],\n", " \"bin_count\": [16, 24, 32],\n",
" \"clip_gradients\": [True, False],\n", " \"clip_gradients\": [False],\n",
" \"learning_rate\": loguniform(0.00001, 0.01),\n", " \"learning_rate\": loguniform(3e-4, 3e-3),\n",
" \"scheduler_gamma\": uniform(0, 1),\n", " \"scheduler_gamma\": uniform(0.5, 1),\n",
" \"num_epochs\": [10],\n", " \"num_epochs\": [16],\n",
" # \"num_epochs\": randint(5, 10),\n", " # \"num_epochs\": randint(5, 10),\n",
" \"model_type\": list(MODELS.keys()),\n", " \"model_type\": list(MODELS.keys()),\n",
"}\n", "}\n",
@ -65,7 +76,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -97,9 +108,20 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Testing model SimpleCNN\n",
"Test passed! Output shape matches input shape.\n",
"Testing model Residual\n",
"Test passed! Output shape matches input shape.\n"
]
}
],
"source": [ "source": [
"from torch.utils.tensorboard import SummaryWriter\n", "from torch.utils.tensorboard import SummaryWriter\n",
"from pathlib import Path\n", "from pathlib import Path\n",
@ -123,10 +145,6 @@
" hyperparameters: Dict[str, Any], max_duration: timedelta, use_tqdm: bool\n", " hyperparameters: Dict[str, Any], max_duration: timedelta, use_tqdm: bool\n",
") -> Path:\n", ") -> Path:\n",
" start_time = datetime.now()\n", " start_time = datetime.now()\n",
" model_path = (MODELS_PATH / get_next_run_name(Path(\"runs\"))).with_suffix(\".pth\")\n",
" params_path = (MODELS_PATH / get_next_run_name(Path(\"runs\"))).with_suffix(\".json\")\n",
" with open(params_path, \"w\") as f:\n",
" json.dump(hyperparameters, f, indent=2)\n",
"\n", "\n",
" log_dir = Path(\"runs\") / get_next_run_name(Path(\"runs\"))\n", " log_dir = Path(\"runs\") / get_next_run_name(Path(\"runs\"))\n",
" with SummaryWriter(log_dir) as writer:\n", " with SummaryWriter(log_dir) as writer:\n",
@ -274,8 +292,15 @@
" except Exception:\n", " except Exception:\n",
" raise\n", " raise\n",
" finally:\n", " finally:\n",
" run_name = get_next_run_name(MODELS_PATH)\n",
" model_path = (MODELS_PATH / run_name).with_suffix(\".pth\")\n",
" params_path = (MODELS_PATH / run_name).with_suffix(\".json\")\n",
"\n",
" logging.info(f\"Saving model to {model_path}\")\n", " logging.info(f\"Saving model to {model_path}\")\n",
" torch.save(model.state_dict(), model_path)\n", " with open(model_path, \"wb\") as f:\n",
" torch.save(model.state_dict(), f)\n",
" with open(params_path, \"w\") as f:\n",
" json.dump(hyperparameters, f, indent=2)\n",
" del model\n", " del model\n",
" torch.cuda.empty_cache()\n", " torch.cuda.empty_cache()\n",
" return model_path" " return model_path"
@ -283,7 +308,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -306,9 +331,31 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-06 08:19:45,651 - INFO - Starting run_26 with hparams {\n",
" \"batch_size\": 64,\n",
" \"bin_count\": 32,\n",
" \"clip_gradients\": false,\n",
" \"edit_count\": 8,\n",
" \"learning_rate\": 0.0020871393198725404,\n",
" \"loss\": \"kl\",\n",
" \"model_type\": \"Residual\",\n",
" \"num_epochs\": 16,\n",
" \"scheduler_gamma\": 1.2086440138363033\n",
"}\n",
"2024-06-06 08:19:45,768 - INFO - Loaded 22479 original images\n",
"2024-06-06 08:19:45,778 - INFO - Loaded 2498 original images\n",
"2024-06-06 08:20:02,104 - INFO - Saving model to /home/andras/projects/bipolaroid/models/run_26.pth\n",
"2024-06-06 08:20:02,248 - INFO - Interrupted, stopping\n"
]
}
],
"source": [ "source": [
"from random import choice\n", "from random import choice\n",
"from itertools import count\n", "from itertools import count\n",
@ -325,7 +372,7 @@
" f\"Starting {get_next_run_name(Path(\"runs\"))} with hparams {key}\"\n", " f\"Starting {get_next_run_name(Path(\"runs\"))} with hparams {key}\"\n",
" )\n", " )\n",
" try:\n", " try:\n",
" train(current_hyperparameters, max_duration=timedelta(hours=4), use_tqdm=False)\n", " train(current_hyperparameters, max_duration=timedelta(hours=8), use_tqdm=False)\n",
" except KeyboardInterrupt as e:\n", " except KeyboardInterrupt as e:\n",
" logging.info(\"Interrupted, stopping\")\n", " logging.info(\"Interrupted, stopping\")\n",
" break\n", " break\n",