bipolaroidbipolaroid/src/train.ipynb

409 lines
16 KiB
Text

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"metadata": {}
},
"outputs": [
{
"data": {
"text/plain": [
"'Using device cuda:0'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"import logging\n",
"import os\n",
"from datetime import datetime\n",
"\n",
"logging.basicConfig(\n",
" level=logging.INFO,\n",
" format=\"%(asctime)s - %(levelname)s - %(message)s\",\n",
" handlers=[\n",
" logging.StreamHandler(),\n",
" logging.FileHandler(\n",
" f\"train-{datetime.now().isoformat(timespec='minutes')}.log\"\n",
" ),\n",
" ],\n",
")\n",
"\n",
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n",
"device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
"f\"Using device {device}\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from scipy.stats import loguniform, uniform, randint\n",
"from editor.models import MODELS\n",
"\n",
"common_hyperparameters = {\n",
" \"batch_size\": [64],\n",
" \"edit_count\": [8],\n",
" \"bin_count\": [16, 24, 32],\n",
" \"clip_gradients\": [False],\n",
" \"learning_rate\": loguniform(3e-4, 3e-3),\n",
" \"scheduler_gamma\": uniform(0.5, 1),\n",
" \"num_epochs\": [16],\n",
" # \"num_epochs\": randint(5, 10),\n",
" \"model_type\": list(MODELS.keys()),\n",
"}\n",
"hyperparameters = [\n",
" # {\n",
" # **common_hyperparameters,\n",
" # \"loss\": [\"progressive\"],\n",
" # \"loss_sizes\": [[4, 8, 16, 32], [8, 16, 32], [16, 32], [8, 32]],\n",
" # \"loss_damping\": uniform(0.2, 5),\n",
" # },\n",
" {\n",
" **common_hyperparameters,\n",
" \"loss\": [\"kl\"],\n",
" },\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"from typing import List, Any, Dict\n",
"from torch.utils.data import DataLoader\n",
"from config import CACHE_PATH\n",
"from editor.training import HistogramDataset\n",
"\n",
"\n",
"def get_data_loader(data: List[Path], hyperparameters: Dict[str, Any]) -> DataLoader:\n",
" return DataLoader(\n",
" dataset=HistogramDataset(\n",
" paths=data,\n",
" edit_count=hyperparameters[\"edit_count\"],\n",
" bin_count=hyperparameters[\"bin_count\"],\n",
" delete_corrupt_images=False,\n",
" cache_path=CACHE_PATH,\n",
" ),\n",
" batch_size=hyperparameters[\"batch_size\"],\n",
" shuffle=True,\n",
" num_workers=os.cpu_count(),\n",
" )\n",
"\n",
"\n",
"def serialise_hparams(hyperparameters: Dict[str, Any]) -> Dict[str, Any]:\n",
" return {k: str(v) if isinstance(v, list) else v for k, v in hyperparameters.items()}"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Testing model SimpleCNN\n",
"Test passed! Output shape matches input shape.\n",
"Testing model Residual\n",
"Test passed! Output shape matches input shape.\n"
]
}
],
"source": [
"from torch.utils.tensorboard import SummaryWriter\n",
"from pathlib import Path\n",
"from torch.optim import Adam\n",
"from tqdm.notebook import tqdm\n",
"from torch.nn.utils import clip_grad_norm_\n",
"from editor.training import ProgressivePoolingLoss\n",
"from editor.utils import get_next_run_name\n",
"from editor.visualisation import plot_histograms_in_2d\n",
"from editor.models import create_model, test_models\n",
"from data import TRAIN_DATA, TEST_DATA\n",
"from datetime import timedelta, datetime\n",
"import json\n",
"from config import MODELS_PATH\n",
"\n",
"\n",
"test_models()\n",
"\n",
"\n",
"def train(\n",
" hyperparameters: Dict[str, Any], max_duration: timedelta, use_tqdm: bool\n",
") -> Path:\n",
" start_time = datetime.now()\n",
"\n",
" log_dir = Path(\"runs\") / get_next_run_name(Path(\"runs\"))\n",
" with SummaryWriter(log_dir) as writer:\n",
" train_data_loader = get_data_loader(TRAIN_DATA, hyperparameters)\n",
" test_data_loader = get_data_loader(TEST_DATA, hyperparameters)\n",
"\n",
" model = (\n",
" create_model(\n",
" type=hyperparameters[\"model_type\"],\n",
" bin_count=hyperparameters[\"bin_count\"],\n",
" )\n",
" .train()\n",
" .to(device)\n",
" )\n",
" writer.add_graph(model, next(iter(train_data_loader))[0].to(device))\n",
"\n",
" optimizer = Adam(model.parameters(), lr=hyperparameters[\"learning_rate\"])\n",
" scheduler = torch.optim.lr_scheduler.StepLR(\n",
" optimizer, step_size=1, gamma=hyperparameters[\"scheduler_gamma\"]\n",
" )\n",
"\n",
" loss_function = {\n",
" \"progressive\": lambda: ProgressivePoolingLoss(\n",
" target_sizes=hyperparameters[\"loss_sizes\"],\n",
" damping=hyperparameters[\"loss_damping\"],\n",
" ),\n",
" \"kl\": lambda: torch.nn.KLDivLoss(reduction=\"batchmean\"),\n",
" }[hyperparameters[\"loss\"]]().to(device)\n",
"\n",
" try:\n",
" for epoch in range(hyperparameters[\"num_epochs\"]):\n",
" epoch_loss = 0\n",
" writer.add_scalar(\n",
" \"Actual learning rate\", scheduler.get_last_lr()[0], epoch\n",
" )\n",
" for batch_id, (edited_histogram, original_histogram) in enumerate(\n",
" tqdm(train_data_loader, desc=f\"Epoch {epoch}\", unit=\"batch\")\n",
" if use_tqdm\n",
" else train_data_loader\n",
" ):\n",
" current_time = datetime.now()\n",
" if current_time - start_time > max_duration:\n",
" raise TimeoutError(f\"Time limit {max_duration} exceeded\")\n",
" edited_histogram = edited_histogram.to(device)\n",
" original_histogram = original_histogram.to(device)\n",
"\n",
" optimizer.zero_grad()\n",
" predicted_original = model(edited_histogram)\n",
" sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)\n",
" predicted_original = predicted_original / sum\n",
"\n",
" if hyperparameters[\"loss\"] == \"kl\":\n",
" predicted_original = torch.clamp(\n",
" predicted_original, 0.0000000000000000000001, 1\n",
" )\n",
"\n",
" loss = {\n",
" \"kl\": lambda: loss_function(\n",
" torch.log(predicted_original),\n",
" original_histogram,\n",
" ),\n",
" \"progressive\": lambda: loss_function(\n",
" predicted_original, original_histogram\n",
" ),\n",
" }[hyperparameters[\"loss\"]]()\n",
"\n",
" epoch_loss += loss.item()\n",
" writer.add_scalar(\n",
" \"Loss/train/batch\",\n",
" loss,\n",
" global_step=epoch * len(train_data_loader) + batch_id,\n",
" )\n",
" loss.backward()\n",
"\n",
" if hyperparameters[\"clip_gradients\"]:\n",
" clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
" optimizer.step()\n",
"\n",
" logging.info(f\"Epoch {epoch} train loss: {epoch_loss}\")\n",
" with torch.no_grad():\n",
" model.eval()\n",
" loader = iter(test_data_loader)\n",
" edited_histogram, original_histogram = next(loader)\n",
" edited_histogram = edited_histogram.to(device)\n",
" original_histogram = original_histogram.to(device)\n",
" predicted_original = model(edited_histogram)\n",
" sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)\n",
" predicted_original = predicted_original / sum\n",
" writer.add_figure(\n",
" \"histogram\",\n",
" plot_histograms_in_2d(\n",
" {\n",
" \"original\": original_histogram.cpu()[0]\n",
" .numpy()\n",
" .squeeze(),\n",
" \"edited\": edited_histogram.cpu()[0].numpy().squeeze(),\n",
" \"predicted\": predicted_original.cpu()[0]\n",
" .numpy()\n",
" .squeeze(),\n",
" }\n",
" ),\n",
" epoch,\n",
" )\n",
"\n",
" epoch_test_loss = 0\n",
" for batch_id, (edited_histogram, original_histogram) in enumerate(\n",
" test_data_loader\n",
" ):\n",
" edited_histogram = edited_histogram.to(device)\n",
" original_histogram = original_histogram.to(device)\n",
"\n",
" predicted_original = model(edited_histogram)\n",
" sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)\n",
" predicted_original = predicted_original / sum\n",
"\n",
" if hyperparameters[\"loss\"] == \"kl\":\n",
" predicted_original = torch.clamp(\n",
" predicted_original, 0.0000000000000000000001, 1\n",
" )\n",
"\n",
" loss = {\n",
" \"kl\": lambda: loss_function(\n",
" torch.log(predicted_original),\n",
" original_histogram,\n",
" ),\n",
" \"progressive\": lambda: loss_function(\n",
" predicted_original, original_histogram\n",
" ),\n",
" }[hyperparameters[\"loss\"]]()\n",
"\n",
" epoch_test_loss += loss.item()\n",
" writer.add_hparams(\n",
" serialise_hparams(hyperparameters),\n",
" {\n",
" \"Loss/test/epoch\": epoch_test_loss,\n",
" \"Loss/train/epoch\": epoch_loss,\n",
" },\n",
" global_step=epoch,\n",
" run_name=log_dir.absolute(),\n",
" )\n",
" logging.info(f\"Epoch {epoch} test loss: {epoch_test_loss}\")\n",
"\n",
" model.train()\n",
" scheduler.step()\n",
" except Exception:\n",
" raise\n",
" finally:\n",
" run_name = get_next_run_name(MODELS_PATH)\n",
" model_path = (MODELS_PATH / run_name).with_suffix(\".pth\")\n",
" params_path = (MODELS_PATH / run_name).with_suffix(\".json\")\n",
"\n",
" logging.info(f\"Saving model to {model_path}\")\n",
" with open(model_path, \"wb\") as f:\n",
" torch.save(model.state_dict(), f)\n",
" with open(params_path, \"w\") as f:\n",
" json.dump(hyperparameters, f, indent=2)\n",
" del model\n",
" torch.cuda.empty_cache()\n",
" return model_path"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# train(\n",
"# {\n",
"# \"batch_size\": 64,\n",
"# \"edit_count\": 25,\n",
"# \"bin_count\": 32,\n",
"# \"clip_gradients\": True,\n",
"# \"learning_rate\": 0.005,\n",
"# \"scheduler_gamma\": 0.7,\n",
"# \"num_epochs\": 20,\n",
"# \"model_type\": \"NormalisedCNN\",\n",
"# \"loss\": \"progressive\",\n",
"# \"loss_sizes\": [16, 32],\n",
"# \"loss_damping\": 2,\n",
"# }\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-06 08:19:45,651 - INFO - Starting run_26 with hparams {\n",
" \"batch_size\": 64,\n",
" \"bin_count\": 32,\n",
" \"clip_gradients\": false,\n",
" \"edit_count\": 8,\n",
" \"learning_rate\": 0.0020871393198725404,\n",
" \"loss\": \"kl\",\n",
" \"model_type\": \"Residual\",\n",
" \"num_epochs\": 16,\n",
" \"scheduler_gamma\": 1.2086440138363033\n",
"}\n",
"2024-06-06 08:19:45,768 - INFO - Loaded 22479 original images\n",
"2024-06-06 08:19:45,778 - INFO - Loaded 2498 original images\n",
"2024-06-06 08:20:02,104 - INFO - Saving model to /home/andras/projects/bipolaroid/models/run_26.pth\n",
"2024-06-06 08:20:02,248 - INFO - Interrupted, stopping\n"
]
}
],
"source": [
"from random import choice\n",
"from itertools import count\n",
"import json\n",
"\n",
"\n",
"for _ in count():\n",
" current_hyperparameters = {\n",
" k: v.rvs() if hasattr(v, \"rvs\") else choice(v)\n",
" for k, v in choice(hyperparameters).items()\n",
" }\n",
" key = json.dumps(current_hyperparameters, indent=2, sort_keys=True)\n",
" logging.info(\n",
" f\"Starting {get_next_run_name(Path(\"runs\"))} with hparams {key}\"\n",
" )\n",
" try:\n",
" train(current_hyperparameters, max_duration=timedelta(hours=8), use_tqdm=False)\n",
" except KeyboardInterrupt as e:\n",
" logging.info(\"Interrupted, stopping\")\n",
" break\n",
" except TimeoutError as e:\n",
" logging.warning(f\"Timeout, aborting experiment\")\n",
" except Exception as e:\n",
" logging.error(\n",
" f\"Error with hparams {current_hyperparameters}:\\n\\t{e}\", stack_info=True\n",
" )"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "bipolaroid",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}