Fine tune training
This commit is contained in:
parent
9aac55f62d
commit
7cfc201229
3 changed files with 21677 additions and 14350 deletions
|
|
@ -13,7 +13,7 @@ import torch
|
|||
|
||||
MODELS = {
|
||||
# "v1": v1,
|
||||
"SimpleCNN": SimpleCNN,
|
||||
# "SimpleCNN": SimpleCNN,
|
||||
"Residual": Residual,
|
||||
# "NormalisedCNN": NormalisedCNN,
|
||||
# "SmartRes": SmartRes,
|
||||
|
|
|
|||
35936
src/inference.ipynb
35936
src/inference.ipynb
File diff suppressed because one or more lines are too long
|
|
@ -2,11 +2,22 @@
|
|||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"metadata": {}
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Using device cuda:0'"
|
||||
]
|
||||
},
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import logging\n",
|
||||
|
|
@ -31,7 +42,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
|
@ -39,13 +50,13 @@
|
|||
"from editor.models import MODELS\n",
|
||||
"\n",
|
||||
"common_hyperparameters = {\n",
|
||||
" \"batch_size\": [16, 32, 64],\n",
|
||||
" \"edit_count\": [8, 16],\n",
|
||||
" \"batch_size\": [64],\n",
|
||||
" \"edit_count\": [8],\n",
|
||||
" \"bin_count\": [16, 24, 32],\n",
|
||||
" \"clip_gradients\": [True, False],\n",
|
||||
" \"learning_rate\": loguniform(0.00001, 0.01),\n",
|
||||
" \"scheduler_gamma\": uniform(0, 1),\n",
|
||||
" \"num_epochs\": [10],\n",
|
||||
" \"clip_gradients\": [False],\n",
|
||||
" \"learning_rate\": loguniform(3e-4, 3e-3),\n",
|
||||
" \"scheduler_gamma\": uniform(0.5, 1),\n",
|
||||
" \"num_epochs\": [16],\n",
|
||||
" # \"num_epochs\": randint(5, 10),\n",
|
||||
" \"model_type\": list(MODELS.keys()),\n",
|
||||
"}\n",
|
||||
|
|
@ -65,7 +76,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
|
@ -97,9 +108,20 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Testing model SimpleCNN\n",
|
||||
"Test passed! Output shape matches input shape.\n",
|
||||
"Testing model Residual\n",
|
||||
"Test passed! Output shape matches input shape.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from torch.utils.tensorboard import SummaryWriter\n",
|
||||
"from pathlib import Path\n",
|
||||
|
|
@ -123,10 +145,6 @@
|
|||
" hyperparameters: Dict[str, Any], max_duration: timedelta, use_tqdm: bool\n",
|
||||
") -> Path:\n",
|
||||
" start_time = datetime.now()\n",
|
||||
" model_path = (MODELS_PATH / get_next_run_name(Path(\"runs\"))).with_suffix(\".pth\")\n",
|
||||
" params_path = (MODELS_PATH / get_next_run_name(Path(\"runs\"))).with_suffix(\".json\")\n",
|
||||
" with open(params_path, \"w\") as f:\n",
|
||||
" json.dump(hyperparameters, f, indent=2)\n",
|
||||
"\n",
|
||||
" log_dir = Path(\"runs\") / get_next_run_name(Path(\"runs\"))\n",
|
||||
" with SummaryWriter(log_dir) as writer:\n",
|
||||
|
|
@ -274,8 +292,15 @@
|
|||
" except Exception:\n",
|
||||
" raise\n",
|
||||
" finally:\n",
|
||||
" run_name = get_next_run_name(MODELS_PATH)\n",
|
||||
" model_path = (MODELS_PATH / run_name).with_suffix(\".pth\")\n",
|
||||
" params_path = (MODELS_PATH / run_name).with_suffix(\".json\")\n",
|
||||
"\n",
|
||||
" logging.info(f\"Saving model to {model_path}\")\n",
|
||||
" torch.save(model.state_dict(), model_path)\n",
|
||||
" with open(model_path, \"wb\") as f:\n",
|
||||
" torch.save(model.state_dict(), f)\n",
|
||||
" with open(params_path, \"w\") as f:\n",
|
||||
" json.dump(hyperparameters, f, indent=2)\n",
|
||||
" del model\n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
" return model_path"
|
||||
|
|
@ -283,7 +308,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
|
@ -306,9 +331,31 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-06 08:19:45,651 - INFO - Starting run_26 with hparams {\n",
|
||||
" \"batch_size\": 64,\n",
|
||||
" \"bin_count\": 32,\n",
|
||||
" \"clip_gradients\": false,\n",
|
||||
" \"edit_count\": 8,\n",
|
||||
" \"learning_rate\": 0.0020871393198725404,\n",
|
||||
" \"loss\": \"kl\",\n",
|
||||
" \"model_type\": \"Residual\",\n",
|
||||
" \"num_epochs\": 16,\n",
|
||||
" \"scheduler_gamma\": 1.2086440138363033\n",
|
||||
"}\n",
|
||||
"2024-06-06 08:19:45,768 - INFO - Loaded 22479 original images\n",
|
||||
"2024-06-06 08:19:45,778 - INFO - Loaded 2498 original images\n",
|
||||
"2024-06-06 08:20:02,104 - INFO - Saving model to /home/andras/projects/bipolaroid/models/run_26.pth\n",
|
||||
"2024-06-06 08:20:02,248 - INFO - Interrupted, stopping\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from random import choice\n",
|
||||
"from itertools import count\n",
|
||||
|
|
@ -325,7 +372,7 @@
|
|||
" f\"Starting {get_next_run_name(Path(\"runs\"))} with hparams {key}\"\n",
|
||||
" )\n",
|
||||
" try:\n",
|
||||
" train(current_hyperparameters, max_duration=timedelta(hours=4), use_tqdm=False)\n",
|
||||
" train(current_hyperparameters, max_duration=timedelta(hours=8), use_tqdm=False)\n",
|
||||
" except KeyboardInterrupt as e:\n",
|
||||
" logging.info(\"Interrupted, stopping\")\n",
|
||||
" break\n",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue