Less clutter

This commit is contained in:
Andras Schmelczer 2024-09-05 22:26:21 +01:00
parent 47fce35a45
commit aecd7ec9cb
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
3 changed files with 85 additions and 48 deletions

View file

@ -9,7 +9,7 @@ TRAIN_SIZE = 0.95
CACHE_PATH = Path("/mnt/wsl/PHYSICALDRIVE2/data/cache") CACHE_PATH = Path("/mnt/wsl/PHYSICALDRIVE2/data/cache")
MODELS_PATH = Path("/home/andras/projects/bipolaroid/saved_models") MODELS_PATH = Path("/home/andras/projects/bipolaroid/saved_models")
LOGS_PATH = Path("/home/andras/projects/bipolaroid/logs") LOGS_PATH = Path("/home/andras/projects/bipolaroid/logs")
RUNS_PATH = Path("/home/andras/projects/bipolaroid/runs") RUNS_PATH = Path("/home/andras/projects/bipolaroid/runs2")
for path in [CACHE_PATH, MODELS_PATH, LOGS_PATH, RUNS_PATH]: for path in [CACHE_PATH, MODELS_PATH, LOGS_PATH, RUNS_PATH]:

View file

@ -11,7 +11,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"2024-09-03 22:20:43,878\tINFO worker.py:1774 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n" "2024-09-05 22:18:20,653\tINFO worker.py:1774 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://172.29.235.222:8265 \u001b[39m\u001b[22m\n"
] ]
}, },
{ {
@ -27,16 +27,11 @@
], ],
"source": [ "source": [
"import os\n", "import os\n",
"import matplotlib\n",
"\n", "\n",
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = (\n", "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = (\n",
" \"expandable_segments:True\" # avoid fragmented CUDA memory\n", " \"expandable_segments:True\" # avoid fragmented CUDA memory\n",
")\n", ")\n",
"\n", "\n",
"matplotlib.use(\n",
" \"agg\"\n",
") # avoid \"UserWarning: Starting a Matplotlib GUI outside of the main thread will likely fail\" warnings\n",
"\n",
"from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA\n", "from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA\n",
"from utils import set_up_logging\n", "from utils import set_up_logging\n",
"\n", "\n",
@ -50,9 +45,9 @@
"from ray.air import RunConfig\n", "from ray.air import RunConfig\n",
"\n", "\n",
"\n", "\n",
"TRIAL_COUNT = 100\n", "TRIAL_COUNT = 50\n",
"CHUNK_COUNT = 40\n", "EPOCH_COUNT = 4\n",
"EPOCH_COUNT = 2\n", "CHUNK_COUNT = EPOCH_COUNT * 40\n",
"\n", "\n",
"ray.init(include_dashboard=True, dashboard_host=\"0.0.0.0\")\n", "ray.init(include_dashboard=True, dashboard_host=\"0.0.0.0\")\n",
"\n", "\n",
@ -74,29 +69,72 @@
" <h3>Tune Status</h3>\n", " <h3>Tune Status</h3>\n",
" <table>\n", " <table>\n",
"<tbody>\n", "<tbody>\n",
"<tr><td>Current time:</td><td>2024-09-01 22:06:10</td></tr>\n", "<tr><td>Current time:</td><td>2024-09-05 22:23:47</td></tr>\n",
"<tr><td>Running for: </td><td>00:02:06.64 </td></tr>\n", "<tr><td>Running for: </td><td>00:05:24.00 </td></tr>\n",
"<tr><td>Memory: </td><td>22.4/47.0 GiB </td></tr>\n", "<tr><td>Memory: </td><td>22.7/54.9 GiB </td></tr>\n",
"</tbody>\n", "</tbody>\n",
"</table>\n", "</table>\n",
" </div>\n", " </div>\n",
" <div class=\"vDivider\"></div>\n", " <div class=\"vDivider\"></div>\n",
" <div class=\"systemInfo\">\n", " <div class=\"systemInfo\">\n",
" <h3>System Info</h3>\n", " <h3>System Info</h3>\n",
" Using AsyncHyperBand: num_stopped=1<br>Bracket: Iter 2.000: -5.516964137554169<br>Logical resource usage: 32.0/32 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)\n", " Using AsyncHyperBand: num_stopped=0<br>Bracket: Iter 128.000: None | Iter 32.000: None | Iter 8.000: None | Iter 2.000: None<br>Logical resource usage: 32.0/32 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)\n",
" </div>\n", " </div>\n",
" \n", " <div class=\"vDivider\"></div>\n",
"<div class=\"messages\">\n",
" <h3>Messages</h3>\n",
" \n",
" ... 30 more trials not shown (30 PENDING)\n",
" \n",
"</div>\n",
"<style>\n",
".messages {\n",
" color: var(--jp-ui-font-color1);\n",
" display: flex;\n",
" flex-direction: column;\n",
" padding-left: 1em;\n",
" overflow-y: auto;\n",
"}\n",
".messages h3 {\n",
" font-weight: bold;\n",
"}\n",
".vDivider {\n",
" border-left-width: var(--jp-border-width);\n",
" border-left-color: var(--jp-border-color0);\n",
" border-left-style: solid;\n",
" margin: 0.5em 1em 0.5em 1em;\n",
"}\n",
"</style>\n",
"\n",
" </div>\n", " </div>\n",
" <div class=\"hDivider\"></div>\n", " <div class=\"hDivider\"></div>\n",
" <div class=\"trialStatus\">\n", " <div class=\"trialStatus\">\n",
" <h3>Trial Status</h3>\n", " <h3>Trial Status</h3>\n",
" <table>\n", " <table>\n",
"<thead>\n", "<thead>\n",
"<tr><th>Trial name </th><th>status </th><th>loc </th><th style=\"text-align: right;\"> batch_size</th><th style=\"text-align: right;\"> dropout_prob</th><th style=\"text-align: right;\"> elu_alpha</th><th>features </th><th style=\"text-align: right;\"> kernel_size</th><th style=\"text-align: right;\"> leaky_relu_alpha</th><th style=\"text-align: right;\"> leaky_relu_slope</th><th style=\"text-align: right;\"> learning_rate</th><th>model_type </th><th style=\"text-align: right;\"> scheduler_gamma</th><th>use_elu </th><th>use_residual </th><th style=\"text-align: right;\"> iter</th><th style=\"text-align: right;\"> total time (s)</th><th style=\"text-align: right;\"> chunk_test_loss</th><th style=\"text-align: right;\"> chunk_training_loss</th></tr>\n", "<tr><th>Trial name </th><th>status </th><th>loc </th><th style=\"text-align: right;\"> dropout_prob</th><th>features </th><th style=\"text-align: right;\"> kernel_size</th><th style=\"text-align: right;\"> leaky_relu_alpha</th><th style=\"text-align: right;\"> leaky_relu_slope</th><th style=\"text-align: right;\"> learning_rate</th><th style=\"text-align: right;\"> scheduler_gamma</th></tr>\n",
"</thead>\n", "</thead>\n",
"<tbody>\n", "<tbody>\n",
"<tr><td>train_with_ray_b7d3c_00000</td><td>TERMINATED</td><td>172.29.235.222:1134109</td><td style=\"text-align: right;\"> 32</td><td style=\"text-align: right;\"> 0.00395892</td><td style=\"text-align: right;\"> 1.35107</td><td>[8, 16, 32] </td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.0448715 </td><td style=\"text-align: right;\"> 0.00664526</td><td style=\"text-align: right;\"> 0.0029226 </td><td>HistogramNet</td><td style=\"text-align: right;\"> 0.816752</td><td>True </td><td>True </td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 70.652 </td><td style=\"text-align: right;\"> 5.22149</td><td style=\"text-align: right;\"> 56.5472</td></tr>\n", "<tr><td>train_with_ray_61b01_00000</td><td>RUNNING </td><td>172.29.235.222:3274755</td><td style=\"text-align: right;\"> 0.0240677 </td><td>[16, 32, 64, 12_25c0</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0147726 </td><td style=\"text-align: right;\"> 0.0164835 </td><td style=\"text-align: right;\"> 0.00409419 </td><td style=\"text-align: right;\"> 0.977701</td></tr>\n",
"<tr><td>train_with_ray_b7d3c_00001</td><td>TERMINATED</td><td>172.29.235.222:1140440</td><td style=\"text-align: right;\"> 32</td><td style=\"text-align: right;\"> 0.0439061 </td><td style=\"text-align: right;\"> 1.74642</td><td>[8, 8, 8, 8, 8,_c140</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.00579656</td><td style=\"text-align: right;\"> 0.0125715 </td><td style=\"text-align: right;\"> 0.00155182</td><td>HistogramNet</td><td style=\"text-align: right;\"> 0.898059</td><td>True </td><td>False </td><td style=\"text-align: right;\"> 2</td><td style=\"text-align: right;\"> 45.4228</td><td style=\"text-align: right;\"> 5.79311</td><td style=\"text-align: right;\"> 64.0481</td></tr>\n", "<tr><td>train_with_ray_61b01_00001</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0252525 </td><td>[16, 32, 64, 12_9940</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0607274 </td><td style=\"text-align: right;\"> 0.00769014 </td><td style=\"text-align: right;\"> 0.00175452 </td><td style=\"text-align: right;\"> 0.971626</td></tr>\n",
"<tr><td>train_with_ray_61b01_00002</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0492015 </td><td>[16, 32, 64, 12_0940</td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.0522751 </td><td style=\"text-align: right;\"> 0.00532775 </td><td style=\"text-align: right;\"> 0.000102449</td><td style=\"text-align: right;\"> 0.985839</td></tr>\n",
"<tr><td>train_with_ray_61b01_00003</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0091843 </td><td>[16, 32, 64, 12_8e80</td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.0215584 </td><td style=\"text-align: right;\"> 0.00986109 </td><td style=\"text-align: right;\"> 0.0046811 </td><td style=\"text-align: right;\"> 0.969517</td></tr>\n",
"<tr><td>train_with_ray_61b01_00004</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0200707 </td><td>[16, 32, 64, 12_e200</td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.00839403</td><td style=\"text-align: right;\"> 0.0236732 </td><td style=\"text-align: right;\"> 0.00456458 </td><td style=\"text-align: right;\"> 0.949592</td></tr>\n",
"<tr><td>train_with_ray_61b01_00005</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.071518 </td><td>[16, 32, 64, 12_4200</td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.00955745</td><td style=\"text-align: right;\"> 0.000338297</td><td style=\"text-align: right;\"> 0.0012211 </td><td style=\"text-align: right;\"> 0.959114</td></tr>\n",
"<tr><td>train_with_ray_61b01_00006</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.00712943</td><td>[16, 32, 64, 12_f8c0</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0128474 </td><td style=\"text-align: right;\"> 0.0214861 </td><td style=\"text-align: right;\"> 0.000319318</td><td style=\"text-align: right;\"> 0.954027</td></tr>\n",
"<tr><td>train_with_ray_61b01_00007</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0554306 </td><td>[16, 32, 64, 12_1900</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0686442 </td><td style=\"text-align: right;\"> 0.0103842 </td><td style=\"text-align: right;\"> 0.00856337 </td><td style=\"text-align: right;\"> 0.966035</td></tr>\n",
"<tr><td>train_with_ray_61b01_00008</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0637642 </td><td>[16, 32, 64, 12_8f80</td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.0125659 </td><td style=\"text-align: right;\"> 0.00386933 </td><td style=\"text-align: right;\"> 0.00135947 </td><td style=\"text-align: right;\"> 0.955525</td></tr>\n",
"<tr><td>train_with_ray_61b01_00009</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0526107 </td><td>[16, 32, 64, 12_16c0</td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.00576232</td><td style=\"text-align: right;\"> 0.0120612 </td><td style=\"text-align: right;\"> 0.000143394</td><td style=\"text-align: right;\"> 0.983664</td></tr>\n",
"<tr><td>train_with_ray_61b01_00010</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0321631 </td><td>[16, 32, 64, 12_fc00</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0100027 </td><td style=\"text-align: right;\"> 0.016943 </td><td style=\"text-align: right;\"> 0.000461478</td><td style=\"text-align: right;\"> 0.981847</td></tr>\n",
"<tr><td>train_with_ray_61b01_00011</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.015057 </td><td>[16, 32, 64, 12_f080</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0599078 </td><td style=\"text-align: right;\"> 0.0293855 </td><td style=\"text-align: right;\"> 0.00182771 </td><td style=\"text-align: right;\"> 0.965469</td></tr>\n",
"<tr><td>train_with_ray_61b01_00012</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0995986 </td><td>[16, 32, 64, 12_d240</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0349584 </td><td style=\"text-align: right;\"> 0.000219269</td><td style=\"text-align: right;\"> 0.000288736</td><td style=\"text-align: right;\"> 0.966034</td></tr>\n",
"<tr><td>train_with_ray_61b01_00013</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0227281 </td><td>[16, 32, 64, 12_c600</td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.0324662 </td><td style=\"text-align: right;\"> 0.0125853 </td><td style=\"text-align: right;\"> 0.000173035</td><td style=\"text-align: right;\"> 0.948412</td></tr>\n",
"<tr><td>train_with_ray_61b01_00014</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0307992 </td><td>[16, 32, 64, 12_d040</td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.062771 </td><td style=\"text-align: right;\"> 0.0198744 </td><td style=\"text-align: right;\"> 0.000639558</td><td style=\"text-align: right;\"> 0.94965 </td></tr>\n",
"<tr><td>train_with_ray_61b01_00015</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0896525 </td><td>[16, 32, 64, 12_2e80</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0150036 </td><td style=\"text-align: right;\"> 0.0171698 </td><td style=\"text-align: right;\"> 0.00206888 </td><td style=\"text-align: right;\"> 0.99602 </td></tr>\n",
"<tr><td>train_with_ray_61b01_00016</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0932697 </td><td>[16, 32, 64, 12_6540</td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.0699308 </td><td style=\"text-align: right;\"> 0.0154819 </td><td style=\"text-align: right;\"> 0.00384655 </td><td style=\"text-align: right;\"> 0.954913</td></tr>\n",
"<tr><td>train_with_ray_61b01_00017</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.00240956</td><td>[16, 32, 64, 12_8640</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0318032 </td><td style=\"text-align: right;\"> 0.00959876 </td><td style=\"text-align: right;\"> 0.000135257</td><td style=\"text-align: right;\"> 0.952462</td></tr>\n",
"<tr><td>train_with_ray_61b01_00018</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0883309 </td><td>[16, 32, 64, 12_b800</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0455664 </td><td style=\"text-align: right;\"> 0.00565 </td><td style=\"text-align: right;\"> 0.00066994 </td><td style=\"text-align: right;\"> 0.997372</td></tr>\n",
"<tr><td>train_with_ray_61b01_00019</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0500287 </td><td>[16, 32, 64, 12_8440</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0674193 </td><td style=\"text-align: right;\"> 0.0220701 </td><td style=\"text-align: right;\"> 0.000987128</td><td style=\"text-align: right;\"> 0.94602 </td></tr>\n",
"</tbody>\n", "</tbody>\n",
"</table>\n", "</table>\n",
" </div>\n", " </div>\n",
@ -140,33 +178,21 @@
"output_type": "display_data" "output_type": "display_data"
}, },
{ {
"name": "stderr", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"\u001b[36m(train_with_ray pid=1134109)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00000_0_batch_size=32,dropout_prob=0.0040,elu_alpha=1.3511,features=8_16_32,kernel_size=5,leaky_relu_alpha=0._2024-09-01_22-04-03/checkpoint_000000)\n", "\u001b[33m(raylet)\u001b[0m Warning: The actor ImplicitFunc is very large (26 MiB). Check that its definition is not implicitly capturing a large array or other object in scope. Tip: use ray.put() to put large objects in the Ray object store.\n"
"\u001b[36m(train_with_ray pid=1134109)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00000_0_batch_size=32,dropout_prob=0.0040,elu_alpha=1.3511,features=8_16_32,kernel_size=5,leaky_relu_alpha=0._2024-09-01_22-04-03/checkpoint_000001)\n",
"\u001b[36m(train_with_ray pid=1134109)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00000_0_batch_size=32,dropout_prob=0.0040,elu_alpha=1.3511,features=8_16_32,kernel_size=5,leaky_relu_alpha=0._2024-09-01_22-04-03/checkpoint_000002)\n",
"\u001b[36m(train_with_ray pid=1140440)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00001_1_batch_size=32,dropout_prob=0.0439,elu_alpha=1.7464,features=8_8_8_8_8_8_8,kernel_size=3,leaky_relu_al_2024-09-01_22-04-03/checkpoint_000000)\n",
"2024-09-01 22:06:10,132\tINFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/andras/projects/bipolaroid/runs5/tune' in 0.0020s.\n",
"2024-09-01 22:06:10,135\tINFO tune.py:1041 -- Total run time: 126.67 seconds (126.64 seconds for the tuning loop).\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[36m(train_with_ray pid=1140440)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00001_1_batch_size=32,dropout_prob=0.0439,elu_alpha=1.7464,features=8_8_8_8_8_8_8,kernel_size=3,leaky_relu_al_2024-09-01_22-04-03/checkpoint_000001)\n"
] ]
} }
], ],
"source": [ "source": [
"config = {\n", "config = {\n",
" \"batch_size\": 32,\n", " \"batch_size\": 48,\n",
" \"edit_count\": EPOCH_COUNT,\n", " \"edit_count\": EPOCH_COUNT,\n",
" \"bin_count\": 32,\n", " \"bin_count\": 32,\n",
" \"learning_rate\": tune.loguniform(5e-4, 5e-3),\n", " \"learning_rate\": tune.loguniform(1e-4, 1e-2),\n",
" \"scheduler_gamma\": tune.uniform(0.8, 0.95),\n", " \"scheduler_gamma\": tune.uniform(0.94, 0.9999),\n",
" \"elu_alpha\": tune.uniform(0.5, 2),\n", " # \"elu_alpha\": tune.uniform(0.5, 2),\n",
" \"leaky_relu_slope\": tune.uniform(0, 0.03),\n", " \"leaky_relu_slope\": tune.uniform(0, 0.03),\n",
" \"dropout_prob\": tune.uniform(0, 0.1),\n", " \"dropout_prob\": tune.uniform(0, 0.1),\n",
" \"chunk_count\": CHUNK_COUNT,\n", " \"chunk_count\": CHUNK_COUNT,\n",
@ -174,25 +200,29 @@
" [\n", " [\n",
" [16, 32, 64],\n", " [16, 32, 64],\n",
" [16, 32, 64, 128],\n", " [16, 32, 64, 128],\n",
" [16, 32, 64, 128, 256],\n",
" [16, 32, 32, 32, 64],\n",
" [32, 64],\n", " [32, 64],\n",
" [32, 128],\n", " [32, 128],\n",
" [8, 16, 32],\n", " [32, 64, 128],\n",
" [8, 8, 8, 8, 8],\n", " [32, 64, 128, 256],\n",
" [8, 8, 8, 8, 8, 8, 8],\n",
" [16, 16, 16],\n",
" [16, 16, 16, 16, 16],\n", " [16, 16, 16, 16, 16],\n",
" [16, 16, 16, 16, 16, 16, 16, 16],\n",
" [16, 16, 16, 16, 16, 16, 16, 16, 16, 16],\n",
" [32, 32, 32],\n", " [32, 32, 32],\n",
" [32, 32, 32, 32],\n", " [32, 32, 32, 32],\n",
" [64, 64],\n",
" [64, 64, 64],\n", " [64, 64, 64],\n",
" [64, 64, 64, 64],\n",
" [64, 64, 64, 64, 64],\n",
" [256, 64, 256],\n",
" ]\n", " ]\n",
" ),\n", " ),\n",
" \"use_residual\": tune.choice([True, False]),\n", " \"use_residual\": True,\n",
" \"kernel_size\": tune.choice([3, 5]),\n", " \"kernel_size\": tune.choice([3, 5]),\n",
" \"model_type\": tune.choice([\"HistogramNet\"]),\n", " \"model_type\": \"HistogramNet\",\n",
" \"use_instance_norm\": True,\n", " \"use_instance_norm\": True,\n",
" \"use_elu\": tune.choice([True, False]),\n", " \"use_elu\": False,\n",
" \"leaky_relu_alpha\": tune.uniform(0, 0.05),\n", " \"leaky_relu_alpha\": tune.uniform(0, 0.07),\n",
"}\n", "}\n",
"\n", "\n",
"tuner = tune.Tuner(\n", "tuner = tune.Tuner(\n",
@ -223,7 +253,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -248,7 +278,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [

View file

@ -15,6 +15,7 @@ import logging
from more_itertools import divide from more_itertools import divide
EPSILON = 1e-5 EPSILON = 1e-5
EXAMPLE_COUNT = 5
def train_with_ray_factory( def train_with_ray_factory(
@ -30,6 +31,12 @@ def train_with_ray_factory(
chunk_count: int, chunk_count: int,
**_, **_,
) -> torch.nn.Module: ) -> torch.nn.Module:
import matplotlib
matplotlib.use(
"agg"
) # avoid "UserWarning: Starting a Matplotlib GUI outside of the main thread will likely fail" warnings
model, optimizer, scheduler, start_chunk_id, run_name = ( model, optimizer, scheduler, start_chunk_id, run_name = (
load_or_create_state( load_or_create_state(
device=device, device=device,
@ -48,7 +55,7 @@ def train_with_ray_factory(
test_data_loader = get_data_loader( test_data_loader = get_data_loader(
test_data_paths, **{**hyperparameters, "edit_count": 1} test_data_paths, **{**hyperparameters, "edit_count": 1}
) )
examples = next(iter(test_data_loader)) examples = next(iter(test_data_loader))[:EXAMPLE_COUNT]
with SummaryWriter(log_dir=log_dir / run_name) as writer: with SummaryWriter(log_dir=log_dir / run_name) as writer:
writer.add_graph(model, examples[0].to(device)) writer.add_graph(model, examples[0].to(device))