Update training
This commit is contained in:
parent
c7c0f292c6
commit
7863611f86
4 changed files with 97 additions and 137 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -2,3 +2,4 @@ __pycache__
|
|||
runs*
|
||||
*.log
|
||||
saved_models/*
|
||||
train.py
|
||||
|
|
|
|||
210
src/train.ipynb
210
src/train.ipynb
|
|
@ -35,24 +35,85 @@
|
|||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# from training import train\n",
|
||||
"from config import RUNS_PATH, TRAIN_DATA, TEST_DATA, MODELS_PATH\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# hparams = {\n",
|
||||
"# \"batch_size\": 64,\n",
|
||||
"# \"edit_count\": 12,\n",
|
||||
"# \"bin_count\": 16,\n",
|
||||
"# \"learning_rate\": 0.001,\n",
|
||||
"# \"scheduler_gamma\": 0.9,\n",
|
||||
"# \"num_epochs\": 12,\n",
|
||||
"# \"model_type\": \"SimpleCNN\",\n",
|
||||
"# }\n",
|
||||
"\n",
|
||||
"# train(\n",
|
||||
"# hparams,\n",
|
||||
"# train_data_paths=TRAIN_DATA,\n",
|
||||
"# test_data_paths=TEST_DATA,\n",
|
||||
"# log_dir=RUNS_PATH,\n",
|
||||
"# max_duration=None,\n",
|
||||
"# use_tqdm=True,\n",
|
||||
"# device=device,\n",
|
||||
"# **hparams\n",
|
||||
"# )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-25 08:59:52,244 - INFO - Testing model Dummy\n",
|
||||
"2024-06-25 08:59:52,249 - INFO - Test passed! Output shape matches input shape.\n",
|
||||
"2024-06-25 08:59:52,249 - INFO - Testing model SimpleCNN\n",
|
||||
"2024-06-25 08:59:52,746 - INFO - Test passed! Output shape matches input shape.\n",
|
||||
"2024-06-25 08:59:52,752 - INFO - Testing model Residual\n",
|
||||
"2024-06-25 08:59:53,853 - INFO - Test passed! Output shape matches input shape.\n",
|
||||
"2024-06-25 08:59:53,917 - INFO - Testing model Residual3\n",
|
||||
"2024-06-25 08:59:55,590 - INFO - Test passed! Output shape matches input shape.\n"
|
||||
"2024-06-27 22:27:47,571 - INFO - Testing model Dummy\n",
|
||||
"2024-06-27 22:27:47,576 - INFO - Test passed! Output shape matches input shape.\n",
|
||||
"2024-06-27 22:27:47,576 - INFO - Testing model SimpleCNN\n",
|
||||
"2024-06-27 22:27:48,012 - INFO - Test passed! Output shape matches input shape.\n",
|
||||
"2024-06-27 22:27:48,014 - INFO - Testing model Residual\n",
|
||||
"2024-06-27 22:27:49,044 - INFO - Test passed! Output shape matches input shape.\n",
|
||||
"2024-06-27 22:27:49,048 - INFO - Testing model HistogramNet\n",
|
||||
"2024-06-27 22:27:50,526 - INFO - Test passed! Output shape matches input shape.\n",
|
||||
"2024-06-27 22:27:50,536 - INFO - Starting run_0 with hparams {\n",
|
||||
" \"batch_size\": 64,\n",
|
||||
" \"bin_count\": 16,\n",
|
||||
" \"dropout_prob\": 0.07931334001160453,\n",
|
||||
" \"edit_count\": 12,\n",
|
||||
" \"elu_alpha\": 1.4632917556519958,\n",
|
||||
" \"features\": [\n",
|
||||
" 8,\n",
|
||||
" 16,\n",
|
||||
" 32\n",
|
||||
" ],\n",
|
||||
" \"kernel_size\": 5,\n",
|
||||
" \"leaky_relu_alpha\": 0.04009731835331415,\n",
|
||||
" \"leaky_relu_slope\": 0.005283501654875031,\n",
|
||||
" \"learning_rate\": 0.000798531032420656,\n",
|
||||
" \"model_type\": \"HistogramNet\",\n",
|
||||
" \"num_epochs\": 12,\n",
|
||||
" \"scheduler_gamma\": 0.8420623161905719,\n",
|
||||
" \"use_elu\": true,\n",
|
||||
" \"use_instance_norm\": false,\n",
|
||||
" \"use_residual\": false\n",
|
||||
"}\n",
|
||||
"2024-06-27 22:27:50,584 - INFO - Loaded 22479 original images\n",
|
||||
"2024-06-27 22:27:50,588 - INFO - Loaded 2498 original images\n",
|
||||
"2024-06-27 22:27:55,003 - INFO - Original result 322461.46875\n",
|
||||
"2024-06-27 22:28:40,001 - INFO - Epoch 0 train loss: 5343.839758455753\n",
|
||||
"2024-06-27 22:28:44,685 - INFO - Epoch 0 test loss: 531.2615858912468\n",
|
||||
"2024-06-27 22:28:45,500 - INFO - Original result 404787.25\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from scipy.stats import loguniform, uniform, randint\n",
|
||||
"from training import random_hparam_search, get_next_run_name\n",
|
||||
"from models import MODELS, test_models\n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
|
@ -67,133 +128,36 @@
|
|||
" \"elu_alpha\": uniform(0.5, 1.5),\n",
|
||||
" \"leaky_relu_slope\": uniform(0, 0.03),\n",
|
||||
" \"dropout_prob\": uniform(0, 0.1),\n",
|
||||
" \"features\": [[16, 32, 64], [32, 64, 128], [8, 16, 32], [8, 8, 8], [16, 16, 16]],\n",
|
||||
" \"kernel_sizes\": [[3, 3, 3]],\n",
|
||||
" \"model_type\": [\"Residual3\"], # list(MODELS.keys()),\n",
|
||||
" \"clip_gradients\": [True, False],\n",
|
||||
" \"features\": [\n",
|
||||
" [16, 32, 64],\n",
|
||||
" [32, 64],\n",
|
||||
" [16, 32],\n",
|
||||
" [32, 64, 128],\n",
|
||||
" [8, 16, 32],\n",
|
||||
" [8, 8, 8],\n",
|
||||
" [8, 8, 8, 8, 8],\n",
|
||||
" [16, 16, 16],\n",
|
||||
" [32, 32],\n",
|
||||
" [64, 64],\n",
|
||||
" ],\n",
|
||||
" \"use_residual\": [True, False],\n",
|
||||
" \"kernel_size\": [3, 5],\n",
|
||||
" \"model_type\": [\"HistogramNet\"],\n",
|
||||
" \"use_instance_norm\": [True, False],\n",
|
||||
" \"use_elu\": [True, False],\n",
|
||||
" \"leaky_relu_alpha\": uniform(0, 0.05),\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"test_models()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# train(\n",
|
||||
"# {\n",
|
||||
"# \"batch_size\": 128,\n",
|
||||
"# \"edit_count\": 12,\n",
|
||||
"# \"bin_count\": 16,\n",
|
||||
"# \"learning_rate\": 1e-3,\n",
|
||||
"# \"scheduler_gamma\": 0.8,\n",
|
||||
"# \"elu_alpha\": 1,\n",
|
||||
"# \"dropout_prob\": 0.05,\n",
|
||||
"# \"features\": [8, 16, 32],\n",
|
||||
"# \"kernel_sizes\": [3, 3, 3],\n",
|
||||
"# \"num_epochs\": 12,\n",
|
||||
"# \"model_type\": \"Residual3\",\n",
|
||||
"# \"clip_gradients\": True,\n",
|
||||
"# \"use_instance_norm\": True,\n",
|
||||
"# \"use_elu\": False,\n",
|
||||
"# \"leaky_relu_alpha\": 0.01,\n",
|
||||
"# }\n",
|
||||
"# )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-25 08:59:55,973 - INFO - Starting run_170 with hparams {\n",
|
||||
" \"batch_size\": 64,\n",
|
||||
" \"bin_count\": 16,\n",
|
||||
" \"clip_gradients\": true,\n",
|
||||
" \"dropout_prob\": 0.09784778880383105,\n",
|
||||
" \"edit_count\": 12,\n",
|
||||
" \"elu_alpha\": 0.5588538605400805,\n",
|
||||
" \"features\": [\n",
|
||||
" 8,\n",
|
||||
" 16,\n",
|
||||
" 32\n",
|
||||
" ],\n",
|
||||
" \"kernel_sizes\": [\n",
|
||||
" 3,\n",
|
||||
" 3,\n",
|
||||
" 3\n",
|
||||
" ],\n",
|
||||
" \"leaky_relu_alpha\": 0.012913890161555076,\n",
|
||||
" \"leaky_relu_slope\": 0.022615416455484896,\n",
|
||||
" \"learning_rate\": 0.002130094098871897,\n",
|
||||
" \"model_type\": \"Residual3\",\n",
|
||||
" \"num_epochs\": 12,\n",
|
||||
" \"scheduler_gamma\": 0.8142448793722726,\n",
|
||||
" \"use_elu\": true,\n",
|
||||
" \"use_instance_norm\": false\n",
|
||||
"}\n",
|
||||
"2024-06-25 08:59:55,976 - INFO - Loaded 1000 original images\n",
|
||||
"2024-06-25 08:59:55,979 - INFO - Loaded 1000 original images\n",
|
||||
"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/torch/jit/_trace.py:1102: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:\n",
|
||||
"Tensor-likes are not close!\n",
|
||||
"\n",
|
||||
"Mismatched elements: 442 / 262144 (0.2%)\n",
|
||||
"Greatest absolute difference: 0.00028640031814575195 at index (14, 0, 3, 6, 1) (up to 1e-05 allowed)\n",
|
||||
"Greatest relative difference: 0.03313953488372093 at index (52, 0, 2, 8, 3) (up to 1e-05 allowed)\n",
|
||||
" _check_trace(\n",
|
||||
"2024-06-25 09:00:10,492 - INFO - Epoch 0 train loss: -668.9345343112946\n",
|
||||
"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/matplotlib/collections.py:996: RuntimeWarning: invalid value encountered in sqrt\n",
|
||||
" scale = np.sqrt(self._sizes) * dpi / 72.0 * self._factor\n",
|
||||
"2024-06-25 09:00:17,910 - INFO - Epoch 0 test loss: -677.4099225997925\n",
|
||||
"2024-06-25 09:00:27,392 - INFO - Epoch 1 train loss: -677.2842514514923\n",
|
||||
"2024-06-25 09:00:35,272 - INFO - Epoch 1 test loss: -677.350163936615\n",
|
||||
"2024-06-25 09:00:44,703 - INFO - Epoch 2 train loss: -677.5818197727203\n",
|
||||
"2024-06-25 09:00:52,274 - INFO - Epoch 2 test loss: -677.3433697223663\n",
|
||||
"2024-06-25 09:01:01,680 - INFO - Epoch 3 train loss: -677.3612790107727\n",
|
||||
"2024-06-25 09:01:09,430 - INFO - Epoch 3 test loss: -677.4993708133698\n",
|
||||
"2024-06-25 09:01:18,837 - INFO - Epoch 4 train loss: -677.3555946350098\n",
|
||||
"2024-06-25 09:01:26,721 - INFO - Epoch 4 test loss: -677.2696735858917\n",
|
||||
"2024-06-25 09:01:36,224 - INFO - Epoch 5 train loss: -677.4827179908752\n",
|
||||
"2024-06-25 09:01:44,065 - INFO - Epoch 5 test loss: -677.4189476966858\n",
|
||||
"2024-06-25 09:01:53,537 - INFO - Epoch 6 train loss: -677.5993602275848\n",
|
||||
"2024-06-25 09:02:01,438 - INFO - Epoch 6 test loss: -677.5275483131409\n",
|
||||
"2024-06-25 09:02:10,913 - INFO - Epoch 7 train loss: -677.417388677597\n",
|
||||
"2024-06-25 09:02:18,622 - INFO - Epoch 7 test loss: -677.5215902328491\n",
|
||||
"2024-06-25 09:02:28,085 - INFO - Epoch 8 train loss: -677.415346622467\n",
|
||||
"2024-06-25 09:02:36,597 - INFO - Epoch 8 test loss: -677.5785489082336\n",
|
||||
"2024-06-25 09:02:46,112 - INFO - Epoch 9 train loss: -677.4984295368195\n",
|
||||
"2024-06-25 09:02:54,038 - INFO - Epoch 9 test loss: -677.5197842121124\n",
|
||||
"2024-06-25 09:03:03,558 - INFO - Epoch 10 train loss: -677.4539258480072\n",
|
||||
"2024-06-25 09:03:11,486 - INFO - Epoch 10 test loss: -677.5460705757141\n",
|
||||
"2024-06-25 09:03:21,255 - INFO - Epoch 11 train loss: -677.54083776474\n",
|
||||
"2024-06-25 09:03:29,637 - INFO - Epoch 11 test loss: -677.658331155777\n",
|
||||
"2024-06-25 09:03:29,818 - INFO - Saving model to /home/andras/projects/bipolaroid/saved_models/run_143.pth\n",
|
||||
"2024-06-25 09:03:29,819 - INFO - Parameter count: 429457\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from training import random_hparam_search\n",
|
||||
"from config import RUNS_PATH, TRAIN_DATA, TEST_DATA, MODELS_PATH\n",
|
||||
"\n",
|
||||
"test_models()\n",
|
||||
"\n",
|
||||
"random_hparam_search(\n",
|
||||
" hyperparameters=hyperparameters,\n",
|
||||
" train_data_paths=TRAIN_DATA,\n",
|
||||
" test_data_paths=TEST_DATA,\n",
|
||||
" models_path=MODELS_PATH,\n",
|
||||
" tensorboard_path=RUNS_PATH,\n",
|
||||
" timeout_hours=8,\n",
|
||||
" tensorboard_path=RUNS_PATH / get_next_run_name(RUNS_PATH),\n",
|
||||
" timeout_hours=4,\n",
|
||||
" device=device,\n",
|
||||
")"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -36,16 +36,10 @@ def random_hparam_search(
|
|||
log_dir = tensorboard_path / get_next_run_name(tensorboard_path)
|
||||
|
||||
try:
|
||||
train_data_loader = get_data_loader(
|
||||
train_data_paths, **current_hyperparameters
|
||||
)
|
||||
test_data_loader = get_data_loader(
|
||||
test_data_paths, **current_hyperparameters
|
||||
)
|
||||
model = train(
|
||||
hyperparameters=current_hyperparameters,
|
||||
train_data_loader=train_data_loader,
|
||||
test_data_loader=test_data_loader,
|
||||
train_data_paths=train_data_paths,
|
||||
test_data_paths=test_data_paths,
|
||||
max_duration=timedelta(hours=timeout_hours),
|
||||
log_dir=log_dir,
|
||||
use_tqdm=False,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from pathlib import Path
|
||||
from torch.optim import Adam
|
||||
|
|
@ -10,32 +10,33 @@ from datetime import timedelta, datetime
|
|||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
from utils import serialise_hparams
|
||||
|
||||
from .get_data_loader import get_data_loader
|
||||
|
||||
EPSILON = 1e-5
|
||||
|
||||
|
||||
def train(
|
||||
hyperparameters: Dict[str, Any],
|
||||
train_data_loader: DataLoader,
|
||||
test_data_loader: DataLoader,
|
||||
train_data_paths: List[Path],
|
||||
test_data_paths: List[Path],
|
||||
log_dir: Path,
|
||||
max_duration: Optional[timedelta],
|
||||
use_tqdm: bool,
|
||||
device: torch.device,
|
||||
model_type: str,
|
||||
bin_count: int,
|
||||
learning_rate: float,
|
||||
scheduler_gamma: float,
|
||||
num_epochs: int,
|
||||
**_,
|
||||
) -> torch.nn.Module:
|
||||
train_data_loader = get_data_loader(train_data_paths, **hyperparameters)
|
||||
test_data_loader = get_data_loader(test_data_paths, **hyperparameters)
|
||||
start_time = datetime.now()
|
||||
|
||||
with SummaryWriter(log_dir) as writer:
|
||||
model = create_model(
|
||||
type=model_type,
|
||||
bin_count=bin_count,
|
||||
hyperparameters=hyperparameters,
|
||||
device=device,
|
||||
).train()
|
||||
writer.add_graph(model, next(iter(train_data_loader))[0].to(device))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue