Fix??
This commit is contained in:
parent
2475d7c8dd
commit
16ba45546b
3 changed files with 116 additions and 502 deletions
612
src/train.ipynb
612
src/train.ipynb
|
|
@ -20,9 +20,8 @@
|
|||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import logging\n",
|
||||
"import os\n",
|
||||
"from editor.utils import set_up_logging\n",
|
||||
"from utils import set_up_logging\n",
|
||||
"from config import LOGS_PATH\n",
|
||||
"\n",
|
||||
"set_up_logging(LOGS_PATH)\n",
|
||||
|
|
@ -41,42 +40,42 @@
|
|||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-22 15:59:06,999 - INFO - Testing model Dummy\n",
|
||||
"2024-06-22 15:59:07,004 - INFO - Test passed! Output shape matches input shape.\n",
|
||||
"2024-06-22 15:59:07,004 - INFO - Testing model SimpleCNN\n",
|
||||
"2024-06-22 15:59:07,478 - INFO - Test passed! Output shape matches input shape.\n",
|
||||
"2024-06-22 15:59:07,481 - INFO - Testing model Residual\n",
|
||||
"2024-06-22 15:59:08,560 - INFO - Test passed! Output shape matches input shape.\n",
|
||||
"2024-06-22 15:59:08,566 - INFO - Testing model Residual2\n",
|
||||
"2024-06-22 15:59:09,671 - INFO - Test passed! Output shape matches input shape.\n",
|
||||
"2024-06-22 15:59:09,676 - INFO - Testing model Residual3\n",
|
||||
"2024-06-22 15:59:11,272 - INFO - Test passed! Output shape matches input shape.\n"
|
||||
"2024-06-25 08:59:52,244 - INFO - Testing model Dummy\n",
|
||||
"2024-06-25 08:59:52,249 - INFO - Test passed! Output shape matches input shape.\n",
|
||||
"2024-06-25 08:59:52,249 - INFO - Testing model SimpleCNN\n",
|
||||
"2024-06-25 08:59:52,746 - INFO - Test passed! Output shape matches input shape.\n",
|
||||
"2024-06-25 08:59:52,752 - INFO - Testing model Residual\n",
|
||||
"2024-06-25 08:59:53,853 - INFO - Test passed! Output shape matches input shape.\n",
|
||||
"2024-06-25 08:59:53,917 - INFO - Testing model Residual3\n",
|
||||
"2024-06-25 08:59:55,590 - INFO - Test passed! Output shape matches input shape.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from scipy.stats import loguniform, uniform, randint\n",
|
||||
"from editor.models import MODELS, test_models\n",
|
||||
"from models import MODELS, test_models\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"hyperparameters = [{\n",
|
||||
" \"batch_size\": [32, 64, 128],\n",
|
||||
" \"edit_count\": [12],\n",
|
||||
" \"bin_count\": [16],\n",
|
||||
" \"learning_rate\": loguniform(5e-4, 5e-3),\n",
|
||||
" \"scheduler_gamma\": uniform(loc=0.8, scale=0.15),\n",
|
||||
" \"num_epochs\": [12],\n",
|
||||
" \"elu_alpha\": uniform(0.5, 1.5),\n",
|
||||
" \"leaky_relu_slope\": uniform(0, 0.03),\n",
|
||||
" \"dropout_prob\": uniform(0, 0.1),\n",
|
||||
" \"features\": [[16, 32, 64], [32, 64, 128], [8, 16, 32], [8, 8, 8], [16, 16, 16]],\n",
|
||||
" \"kernel_sizes\": [[3, 3, 3]],\n",
|
||||
" \"model_type\": [\"Residual3\"], # list(MODELS.keys()),\n",
|
||||
" \"clip_gradients\": [True, False],\n",
|
||||
" \"use_instance_norm\": [True, False],\n",
|
||||
" \"use_elu\": [True, False],\n",
|
||||
" \"leaky_relu_alpha\": uniform(0, 0.05),\n",
|
||||
"}]\n",
|
||||
"hyperparameters = [\n",
|
||||
" {\n",
|
||||
" \"batch_size\": [32, 64, 128],\n",
|
||||
" \"edit_count\": [12],\n",
|
||||
" \"bin_count\": [16],\n",
|
||||
" \"learning_rate\": loguniform(5e-4, 5e-3),\n",
|
||||
" \"scheduler_gamma\": uniform(loc=0.8, scale=0.15),\n",
|
||||
" \"num_epochs\": [12],\n",
|
||||
" \"elu_alpha\": uniform(0.5, 1.5),\n",
|
||||
" \"leaky_relu_slope\": uniform(0, 0.03),\n",
|
||||
" \"dropout_prob\": uniform(0, 0.1),\n",
|
||||
" \"features\": [[16, 32, 64], [32, 64, 128], [8, 16, 32], [8, 8, 8], [16, 16, 16]],\n",
|
||||
" \"kernel_sizes\": [[3, 3, 3]],\n",
|
||||
" \"model_type\": [\"Residual3\"], # list(MODELS.keys()),\n",
|
||||
" \"clip_gradients\": [True, False],\n",
|
||||
" \"use_instance_norm\": [True, False],\n",
|
||||
" \"use_elu\": [True, False],\n",
|
||||
" \"leaky_relu_alpha\": uniform(0, 0.05),\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"test_models()"
|
||||
]
|
||||
|
|
@ -86,451 +85,6 @@
|
|||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"from typing import List, Any, Dict\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"from config import CACHE_PATH\n",
|
||||
"from editor.training import HistogramDataset\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_data_loader(data: List[Path], hyperparameters: Dict[str, Any]) -> DataLoader:\n",
|
||||
" return DataLoader(\n",
|
||||
" dataset=HistogramDataset(\n",
|
||||
" paths=data,\n",
|
||||
" edit_count=hyperparameters[\"edit_count\"],\n",
|
||||
" bin_count=hyperparameters[\"bin_count\"],\n",
|
||||
" delete_corrupt_images=False,\n",
|
||||
" cache_path=CACHE_PATH,\n",
|
||||
" ),\n",
|
||||
" batch_size=hyperparameters[\"batch_size\"],\n",
|
||||
" shuffle=True,\n",
|
||||
" num_workers=os.cpu_count(),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def serialise_hparams(hyperparameters: Dict[str, Any]) -> Dict[str, Any]:\n",
|
||||
" return {k: str(v) if isinstance(v, list) else v for k, v in hyperparameters.items()}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Optional\n",
|
||||
"from torch.utils.tensorboard import SummaryWriter\n",
|
||||
"from pathlib import Path\n",
|
||||
"from torch.optim import Adam\n",
|
||||
"from tqdm.notebook import tqdm\n",
|
||||
"from editor.utils import get_next_run_name\n",
|
||||
"from editor.visualisation import plot_histograms_in_2d\n",
|
||||
"from editor.models import create_model, save_model\n",
|
||||
"from datetime import timedelta, datetime\n",
|
||||
"from config import MODELS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def train(\n",
|
||||
" hyperparameters: Dict[str, Any],\n",
|
||||
" max_duration: Optional[timedelta] = None,\n",
|
||||
" use_tqdm: bool = True,\n",
|
||||
") -> Path:\n",
|
||||
" start_time = datetime.now()\n",
|
||||
"\n",
|
||||
" log_dir = RUNS_PATH / get_next_run_name(RUNS_PATH)\n",
|
||||
" with SummaryWriter(log_dir) as writer:\n",
|
||||
" train_data_loader = get_data_loader(TRAIN_DATA, hyperparameters)\n",
|
||||
" test_data_loader = get_data_loader(TEST_DATA, hyperparameters)\n",
|
||||
"\n",
|
||||
" model = create_model(\n",
|
||||
" type=hyperparameters[\"model_type\"],\n",
|
||||
" bin_count=hyperparameters[\"bin_count\"],\n",
|
||||
" device=device,\n",
|
||||
" ).train()\n",
|
||||
" writer.add_graph(model, next(iter(train_data_loader))[0].to(device))\n",
|
||||
"\n",
|
||||
" optimizer = Adam(model.parameters(), lr=hyperparameters[\"learning_rate\"])\n",
|
||||
" scheduler = torch.optim.lr_scheduler.StepLR(\n",
|
||||
" optimizer, step_size=1, gamma=hyperparameters[\"scheduler_gamma\"]\n",
|
||||
" )\n",
|
||||
" loss_function = torch.nn.KLDivLoss(reduction=\"batchmean\").to(device)\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" for epoch in range(hyperparameters[\"num_epochs\"]):\n",
|
||||
" epoch_loss = 0\n",
|
||||
" writer.add_scalar(\n",
|
||||
" \"Actual learning rate\", scheduler.get_last_lr()[0], epoch\n",
|
||||
" )\n",
|
||||
" for batch_id, (edited_histogram, original_histogram) in enumerate(\n",
|
||||
" tqdm(train_data_loader, desc=f\"Epoch {epoch}\", unit=\"batch\")\n",
|
||||
" if use_tqdm\n",
|
||||
" else train_data_loader\n",
|
||||
" ):\n",
|
||||
" current_time = datetime.now()\n",
|
||||
" if (\n",
|
||||
" max_duration is not None\n",
|
||||
" and current_time - start_time > max_duration\n",
|
||||
" ):\n",
|
||||
" raise TimeoutError(f\"Time limit {max_duration} exceeded\")\n",
|
||||
"\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" predicted_original = model(edited_histogram.to(device))\n",
|
||||
" loss = loss_function(\n",
|
||||
" torch.log(torch.clamp(predicted_original, 1e-10, 1)),\n",
|
||||
" original_histogram.to(device),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" epoch_loss += loss.item()\n",
|
||||
" writer.add_scalar(\n",
|
||||
" \"Loss/train/batch\",\n",
|
||||
" loss,\n",
|
||||
" global_step=epoch * len(train_data_loader) + batch_id,\n",
|
||||
" )\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
" logging.info(f\"Epoch {epoch} train loss: {epoch_loss}\")\n",
|
||||
" with torch.no_grad():\n",
|
||||
" model.eval()\n",
|
||||
" loader = iter(test_data_loader)\n",
|
||||
" edited_histogram, original_histogram = next(loader)\n",
|
||||
" predicted_original = model(edited_histogram.to(device))\n",
|
||||
" writer.add_figure(\n",
|
||||
" \"histogram\",\n",
|
||||
" plot_histograms_in_2d(\n",
|
||||
" {\n",
|
||||
" \"original\": original_histogram[0].numpy().squeeze(),\n",
|
||||
" \"edited\": edited_histogram.cpu()[0].numpy().squeeze(),\n",
|
||||
" \"predicted\": predicted_original.cpu()[0]\n",
|
||||
" .numpy()\n",
|
||||
" .squeeze(),\n",
|
||||
" }\n",
|
||||
" ),\n",
|
||||
" epoch,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" epoch_test_loss = 0\n",
|
||||
" for batch_id, (edited_histogram, original_histogram) in enumerate(\n",
|
||||
" test_data_loader\n",
|
||||
" ):\n",
|
||||
" predicted_original = model(edited_histogram.to(device))\n",
|
||||
" epoch_test_loss += loss_function(\n",
|
||||
" torch.log(torch.clamp(predicted_original, 1e-10, 1)),\n",
|
||||
" original_histogram.to(device),\n",
|
||||
" ).item()\n",
|
||||
" writer.add_hparams(\n",
|
||||
" serialise_hparams(hyperparameters),\n",
|
||||
" {\n",
|
||||
" \"Loss/test/epoch\": epoch_test_loss,\n",
|
||||
" \"Loss/train/epoch\": epoch_loss,\n",
|
||||
" },\n",
|
||||
" global_step=epoch,\n",
|
||||
" run_name=log_dir.absolute(),\n",
|
||||
" )\n",
|
||||
" logging.info(f\"Epoch {epoch} test loss: {epoch_test_loss}\")\n",
|
||||
"\n",
|
||||
" model.train()\n",
|
||||
" scheduler.step()\n",
|
||||
" finally:\n",
|
||||
" model_path = MODELS_PATH / get_next_run_name(MODELS_PATH)\n",
|
||||
" save_model(model, hyperparameters, model_path)\n",
|
||||
" del model\n",
|
||||
" return model_path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-22 16:57:28,986 - INFO - Loaded 22479 original images\n",
|
||||
"2024-06-22 16:57:28,991 - INFO - Loaded 2498 original images\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "63e068d857484e59afb823ca4b5c3a58",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 0: 0%| | 0/2108 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-22 17:00:34,218 - INFO - Epoch 0 train loss: 11718.475350141525\n",
|
||||
"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/matplotlib/collections.py:996: RuntimeWarning: invalid value encountered in sqrt\n",
|
||||
" scale = np.sqrt(self._sizes) * dpi / 72.0 * self._factor\n",
|
||||
"2024-06-22 17:00:43,669 - INFO - Epoch 0 test loss: 575.4344878196716\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "594bc7c470a149b4a1c89a5728d67ab1",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 1: 0%| | 0/2108 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-22 17:03:46,881 - INFO - Epoch 1 train loss: 9741.187401413918\n",
|
||||
"2024-06-22 17:03:56,471 - INFO - Epoch 1 test loss: 536.2769713401794\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "830ea2fbef8d4e1c8e45650f82560c74",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 2: 0%| | 0/2108 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-22 17:06:59,896 - INFO - Epoch 2 train loss: 9120.070751070976\n",
|
||||
"2024-06-22 17:07:09,641 - INFO - Epoch 2 test loss: 553.2901458740234\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "077a671090584e97a88adbfc35007919",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 3: 0%| | 0/2108 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-22 17:10:12,942 - INFO - Epoch 3 train loss: 5763.117876529694\n",
|
||||
"2024-06-22 17:10:22,622 - INFO - Epoch 3 test loss: 507.6950304508209\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "786b050c3f85486096d84816aab6affd",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 4: 0%| | 0/2108 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-22 17:13:25,908 - INFO - Epoch 4 train loss: 6363.094870328903\n",
|
||||
"2024-06-22 17:13:36,532 - INFO - Epoch 4 test loss: 532.4468264579773\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "070b1d8d19ec4958b9d1175cef666802",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 5: 0%| | 0/2108 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-22 17:16:40,056 - INFO - Epoch 5 train loss: 4596.043945550919\n",
|
||||
"2024-06-22 17:16:49,784 - INFO - Epoch 5 test loss: 438.763400554657\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "5531f44fbc84422e9739dd94121b100f",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 6: 0%| | 0/2108 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-22 17:19:53,205 - INFO - Epoch 6 train loss: 5266.503381967545\n",
|
||||
"2024-06-22 17:20:02,990 - INFO - Epoch 6 test loss: 573.5293898582458\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "a73cdeb982ab418185fde63e80fe07db",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 7: 0%| | 0/2108 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-22 17:23:06,445 - INFO - Epoch 7 train loss: 5163.991681098938\n",
|
||||
"2024-06-22 17:23:16,136 - INFO - Epoch 7 test loss: 672.4951323270798\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "8f04a8e6f5324b688ce8a6c9f4453f78",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 8: 0%| | 0/2108 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-22 17:26:19,453 - INFO - Epoch 8 train loss: 12930.857147455215\n",
|
||||
"2024-06-22 17:26:29,204 - INFO - Epoch 8 test loss: 636.4001806974411\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "6dee2d0ebb564ff89604ad8a81787998",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 9: 0%| | 0/2108 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-22 17:29:32,560 - INFO - Epoch 9 train loss: 13841.072596549988\n",
|
||||
"2024-06-22 17:29:42,246 - INFO - Epoch 9 test loss: 2833.8614711761475\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "0b4729e7da054a3c8b2e0609459a18a7",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 10: 0%| | 0/2108 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-22 17:32:45,585 - INFO - Epoch 10 train loss: 15531.006411075592\n",
|
||||
"2024-06-22 17:32:55,355 - INFO - Epoch 10 test loss: 469.56569051742554\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "9fda0f52c620475b927bbff090b048d5",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 11: 0%| | 0/2108 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-22 17:35:58,670 - INFO - Epoch 11 train loss: 17766.0949113369\n",
|
||||
"2024-06-22 17:36:08,527 - INFO - Epoch 11 test loss: 3254.2825841903687\n",
|
||||
"2024-06-22 17:36:08,529 - INFO - Saving model to /home/andras/projects/bipolaroid/models/run_66.pth\n",
|
||||
"2024-06-22 17:36:08,529 - INFO - Parameter count: 429457\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"PosixPath('/home/andras/projects/bipolaroid/models/run_66')"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# train(\n",
|
||||
"# {\n",
|
||||
|
|
@ -555,33 +109,93 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-25 08:59:55,973 - INFO - Starting run_170 with hparams {\n",
|
||||
" \"batch_size\": 64,\n",
|
||||
" \"bin_count\": 16,\n",
|
||||
" \"clip_gradients\": true,\n",
|
||||
" \"dropout_prob\": 0.09784778880383105,\n",
|
||||
" \"edit_count\": 12,\n",
|
||||
" \"elu_alpha\": 0.5588538605400805,\n",
|
||||
" \"features\": [\n",
|
||||
" 8,\n",
|
||||
" 16,\n",
|
||||
" 32\n",
|
||||
" ],\n",
|
||||
" \"kernel_sizes\": [\n",
|
||||
" 3,\n",
|
||||
" 3,\n",
|
||||
" 3\n",
|
||||
" ],\n",
|
||||
" \"leaky_relu_alpha\": 0.012913890161555076,\n",
|
||||
" \"leaky_relu_slope\": 0.022615416455484896,\n",
|
||||
" \"learning_rate\": 0.002130094098871897,\n",
|
||||
" \"model_type\": \"Residual3\",\n",
|
||||
" \"num_epochs\": 12,\n",
|
||||
" \"scheduler_gamma\": 0.8142448793722726,\n",
|
||||
" \"use_elu\": true,\n",
|
||||
" \"use_instance_norm\": false\n",
|
||||
"}\n",
|
||||
"2024-06-25 08:59:55,976 - INFO - Loaded 1000 original images\n",
|
||||
"2024-06-25 08:59:55,979 - INFO - Loaded 1000 original images\n",
|
||||
"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/torch/jit/_trace.py:1102: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:\n",
|
||||
"Tensor-likes are not close!\n",
|
||||
"\n",
|
||||
"Mismatched elements: 442 / 262144 (0.2%)\n",
|
||||
"Greatest absolute difference: 0.00028640031814575195 at index (14, 0, 3, 6, 1) (up to 1e-05 allowed)\n",
|
||||
"Greatest relative difference: 0.03313953488372093 at index (52, 0, 2, 8, 3) (up to 1e-05 allowed)\n",
|
||||
" _check_trace(\n",
|
||||
"2024-06-25 09:00:10,492 - INFO - Epoch 0 train loss: -668.9345343112946\n",
|
||||
"/home/andras/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/matplotlib/collections.py:996: RuntimeWarning: invalid value encountered in sqrt\n",
|
||||
" scale = np.sqrt(self._sizes) * dpi / 72.0 * self._factor\n",
|
||||
"2024-06-25 09:00:17,910 - INFO - Epoch 0 test loss: -677.4099225997925\n",
|
||||
"2024-06-25 09:00:27,392 - INFO - Epoch 1 train loss: -677.2842514514923\n",
|
||||
"2024-06-25 09:00:35,272 - INFO - Epoch 1 test loss: -677.350163936615\n",
|
||||
"2024-06-25 09:00:44,703 - INFO - Epoch 2 train loss: -677.5818197727203\n",
|
||||
"2024-06-25 09:00:52,274 - INFO - Epoch 2 test loss: -677.3433697223663\n",
|
||||
"2024-06-25 09:01:01,680 - INFO - Epoch 3 train loss: -677.3612790107727\n",
|
||||
"2024-06-25 09:01:09,430 - INFO - Epoch 3 test loss: -677.4993708133698\n",
|
||||
"2024-06-25 09:01:18,837 - INFO - Epoch 4 train loss: -677.3555946350098\n",
|
||||
"2024-06-25 09:01:26,721 - INFO - Epoch 4 test loss: -677.2696735858917\n",
|
||||
"2024-06-25 09:01:36,224 - INFO - Epoch 5 train loss: -677.4827179908752\n",
|
||||
"2024-06-25 09:01:44,065 - INFO - Epoch 5 test loss: -677.4189476966858\n",
|
||||
"2024-06-25 09:01:53,537 - INFO - Epoch 6 train loss: -677.5993602275848\n",
|
||||
"2024-06-25 09:02:01,438 - INFO - Epoch 6 test loss: -677.5275483131409\n",
|
||||
"2024-06-25 09:02:10,913 - INFO - Epoch 7 train loss: -677.417388677597\n",
|
||||
"2024-06-25 09:02:18,622 - INFO - Epoch 7 test loss: -677.5215902328491\n",
|
||||
"2024-06-25 09:02:28,085 - INFO - Epoch 8 train loss: -677.415346622467\n",
|
||||
"2024-06-25 09:02:36,597 - INFO - Epoch 8 test loss: -677.5785489082336\n",
|
||||
"2024-06-25 09:02:46,112 - INFO - Epoch 9 train loss: -677.4984295368195\n",
|
||||
"2024-06-25 09:02:54,038 - INFO - Epoch 9 test loss: -677.5197842121124\n",
|
||||
"2024-06-25 09:03:03,558 - INFO - Epoch 10 train loss: -677.4539258480072\n",
|
||||
"2024-06-25 09:03:11,486 - INFO - Epoch 10 test loss: -677.5460705757141\n",
|
||||
"2024-06-25 09:03:21,255 - INFO - Epoch 11 train loss: -677.54083776474\n",
|
||||
"2024-06-25 09:03:29,637 - INFO - Epoch 11 test loss: -677.658331155777\n",
|
||||
"2024-06-25 09:03:29,818 - INFO - Saving model to /home/andras/projects/bipolaroid/saved_models/run_143.pth\n",
|
||||
"2024-06-25 09:03:29,819 - INFO - Parameter count: 429457\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from random import choice\n",
|
||||
"from itertools import count\n",
|
||||
"import json\n",
|
||||
"from training import random_hparam_search\n",
|
||||
"from config import RUNS_PATH, TRAIN_DATA, TEST_DATA, MODELS_PATH\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for _ in count():\n",
|
||||
" current_hyperparameters = {\n",
|
||||
" k: v.rvs() if hasattr(v, \"rvs\") else choice(v)\n",
|
||||
" for k, v in choice(hyperparameters).items()\n",
|
||||
" }\n",
|
||||
" key = json.dumps(current_hyperparameters, indent=2, sort_keys=True)\n",
|
||||
" logging.info(f\"Starting {get_next_run_name(RUNS_PATH)} with hparams {key}\")\n",
|
||||
" try:\n",
|
||||
" train(current_hyperparameters, max_duration=timedelta(hours=8), use_tqdm=False)\n",
|
||||
" except KeyboardInterrupt as e:\n",
|
||||
" logging.info(\"Interrupted, stopping\")\n",
|
||||
" break\n",
|
||||
" except TimeoutError as e:\n",
|
||||
" logging.warning(f\"Timeout, aborting experiment\")\n",
|
||||
" except Exception as e:\n",
|
||||
" logging.error(\n",
|
||||
" f\"Error with hparams {current_hyperparameters}:\\n\\t{e}\", stack_info=True\n",
|
||||
" )"
|
||||
"random_hparam_search(\n",
|
||||
" hyperparameters=hyperparameters,\n",
|
||||
" train_data_paths=TRAIN_DATA,\n",
|
||||
" test_data_paths=TEST_DATA,\n",
|
||||
" models_path=MODELS_PATH,\n",
|
||||
" tensorboard_path=RUNS_PATH,\n",
|
||||
" timeout_hours=8,\n",
|
||||
" device=device,\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ def random_hparam_search(
|
|||
timeout_hours: int,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
for _ in range(1):
|
||||
for _ in count():
|
||||
current_hyperparameters = {
|
||||
k: v.rvs() if hasattr(v, "rvs") else choice(v)
|
||||
for k, v in choice(hyperparameters).items()
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ def train(
|
|||
optimizer.zero_grad()
|
||||
predicted_original = model(edited_histogram.to(device))
|
||||
loss = loss_function(
|
||||
torch.log(torch.clamp(predicted_original, EPSILON, 1)),
|
||||
torch.log(predicted_original + EPSILON),
|
||||
original_histogram.to(device),
|
||||
)
|
||||
|
||||
|
|
@ -101,7 +101,7 @@ def train(
|
|||
):
|
||||
predicted_original = model(edited_histogram.to(device))
|
||||
epoch_test_loss += loss_function(
|
||||
torch.log(torch.clamp(predicted_original, EPSILON, 1)),
|
||||
torch.log(predicted_original + EPSILON),
|
||||
original_histogram.to(device),
|
||||
).item()
|
||||
writer.add_hparams(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue