Update training
This commit is contained in:
parent
c7c0f292c6
commit
7863611f86
4 changed files with 97 additions and 137 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -2,3 +2,4 @@ __pycache__
|
||||||
runs*
|
runs*
|
||||||
*.log
|
*.log
|
||||||
saved_models/*
|
saved_models/*
|
||||||
|
train.py
|
||||||
|
|
|
||||||
210
src/train.ipynb
210
src/train.ipynb
|
|
@ -35,24 +35,85 @@
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 2,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# from training import train\n",
|
||||||
|
"from config import RUNS_PATH, TRAIN_DATA, TEST_DATA, MODELS_PATH\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# hparams = {\n",
|
||||||
|
"# \"batch_size\": 64,\n",
|
||||||
|
"# \"edit_count\": 12,\n",
|
||||||
|
"# \"bin_count\": 16,\n",
|
||||||
|
"# \"learning_rate\": 0.001,\n",
|
||||||
|
"# \"scheduler_gamma\": 0.9,\n",
|
||||||
|
"# \"num_epochs\": 12,\n",
|
||||||
|
"# \"model_type\": \"SimpleCNN\",\n",
|
||||||
|
"# }\n",
|
||||||
|
"\n",
|
||||||
|
"# train(\n",
|
||||||
|
"# hparams,\n",
|
||||||
|
"# train_data_paths=TRAIN_DATA,\n",
|
||||||
|
"# test_data_paths=TEST_DATA,\n",
|
||||||
|
"# log_dir=RUNS_PATH,\n",
|
||||||
|
"# max_duration=None,\n",
|
||||||
|
"# use_tqdm=True,\n",
|
||||||
|
"# device=device,\n",
|
||||||
|
"# **hparams\n",
|
||||||
|
"# )"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"2024-06-25 08:59:52,244 - INFO - Testing model Dummy\n",
|
"2024-06-27 22:27:47,571 - INFO - Testing model Dummy\n",
|
||||||
"2024-06-25 08:59:52,249 - INFO - Test passed! Output shape matches input shape.\n",
|
"2024-06-27 22:27:47,576 - INFO - Test passed! Output shape matches input shape.\n",
|
||||||
"2024-06-25 08:59:52,249 - INFO - Testing model SimpleCNN\n",
|
"2024-06-27 22:27:47,576 - INFO - Testing model SimpleCNN\n",
|
||||||
"2024-06-25 08:59:52,746 - INFO - Test passed! Output shape matches input shape.\n",
|
"2024-06-27 22:27:48,012 - INFO - Test passed! Output shape matches input shape.\n",
|
||||||
"2024-06-25 08:59:52,752 - INFO - Testing model Residual\n",
|
"2024-06-27 22:27:48,014 - INFO - Testing model Residual\n",
|
||||||
"2024-06-25 08:59:53,853 - INFO - Test passed! Output shape matches input shape.\n",
|
"2024-06-27 22:27:49,044 - INFO - Test passed! Output shape matches input shape.\n",
|
||||||
"2024-06-25 08:59:53,917 - INFO - Testing model Residual3\n",
|
"2024-06-27 22:27:49,048 - INFO - Testing model HistogramNet\n",
|
||||||
"2024-06-25 08:59:55,590 - INFO - Test passed! Output shape matches input shape.\n"
|
"2024-06-27 22:27:50,526 - INFO - Test passed! Output shape matches input shape.\n",
|
||||||
|
"2024-06-27 22:27:50,536 - INFO - Starting run_0 with hparams {\n",
|
||||||
|
" \"batch_size\": 64,\n",
|
||||||
|
" \"bin_count\": 16,\n",
|
||||||
|
" \"dropout_prob\": 0.07931334001160453,\n",
|
||||||
|
" \"edit_count\": 12,\n",
|
||||||
|
" \"elu_alpha\": 1.4632917556519958,\n",
|
||||||
|
" \"features\": [\n",
|
||||||
|
" 8,\n",
|
||||||
|
" 16,\n",
|
||||||
|
" 32\n",
|
||||||
|
" ],\n",
|
||||||
|
" \"kernel_size\": 5,\n",
|
||||||
|
" \"leaky_relu_alpha\": 0.04009731835331415,\n",
|
||||||
|
" \"leaky_relu_slope\": 0.005283501654875031,\n",
|
||||||
|
" \"learning_rate\": 0.000798531032420656,\n",
|
||||||
|
" \"model_type\": \"HistogramNet\",\n",
|
||||||
|
" \"num_epochs\": 12,\n",
|
||||||
|
" \"scheduler_gamma\": 0.8420623161905719,\n",
|
||||||
|
" \"use_elu\": true,\n",
|
||||||
|
" \"use_instance_norm\": false,\n",
|
||||||
|
" \"use_residual\": false\n",
|
||||||
|
"}\n",
|
||||||
|
"2024-06-27 22:27:50,584 - INFO - Loaded 22479 original images\n",
|
||||||
|
"2024-06-27 22:27:50,588 - INFO - Loaded 2498 original images\n",
|
||||||
|
"2024-06-27 22:27:55,003 - INFO - Original result 322461.46875\n",
|
||||||
|
"2024-06-27 22:28:40,001 - INFO - Epoch 0 train loss: 5343.839758455753\n",
|
||||||
|
"2024-06-27 22:28:44,685 - INFO - Epoch 0 test loss: 531.2615858912468\n",
|
||||||
|
"2024-06-27 22:28:45,500 - INFO - Original result 404787.25\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from scipy.stats import loguniform, uniform, randint\n",
|
"from scipy.stats import loguniform, uniform, randint\n",
|
||||||
|
"from training import random_hparam_search, get_next_run_name\n",
|
||||||
"from models import MODELS, test_models\n",
|
"from models import MODELS, test_models\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -67,133 +128,36 @@
|
||||||
" \"elu_alpha\": uniform(0.5, 1.5),\n",
|
" \"elu_alpha\": uniform(0.5, 1.5),\n",
|
||||||
" \"leaky_relu_slope\": uniform(0, 0.03),\n",
|
" \"leaky_relu_slope\": uniform(0, 0.03),\n",
|
||||||
" \"dropout_prob\": uniform(0, 0.1),\n",
|
" \"dropout_prob\": uniform(0, 0.1),\n",
|
||||||
" \"features\": [[16, 32, 64], [32, 64, 128], [8, 16, 32], [8, 8, 8], [16, 16, 16]],\n",
|
" \"features\": [\n",
|
||||||
" \"kernel_sizes\": [[3, 3, 3]],\n",
|
" [16, 32, 64],\n",
|
||||||
" \"model_type\": [\"Residual3\"], # list(MODELS.keys()),\n",
|
" [32, 64],\n",
|
||||||
" \"clip_gradients\": [True, False],\n",
|
" [16, 32],\n",
|
||||||
|
" [32, 64, 128],\n",
|
||||||
|
" [8, 16, 32],\n",
|
||||||
|
" [8, 8, 8],\n",
|
||||||
|
" [8, 8, 8, 8, 8],\n",
|
||||||
|
" [16, 16, 16],\n",
|
||||||
|
" [32, 32],\n",
|
||||||
|
" [64, 64],\n",
|
||||||
|
" ],\n",
|
||||||
|
" \"use_residual\": [True, False],\n",
|
||||||
|
" \"kernel_size\": [3, 5],\n",
|
||||||
|
" \"model_type\": [\"HistogramNet\"],\n",
|
||||||
" \"use_instance_norm\": [True, False],\n",
|
" \"use_instance_norm\": [True, False],\n",
|
||||||
" \"use_elu\": [True, False],\n",
|
" \"use_elu\": [True, False],\n",
|
||||||
" \"leaky_relu_alpha\": uniform(0, 0.05),\n",
|
" \"leaky_relu_alpha\": uniform(0, 0.05),\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
"]\n",
|
"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"test_models()"
|
"test_models()\n",
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# train(\n",
|
|
||||||
"# {\n",
|
|
||||||
"# \"batch_size\": 128,\n",
|
|
||||||
"# \"edit_count\": 12,\n",
|
|
||||||
"# \"bin_count\": 16,\n",
|
|
||||||
"# \"learning_rate\": 1e-3,\n",
|
|
||||||
"# \"scheduler_gamma\": 0.8,\n",
|
|
||||||
"# \"elu_alpha\": 1,\n",
|
|
||||||
"# \"dropout_prob\": 0.05,\n",
|
|
||||||
"# \"features\": [8, 16, 32],\n",
|
|
||||||
"# \"kernel_sizes\": [3, 3, 3],\n",
|
|
||||||
"# \"num_epochs\": 12,\n",
|
|
||||||
"# \"model_type\": \"Residual3\",\n",
|
|
||||||
"# \"clip_gradients\": True,\n",
|
|
||||||
"# \"use_instance_norm\": True,\n",
|
|
||||||
"# \"use_elu\": False,\n",
|
|
||||||
"# \"leaky_relu_alpha\": 0.01,\n",
|
|
||||||
"# }\n",
|
|
||||||
"# )"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"2024-06-25 08:59:55,973 - INFO - Starting run_170 with hparams {\n",
|
|
||||||
" \"batch_size\": 64,\n",
|
|
||||||
" \"bin_count\": 16,\n",
|
|
||||||
" \"clip_gradients\": true,\n",
|
|
||||||
" \"dropout_prob\": 0.09784778880383105,\n",
|
|
||||||
" \"edit_count\": 12,\n",
|
|
||||||
" \"elu_alpha\": 0.5588538605400805,\n",
|
|
||||||
" \"features\": [\n",
|
|
||||||
" 8,\n",
|
|
||||||
" 16,\n",
|
|
||||||
" 32\n",
|
|
||||||
" ],\n",
|
|
||||||
" \"kernel_sizes\": [\n",
|
|
||||||
" 3,\n",
|
|
||||||
" 3,\n",
|
|
||||||
" 3\n",
|
|
||||||
" ],\n",
|
|
||||||
" \"leaky_relu_alpha\": 0.012913890161555076,\n",
|
|
||||||
" \"leaky_relu_slope\": 0.022615416455484896,\n",
|
|
||||||
" \"learning_rate\": 0.002130094098871897,\n",
|
|
||||||
" \"model_type\": \"Residual3\",\n",
|
|
||||||
" \"num_epochs\": 12,\n",
|
|
||||||
" \"scheduler_gamma\": 0.8142448793722726,\n",
|
|
||||||
" \"use_elu\": true,\n",
|
|
||||||
" \"use_instance_norm\": false\n",
|
|
||||||
"}\n",
|
|
||||||
"2024-06-25 08:59:55,976 - INFO - Loaded 1000 original images\n",
|
|
||||||
"2024-06-25 08:59:55,979 - INFO - Loaded 1000 original images\n",
|
|
||||||
"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/torch/jit/_trace.py:1102: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:\n",
|
|
||||||
"Tensor-likes are not close!\n",
|
|
||||||
"\n",
|
|
||||||
"Mismatched elements: 442 / 262144 (0.2%)\n",
|
|
||||||
"Greatest absolute difference: 0.00028640031814575195 at index (14, 0, 3, 6, 1) (up to 1e-05 allowed)\n",
|
|
||||||
"Greatest relative difference: 0.03313953488372093 at index (52, 0, 2, 8, 3) (up to 1e-05 allowed)\n",
|
|
||||||
" _check_trace(\n",
|
|
||||||
"2024-06-25 09:00:10,492 - INFO - Epoch 0 train loss: -668.9345343112946\n",
|
|
||||||
"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/matplotlib/collections.py:996: RuntimeWarning: invalid value encountered in sqrt\n",
|
|
||||||
" scale = np.sqrt(self._sizes) * dpi / 72.0 * self._factor\n",
|
|
||||||
"2024-06-25 09:00:17,910 - INFO - Epoch 0 test loss: -677.4099225997925\n",
|
|
||||||
"2024-06-25 09:00:27,392 - INFO - Epoch 1 train loss: -677.2842514514923\n",
|
|
||||||
"2024-06-25 09:00:35,272 - INFO - Epoch 1 test loss: -677.350163936615\n",
|
|
||||||
"2024-06-25 09:00:44,703 - INFO - Epoch 2 train loss: -677.5818197727203\n",
|
|
||||||
"2024-06-25 09:00:52,274 - INFO - Epoch 2 test loss: -677.3433697223663\n",
|
|
||||||
"2024-06-25 09:01:01,680 - INFO - Epoch 3 train loss: -677.3612790107727\n",
|
|
||||||
"2024-06-25 09:01:09,430 - INFO - Epoch 3 test loss: -677.4993708133698\n",
|
|
||||||
"2024-06-25 09:01:18,837 - INFO - Epoch 4 train loss: -677.3555946350098\n",
|
|
||||||
"2024-06-25 09:01:26,721 - INFO - Epoch 4 test loss: -677.2696735858917\n",
|
|
||||||
"2024-06-25 09:01:36,224 - INFO - Epoch 5 train loss: -677.4827179908752\n",
|
|
||||||
"2024-06-25 09:01:44,065 - INFO - Epoch 5 test loss: -677.4189476966858\n",
|
|
||||||
"2024-06-25 09:01:53,537 - INFO - Epoch 6 train loss: -677.5993602275848\n",
|
|
||||||
"2024-06-25 09:02:01,438 - INFO - Epoch 6 test loss: -677.5275483131409\n",
|
|
||||||
"2024-06-25 09:02:10,913 - INFO - Epoch 7 train loss: -677.417388677597\n",
|
|
||||||
"2024-06-25 09:02:18,622 - INFO - Epoch 7 test loss: -677.5215902328491\n",
|
|
||||||
"2024-06-25 09:02:28,085 - INFO - Epoch 8 train loss: -677.415346622467\n",
|
|
||||||
"2024-06-25 09:02:36,597 - INFO - Epoch 8 test loss: -677.5785489082336\n",
|
|
||||||
"2024-06-25 09:02:46,112 - INFO - Epoch 9 train loss: -677.4984295368195\n",
|
|
||||||
"2024-06-25 09:02:54,038 - INFO - Epoch 9 test loss: -677.5197842121124\n",
|
|
||||||
"2024-06-25 09:03:03,558 - INFO - Epoch 10 train loss: -677.4539258480072\n",
|
|
||||||
"2024-06-25 09:03:11,486 - INFO - Epoch 10 test loss: -677.5460705757141\n",
|
|
||||||
"2024-06-25 09:03:21,255 - INFO - Epoch 11 train loss: -677.54083776474\n",
|
|
||||||
"2024-06-25 09:03:29,637 - INFO - Epoch 11 test loss: -677.658331155777\n",
|
|
||||||
"2024-06-25 09:03:29,818 - INFO - Saving model to /home/andras/projects/bipolaroid/saved_models/run_143.pth\n",
|
|
||||||
"2024-06-25 09:03:29,819 - INFO - Parameter count: 429457\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"from training import random_hparam_search\n",
|
|
||||||
"from config import RUNS_PATH, TRAIN_DATA, TEST_DATA, MODELS_PATH\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"random_hparam_search(\n",
|
"random_hparam_search(\n",
|
||||||
" hyperparameters=hyperparameters,\n",
|
" hyperparameters=hyperparameters,\n",
|
||||||
" train_data_paths=TRAIN_DATA,\n",
|
" train_data_paths=TRAIN_DATA,\n",
|
||||||
" test_data_paths=TEST_DATA,\n",
|
" test_data_paths=TEST_DATA,\n",
|
||||||
" models_path=MODELS_PATH,\n",
|
" models_path=MODELS_PATH,\n",
|
||||||
" tensorboard_path=RUNS_PATH,\n",
|
" tensorboard_path=RUNS_PATH / get_next_run_name(RUNS_PATH),\n",
|
||||||
" timeout_hours=8,\n",
|
" timeout_hours=4,\n",
|
||||||
" device=device,\n",
|
" device=device,\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -36,16 +36,10 @@ def random_hparam_search(
|
||||||
log_dir = tensorboard_path / get_next_run_name(tensorboard_path)
|
log_dir = tensorboard_path / get_next_run_name(tensorboard_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
train_data_loader = get_data_loader(
|
|
||||||
train_data_paths, **current_hyperparameters
|
|
||||||
)
|
|
||||||
test_data_loader = get_data_loader(
|
|
||||||
test_data_paths, **current_hyperparameters
|
|
||||||
)
|
|
||||||
model = train(
|
model = train(
|
||||||
hyperparameters=current_hyperparameters,
|
hyperparameters=current_hyperparameters,
|
||||||
train_data_loader=train_data_loader,
|
train_data_paths=train_data_paths,
|
||||||
test_data_loader=test_data_loader,
|
test_data_paths=test_data_paths,
|
||||||
max_duration=timedelta(hours=timeout_hours),
|
max_duration=timedelta(hours=timeout_hours),
|
||||||
log_dir=log_dir,
|
log_dir=log_dir,
|
||||||
use_tqdm=False,
|
use_tqdm=False,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
|
|
@ -10,32 +10,33 @@ from datetime import timedelta, datetime
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import torch
|
import torch
|
||||||
from utils import serialise_hparams
|
from utils import serialise_hparams
|
||||||
|
from .get_data_loader import get_data_loader
|
||||||
|
|
||||||
EPSILON = 1e-5
|
EPSILON = 1e-5
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
hyperparameters: Dict[str, Any],
|
hyperparameters: Dict[str, Any],
|
||||||
train_data_loader: DataLoader,
|
train_data_paths: List[Path],
|
||||||
test_data_loader: DataLoader,
|
test_data_paths: List[Path],
|
||||||
log_dir: Path,
|
log_dir: Path,
|
||||||
max_duration: Optional[timedelta],
|
max_duration: Optional[timedelta],
|
||||||
use_tqdm: bool,
|
use_tqdm: bool,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
model_type: str,
|
model_type: str,
|
||||||
bin_count: int,
|
|
||||||
learning_rate: float,
|
learning_rate: float,
|
||||||
scheduler_gamma: float,
|
scheduler_gamma: float,
|
||||||
num_epochs: int,
|
num_epochs: int,
|
||||||
**_,
|
**_,
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
|
train_data_loader = get_data_loader(train_data_paths, **hyperparameters)
|
||||||
|
test_data_loader = get_data_loader(test_data_paths, **hyperparameters)
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
|
|
||||||
with SummaryWriter(log_dir) as writer:
|
with SummaryWriter(log_dir) as writer:
|
||||||
model = create_model(
|
model = create_model(
|
||||||
type=model_type,
|
type=model_type,
|
||||||
bin_count=bin_count,
|
hyperparameters=hyperparameters,
|
||||||
device=device,
|
device=device,
|
||||||
).train()
|
).train()
|
||||||
writer.add_graph(model, next(iter(train_data_loader))[0].to(device))
|
writer.add_graph(model, next(iter(train_data_loader))[0].to(device))
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue