Adjust training script for Residual3

This commit is contained in:
Andras Schmelczer 2024-06-22 18:31:43 +01:00
parent 129a315228
commit e5959268c1
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C

View file

@ -22,19 +22,10 @@
"import torch\n",
"import logging\n",
"import os\n",
"from datetime import datetime\n",
"from editor.utils import set_up_logging\n",
"from config import LOGS_PATH\n",
"\n",
"logging.basicConfig(\n",
" level=logging.INFO,\n",
" format=\"%(asctime)s - %(levelname)s - %(message)s\",\n",
" handlers=[\n",
" logging.StreamHandler(),\n",
" logging.FileHandler(\n",
" LOGS_PATH / f\"train-{datetime.now().isoformat(timespec='minutes')}.log\"\n",
" ),\n",
" ],\n",
")\n",
"set_up_logging(LOGS_PATH)\n",
"\n",
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n",
"device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
@ -45,34 +36,49 @@
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-22 15:59:06,999 - INFO - Testing model Dummy\n",
"2024-06-22 15:59:07,004 - INFO - Test passed! Output shape matches input shape.\n",
"2024-06-22 15:59:07,004 - INFO - Testing model SimpleCNN\n",
"2024-06-22 15:59:07,478 - INFO - Test passed! Output shape matches input shape.\n",
"2024-06-22 15:59:07,481 - INFO - Testing model Residual\n",
"2024-06-22 15:59:08,560 - INFO - Test passed! Output shape matches input shape.\n",
"2024-06-22 15:59:08,566 - INFO - Testing model Residual2\n",
"2024-06-22 15:59:09,671 - INFO - Test passed! Output shape matches input shape.\n",
"2024-06-22 15:59:09,676 - INFO - Testing model Residual3\n",
"2024-06-22 15:59:11,272 - INFO - Test passed! Output shape matches input shape.\n"
]
}
],
"source": [
"from scipy.stats import loguniform, uniform, randint\n",
"from editor.models import MODELS\n",
"from editor.models import MODELS, test_models\n",
"\n",
"common_hyperparameters = {\n",
" \"batch_size\": [64],\n",
"\n",
"hyperparameters = [{\n",
" \"batch_size\": [32, 64, 128],\n",
" \"edit_count\": [12],\n",
" \"bin_count\": [16],\n",
" \"clip_gradients\": [False],\n",
" \"learning_rate\": loguniform(1e-4, 5e-3),\n",
" \"scheduler_gamma\": uniform(loc=0.7, scale=0.3),\n",
" \"num_epochs\": [24],\n",
" # \"num_epochs\": randint(5, 10),\n",
" \"model_type\": list(MODELS.keys()),\n",
"}\n",
"hyperparameters = [\n",
" # {\n",
" # **common_hyperparameters,\n",
" # \"loss\": [\"progressive\"],\n",
" # \"loss_sizes\": [[4, 8, 16, 32], [8, 16, 32], [16, 32], [8, 32]],\n",
" # \"loss_damping\": uniform(0.2, 5),\n",
" # },\n",
" {\n",
" **common_hyperparameters,\n",
" \"loss\": [\"kl\"],\n",
" },\n",
"]"
" \"learning_rate\": loguniform(5e-4, 5e-3),\n",
" \"scheduler_gamma\": uniform(loc=0.8, scale=0.15),\n",
" \"num_epochs\": [12],\n",
" \"elu_alpha\": uniform(0.5, 1.5),\n",
" \"leaky_relu_slope\": uniform(0, 0.03),\n",
" \"dropout_prob\": uniform(0, 0.1),\n",
" \"features\": [[16, 32, 64], [32, 64, 128], [8, 16, 32], [8, 8, 8], [16, 16, 16]],\n",
" \"kernel_sizes\": [[3, 3, 3]],\n",
" \"model_type\": [\"Residual3\"], # list(MODELS.keys()),\n",
" \"clip_gradients\": [True, False],\n",
" \"use_instance_norm\": [True, False],\n",
" \"use_elu\": [True, False],\n",
" \"leaky_relu_alpha\": uniform(0, 0.05),\n",
"}]\n",
"\n",
"test_models()"
]
},
{
@ -111,36 +117,24 @@
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Testing model Residual\n",
"Test passed! Output shape matches input shape.\n"
]
}
],
"outputs": [],
"source": [
"from typing import Optional\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"from pathlib import Path\n",
"from torch.optim import Adam\n",
"from tqdm.notebook import tqdm\n",
"from torch.nn.utils import clip_grad_norm_\n",
"from editor.training import ProgressivePoolingLoss\n",
"from editor.utils import get_next_run_name\n",
"from editor.visualisation import plot_histograms_in_2d\n",
"from editor.models import create_model, test_models\n",
"from editor.models import create_model, save_model\n",
"from datetime import timedelta, datetime\n",
"import json\n",
"from config import MODELS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA\n",
"\n",
"\n",
"test_models()\n",
"\n",
"\n",
"def train(\n",
" hyperparameters: Dict[str, Any], max_duration: timedelta, use_tqdm: bool\n",
" hyperparameters: Dict[str, Any],\n",
" max_duration: Optional[timedelta] = None,\n",
" use_tqdm: bool = True,\n",
") -> Path:\n",
" start_time = datetime.now()\n",
"\n",
@ -149,28 +143,18 @@
" train_data_loader = get_data_loader(TRAIN_DATA, hyperparameters)\n",
" test_data_loader = get_data_loader(TEST_DATA, hyperparameters)\n",
"\n",
" model = (\n",
" create_model(\n",
" type=hyperparameters[\"model_type\"],\n",
" bin_count=hyperparameters[\"bin_count\"],\n",
" )\n",
" .train()\n",
" .to(device)\n",
" )\n",
" model = create_model(\n",
" type=hyperparameters[\"model_type\"],\n",
" bin_count=hyperparameters[\"bin_count\"],\n",
" device=device,\n",
" ).train()\n",
" writer.add_graph(model, next(iter(train_data_loader))[0].to(device))\n",
"\n",
" optimizer = Adam(model.parameters(), lr=hyperparameters[\"learning_rate\"])\n",
" scheduler = torch.optim.lr_scheduler.StepLR(\n",
" optimizer, step_size=1, gamma=hyperparameters[\"scheduler_gamma\"]\n",
" )\n",
"\n",
" loss_function = {\n",
" \"progressive\": lambda: ProgressivePoolingLoss(\n",
" target_sizes=hyperparameters[\"loss_sizes\"],\n",
" damping=hyperparameters[\"loss_damping\"],\n",
" ),\n",
" \"kl\": lambda: torch.nn.KLDivLoss(reduction=\"batchmean\"),\n",
" }[hyperparameters[\"loss\"]]().to(device)\n",
" loss_function = torch.nn.KLDivLoss(reduction=\"batchmean\").to(device)\n",
"\n",
" try:\n",
" for epoch in range(hyperparameters[\"num_epochs\"]):\n",
@ -184,30 +168,18 @@
" else train_data_loader\n",
" ):\n",
" current_time = datetime.now()\n",
" if current_time - start_time > max_duration:\n",
" if (\n",
" max_duration is not None\n",
" and current_time - start_time > max_duration\n",
" ):\n",
" raise TimeoutError(f\"Time limit {max_duration} exceeded\")\n",
" edited_histogram = edited_histogram.to(device)\n",
" original_histogram = original_histogram.to(device)\n",
"\n",
" optimizer.zero_grad()\n",
" predicted_original = model(edited_histogram)\n",
" sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)\n",
" predicted_original = predicted_original / sum\n",
"\n",
" if hyperparameters[\"loss\"] == \"kl\":\n",
" predicted_original = torch.clamp(\n",
" predicted_original, 0.0000000000000000000001, 1\n",
" )\n",
"\n",
" loss = {\n",
" \"kl\": lambda: loss_function(\n",
" torch.log(predicted_original),\n",
" original_histogram,\n",
" ),\n",
" \"progressive\": lambda: loss_function(\n",
" predicted_original, original_histogram\n",
" ),\n",
" }[hyperparameters[\"loss\"]]()\n",
" predicted_original = model(edited_histogram.to(device))\n",
" loss = loss_function(\n",
" torch.log(torch.clamp(predicted_original, 1e-10, 1)),\n",
" original_histogram.to(device),\n",
" )\n",
"\n",
" epoch_loss += loss.item()\n",
" writer.add_scalar(\n",
@ -216,9 +188,6 @@
" global_step=epoch * len(train_data_loader) + batch_id,\n",
" )\n",
" loss.backward()\n",
"\n",
" if hyperparameters[\"clip_gradients\"]:\n",
" clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
" optimizer.step()\n",
"\n",
" logging.info(f\"Epoch {epoch} train loss: {epoch_loss}\")\n",
@ -226,18 +195,12 @@
" model.eval()\n",
" loader = iter(test_data_loader)\n",
" edited_histogram, original_histogram = next(loader)\n",
" edited_histogram = edited_histogram.to(device)\n",
" original_histogram = original_histogram.to(device)\n",
" predicted_original = model(edited_histogram)\n",
" sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)\n",
" predicted_original = predicted_original / sum\n",
" predicted_original = model(edited_histogram.to(device))\n",
" writer.add_figure(\n",
" \"histogram\",\n",
" plot_histograms_in_2d(\n",
" {\n",
" \"original\": original_histogram.cpu()[0]\n",
" .numpy()\n",
" .squeeze(),\n",
" \"original\": original_histogram[0].numpy().squeeze(),\n",
" \"edited\": edited_histogram.cpu()[0].numpy().squeeze(),\n",
" \"predicted\": predicted_original.cpu()[0]\n",
" .numpy()\n",
@ -251,29 +214,11 @@
" for batch_id, (edited_histogram, original_histogram) in enumerate(\n",
" test_data_loader\n",
" ):\n",
" edited_histogram = edited_histogram.to(device)\n",
" original_histogram = original_histogram.to(device)\n",
"\n",
" predicted_original = model(edited_histogram)\n",
" sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)\n",
" predicted_original = predicted_original / sum\n",
"\n",
" if hyperparameters[\"loss\"] == \"kl\":\n",
" predicted_original = torch.clamp(\n",
" predicted_original, 0.0000000000000000000001, 1\n",
" )\n",
"\n",
" loss = {\n",
" \"kl\": lambda: loss_function(\n",
" torch.log(predicted_original),\n",
" original_histogram,\n",
" ),\n",
" \"progressive\": lambda: loss_function(\n",
" predicted_original, original_histogram\n",
" ),\n",
" }[hyperparameters[\"loss\"]]()\n",
"\n",
" epoch_test_loss += loss.item()\n",
" predicted_original = model(edited_histogram.to(device))\n",
" epoch_test_loss += loss_function(\n",
" torch.log(torch.clamp(predicted_original, 1e-10, 1)),\n",
" original_histogram.to(device),\n",
" ).item()\n",
" writer.add_hparams(\n",
" serialise_hparams(hyperparameters),\n",
" {\n",
@ -287,40 +232,323 @@
"\n",
" model.train()\n",
" scheduler.step()\n",
" except Exception:\n",
" raise\n",
" finally:\n",
" run_name = get_next_run_name(MODELS_PATH)\n",
" model_path = (MODELS_PATH / run_name).with_suffix(\".pth\")\n",
" params_path = (MODELS_PATH / run_name).with_suffix(\".json\")\n",
"\n",
" logging.info(f\"Saving model to {model_path}\")\n",
" with open(model_path, \"wb\") as f:\n",
" torch.save(model.state_dict(), f)\n",
" with open(params_path, \"w\") as f:\n",
" json.dump(hyperparameters, f, indent=2)\n",
" model_path = MODELS_PATH / get_next_run_name(MODELS_PATH)\n",
" save_model(model, hyperparameters, model_path)\n",
" del model\n",
" torch.cuda.empty_cache()\n",
" return model_path"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 7,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-22 16:57:28,986 - INFO - Loaded 22479 original images\n",
"2024-06-22 16:57:28,991 - INFO - Loaded 2498 original images\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "63e068d857484e59afb823ca4b5c3a58",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 0: 0%| | 0/2108 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-22 17:00:34,218 - INFO - Epoch 0 train loss: 11718.475350141525\n",
"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/matplotlib/collections.py:996: RuntimeWarning: invalid value encountered in sqrt\n",
" scale = np.sqrt(self._sizes) * dpi / 72.0 * self._factor\n",
"2024-06-22 17:00:43,669 - INFO - Epoch 0 test loss: 575.4344878196716\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "594bc7c470a149b4a1c89a5728d67ab1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 1: 0%| | 0/2108 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-22 17:03:46,881 - INFO - Epoch 1 train loss: 9741.187401413918\n",
"2024-06-22 17:03:56,471 - INFO - Epoch 1 test loss: 536.2769713401794\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "830ea2fbef8d4e1c8e45650f82560c74",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 2: 0%| | 0/2108 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-22 17:06:59,896 - INFO - Epoch 2 train loss: 9120.070751070976\n",
"2024-06-22 17:07:09,641 - INFO - Epoch 2 test loss: 553.2901458740234\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "077a671090584e97a88adbfc35007919",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 3: 0%| | 0/2108 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-22 17:10:12,942 - INFO - Epoch 3 train loss: 5763.117876529694\n",
"2024-06-22 17:10:22,622 - INFO - Epoch 3 test loss: 507.6950304508209\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "786b050c3f85486096d84816aab6affd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 4: 0%| | 0/2108 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-22 17:13:25,908 - INFO - Epoch 4 train loss: 6363.094870328903\n",
"2024-06-22 17:13:36,532 - INFO - Epoch 4 test loss: 532.4468264579773\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "070b1d8d19ec4958b9d1175cef666802",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 5: 0%| | 0/2108 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-22 17:16:40,056 - INFO - Epoch 5 train loss: 4596.043945550919\n",
"2024-06-22 17:16:49,784 - INFO - Epoch 5 test loss: 438.763400554657\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5531f44fbc84422e9739dd94121b100f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 6: 0%| | 0/2108 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-22 17:19:53,205 - INFO - Epoch 6 train loss: 5266.503381967545\n",
"2024-06-22 17:20:02,990 - INFO - Epoch 6 test loss: 573.5293898582458\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a73cdeb982ab418185fde63e80fe07db",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 7: 0%| | 0/2108 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-22 17:23:06,445 - INFO - Epoch 7 train loss: 5163.991681098938\n",
"2024-06-22 17:23:16,136 - INFO - Epoch 7 test loss: 672.4951323270798\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8f04a8e6f5324b688ce8a6c9f4453f78",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 8: 0%| | 0/2108 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-22 17:26:19,453 - INFO - Epoch 8 train loss: 12930.857147455215\n",
"2024-06-22 17:26:29,204 - INFO - Epoch 8 test loss: 636.4001806974411\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6dee2d0ebb564ff89604ad8a81787998",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 9: 0%| | 0/2108 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-22 17:29:32,560 - INFO - Epoch 9 train loss: 13841.072596549988\n",
"2024-06-22 17:29:42,246 - INFO - Epoch 9 test loss: 2833.8614711761475\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0b4729e7da054a3c8b2e0609459a18a7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 10: 0%| | 0/2108 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-22 17:32:45,585 - INFO - Epoch 10 train loss: 15531.006411075592\n",
"2024-06-22 17:32:55,355 - INFO - Epoch 10 test loss: 469.56569051742554\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9fda0f52c620475b927bbff090b048d5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 11: 0%| | 0/2108 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-22 17:35:58,670 - INFO - Epoch 11 train loss: 17766.0949113369\n",
"2024-06-22 17:36:08,527 - INFO - Epoch 11 test loss: 3254.2825841903687\n",
"2024-06-22 17:36:08,529 - INFO - Saving model to /home/andras/projects/bipolaroid/models/run_66.pth\n",
"2024-06-22 17:36:08,529 - INFO - Parameter count: 429457\n"
]
},
{
"data": {
"text/plain": [
"PosixPath('/home/andras/projects/bipolaroid/models/run_66')"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# train(\n",
"# {\n",
"# \"batch_size\": 64,\n",
"# \"edit_count\": 8,\n",
"# \"batch_size\": 128,\n",
"# \"edit_count\": 12,\n",
"# \"bin_count\": 16,\n",
"# \"clip_gradients\": False,\n",
"# \"learning_rate\": 0.0005220900529274365,\n",
"# \"scheduler_gamma\": 0.5479991284291021,\n",
"# \"num_epochs\": 24,\n",
"# \"model_type\": \"Residual\",\n",
"# \"loss\": \"kl\",\n",
"# \"learning_rate\": 1e-3,\n",
"# \"scheduler_gamma\": 0.8,\n",
"# \"elu_alpha\": 1,\n",
"# \"dropout_prob\": 0.05,\n",
"# \"features\": [8, 16, 32],\n",
"# \"kernel_sizes\": [3, 3, 3],\n",
"# \"num_epochs\": 12,\n",
"# \"model_type\": \"Residual3\",\n",
"# \"clip_gradients\": True,\n",
"# \"use_instance_norm\": True,\n",
"# \"use_elu\": False,\n",
"# \"leaky_relu_alpha\": 0.01,\n",
"# }\n",
"# )"
]
@ -329,45 +557,7 @@
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-16 20:21:51,962 - INFO - Starting run_0 with hparams {\n",
" \"batch_size\": 64,\n",
" \"bin_count\": 32,\n",
" \"clip_gradients\": false,\n",
" \"edit_count\": 8,\n",
" \"learning_rate\": 0.0013249692770052317,\n",
" \"loss\": \"kl\",\n",
" \"model_type\": \"Residual\",\n",
" \"num_epochs\": 16,\n",
" \"scheduler_gamma\": 1.3114281184948258\n",
"}\n",
"2024-06-16 20:21:52,012 - INFO - Loaded 22479 original images\n",
"2024-06-16 20:21:52,016 - INFO - Loaded 2498 original images\n",
"2024-06-16 20:35:43,995 - INFO - Epoch 0 train loss: 6540.840226650238\n",
"2024-06-16 20:36:15,017 - INFO - Epoch 0 test loss: 1531.6546006202698\n",
"2024-06-16 20:49:58,543 - INFO - Epoch 1 train loss: 5763.938045859337\n",
"2024-06-16 20:50:29,893 - INFO - Epoch 1 test loss: 1608.853798866272\n",
"2024-06-16 21:04:13,577 - INFO - Epoch 2 train loss: 5448.952376246452\n",
"2024-06-16 21:04:45,607 - INFO - Epoch 2 test loss: 1465.128571987152\n",
"2024-06-16 21:18:31,962 - INFO - Epoch 3 train loss: 5633.2793600559235\n",
"2024-06-16 21:19:09,149 - INFO - Epoch 3 test loss: 1330.329261302948\n",
"2024-06-16 21:32:58,465 - INFO - Epoch 4 train loss: 5338.784257531166\n",
"2024-06-16 21:33:37,006 - INFO - Epoch 4 test loss: 2083.3998107910156\n",
"2024-06-16 21:47:25,527 - INFO - Epoch 5 train loss: 5321.843332529068\n",
"2024-06-16 21:48:04,110 - INFO - Epoch 5 test loss: 1314.629390001297\n",
"2024-06-16 22:01:51,264 - INFO - Epoch 6 train loss: 5337.748890757561\n",
"2024-06-16 22:02:29,786 - INFO - Epoch 6 test loss: 1290.2974362373352\n",
"2024-06-16 22:16:17,529 - INFO - Epoch 7 train loss: 5167.580719232559\n",
"2024-06-16 22:16:56,284 - INFO - Epoch 7 test loss: 1889.667366027832\n",
"2024-06-16 22:26:15,914 - INFO - Saving model to /home/andras/projects/bipolaroid/models/run_67.pth\n",
"2024-06-16 22:26:16,086 - INFO - Interrupted, stopping\n"
]
}
],
"outputs": [],
"source": [
"from random import choice\n",
"from itertools import count\n",
@ -380,9 +570,7 @@
" for k, v in choice(hyperparameters).items()\n",
" }\n",
" key = json.dumps(current_hyperparameters, indent=2, sort_keys=True)\n",
" logging.info(\n",
" f\"Starting {get_next_run_name(RUNS_PATH)} with hparams {key}\"\n",
" )\n",
" logging.info(f\"Starting {get_next_run_name(RUNS_PATH)} with hparams {key}\")\n",
" try:\n",
" train(current_hyperparameters, max_duration=timedelta(hours=8), use_tqdm=False)\n",
" except KeyboardInterrupt as e:\n",
@ -413,7 +601,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.1.-1"
"version": "3.12.2"
}
},
"nbformat": 4,