Fine tune training
This commit is contained in:
parent
9aac55f62d
commit
7cfc201229
3 changed files with 21677 additions and 14350 deletions
|
|
@ -13,7 +13,7 @@ import torch
|
||||||
|
|
||||||
MODELS = {
|
MODELS = {
|
||||||
# "v1": v1,
|
# "v1": v1,
|
||||||
"SimpleCNN": SimpleCNN,
|
# "SimpleCNN": SimpleCNN,
|
||||||
"Residual": Residual,
|
"Residual": Residual,
|
||||||
# "NormalisedCNN": NormalisedCNN,
|
# "NormalisedCNN": NormalisedCNN,
|
||||||
# "SmartRes": SmartRes,
|
# "SmartRes": SmartRes,
|
||||||
|
|
|
||||||
35936
src/inference.ipynb
35936
src/inference.ipynb
File diff suppressed because one or more lines are too long
|
|
@ -2,11 +2,22 @@
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 1,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"metadata": {}
|
"metadata": {}
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'Using device cuda:0'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import torch\n",
|
"import torch\n",
|
||||||
"import logging\n",
|
"import logging\n",
|
||||||
|
|
@ -31,7 +42,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 2,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
|
@ -39,13 +50,13 @@
|
||||||
"from editor.models import MODELS\n",
|
"from editor.models import MODELS\n",
|
||||||
"\n",
|
"\n",
|
||||||
"common_hyperparameters = {\n",
|
"common_hyperparameters = {\n",
|
||||||
" \"batch_size\": [16, 32, 64],\n",
|
" \"batch_size\": [64],\n",
|
||||||
" \"edit_count\": [8, 16],\n",
|
" \"edit_count\": [8],\n",
|
||||||
" \"bin_count\": [16, 24, 32],\n",
|
" \"bin_count\": [16, 24, 32],\n",
|
||||||
" \"clip_gradients\": [True, False],\n",
|
" \"clip_gradients\": [False],\n",
|
||||||
" \"learning_rate\": loguniform(0.00001, 0.01),\n",
|
" \"learning_rate\": loguniform(3e-4, 3e-3),\n",
|
||||||
" \"scheduler_gamma\": uniform(0, 1),\n",
|
" \"scheduler_gamma\": uniform(0.5, 1),\n",
|
||||||
" \"num_epochs\": [10],\n",
|
" \"num_epochs\": [16],\n",
|
||||||
" # \"num_epochs\": randint(5, 10),\n",
|
" # \"num_epochs\": randint(5, 10),\n",
|
||||||
" \"model_type\": list(MODELS.keys()),\n",
|
" \"model_type\": list(MODELS.keys()),\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
|
|
@ -65,7 +76,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
|
@ -97,9 +108,20 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Testing model SimpleCNN\n",
|
||||||
|
"Test passed! Output shape matches input shape.\n",
|
||||||
|
"Testing model Residual\n",
|
||||||
|
"Test passed! Output shape matches input shape.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from torch.utils.tensorboard import SummaryWriter\n",
|
"from torch.utils.tensorboard import SummaryWriter\n",
|
||||||
"from pathlib import Path\n",
|
"from pathlib import Path\n",
|
||||||
|
|
@ -123,10 +145,6 @@
|
||||||
" hyperparameters: Dict[str, Any], max_duration: timedelta, use_tqdm: bool\n",
|
" hyperparameters: Dict[str, Any], max_duration: timedelta, use_tqdm: bool\n",
|
||||||
") -> Path:\n",
|
") -> Path:\n",
|
||||||
" start_time = datetime.now()\n",
|
" start_time = datetime.now()\n",
|
||||||
" model_path = (MODELS_PATH / get_next_run_name(Path(\"runs\"))).with_suffix(\".pth\")\n",
|
|
||||||
" params_path = (MODELS_PATH / get_next_run_name(Path(\"runs\"))).with_suffix(\".json\")\n",
|
|
||||||
" with open(params_path, \"w\") as f:\n",
|
|
||||||
" json.dump(hyperparameters, f, indent=2)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" log_dir = Path(\"runs\") / get_next_run_name(Path(\"runs\"))\n",
|
" log_dir = Path(\"runs\") / get_next_run_name(Path(\"runs\"))\n",
|
||||||
" with SummaryWriter(log_dir) as writer:\n",
|
" with SummaryWriter(log_dir) as writer:\n",
|
||||||
|
|
@ -274,8 +292,15 @@
|
||||||
" except Exception:\n",
|
" except Exception:\n",
|
||||||
" raise\n",
|
" raise\n",
|
||||||
" finally:\n",
|
" finally:\n",
|
||||||
|
" run_name = get_next_run_name(MODELS_PATH)\n",
|
||||||
|
" model_path = (MODELS_PATH / run_name).with_suffix(\".pth\")\n",
|
||||||
|
" params_path = (MODELS_PATH / run_name).with_suffix(\".json\")\n",
|
||||||
|
"\n",
|
||||||
" logging.info(f\"Saving model to {model_path}\")\n",
|
" logging.info(f\"Saving model to {model_path}\")\n",
|
||||||
" torch.save(model.state_dict(), model_path)\n",
|
" with open(model_path, \"wb\") as f:\n",
|
||||||
|
" torch.save(model.state_dict(), f)\n",
|
||||||
|
" with open(params_path, \"w\") as f:\n",
|
||||||
|
" json.dump(hyperparameters, f, indent=2)\n",
|
||||||
" del model\n",
|
" del model\n",
|
||||||
" torch.cuda.empty_cache()\n",
|
" torch.cuda.empty_cache()\n",
|
||||||
" return model_path"
|
" return model_path"
|
||||||
|
|
@ -283,7 +308,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
|
@ -306,9 +331,31 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"2024-06-06 08:19:45,651 - INFO - Starting run_26 with hparams {\n",
|
||||||
|
" \"batch_size\": 64,\n",
|
||||||
|
" \"bin_count\": 32,\n",
|
||||||
|
" \"clip_gradients\": false,\n",
|
||||||
|
" \"edit_count\": 8,\n",
|
||||||
|
" \"learning_rate\": 0.0020871393198725404,\n",
|
||||||
|
" \"loss\": \"kl\",\n",
|
||||||
|
" \"model_type\": \"Residual\",\n",
|
||||||
|
" \"num_epochs\": 16,\n",
|
||||||
|
" \"scheduler_gamma\": 1.2086440138363033\n",
|
||||||
|
"}\n",
|
||||||
|
"2024-06-06 08:19:45,768 - INFO - Loaded 22479 original images\n",
|
||||||
|
"2024-06-06 08:19:45,778 - INFO - Loaded 2498 original images\n",
|
||||||
|
"2024-06-06 08:20:02,104 - INFO - Saving model to /home/andras/projects/bipolaroid/models/run_26.pth\n",
|
||||||
|
"2024-06-06 08:20:02,248 - INFO - Interrupted, stopping\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from random import choice\n",
|
"from random import choice\n",
|
||||||
"from itertools import count\n",
|
"from itertools import count\n",
|
||||||
|
|
@ -325,7 +372,7 @@
|
||||||
" f\"Starting {get_next_run_name(Path(\"runs\"))} with hparams {key}\"\n",
|
" f\"Starting {get_next_run_name(Path(\"runs\"))} with hparams {key}\"\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" try:\n",
|
" try:\n",
|
||||||
" train(current_hyperparameters, max_duration=timedelta(hours=4), use_tqdm=False)\n",
|
" train(current_hyperparameters, max_duration=timedelta(hours=8), use_tqdm=False)\n",
|
||||||
" except KeyboardInterrupt as e:\n",
|
" except KeyboardInterrupt as e:\n",
|
||||||
" logging.info(\"Interrupted, stopping\")\n",
|
" logging.info(\"Interrupted, stopping\")\n",
|
||||||
" break\n",
|
" break\n",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue