Smal updates
This commit is contained in:
parent
48227feba5
commit
e98cfa31fa
4 changed files with 163 additions and 18 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -1,5 +1,5 @@
|
|||
__pycache__
|
||||
runs*
|
||||
*.log
|
||||
saved_models/*
|
||||
saved_models*
|
||||
train.py
|
||||
|
|
|
|||
173
src/train.ipynb
173
src/train.ipynb
|
|
@ -2,22 +2,32 @@
|
|||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"metadata": {}
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Using device cuda:0'"
|
||||
]
|
||||
},
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import os\n",
|
||||
"from utils import set_up_logging\n",
|
||||
"from utils import set_up_logging, get_device\n",
|
||||
"from training import train, random_hparam_search\n",
|
||||
"from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA, MODELS_PATH\n",
|
||||
"\n",
|
||||
"set_up_logging(LOGS_PATH)\n",
|
||||
"\n",
|
||||
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n",
|
||||
"device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
|
||||
"device = get_device()\n",
|
||||
"f\"Using device {device}\""
|
||||
]
|
||||
},
|
||||
|
|
@ -49,6 +59,144 @@
|
|||
"# )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-29 10:45:18,539 - INFO - Loaded 22479 original images\n",
|
||||
"2024-06-29 10:45:18,544 - INFO - Loaded 2498 original images\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "61d4288a87a24255824e1430d8e69804",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 0: 0%| | 0/4215 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-29 10:50:49,231 - INFO - Epoch 0 train loss: 5050.157035887241\n",
|
||||
"2024-06-29 10:51:01,594 - INFO - Epoch 0 test loss: 460.928723692894\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "55aa7c6adc41427081af06330d541c2d",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 1: 0%| | 0/4215 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-29 10:56:26,321 - INFO - Epoch 1 train loss: 4134.876303792\n",
|
||||
"2024-06-29 10:56:39,064 - INFO - Epoch 1 test loss: 432.3766641020775\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "2b7d28fc6f62489ebea844441a57178f",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 2: 0%| | 0/4215 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-29 11:02:03,858 - INFO - Epoch 2 train loss: 3871.2984607219696\n",
|
||||
"2024-06-29 11:02:16,040 - INFO - Epoch 2 test loss: 408.74887996912\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "1395274c7a38465099c93c2d241ea1da",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 3: 0%| | 0/4215 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"ename": "KeyboardInterrupt",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[2], line 20\u001b[0m\n\u001b[1;32m 1\u001b[0m hparams \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 2\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m64\u001b[39m,\n\u001b[1;32m 3\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124medit_count\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m12\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mleaky_relu_alpha\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m0.03745605986732464\u001b[39m,\n\u001b[1;32m 18\u001b[0m }\n\u001b[0;32m---> 20\u001b[0m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[43m \u001b[49m\u001b[43mhparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 22\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_data_paths\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mTRAIN_DATA\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 23\u001b[0m \u001b[43m \u001b[49m\u001b[43mtest_data_paths\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mTEST_DATA\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[43m \u001b[49m\u001b[43mlog_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mRUNS_PATH\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 25\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_duration\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 26\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_tqdm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 27\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 28\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mhparams\u001b[49m\n\u001b[1;32m 29\u001b[0m \u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m/mnt/wsl/PHYSICALDRIVE1/projects/bipolaroid/src/training/train.py:55\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(hyperparameters, train_data_paths, test_data_paths, log_dir, use_tqdm, device, model_type, learning_rate, scheduler_gamma, num_epochs, **_)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch_id, (edited_histogram, original_histogram) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\n\u001b[1;32m 50\u001b[0m tqdm(train_data_loader, desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, unit\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_tqdm\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m train_data_loader\n\u001b[1;32m 53\u001b[0m ):\n\u001b[1;32m 54\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 55\u001b[0m predicted_original \u001b[38;5;241m=\u001b[39m model(\u001b[43medited_histogram\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 56\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss_function(\n\u001b[1;32m 57\u001b[0m torch\u001b[38;5;241m.\u001b[39mlog(predicted_original \u001b[38;5;241m+\u001b[39m EPSILON),\n\u001b[1;32m 58\u001b[0m original_histogram\u001b[38;5;241m.\u001b[39mto(device),\n\u001b[1;32m 59\u001b[0m )\n\u001b[1;32m 61\u001b[0m epoch_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mitem()\n",
|
||||
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# hparams = {\n",
|
||||
"# \"batch_size\": 64,\n",
|
||||
"# \"edit_count\": 12,\n",
|
||||
"# \"bin_count\": 16,\n",
|
||||
"# \"learning_rate\": 0.0006126108207352808,\n",
|
||||
"# \"scheduler_gamma\": 0.9382286228762693,\n",
|
||||
"# \"num_epochs\": 10,\n",
|
||||
"# \"elu_alpha\": 1.3092260477215776,\n",
|
||||
"# \"leaky_relu_slope\": 0.029438156325552762,\n",
|
||||
"# \"dropout_prob\": 0.06261255195786307,\n",
|
||||
"# \"features\": [8, 16, 32],\n",
|
||||
"# \"use_residual\": True,\n",
|
||||
"# \"kernel_size\": 5,\n",
|
||||
"# \"model_type\": \"HistogramNet\",\n",
|
||||
"# \"use_instance_norm\": True,\n",
|
||||
"# \"use_elu\": False,\n",
|
||||
"# \"leaky_relu_alpha\": 0.03745605986732464,\n",
|
||||
"# }\n",
|
||||
"\n",
|
||||
"# train(\n",
|
||||
"# hparams,\n",
|
||||
"# train_data_paths=TRAIN_DATA,\n",
|
||||
"# test_data_paths=TEST_DATA,\n",
|
||||
"# log_dir=RUNS_PATH,\n",
|
||||
"# max_duration=None,\n",
|
||||
"# use_tqdm=True,\n",
|
||||
"# device=device,\n",
|
||||
"# **hparams\n",
|
||||
"# )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
|
@ -61,37 +209,34 @@
|
|||
"\n",
|
||||
"hyperparameters = [\n",
|
||||
" {\n",
|
||||
" \"batch_size\": [32, 64, 128],\n",
|
||||
" \"batch_size\": [64, 128, 256],\n",
|
||||
" \"edit_count\": [12],\n",
|
||||
" \"bin_count\": [16],\n",
|
||||
" \"learning_rate\": loguniform(5e-4, 5e-3),\n",
|
||||
" \"scheduler_gamma\": uniform(loc=0.8, scale=0.15),\n",
|
||||
" \"num_epochs\": [12],\n",
|
||||
" \"num_epochs\": [10],\n",
|
||||
" \"elu_alpha\": uniform(0.5, 1.5),\n",
|
||||
" \"leaky_relu_slope\": uniform(0, 0.03),\n",
|
||||
" \"dropout_prob\": uniform(0, 0.1),\n",
|
||||
" \"features\": [\n",
|
||||
" [16, 32],\n",
|
||||
" [16, 32, 64],\n",
|
||||
" [16, 32, 64, 128],\n",
|
||||
" [32, 64],\n",
|
||||
" [32, 64, 128],\n",
|
||||
" [8, 16, 32],\n",
|
||||
" [8, 8, 8],\n",
|
||||
" [8, 8, 8, 8, 8],\n",
|
||||
" [8, 8, 8, 8, 8, 8, 8],\n",
|
||||
" [16, 16, 16, 16, 16],\n",
|
||||
" [16, 16, 16],\n",
|
||||
" [32, 32], \n",
|
||||
" [16, 16, 16, 16, 16],\n",
|
||||
" [32, 32, 32],\n",
|
||||
" [32, 32, 32, 32],\n",
|
||||
" [64, 64],\n",
|
||||
" [64, 64, 64]\n",
|
||||
" [64, 64, 64],\n",
|
||||
" ],\n",
|
||||
" \"use_residual\": [True, False],\n",
|
||||
" \"kernel_size\": [3, 5],\n",
|
||||
" \"model_type\": [\"HistogramNet\"],\n",
|
||||
" \"use_instance_norm\": [True, False],\n",
|
||||
" \"use_instance_norm\": [True],\n",
|
||||
" \"use_elu\": [True, False],\n",
|
||||
" \"leaky_relu_alpha\": uniform(0, 0.05),\n",
|
||||
" }\n",
|
||||
|
|
@ -127,7 +272,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.1.-1"
|
||||
"version": "3.12.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
|||
|
|
@ -18,8 +18,8 @@ class HistogramDataset(Dataset):
|
|||
self,
|
||||
/,
|
||||
paths: List[Path],
|
||||
bin_count: int,
|
||||
edit_count: int = 5,
|
||||
bin_count: int = 16,
|
||||
edit_count: int = 12,
|
||||
target_size=(240, 240),
|
||||
delete_corrupt_images: bool = False,
|
||||
cache_path: Optional[Path] = None,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import numpy as np
|
|||
|
||||
def compute_histogram(
|
||||
image: Image.Image | np.ndarray,
|
||||
bins: int,
|
||||
bins: int = 16,
|
||||
value_range=(0, 256),
|
||||
normalize: bool = True,
|
||||
) -> np.ndarray:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue