Smal updates

This commit is contained in:
Andras Schmelczer 2024-06-30 22:18:16 +01:00
parent 48227feba5
commit e98cfa31fa
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
4 changed files with 163 additions and 18 deletions

2
.gitignore vendored
View file

@ -1,5 +1,5 @@
__pycache__
runs*
*.log
saved_models/*
saved_models*
train.py

View file

@ -2,22 +2,32 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {
"metadata": {}
},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"'Using device cuda:0'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"import os\n",
"from utils import set_up_logging\n",
"from utils import set_up_logging, get_device\n",
"from training import train, random_hparam_search\n",
"from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA, MODELS_PATH\n",
"\n",
"set_up_logging(LOGS_PATH)\n",
"\n",
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n",
"device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
"device = get_device()\n",
"f\"Using device {device}\""
]
},
@ -49,6 +59,144 @@
"# )"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-29 10:45:18,539 - INFO - Loaded 22479 original images\n",
"2024-06-29 10:45:18,544 - INFO - Loaded 2498 original images\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "61d4288a87a24255824e1430d8e69804",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 0: 0%| | 0/4215 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-29 10:50:49,231 - INFO - Epoch 0 train loss: 5050.157035887241\n",
"2024-06-29 10:51:01,594 - INFO - Epoch 0 test loss: 460.928723692894\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "55aa7c6adc41427081af06330d541c2d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 1: 0%| | 0/4215 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-29 10:56:26,321 - INFO - Epoch 1 train loss: 4134.876303792\n",
"2024-06-29 10:56:39,064 - INFO - Epoch 1 test loss: 432.3766641020775\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2b7d28fc6f62489ebea844441a57178f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 2: 0%| | 0/4215 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-29 11:02:03,858 - INFO - Epoch 2 train loss: 3871.2984607219696\n",
"2024-06-29 11:02:16,040 - INFO - Epoch 2 test loss: 408.74887996912\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1395274c7a38465099c93c2d241ea1da",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch 3: 0%| | 0/4215 [00:00<?, ?batch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[2], line 20\u001b[0m\n\u001b[1;32m 1\u001b[0m hparams \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 2\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m64\u001b[39m,\n\u001b[1;32m 3\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124medit_count\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m12\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mleaky_relu_alpha\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m0.03745605986732464\u001b[39m,\n\u001b[1;32m 18\u001b[0m }\n\u001b[0;32m---> 20\u001b[0m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[43m \u001b[49m\u001b[43mhparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 22\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_data_paths\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mTRAIN_DATA\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 23\u001b[0m \u001b[43m \u001b[49m\u001b[43mtest_data_paths\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mTEST_DATA\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[43m \u001b[49m\u001b[43mlog_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mRUNS_PATH\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 25\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_duration\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 26\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_tqdm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 27\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 28\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mhparams\u001b[49m\n\u001b[1;32m 29\u001b[0m \u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/mnt/wsl/PHYSICALDRIVE1/projects/bipolaroid/src/training/train.py:55\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(hyperparameters, train_data_paths, test_data_paths, log_dir, use_tqdm, device, model_type, learning_rate, scheduler_gamma, num_epochs, **_)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch_id, (edited_histogram, original_histogram) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\n\u001b[1;32m 50\u001b[0m tqdm(train_data_loader, desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, unit\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_tqdm\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m train_data_loader\n\u001b[1;32m 53\u001b[0m ):\n\u001b[1;32m 54\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 55\u001b[0m predicted_original \u001b[38;5;241m=\u001b[39m model(\u001b[43medited_histogram\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 56\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss_function(\n\u001b[1;32m 57\u001b[0m torch\u001b[38;5;241m.\u001b[39mlog(predicted_original \u001b[38;5;241m+\u001b[39m EPSILON),\n\u001b[1;32m 58\u001b[0m original_histogram\u001b[38;5;241m.\u001b[39mto(device),\n\u001b[1;32m 59\u001b[0m )\n\u001b[1;32m 61\u001b[0m epoch_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mitem()\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"# hparams = {\n",
"# \"batch_size\": 64,\n",
"# \"edit_count\": 12,\n",
"# \"bin_count\": 16,\n",
"# \"learning_rate\": 0.0006126108207352808,\n",
"# \"scheduler_gamma\": 0.9382286228762693,\n",
"# \"num_epochs\": 10,\n",
"# \"elu_alpha\": 1.3092260477215776,\n",
"# \"leaky_relu_slope\": 0.029438156325552762,\n",
"# \"dropout_prob\": 0.06261255195786307,\n",
"# \"features\": [8, 16, 32],\n",
"# \"use_residual\": True,\n",
"# \"kernel_size\": 5,\n",
"# \"model_type\": \"HistogramNet\",\n",
"# \"use_instance_norm\": True,\n",
"# \"use_elu\": False,\n",
"# \"leaky_relu_alpha\": 0.03745605986732464,\n",
"# }\n",
"\n",
"# train(\n",
"# hparams,\n",
"# train_data_paths=TRAIN_DATA,\n",
"# test_data_paths=TEST_DATA,\n",
"# log_dir=RUNS_PATH,\n",
"# max_duration=None,\n",
"# use_tqdm=True,\n",
"# device=device,\n",
"# **hparams\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": null,
@ -61,37 +209,34 @@
"\n",
"hyperparameters = [\n",
" {\n",
" \"batch_size\": [32, 64, 128],\n",
" \"batch_size\": [64, 128, 256],\n",
" \"edit_count\": [12],\n",
" \"bin_count\": [16],\n",
" \"learning_rate\": loguniform(5e-4, 5e-3),\n",
" \"scheduler_gamma\": uniform(loc=0.8, scale=0.15),\n",
" \"num_epochs\": [12],\n",
" \"num_epochs\": [10],\n",
" \"elu_alpha\": uniform(0.5, 1.5),\n",
" \"leaky_relu_slope\": uniform(0, 0.03),\n",
" \"dropout_prob\": uniform(0, 0.1),\n",
" \"features\": [\n",
" [16, 32],\n",
" [16, 32, 64],\n",
" [16, 32, 64, 128],\n",
" [32, 64],\n",
" [32, 64, 128],\n",
" [8, 16, 32],\n",
" [8, 8, 8],\n",
" [8, 8, 8, 8, 8],\n",
" [8, 8, 8, 8, 8, 8, 8],\n",
" [16, 16, 16, 16, 16],\n",
" [16, 16, 16],\n",
" [32, 32], \n",
" [16, 16, 16, 16, 16],\n",
" [32, 32, 32],\n",
" [32, 32, 32, 32],\n",
" [64, 64],\n",
" [64, 64, 64]\n",
" [64, 64, 64],\n",
" ],\n",
" \"use_residual\": [True, False],\n",
" \"kernel_size\": [3, 5],\n",
" \"model_type\": [\"HistogramNet\"],\n",
" \"use_instance_norm\": [True, False],\n",
" \"use_instance_norm\": [True],\n",
" \"use_elu\": [True, False],\n",
" \"leaky_relu_alpha\": uniform(0, 0.05),\n",
" }\n",
@ -127,7 +272,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.1.-1"
"version": "3.12.2"
}
},
"nbformat": 4,

View file

@ -18,8 +18,8 @@ class HistogramDataset(Dataset):
self,
/,
paths: List[Path],
bin_count: int,
edit_count: int = 5,
bin_count: int = 16,
edit_count: int = 12,
target_size=(240, 240),
delete_corrupt_images: bool = False,
cache_path: Optional[Path] = None,

View file

@ -4,7 +4,7 @@ import numpy as np
def compute_histogram(
image: Image.Image | np.ndarray,
bins: int,
bins: int = 16,
value_range=(0, 256),
normalize: bool = True,
) -> np.ndarray: