Update training

This commit is contained in:
Andras Schmelczer 2024-06-27 22:30:27 +01:00
parent c7c0f292c6
commit 7863611f86
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
4 changed files with 97 additions and 137 deletions

1
.gitignore vendored
View file

@ -2,3 +2,4 @@ __pycache__
runs*
*.log
saved_models/*
train.py

View file

@ -35,24 +35,85 @@
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# from training import train\n",
"from config import RUNS_PATH, TRAIN_DATA, TEST_DATA, MODELS_PATH\n",
"\n",
"\n",
"# hparams = {\n",
"# \"batch_size\": 64,\n",
"# \"edit_count\": 12,\n",
"# \"bin_count\": 16,\n",
"# \"learning_rate\": 0.001,\n",
"# \"scheduler_gamma\": 0.9,\n",
"# \"num_epochs\": 12,\n",
"# \"model_type\": \"SimpleCNN\",\n",
"# }\n",
"\n",
"# train(\n",
"# hparams,\n",
"# train_data_paths=TRAIN_DATA,\n",
"# test_data_paths=TEST_DATA,\n",
"# log_dir=RUNS_PATH,\n",
"# max_duration=None,\n",
"# use_tqdm=True,\n",
"# device=device,\n",
"# **hparams\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-25 08:59:52,244 - INFO - Testing model Dummy\n",
"2024-06-25 08:59:52,249 - INFO - Test passed! Output shape matches input shape.\n",
"2024-06-25 08:59:52,249 - INFO - Testing model SimpleCNN\n",
"2024-06-25 08:59:52,746 - INFO - Test passed! Output shape matches input shape.\n",
"2024-06-25 08:59:52,752 - INFO - Testing model Residual\n",
"2024-06-25 08:59:53,853 - INFO - Test passed! Output shape matches input shape.\n",
"2024-06-25 08:59:53,917 - INFO - Testing model Residual3\n",
"2024-06-25 08:59:55,590 - INFO - Test passed! Output shape matches input shape.\n"
"2024-06-27 22:27:47,571 - INFO - Testing model Dummy\n",
"2024-06-27 22:27:47,576 - INFO - Test passed! Output shape matches input shape.\n",
"2024-06-27 22:27:47,576 - INFO - Testing model SimpleCNN\n",
"2024-06-27 22:27:48,012 - INFO - Test passed! Output shape matches input shape.\n",
"2024-06-27 22:27:48,014 - INFO - Testing model Residual\n",
"2024-06-27 22:27:49,044 - INFO - Test passed! Output shape matches input shape.\n",
"2024-06-27 22:27:49,048 - INFO - Testing model HistogramNet\n",
"2024-06-27 22:27:50,526 - INFO - Test passed! Output shape matches input shape.\n",
"2024-06-27 22:27:50,536 - INFO - Starting run_0 with hparams {\n",
" \"batch_size\": 64,\n",
" \"bin_count\": 16,\n",
" \"dropout_prob\": 0.07931334001160453,\n",
" \"edit_count\": 12,\n",
" \"elu_alpha\": 1.4632917556519958,\n",
" \"features\": [\n",
" 8,\n",
" 16,\n",
" 32\n",
" ],\n",
" \"kernel_size\": 5,\n",
" \"leaky_relu_alpha\": 0.04009731835331415,\n",
" \"leaky_relu_slope\": 0.005283501654875031,\n",
" \"learning_rate\": 0.000798531032420656,\n",
" \"model_type\": \"HistogramNet\",\n",
" \"num_epochs\": 12,\n",
" \"scheduler_gamma\": 0.8420623161905719,\n",
" \"use_elu\": true,\n",
" \"use_instance_norm\": false,\n",
" \"use_residual\": false\n",
"}\n",
"2024-06-27 22:27:50,584 - INFO - Loaded 22479 original images\n",
"2024-06-27 22:27:50,588 - INFO - Loaded 2498 original images\n",
"2024-06-27 22:27:55,003 - INFO - Original result 322461.46875\n",
"2024-06-27 22:28:40,001 - INFO - Epoch 0 train loss: 5343.839758455753\n",
"2024-06-27 22:28:44,685 - INFO - Epoch 0 test loss: 531.2615858912468\n",
"2024-06-27 22:28:45,500 - INFO - Original result 404787.25\n"
]
}
],
"source": [
"from scipy.stats import loguniform, uniform, randint\n",
"from training import random_hparam_search, get_next_run_name\n",
"from models import MODELS, test_models\n",
"\n",
"\n",
@ -67,133 +128,36 @@
" \"elu_alpha\": uniform(0.5, 1.5),\n",
" \"leaky_relu_slope\": uniform(0, 0.03),\n",
" \"dropout_prob\": uniform(0, 0.1),\n",
" \"features\": [[16, 32, 64], [32, 64, 128], [8, 16, 32], [8, 8, 8], [16, 16, 16]],\n",
" \"kernel_sizes\": [[3, 3, 3]],\n",
" \"model_type\": [\"Residual3\"], # list(MODELS.keys()),\n",
" \"clip_gradients\": [True, False],\n",
" \"features\": [\n",
" [16, 32, 64],\n",
" [32, 64],\n",
" [16, 32],\n",
" [32, 64, 128],\n",
" [8, 16, 32],\n",
" [8, 8, 8],\n",
" [8, 8, 8, 8, 8],\n",
" [16, 16, 16],\n",
" [32, 32],\n",
" [64, 64],\n",
" ],\n",
" \"use_residual\": [True, False],\n",
" \"kernel_size\": [3, 5],\n",
" \"model_type\": [\"HistogramNet\"],\n",
" \"use_instance_norm\": [True, False],\n",
" \"use_elu\": [True, False],\n",
" \"leaky_relu_alpha\": uniform(0, 0.05),\n",
" }\n",
"]\n",
"\n",
"test_models()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# train(\n",
"# {\n",
"# \"batch_size\": 128,\n",
"# \"edit_count\": 12,\n",
"# \"bin_count\": 16,\n",
"# \"learning_rate\": 1e-3,\n",
"# \"scheduler_gamma\": 0.8,\n",
"# \"elu_alpha\": 1,\n",
"# \"dropout_prob\": 0.05,\n",
"# \"features\": [8, 16, 32],\n",
"# \"kernel_sizes\": [3, 3, 3],\n",
"# \"num_epochs\": 12,\n",
"# \"model_type\": \"Residual3\",\n",
"# \"clip_gradients\": True,\n",
"# \"use_instance_norm\": True,\n",
"# \"use_elu\": False,\n",
"# \"leaky_relu_alpha\": 0.01,\n",
"# }\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-25 08:59:55,973 - INFO - Starting run_170 with hparams {\n",
" \"batch_size\": 64,\n",
" \"bin_count\": 16,\n",
" \"clip_gradients\": true,\n",
" \"dropout_prob\": 0.09784778880383105,\n",
" \"edit_count\": 12,\n",
" \"elu_alpha\": 0.5588538605400805,\n",
" \"features\": [\n",
" 8,\n",
" 16,\n",
" 32\n",
" ],\n",
" \"kernel_sizes\": [\n",
" 3,\n",
" 3,\n",
" 3\n",
" ],\n",
" \"leaky_relu_alpha\": 0.012913890161555076,\n",
" \"leaky_relu_slope\": 0.022615416455484896,\n",
" \"learning_rate\": 0.002130094098871897,\n",
" \"model_type\": \"Residual3\",\n",
" \"num_epochs\": 12,\n",
" \"scheduler_gamma\": 0.8142448793722726,\n",
" \"use_elu\": true,\n",
" \"use_instance_norm\": false\n",
"}\n",
"2024-06-25 08:59:55,976 - INFO - Loaded 1000 original images\n",
"2024-06-25 08:59:55,979 - INFO - Loaded 1000 original images\n",
"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/torch/jit/_trace.py:1102: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:\n",
"Tensor-likes are not close!\n",
"\n",
"Mismatched elements: 442 / 262144 (0.2%)\n",
"Greatest absolute difference: 0.00028640031814575195 at index (14, 0, 3, 6, 1) (up to 1e-05 allowed)\n",
"Greatest relative difference: 0.03313953488372093 at index (52, 0, 2, 8, 3) (up to 1e-05 allowed)\n",
" _check_trace(\n",
"2024-06-25 09:00:10,492 - INFO - Epoch 0 train loss: -668.9345343112946\n",
"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/matplotlib/collections.py:996: RuntimeWarning: invalid value encountered in sqrt\n",
" scale = np.sqrt(self._sizes) * dpi / 72.0 * self._factor\n",
"2024-06-25 09:00:17,910 - INFO - Epoch 0 test loss: -677.4099225997925\n",
"2024-06-25 09:00:27,392 - INFO - Epoch 1 train loss: -677.2842514514923\n",
"2024-06-25 09:00:35,272 - INFO - Epoch 1 test loss: -677.350163936615\n",
"2024-06-25 09:00:44,703 - INFO - Epoch 2 train loss: -677.5818197727203\n",
"2024-06-25 09:00:52,274 - INFO - Epoch 2 test loss: -677.3433697223663\n",
"2024-06-25 09:01:01,680 - INFO - Epoch 3 train loss: -677.3612790107727\n",
"2024-06-25 09:01:09,430 - INFO - Epoch 3 test loss: -677.4993708133698\n",
"2024-06-25 09:01:18,837 - INFO - Epoch 4 train loss: -677.3555946350098\n",
"2024-06-25 09:01:26,721 - INFO - Epoch 4 test loss: -677.2696735858917\n",
"2024-06-25 09:01:36,224 - INFO - Epoch 5 train loss: -677.4827179908752\n",
"2024-06-25 09:01:44,065 - INFO - Epoch 5 test loss: -677.4189476966858\n",
"2024-06-25 09:01:53,537 - INFO - Epoch 6 train loss: -677.5993602275848\n",
"2024-06-25 09:02:01,438 - INFO - Epoch 6 test loss: -677.5275483131409\n",
"2024-06-25 09:02:10,913 - INFO - Epoch 7 train loss: -677.417388677597\n",
"2024-06-25 09:02:18,622 - INFO - Epoch 7 test loss: -677.5215902328491\n",
"2024-06-25 09:02:28,085 - INFO - Epoch 8 train loss: -677.415346622467\n",
"2024-06-25 09:02:36,597 - INFO - Epoch 8 test loss: -677.5785489082336\n",
"2024-06-25 09:02:46,112 - INFO - Epoch 9 train loss: -677.4984295368195\n",
"2024-06-25 09:02:54,038 - INFO - Epoch 9 test loss: -677.5197842121124\n",
"2024-06-25 09:03:03,558 - INFO - Epoch 10 train loss: -677.4539258480072\n",
"2024-06-25 09:03:11,486 - INFO - Epoch 10 test loss: -677.5460705757141\n",
"2024-06-25 09:03:21,255 - INFO - Epoch 11 train loss: -677.54083776474\n",
"2024-06-25 09:03:29,637 - INFO - Epoch 11 test loss: -677.658331155777\n",
"2024-06-25 09:03:29,818 - INFO - Saving model to /home/andras/projects/bipolaroid/saved_models/run_143.pth\n",
"2024-06-25 09:03:29,819 - INFO - Parameter count: 429457\n"
]
}
],
"source": [
"from training import random_hparam_search\n",
"from config import RUNS_PATH, TRAIN_DATA, TEST_DATA, MODELS_PATH\n",
"\n",
"test_models()\n",
"\n",
"random_hparam_search(\n",
" hyperparameters=hyperparameters,\n",
" train_data_paths=TRAIN_DATA,\n",
" test_data_paths=TEST_DATA,\n",
" models_path=MODELS_PATH,\n",
" tensorboard_path=RUNS_PATH,\n",
" timeout_hours=8,\n",
" tensorboard_path=RUNS_PATH / get_next_run_name(RUNS_PATH),\n",
" timeout_hours=4,\n",
" device=device,\n",
")"
]

View file

@ -36,16 +36,10 @@ def random_hparam_search(
log_dir = tensorboard_path / get_next_run_name(tensorboard_path)
try:
train_data_loader = get_data_loader(
train_data_paths, **current_hyperparameters
)
test_data_loader = get_data_loader(
test_data_paths, **current_hyperparameters
)
model = train(
hyperparameters=current_hyperparameters,
train_data_loader=train_data_loader,
test_data_loader=test_data_loader,
train_data_paths=train_data_paths,
test_data_paths=test_data_paths,
max_duration=timedelta(hours=timeout_hours),
log_dir=log_dir,
use_tqdm=False,

View file

@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from torch.optim import Adam
@ -10,32 +10,33 @@ from datetime import timedelta, datetime
from torch.utils.data import DataLoader
import torch
from utils import serialise_hparams
from .get_data_loader import get_data_loader
EPSILON = 1e-5
def train(
hyperparameters: Dict[str, Any],
train_data_loader: DataLoader,
test_data_loader: DataLoader,
train_data_paths: List[Path],
test_data_paths: List[Path],
log_dir: Path,
max_duration: Optional[timedelta],
use_tqdm: bool,
device: torch.device,
model_type: str,
bin_count: int,
learning_rate: float,
scheduler_gamma: float,
num_epochs: int,
**_,
) -> torch.nn.Module:
train_data_loader = get_data_loader(train_data_paths, **hyperparameters)
test_data_loader = get_data_loader(test_data_paths, **hyperparameters)
start_time = datetime.now()
with SummaryWriter(log_dir) as writer:
model = create_model(
type=model_type,
bin_count=bin_count,
hyperparameters=hyperparameters,
device=device,
).train()
writer.add_graph(model, next(iter(train_data_loader))[0].to(device))