Update training

This commit is contained in:
Andras Schmelczer 2024-06-27 22:30:27 +01:00
parent c7c0f292c6
commit 7863611f86
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
4 changed files with 97 additions and 137 deletions

1
.gitignore vendored
View file

@ -2,3 +2,4 @@ __pycache__
runs* runs*
*.log *.log
saved_models/* saved_models/*
train.py

View file

@ -35,24 +35,85 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [],
"source": [
"# from training import train\n",
"from config import RUNS_PATH, TRAIN_DATA, TEST_DATA, MODELS_PATH\n",
"\n",
"\n",
"# hparams = {\n",
"# \"batch_size\": 64,\n",
"# \"edit_count\": 12,\n",
"# \"bin_count\": 16,\n",
"# \"learning_rate\": 0.001,\n",
"# \"scheduler_gamma\": 0.9,\n",
"# \"num_epochs\": 12,\n",
"# \"model_type\": \"SimpleCNN\",\n",
"# }\n",
"\n",
"# train(\n",
"# hparams,\n",
"# train_data_paths=TRAIN_DATA,\n",
"# test_data_paths=TEST_DATA,\n",
"# log_dir=RUNS_PATH,\n",
"# max_duration=None,\n",
"# use_tqdm=True,\n",
"# device=device,\n",
"# **hparams\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"2024-06-25 08:59:52,244 - INFO - Testing model Dummy\n", "2024-06-27 22:27:47,571 - INFO - Testing model Dummy\n",
"2024-06-25 08:59:52,249 - INFO - Test passed! Output shape matches input shape.\n", "2024-06-27 22:27:47,576 - INFO - Test passed! Output shape matches input shape.\n",
"2024-06-25 08:59:52,249 - INFO - Testing model SimpleCNN\n", "2024-06-27 22:27:47,576 - INFO - Testing model SimpleCNN\n",
"2024-06-25 08:59:52,746 - INFO - Test passed! Output shape matches input shape.\n", "2024-06-27 22:27:48,012 - INFO - Test passed! Output shape matches input shape.\n",
"2024-06-25 08:59:52,752 - INFO - Testing model Residual\n", "2024-06-27 22:27:48,014 - INFO - Testing model Residual\n",
"2024-06-25 08:59:53,853 - INFO - Test passed! Output shape matches input shape.\n", "2024-06-27 22:27:49,044 - INFO - Test passed! Output shape matches input shape.\n",
"2024-06-25 08:59:53,917 - INFO - Testing model Residual3\n", "2024-06-27 22:27:49,048 - INFO - Testing model HistogramNet\n",
"2024-06-25 08:59:55,590 - INFO - Test passed! Output shape matches input shape.\n" "2024-06-27 22:27:50,526 - INFO - Test passed! Output shape matches input shape.\n",
"2024-06-27 22:27:50,536 - INFO - Starting run_0 with hparams {\n",
" \"batch_size\": 64,\n",
" \"bin_count\": 16,\n",
" \"dropout_prob\": 0.07931334001160453,\n",
" \"edit_count\": 12,\n",
" \"elu_alpha\": 1.4632917556519958,\n",
" \"features\": [\n",
" 8,\n",
" 16,\n",
" 32\n",
" ],\n",
" \"kernel_size\": 5,\n",
" \"leaky_relu_alpha\": 0.04009731835331415,\n",
" \"leaky_relu_slope\": 0.005283501654875031,\n",
" \"learning_rate\": 0.000798531032420656,\n",
" \"model_type\": \"HistogramNet\",\n",
" \"num_epochs\": 12,\n",
" \"scheduler_gamma\": 0.8420623161905719,\n",
" \"use_elu\": true,\n",
" \"use_instance_norm\": false,\n",
" \"use_residual\": false\n",
"}\n",
"2024-06-27 22:27:50,584 - INFO - Loaded 22479 original images\n",
"2024-06-27 22:27:50,588 - INFO - Loaded 2498 original images\n",
"2024-06-27 22:27:55,003 - INFO - Original result 322461.46875\n",
"2024-06-27 22:28:40,001 - INFO - Epoch 0 train loss: 5343.839758455753\n",
"2024-06-27 22:28:44,685 - INFO - Epoch 0 test loss: 531.2615858912468\n",
"2024-06-27 22:28:45,500 - INFO - Original result 404787.25\n"
] ]
} }
], ],
"source": [ "source": [
"from scipy.stats import loguniform, uniform, randint\n", "from scipy.stats import loguniform, uniform, randint\n",
"from training import random_hparam_search, get_next_run_name\n",
"from models import MODELS, test_models\n", "from models import MODELS, test_models\n",
"\n", "\n",
"\n", "\n",
@ -67,133 +128,36 @@
" \"elu_alpha\": uniform(0.5, 1.5),\n", " \"elu_alpha\": uniform(0.5, 1.5),\n",
" \"leaky_relu_slope\": uniform(0, 0.03),\n", " \"leaky_relu_slope\": uniform(0, 0.03),\n",
" \"dropout_prob\": uniform(0, 0.1),\n", " \"dropout_prob\": uniform(0, 0.1),\n",
" \"features\": [[16, 32, 64], [32, 64, 128], [8, 16, 32], [8, 8, 8], [16, 16, 16]],\n", " \"features\": [\n",
" \"kernel_sizes\": [[3, 3, 3]],\n", " [16, 32, 64],\n",
" \"model_type\": [\"Residual3\"], # list(MODELS.keys()),\n", " [32, 64],\n",
" \"clip_gradients\": [True, False],\n", " [16, 32],\n",
" [32, 64, 128],\n",
" [8, 16, 32],\n",
" [8, 8, 8],\n",
" [8, 8, 8, 8, 8],\n",
" [16, 16, 16],\n",
" [32, 32],\n",
" [64, 64],\n",
" ],\n",
" \"use_residual\": [True, False],\n",
" \"kernel_size\": [3, 5],\n",
" \"model_type\": [\"HistogramNet\"],\n",
" \"use_instance_norm\": [True, False],\n", " \"use_instance_norm\": [True, False],\n",
" \"use_elu\": [True, False],\n", " \"use_elu\": [True, False],\n",
" \"leaky_relu_alpha\": uniform(0, 0.05),\n", " \"leaky_relu_alpha\": uniform(0, 0.05),\n",
" }\n", " }\n",
"]\n", "]\n",
"\n", "\n",
"test_models()" "test_models()\n",
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# train(\n",
"# {\n",
"# \"batch_size\": 128,\n",
"# \"edit_count\": 12,\n",
"# \"bin_count\": 16,\n",
"# \"learning_rate\": 1e-3,\n",
"# \"scheduler_gamma\": 0.8,\n",
"# \"elu_alpha\": 1,\n",
"# \"dropout_prob\": 0.05,\n",
"# \"features\": [8, 16, 32],\n",
"# \"kernel_sizes\": [3, 3, 3],\n",
"# \"num_epochs\": 12,\n",
"# \"model_type\": \"Residual3\",\n",
"# \"clip_gradients\": True,\n",
"# \"use_instance_norm\": True,\n",
"# \"use_elu\": False,\n",
"# \"leaky_relu_alpha\": 0.01,\n",
"# }\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-25 08:59:55,973 - INFO - Starting run_170 with hparams {\n",
" \"batch_size\": 64,\n",
" \"bin_count\": 16,\n",
" \"clip_gradients\": true,\n",
" \"dropout_prob\": 0.09784778880383105,\n",
" \"edit_count\": 12,\n",
" \"elu_alpha\": 0.5588538605400805,\n",
" \"features\": [\n",
" 8,\n",
" 16,\n",
" 32\n",
" ],\n",
" \"kernel_sizes\": [\n",
" 3,\n",
" 3,\n",
" 3\n",
" ],\n",
" \"leaky_relu_alpha\": 0.012913890161555076,\n",
" \"leaky_relu_slope\": 0.022615416455484896,\n",
" \"learning_rate\": 0.002130094098871897,\n",
" \"model_type\": \"Residual3\",\n",
" \"num_epochs\": 12,\n",
" \"scheduler_gamma\": 0.8142448793722726,\n",
" \"use_elu\": true,\n",
" \"use_instance_norm\": false\n",
"}\n",
"2024-06-25 08:59:55,976 - INFO - Loaded 1000 original images\n",
"2024-06-25 08:59:55,979 - INFO - Loaded 1000 original images\n",
"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/torch/jit/_trace.py:1102: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:\n",
"Tensor-likes are not close!\n",
"\n",
"Mismatched elements: 442 / 262144 (0.2%)\n",
"Greatest absolute difference: 0.00028640031814575195 at index (14, 0, 3, 6, 1) (up to 1e-05 allowed)\n",
"Greatest relative difference: 0.03313953488372093 at index (52, 0, 2, 8, 3) (up to 1e-05 allowed)\n",
" _check_trace(\n",
"2024-06-25 09:00:10,492 - INFO - Epoch 0 train loss: -668.9345343112946\n",
"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/matplotlib/collections.py:996: RuntimeWarning: invalid value encountered in sqrt\n",
" scale = np.sqrt(self._sizes) * dpi / 72.0 * self._factor\n",
"2024-06-25 09:00:17,910 - INFO - Epoch 0 test loss: -677.4099225997925\n",
"2024-06-25 09:00:27,392 - INFO - Epoch 1 train loss: -677.2842514514923\n",
"2024-06-25 09:00:35,272 - INFO - Epoch 1 test loss: -677.350163936615\n",
"2024-06-25 09:00:44,703 - INFO - Epoch 2 train loss: -677.5818197727203\n",
"2024-06-25 09:00:52,274 - INFO - Epoch 2 test loss: -677.3433697223663\n",
"2024-06-25 09:01:01,680 - INFO - Epoch 3 train loss: -677.3612790107727\n",
"2024-06-25 09:01:09,430 - INFO - Epoch 3 test loss: -677.4993708133698\n",
"2024-06-25 09:01:18,837 - INFO - Epoch 4 train loss: -677.3555946350098\n",
"2024-06-25 09:01:26,721 - INFO - Epoch 4 test loss: -677.2696735858917\n",
"2024-06-25 09:01:36,224 - INFO - Epoch 5 train loss: -677.4827179908752\n",
"2024-06-25 09:01:44,065 - INFO - Epoch 5 test loss: -677.4189476966858\n",
"2024-06-25 09:01:53,537 - INFO - Epoch 6 train loss: -677.5993602275848\n",
"2024-06-25 09:02:01,438 - INFO - Epoch 6 test loss: -677.5275483131409\n",
"2024-06-25 09:02:10,913 - INFO - Epoch 7 train loss: -677.417388677597\n",
"2024-06-25 09:02:18,622 - INFO - Epoch 7 test loss: -677.5215902328491\n",
"2024-06-25 09:02:28,085 - INFO - Epoch 8 train loss: -677.415346622467\n",
"2024-06-25 09:02:36,597 - INFO - Epoch 8 test loss: -677.5785489082336\n",
"2024-06-25 09:02:46,112 - INFO - Epoch 9 train loss: -677.4984295368195\n",
"2024-06-25 09:02:54,038 - INFO - Epoch 9 test loss: -677.5197842121124\n",
"2024-06-25 09:03:03,558 - INFO - Epoch 10 train loss: -677.4539258480072\n",
"2024-06-25 09:03:11,486 - INFO - Epoch 10 test loss: -677.5460705757141\n",
"2024-06-25 09:03:21,255 - INFO - Epoch 11 train loss: -677.54083776474\n",
"2024-06-25 09:03:29,637 - INFO - Epoch 11 test loss: -677.658331155777\n",
"2024-06-25 09:03:29,818 - INFO - Saving model to /home/andras/projects/bipolaroid/saved_models/run_143.pth\n",
"2024-06-25 09:03:29,819 - INFO - Parameter count: 429457\n"
]
}
],
"source": [
"from training import random_hparam_search\n",
"from config import RUNS_PATH, TRAIN_DATA, TEST_DATA, MODELS_PATH\n",
"\n",
"\n", "\n",
"random_hparam_search(\n", "random_hparam_search(\n",
" hyperparameters=hyperparameters,\n", " hyperparameters=hyperparameters,\n",
" train_data_paths=TRAIN_DATA,\n", " train_data_paths=TRAIN_DATA,\n",
" test_data_paths=TEST_DATA,\n", " test_data_paths=TEST_DATA,\n",
" models_path=MODELS_PATH,\n", " models_path=MODELS_PATH,\n",
" tensorboard_path=RUNS_PATH,\n", " tensorboard_path=RUNS_PATH / get_next_run_name(RUNS_PATH),\n",
" timeout_hours=8,\n", " timeout_hours=4,\n",
" device=device,\n", " device=device,\n",
")" ")"
] ]

View file

@ -36,16 +36,10 @@ def random_hparam_search(
log_dir = tensorboard_path / get_next_run_name(tensorboard_path) log_dir = tensorboard_path / get_next_run_name(tensorboard_path)
try: try:
train_data_loader = get_data_loader(
train_data_paths, **current_hyperparameters
)
test_data_loader = get_data_loader(
test_data_paths, **current_hyperparameters
)
model = train( model = train(
hyperparameters=current_hyperparameters, hyperparameters=current_hyperparameters,
train_data_loader=train_data_loader, train_data_paths=train_data_paths,
test_data_loader=test_data_loader, test_data_paths=test_data_paths,
max_duration=timedelta(hours=timeout_hours), max_duration=timedelta(hours=timeout_hours),
log_dir=log_dir, log_dir=log_dir,
use_tqdm=False, use_tqdm=False,

View file

@ -1,5 +1,5 @@
import logging import logging
from typing import Any, Dict, Optional from typing import Any, Dict, List, Optional
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from pathlib import Path from pathlib import Path
from torch.optim import Adam from torch.optim import Adam
@ -10,32 +10,33 @@ from datetime import timedelta, datetime
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torch import torch
from utils import serialise_hparams from utils import serialise_hparams
from .get_data_loader import get_data_loader
EPSILON = 1e-5 EPSILON = 1e-5
def train( def train(
hyperparameters: Dict[str, Any], hyperparameters: Dict[str, Any],
train_data_loader: DataLoader, train_data_paths: List[Path],
test_data_loader: DataLoader, test_data_paths: List[Path],
log_dir: Path, log_dir: Path,
max_duration: Optional[timedelta], max_duration: Optional[timedelta],
use_tqdm: bool, use_tqdm: bool,
device: torch.device, device: torch.device,
model_type: str, model_type: str,
bin_count: int,
learning_rate: float, learning_rate: float,
scheduler_gamma: float, scheduler_gamma: float,
num_epochs: int, num_epochs: int,
**_, **_,
) -> torch.nn.Module: ) -> torch.nn.Module:
train_data_loader = get_data_loader(train_data_paths, **hyperparameters)
test_data_loader = get_data_loader(test_data_paths, **hyperparameters)
start_time = datetime.now() start_time = datetime.now()
with SummaryWriter(log_dir) as writer: with SummaryWriter(log_dir) as writer:
model = create_model( model = create_model(
type=model_type, type=model_type,
bin_count=bin_count, hyperparameters=hyperparameters,
device=device, device=device,
).train() ).train()
writer.add_graph(model, next(iter(train_data_loader))[0].to(device)) writer.add_graph(model, next(iter(train_data_loader))[0].to(device))