diff --git a/src/train.ipynb b/src/train.ipynb index 84be519..c424838 100644 --- a/src/train.ipynb +++ b/src/train.ipynb @@ -21,66 +21,98 @@ "source": [ "import os\n", "from utils import set_up_logging, get_device\n", - "from training import train, random_hparam_search\n", - "from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA, MODELS_PATH\n", + "from training import train_with_ray_factory\n", + "from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA\n", + "from ray import tune\n", + "from ray.tune.schedulers import ASHAScheduler\n", + "import os\n", + "from ray.air import RunConfig\n", "\n", "set_up_logging(LOGS_PATH)\n", "\n", + "TRIAL_COUNT = 100\n", + "CHUNK_COUNT = 40\n", + "EPOCH_COUNT = 2\n", + "\n", "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n", "device = get_device()\n", "f\"Using device {device}\"" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# hparams = {\n", - "# \"batch_size\": 64,\n", - "# \"edit_count\": 12,\n", - "# \"bin_count\": 16,\n", - "# \"learning_rate\": 0.001,\n", - "# \"scheduler_gamma\": 0.9,\n", - "# \"num_epochs\": 12,\n", - "# \"model_type\": \"SimpleCNN\",\n", - "# }\n", - "\n", - "# train(\n", - "# hparams,\n", - "# train_data_paths=TRAIN_DATA,\n", - "# test_data_paths=TEST_DATA,\n", - "# log_dir=RUNS_PATH,\n", - "# max_duration=None,\n", - "# use_tqdm=True,\n", - "# device=device,\n", - "# **hparams\n", - "# )" - ] - }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-06-29 10:45:18,539 - INFO - Loaded 22479 original images\n", - "2024-06-29 10:45:18,544 - INFO - Loaded 2498 original images\n" - ] - }, { "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "61d4288a87a24255824e1430d8e69804", - "version_major": 2, - "version_minor": 0 - }, + "text/html": [ + "
\n", + "
\n", + "
\n", + "

Tune Status

\n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
Current time:2024-09-01 22:06:10
Running for: 00:02:06.64
Memory: 22.4/47.0 GiB
\n", + "
\n", + "
\n", + "
\n", + "

System Info

\n", + " Using AsyncHyperBand: num_stopped=1
Bracket: Iter 2.000: -5.516964137554169
Logical resource usage: 32.0/32 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "

Trial Status

\n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
Trial name status loc batch_size dropout_prob elu_alphafeatures kernel_size leaky_relu_alpha leaky_relu_slope learning_ratemodel_type scheduler_gammause_elu use_residual iter total time (s) chunk_test_loss chunk_training_loss
train_with_ray_b7d3c_00000TERMINATED172.29.235.222:1134109 32 0.00395892 1.35107[8, 16, 32] 5 0.0448715 0.00664526 0.0029226 HistogramNet 0.816752True True 3 70.652 5.22149 56.5472
train_with_ray_b7d3c_00001TERMINATED172.29.235.222:1140440 32 0.0439061 1.74642[8, 8, 8, 8, 8,_c140 3 0.00579656 0.0125715 0.00155182HistogramNet 0.898059True False 2 45.4228 5.79311 64.0481
\n", + "
\n", + "
\n", + "\n" + ], "text/plain": [ - "Epoch 0: 0%| | 0/4215 [00:00" ] }, "metadata": {}, @@ -90,81 +122,113 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-06-29 10:50:49,231 - INFO - Epoch 0 train loss: 5050.157035887241\n", - "2024-06-29 10:51:01,594 - INFO - Epoch 0 test loss: 460.928723692894\n" + "\u001b[36m(train_with_ray pid=1134109)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00000_0_batch_size=32,dropout_prob=0.0040,elu_alpha=1.3511,features=8_16_32,kernel_size=5,leaky_relu_alpha=0._2024-09-01_22-04-03/checkpoint_000000)\n", + "\u001b[36m(train_with_ray pid=1134109)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00000_0_batch_size=32,dropout_prob=0.0040,elu_alpha=1.3511,features=8_16_32,kernel_size=5,leaky_relu_alpha=0._2024-09-01_22-04-03/checkpoint_000001)\n", + "\u001b[36m(train_with_ray pid=1134109)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00000_0_batch_size=32,dropout_prob=0.0040,elu_alpha=1.3511,features=8_16_32,kernel_size=5,leaky_relu_alpha=0._2024-09-01_22-04-03/checkpoint_000002)\n", + "\u001b[36m(train_with_ray pid=1140440)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00001_1_batch_size=32,dropout_prob=0.0439,elu_alpha=1.7464,features=8_8_8_8_8_8_8,kernel_size=3,leaky_relu_al_2024-09-01_22-04-03/checkpoint_000000)\n", + "2024-09-01 22:06:10,132\tINFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/andras/projects/bipolaroid/runs5/tune' in 0.0020s.\n", + "2024-09-01 22:06:10,135\tINFO tune.py:1041 -- Total run time: 126.67 seconds (126.64 seconds for the tuning loop).\n" ] }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "55aa7c6adc41427081af06330d541c2d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Epoch 1: 0%| | 0/4215 [00:00 20\u001b[0m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[43m \u001b[49m\u001b[43mhparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 22\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_data_paths\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mTRAIN_DATA\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 23\u001b[0m \u001b[43m \u001b[49m\u001b[43mtest_data_paths\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mTEST_DATA\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[43m \u001b[49m\u001b[43mlog_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mRUNS_PATH\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 25\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_duration\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 26\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_tqdm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 27\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 28\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mhparams\u001b[49m\n\u001b[1;32m 29\u001b[0m \u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/mnt/wsl/PHYSICALDRIVE1/projects/bipolaroid/src/training/train.py:55\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(hyperparameters, train_data_paths, test_data_paths, log_dir, use_tqdm, device, model_type, learning_rate, scheduler_gamma, num_epochs, **_)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch_id, (edited_histogram, original_histogram) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\n\u001b[1;32m 50\u001b[0m tqdm(train_data_loader, desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, unit\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_tqdm\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m train_data_loader\n\u001b[1;32m 53\u001b[0m ):\n\u001b[1;32m 54\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 55\u001b[0m predicted_original \u001b[38;5;241m=\u001b[39m model(\u001b[43medited_histogram\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 56\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss_function(\n\u001b[1;32m 57\u001b[0m torch\u001b[38;5;241m.\u001b[39mlog(predicted_original \u001b[38;5;241m+\u001b[39m EPSILON),\n\u001b[1;32m 58\u001b[0m original_histogram\u001b[38;5;241m.\u001b[39mto(device),\n\u001b[1;32m 59\u001b[0m )\n\u001b[1;32m 61\u001b[0m epoch_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mitem()\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + "\u001b[36m(train_with_ray pid=1140440)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00001_1_batch_size=32,dropout_prob=0.0439,elu_alpha=1.7464,features=8_8_8_8_8_8_8,kernel_size=3,leaky_relu_al_2024-09-01_22-04-03/checkpoint_000001)\n" ] } ], + "source": [ + "config = {\n", + " \"batch_size\": 32,\n", + " \"edit_count\": EPOCH_COUNT,\n", + " \"bin_count\": 32,\n", + " \"learning_rate\": tune.loguniform(5e-4, 5e-3),\n", + " \"scheduler_gamma\": tune.uniform(0.8, 0.95),\n", + " \"elu_alpha\": tune.uniform(0.5, 2),\n", + " \"leaky_relu_slope\": tune.uniform(0, 0.03),\n", + " \"dropout_prob\": tune.uniform(0, 0.1),\n", + " \"chunk_count\": CHUNK_COUNT,\n", + " \"features\": tune.choice(\n", + " [\n", + " [16, 32, 64],\n", + " [16, 32, 64, 128],\n", + " [32, 64],\n", + " [32, 128],\n", + " [8, 16, 32],\n", + " [8, 8, 8, 8, 8],\n", + " [8, 8, 8, 8, 8, 8, 8],\n", + " [16, 16, 16],\n", + " [16, 16, 16, 16, 16],\n", + " [32, 32, 32],\n", + " [32, 32, 32, 32],\n", + " [64, 64],\n", + " [64, 64, 64],\n", + " ]\n", + " ),\n", + " \"use_residual\": tune.choice([True, False]),\n", + " \"kernel_size\": tune.choice([3, 5]),\n", + " \"model_type\": tune.choice([\"HistogramNet\"]),\n", + " \"use_instance_norm\": True,\n", + " \"use_elu\": tune.choice([True, False]),\n", + " \"leaky_relu_alpha\": tune.uniform(0, 0.05),\n", + "}\n", + "scheduler = ASHAScheduler(max_t=config[\"chunk_count\"], grace_period=2)\n", + "\n", + "tuner = tune.Tuner(\n", + " tune.with_resources(\n", + " tune.with_parameters(\n", + " train_with_ray_factory(\n", + " train_data_paths=TRAIN_DATA,\n", + " test_data_paths=TEST_DATA,\n", + " device=device,\n", + " log_dir=RUNS_PATH / \"custom\",\n", + " )\n", + " ),\n", + " resources={\"cpu\": 32, \"gpu\": 1},\n", + " ),\n", + " run_config=RunConfig(storage_path=RUNS_PATH, name=\"tune\"),\n", + " tune_config=tune.TuneConfig(\n", + " metric=\"chunk_test_loss\",\n", + " mode=\"min\",\n", + " scheduler=scheduler,\n", + " num_samples=TRIAL_COUNT,\n", + " ),\n", + " param_space=config,\n", + ")\n", + "results = tuner.fit()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Best trial config: {'batch_size': 32, 'edit_count': 3, 'bin_count': 32, 'learning_rate': 0.0029226033808016005, 'scheduler_gamma': 0.8167516482513361, 'elu_alpha': 1.3510723758569865, 'leaky_relu_slope': 0.0066452562138349025, 'dropout_prob': 0.0039589213934103865, 'chunk_count': 4, 'features': [8, 16, 32], 'use_residual': True, 'kernel_size': 5, 'model_type': 'HistogramNet', 'use_instance_norm': True, 'use_elu': True, 'leaky_relu_alpha': 0.04487148648446764}\n", + "Best trial final validation loss: 5.2214884757995605\n" + ] + } + ], + "source": [ + "best_result = results.get_best_result(\"chunk_test_loss\", \"min\")\n", + "\n", + "print(\"Best trial config: {}\".format(best_result.config))\n", + "print(\n", + " \"Best trial final validation loss: {}\".format(\n", + " best_result.metrics[\"chunk_test_loss\"]\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], "source": [ "# hparams = {\n", "# \"batch_size\": 64,\n", @@ -196,64 +260,6 @@ "# **hparams\n", "# )" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from scipy.stats import loguniform, uniform, randint\n", - "from models import MODELS, test_models\n", - "\n", - "\n", - "hyperparameters = [\n", - " {\n", - " \"batch_size\": [64, 128, 256],\n", - " \"edit_count\": [12],\n", - " \"bin_count\": [16],\n", - " \"learning_rate\": loguniform(5e-4, 5e-3),\n", - " \"scheduler_gamma\": uniform(loc=0.8, scale=0.15),\n", - " \"num_epochs\": [10],\n", - " \"elu_alpha\": uniform(0.5, 1.5),\n", - " \"leaky_relu_slope\": uniform(0, 0.03),\n", - " \"dropout_prob\": uniform(0, 0.1),\n", - " \"features\": [\n", - " [16, 32, 64],\n", - " [16, 32, 64, 128],\n", - " [32, 64],\n", - " [32, 64, 128],\n", - " [8, 16, 32],\n", - " [8, 8, 8, 8, 8],\n", - " [8, 8, 8, 8, 8, 8, 8],\n", - " [16, 16, 16],\n", - " [16, 16, 16, 16, 16],\n", - " [32, 32, 32],\n", - " [32, 32, 32, 32],\n", - " [64, 64],\n", - " [64, 64, 64],\n", - " ],\n", - " \"use_residual\": [True, False],\n", - " \"kernel_size\": [3, 5],\n", - " \"model_type\": [\"HistogramNet\"],\n", - " \"use_instance_norm\": [True],\n", - " \"use_elu\": [True, False],\n", - " \"leaky_relu_alpha\": uniform(0, 0.05),\n", - " }\n", - "]\n", - "\n", - "test_models()\n", - "\n", - "random_hparam_search(\n", - " hyperparameters=hyperparameters,\n", - " train_data_paths=TRAIN_DATA,\n", - " test_data_paths=TEST_DATA,\n", - " models_path=MODELS_PATH,\n", - " tensorboard_path=RUNS_PATH,\n", - " timeout_hours=4,\n", - " device=device,\n", - ")" - ] } ], "metadata": { diff --git a/src/training/train_with_ray.py b/src/training/train_with_ray.py new file mode 100644 index 0000000..5ddcc1d --- /dev/null +++ b/src/training/train_with_ray.py @@ -0,0 +1,205 @@ +from typing import Any, Dict, List, Tuple +from torch.utils.tensorboard import SummaryWriter +from pathlib import Path +from torch.optim import Adam +from .get_next_run_name import get_next_run_name +from utils import serialise_hparams +from visualisation import plot_histograms_in_2d +from models import create_model, load_model, save_model +import torch +from .get_data_loader import get_data_loader +from ray import train +from ray.train import Checkpoint +import os +import tempfile +from more_itertools import distribute + + +EPSILON = 1e-5 + + +def train_with_ray_factory( + train_data_paths: List[Path], + test_data_paths: List[Path], + device: torch.device, + log_dir: Path, +): + def train_with_ray(hyperparameters: Dict[str, Any]): + def inner( + hyperparameters: Dict[str, Any], + chunk_count: int, + **_, + ) -> torch.nn.Module: + train_data_loader = get_data_loader(train_data_paths, **hyperparameters) + test_data_loader = get_data_loader( + test_data_paths, **{**hyperparameters, "edit_count": 1} + ) + examples = next(iter(test_data_loader)) + + model, optimizer, scheduler, start_chunk_id, run_name = ( + load_or_create_state( + device=device, + log_dir=log_dir, + **hyperparameters, + ) + ) + loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device) + + with SummaryWriter(log_dir=log_dir / run_name) as writer: + writer.add_graph(model, examples[0].to(device)) + for chunk_id, chunk in enumerate( + distribute(chunk_count, train_data_loader)[start_chunk_id:-1], + start=start_chunk_id, + ): + chunk_training_loss = 0 + writer.add_scalar( + "Actual learning rate", + scheduler.get_last_lr()[0], + chunk_id, + ) + for batch_id, (edited_histogram, original_histogram) in enumerate( + chunk + ): + global_step = ( + chunk_id * (len(train_data_loader) // chunk_count) + + batch_id + ) + optimizer.zero_grad() + predicted_original = model(edited_histogram.to(device)) + loss = loss_function( + torch.log(predicted_original + EPSILON), + original_histogram.to(device), + ) + + chunk_training_loss += loss.item() + writer.add_scalar( + "Loss/train/batch", loss, global_step=global_step + ) + loss.backward() + optimizer.step() + + with torch.no_grad(): + model.eval() + write_histograms( + model=model, + examples=examples, + writer=writer, + device=device, + global_step=global_step, + ) + chunk_test_loss = 0 + for ( + edited_histogram, + original_histogram, + ) in test_data_loader: + predicted_original = model(edited_histogram.to(device)) + chunk_test_loss += loss_function( + torch.log(predicted_original + EPSILON), + original_histogram.to(device), + ).item() + model.train() + + with tempfile.TemporaryDirectory() as temp_checkpoint_dir: + temp_checkpoint_dir = Path(temp_checkpoint_dir) + checkpoint_path = temp_checkpoint_dir / "checkpoint.pt" + torch.save( + ( + optimizer.state_dict(), + scheduler.state_dict(), + chunk_id, + run_name, + ), + checkpoint_path, + ) + save_model( + model, hyperparameters, temp_checkpoint_dir / "model" + ) + writer.add_hparams( + serialise_hparams(hyperparameters), + { + "Loss/test/epoch": chunk_test_loss, + "Loss/train/epoch": chunk_training_loss, + }, + global_step=global_step, + run_name=(log_dir / run_name).absolute(), + ) + train.report( + { + "chunk_test_loss": chunk_test_loss, + "chunk_training_loss": chunk_training_loss, + }, + checkpoint=Checkpoint.from_directory(temp_checkpoint_dir), + ) + + scheduler.step() + + return inner(hyperparameters=hyperparameters, **hyperparameters) + + return train_with_ray + + +def load_or_create_state( + device, log_dir, model_type, learning_rate, scheduler_gamma, **hyperparameters +) -> Tuple[ + torch.nn.Module, + torch.optim.Optimizer, + torch.optim.lr_scheduler.LRScheduler, + int, + str, +]: + loaded_checkpoint = train.get_checkpoint() + if loaded_checkpoint: + with loaded_checkpoint.as_directory() as loaded_checkpoint_dir: + loaded_checkpoint_dir = Path(loaded_checkpoint_dir) + model, hyperparameters = load_model( + loaded_checkpoint_dir / "model", device=device + ) + optimizer = Adam(model.parameters(), lr=learning_rate) + + optimizer_state, scheduler_state, start_chunk_id, run_name = torch.load( + loaded_checkpoint_dir / "checkpoint.pt" + ) + optimizer.load_state_dict(optimizer_state) + scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, step_size=1, gamma=scheduler_gamma + ) + scheduler.load_state_dict(scheduler_state) + else: + run_name = get_next_run_name(log_dir) + model = create_model( + type=model_type, + hyperparameters=hyperparameters, + device=device, + ).train() + optimizer = Adam(model.parameters(), lr=learning_rate) + scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, step_size=1, gamma=scheduler_gamma + ) + start_chunk_id = 0 + + return model, optimizer, scheduler, start_chunk_id, run_name + + +def write_histograms( + model: torch.nn.Module, + examples: List[Tuple[torch.Tensor, torch.Tensor]], + writer: SummaryWriter, + device: torch.device, + global_step: int, +): + edited_histograms, original_histograms = examples + predicted_originals = model(edited_histograms.to(device)) + for i, (original, edited, predicted) in enumerate( + zip(original_histograms, edited_histograms, predicted_originals) + ): + writer.add_figure( + f"histogram_{i}", + plot_histograms_in_2d( + { + "original": original[0].numpy().squeeze(), + "edited": edited.cpu()[0].numpy().squeeze(), + "predicted": predicted.cpu()[0].numpy().squeeze(), + } + ), + global_step=global_step, + )