Less clutter
This commit is contained in:
parent
47fce35a45
commit
aecd7ec9cb
3 changed files with 85 additions and 48 deletions
|
|
@ -9,7 +9,7 @@ TRAIN_SIZE = 0.95
|
|||
CACHE_PATH = Path("/mnt/wsl/PHYSICALDRIVE2/data/cache")
|
||||
MODELS_PATH = Path("/home/andras/projects/bipolaroid/saved_models")
|
||||
LOGS_PATH = Path("/home/andras/projects/bipolaroid/logs")
|
||||
RUNS_PATH = Path("/home/andras/projects/bipolaroid/runs")
|
||||
RUNS_PATH = Path("/home/andras/projects/bipolaroid/runs2")
|
||||
|
||||
|
||||
for path in [CACHE_PATH, MODELS_PATH, LOGS_PATH, RUNS_PATH]:
|
||||
|
|
|
|||
120
src/train.ipynb
120
src/train.ipynb
|
|
@ -11,7 +11,7 @@
|
|||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-09-03 22:20:43,878\tINFO worker.py:1774 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n"
|
||||
"2024-09-05 22:18:20,653\tINFO worker.py:1774 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://172.29.235.222:8265 \u001b[39m\u001b[22m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -27,16 +27,11 @@
|
|||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import matplotlib\n",
|
||||
"\n",
|
||||
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = (\n",
|
||||
" \"expandable_segments:True\" # avoid fragmented CUDA memory\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"matplotlib.use(\n",
|
||||
" \"agg\"\n",
|
||||
") # avoid \"UserWarning: Starting a Matplotlib GUI outside of the main thread will likely fail\" warnings\n",
|
||||
"\n",
|
||||
"from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA\n",
|
||||
"from utils import set_up_logging\n",
|
||||
"\n",
|
||||
|
|
@ -50,9 +45,9 @@
|
|||
"from ray.air import RunConfig\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"TRIAL_COUNT = 100\n",
|
||||
"CHUNK_COUNT = 40\n",
|
||||
"EPOCH_COUNT = 2\n",
|
||||
"TRIAL_COUNT = 50\n",
|
||||
"EPOCH_COUNT = 4\n",
|
||||
"CHUNK_COUNT = EPOCH_COUNT * 40\n",
|
||||
"\n",
|
||||
"ray.init(include_dashboard=True, dashboard_host=\"0.0.0.0\")\n",
|
||||
"\n",
|
||||
|
|
@ -74,29 +69,72 @@
|
|||
" <h3>Tune Status</h3>\n",
|
||||
" <table>\n",
|
||||
"<tbody>\n",
|
||||
"<tr><td>Current time:</td><td>2024-09-01 22:06:10</td></tr>\n",
|
||||
"<tr><td>Running for: </td><td>00:02:06.64 </td></tr>\n",
|
||||
"<tr><td>Memory: </td><td>22.4/47.0 GiB </td></tr>\n",
|
||||
"<tr><td>Current time:</td><td>2024-09-05 22:23:47</td></tr>\n",
|
||||
"<tr><td>Running for: </td><td>00:05:24.00 </td></tr>\n",
|
||||
"<tr><td>Memory: </td><td>22.7/54.9 GiB </td></tr>\n",
|
||||
"</tbody>\n",
|
||||
"</table>\n",
|
||||
" </div>\n",
|
||||
" <div class=\"vDivider\"></div>\n",
|
||||
" <div class=\"systemInfo\">\n",
|
||||
" <h3>System Info</h3>\n",
|
||||
" Using AsyncHyperBand: num_stopped=1<br>Bracket: Iter 2.000: -5.516964137554169<br>Logical resource usage: 32.0/32 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)\n",
|
||||
" Using AsyncHyperBand: num_stopped=0<br>Bracket: Iter 128.000: None | Iter 32.000: None | Iter 8.000: None | Iter 2.000: None<br>Logical resource usage: 32.0/32 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:G)\n",
|
||||
" </div>\n",
|
||||
" <div class=\"vDivider\"></div>\n",
|
||||
"<div class=\"messages\">\n",
|
||||
" <h3>Messages</h3>\n",
|
||||
" \n",
|
||||
" ... 30 more trials not shown (30 PENDING)\n",
|
||||
" \n",
|
||||
"</div>\n",
|
||||
"<style>\n",
|
||||
".messages {\n",
|
||||
" color: var(--jp-ui-font-color1);\n",
|
||||
" display: flex;\n",
|
||||
" flex-direction: column;\n",
|
||||
" padding-left: 1em;\n",
|
||||
" overflow-y: auto;\n",
|
||||
"}\n",
|
||||
".messages h3 {\n",
|
||||
" font-weight: bold;\n",
|
||||
"}\n",
|
||||
".vDivider {\n",
|
||||
" border-left-width: var(--jp-border-width);\n",
|
||||
" border-left-color: var(--jp-border-color0);\n",
|
||||
" border-left-style: solid;\n",
|
||||
" margin: 0.5em 1em 0.5em 1em;\n",
|
||||
"}\n",
|
||||
"</style>\n",
|
||||
"\n",
|
||||
" </div>\n",
|
||||
" <div class=\"hDivider\"></div>\n",
|
||||
" <div class=\"trialStatus\">\n",
|
||||
" <h3>Trial Status</h3>\n",
|
||||
" <table>\n",
|
||||
"<thead>\n",
|
||||
"<tr><th>Trial name </th><th>status </th><th>loc </th><th style=\"text-align: right;\"> batch_size</th><th style=\"text-align: right;\"> dropout_prob</th><th style=\"text-align: right;\"> elu_alpha</th><th>features </th><th style=\"text-align: right;\"> kernel_size</th><th style=\"text-align: right;\"> leaky_relu_alpha</th><th style=\"text-align: right;\"> leaky_relu_slope</th><th style=\"text-align: right;\"> learning_rate</th><th>model_type </th><th style=\"text-align: right;\"> scheduler_gamma</th><th>use_elu </th><th>use_residual </th><th style=\"text-align: right;\"> iter</th><th style=\"text-align: right;\"> total time (s)</th><th style=\"text-align: right;\"> chunk_test_loss</th><th style=\"text-align: right;\"> chunk_training_loss</th></tr>\n",
|
||||
"<tr><th>Trial name </th><th>status </th><th>loc </th><th style=\"text-align: right;\"> dropout_prob</th><th>features </th><th style=\"text-align: right;\"> kernel_size</th><th style=\"text-align: right;\"> leaky_relu_alpha</th><th style=\"text-align: right;\"> leaky_relu_slope</th><th style=\"text-align: right;\"> learning_rate</th><th style=\"text-align: right;\"> scheduler_gamma</th></tr>\n",
|
||||
"</thead>\n",
|
||||
"<tbody>\n",
|
||||
"<tr><td>train_with_ray_b7d3c_00000</td><td>TERMINATED</td><td>172.29.235.222:1134109</td><td style=\"text-align: right;\"> 32</td><td style=\"text-align: right;\"> 0.00395892</td><td style=\"text-align: right;\"> 1.35107</td><td>[8, 16, 32] </td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.0448715 </td><td style=\"text-align: right;\"> 0.00664526</td><td style=\"text-align: right;\"> 0.0029226 </td><td>HistogramNet</td><td style=\"text-align: right;\"> 0.816752</td><td>True </td><td>True </td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 70.652 </td><td style=\"text-align: right;\"> 5.22149</td><td style=\"text-align: right;\"> 56.5472</td></tr>\n",
|
||||
"<tr><td>train_with_ray_b7d3c_00001</td><td>TERMINATED</td><td>172.29.235.222:1140440</td><td style=\"text-align: right;\"> 32</td><td style=\"text-align: right;\"> 0.0439061 </td><td style=\"text-align: right;\"> 1.74642</td><td>[8, 8, 8, 8, 8,_c140</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.00579656</td><td style=\"text-align: right;\"> 0.0125715 </td><td style=\"text-align: right;\"> 0.00155182</td><td>HistogramNet</td><td style=\"text-align: right;\"> 0.898059</td><td>True </td><td>False </td><td style=\"text-align: right;\"> 2</td><td style=\"text-align: right;\"> 45.4228</td><td style=\"text-align: right;\"> 5.79311</td><td style=\"text-align: right;\"> 64.0481</td></tr>\n",
|
||||
"<tr><td>train_with_ray_61b01_00000</td><td>RUNNING </td><td>172.29.235.222:3274755</td><td style=\"text-align: right;\"> 0.0240677 </td><td>[16, 32, 64, 12_25c0</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0147726 </td><td style=\"text-align: right;\"> 0.0164835 </td><td style=\"text-align: right;\"> 0.00409419 </td><td style=\"text-align: right;\"> 0.977701</td></tr>\n",
|
||||
"<tr><td>train_with_ray_61b01_00001</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0252525 </td><td>[16, 32, 64, 12_9940</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0607274 </td><td style=\"text-align: right;\"> 0.00769014 </td><td style=\"text-align: right;\"> 0.00175452 </td><td style=\"text-align: right;\"> 0.971626</td></tr>\n",
|
||||
"<tr><td>train_with_ray_61b01_00002</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0492015 </td><td>[16, 32, 64, 12_0940</td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.0522751 </td><td style=\"text-align: right;\"> 0.00532775 </td><td style=\"text-align: right;\"> 0.000102449</td><td style=\"text-align: right;\"> 0.985839</td></tr>\n",
|
||||
"<tr><td>train_with_ray_61b01_00003</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0091843 </td><td>[16, 32, 64, 12_8e80</td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.0215584 </td><td style=\"text-align: right;\"> 0.00986109 </td><td style=\"text-align: right;\"> 0.0046811 </td><td style=\"text-align: right;\"> 0.969517</td></tr>\n",
|
||||
"<tr><td>train_with_ray_61b01_00004</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0200707 </td><td>[16, 32, 64, 12_e200</td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.00839403</td><td style=\"text-align: right;\"> 0.0236732 </td><td style=\"text-align: right;\"> 0.00456458 </td><td style=\"text-align: right;\"> 0.949592</td></tr>\n",
|
||||
"<tr><td>train_with_ray_61b01_00005</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.071518 </td><td>[16, 32, 64, 12_4200</td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.00955745</td><td style=\"text-align: right;\"> 0.000338297</td><td style=\"text-align: right;\"> 0.0012211 </td><td style=\"text-align: right;\"> 0.959114</td></tr>\n",
|
||||
"<tr><td>train_with_ray_61b01_00006</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.00712943</td><td>[16, 32, 64, 12_f8c0</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0128474 </td><td style=\"text-align: right;\"> 0.0214861 </td><td style=\"text-align: right;\"> 0.000319318</td><td style=\"text-align: right;\"> 0.954027</td></tr>\n",
|
||||
"<tr><td>train_with_ray_61b01_00007</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0554306 </td><td>[16, 32, 64, 12_1900</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0686442 </td><td style=\"text-align: right;\"> 0.0103842 </td><td style=\"text-align: right;\"> 0.00856337 </td><td style=\"text-align: right;\"> 0.966035</td></tr>\n",
|
||||
"<tr><td>train_with_ray_61b01_00008</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0637642 </td><td>[16, 32, 64, 12_8f80</td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.0125659 </td><td style=\"text-align: right;\"> 0.00386933 </td><td style=\"text-align: right;\"> 0.00135947 </td><td style=\"text-align: right;\"> 0.955525</td></tr>\n",
|
||||
"<tr><td>train_with_ray_61b01_00009</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0526107 </td><td>[16, 32, 64, 12_16c0</td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.00576232</td><td style=\"text-align: right;\"> 0.0120612 </td><td style=\"text-align: right;\"> 0.000143394</td><td style=\"text-align: right;\"> 0.983664</td></tr>\n",
|
||||
"<tr><td>train_with_ray_61b01_00010</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0321631 </td><td>[16, 32, 64, 12_fc00</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0100027 </td><td style=\"text-align: right;\"> 0.016943 </td><td style=\"text-align: right;\"> 0.000461478</td><td style=\"text-align: right;\"> 0.981847</td></tr>\n",
|
||||
"<tr><td>train_with_ray_61b01_00011</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.015057 </td><td>[16, 32, 64, 12_f080</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0599078 </td><td style=\"text-align: right;\"> 0.0293855 </td><td style=\"text-align: right;\"> 0.00182771 </td><td style=\"text-align: right;\"> 0.965469</td></tr>\n",
|
||||
"<tr><td>train_with_ray_61b01_00012</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0995986 </td><td>[16, 32, 64, 12_d240</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0349584 </td><td style=\"text-align: right;\"> 0.000219269</td><td style=\"text-align: right;\"> 0.000288736</td><td style=\"text-align: right;\"> 0.966034</td></tr>\n",
|
||||
"<tr><td>train_with_ray_61b01_00013</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0227281 </td><td>[16, 32, 64, 12_c600</td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.0324662 </td><td style=\"text-align: right;\"> 0.0125853 </td><td style=\"text-align: right;\"> 0.000173035</td><td style=\"text-align: right;\"> 0.948412</td></tr>\n",
|
||||
"<tr><td>train_with_ray_61b01_00014</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0307992 </td><td>[16, 32, 64, 12_d040</td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.062771 </td><td style=\"text-align: right;\"> 0.0198744 </td><td style=\"text-align: right;\"> 0.000639558</td><td style=\"text-align: right;\"> 0.94965 </td></tr>\n",
|
||||
"<tr><td>train_with_ray_61b01_00015</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0896525 </td><td>[16, 32, 64, 12_2e80</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0150036 </td><td style=\"text-align: right;\"> 0.0171698 </td><td style=\"text-align: right;\"> 0.00206888 </td><td style=\"text-align: right;\"> 0.99602 </td></tr>\n",
|
||||
"<tr><td>train_with_ray_61b01_00016</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0932697 </td><td>[16, 32, 64, 12_6540</td><td style=\"text-align: right;\"> 5</td><td style=\"text-align: right;\"> 0.0699308 </td><td style=\"text-align: right;\"> 0.0154819 </td><td style=\"text-align: right;\"> 0.00384655 </td><td style=\"text-align: right;\"> 0.954913</td></tr>\n",
|
||||
"<tr><td>train_with_ray_61b01_00017</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.00240956</td><td>[16, 32, 64, 12_8640</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0318032 </td><td style=\"text-align: right;\"> 0.00959876 </td><td style=\"text-align: right;\"> 0.000135257</td><td style=\"text-align: right;\"> 0.952462</td></tr>\n",
|
||||
"<tr><td>train_with_ray_61b01_00018</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0883309 </td><td>[16, 32, 64, 12_b800</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0455664 </td><td style=\"text-align: right;\"> 0.00565 </td><td style=\"text-align: right;\"> 0.00066994 </td><td style=\"text-align: right;\"> 0.997372</td></tr>\n",
|
||||
"<tr><td>train_with_ray_61b01_00019</td><td>PENDING </td><td> </td><td style=\"text-align: right;\"> 0.0500287 </td><td>[16, 32, 64, 12_8440</td><td style=\"text-align: right;\"> 3</td><td style=\"text-align: right;\"> 0.0674193 </td><td style=\"text-align: right;\"> 0.0220701 </td><td style=\"text-align: right;\"> 0.000987128</td><td style=\"text-align: right;\"> 0.94602 </td></tr>\n",
|
||||
"</tbody>\n",
|
||||
"</table>\n",
|
||||
" </div>\n",
|
||||
|
|
@ -140,33 +178,21 @@
|
|||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[36m(train_with_ray pid=1134109)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00000_0_batch_size=32,dropout_prob=0.0040,elu_alpha=1.3511,features=8_16_32,kernel_size=5,leaky_relu_alpha=0._2024-09-01_22-04-03/checkpoint_000000)\n",
|
||||
"\u001b[36m(train_with_ray pid=1134109)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00000_0_batch_size=32,dropout_prob=0.0040,elu_alpha=1.3511,features=8_16_32,kernel_size=5,leaky_relu_alpha=0._2024-09-01_22-04-03/checkpoint_000001)\n",
|
||||
"\u001b[36m(train_with_ray pid=1134109)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00000_0_batch_size=32,dropout_prob=0.0040,elu_alpha=1.3511,features=8_16_32,kernel_size=5,leaky_relu_alpha=0._2024-09-01_22-04-03/checkpoint_000002)\n",
|
||||
"\u001b[36m(train_with_ray pid=1140440)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00001_1_batch_size=32,dropout_prob=0.0439,elu_alpha=1.7464,features=8_8_8_8_8_8_8,kernel_size=3,leaky_relu_al_2024-09-01_22-04-03/checkpoint_000000)\n",
|
||||
"2024-09-01 22:06:10,132\tINFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/andras/projects/bipolaroid/runs5/tune' in 0.0020s.\n",
|
||||
"2024-09-01 22:06:10,135\tINFO tune.py:1041 -- Total run time: 126.67 seconds (126.64 seconds for the tuning loop).\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[36m(train_with_ray pid=1140440)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/andras/projects/bipolaroid/runs5/tune/train_with_ray_b7d3c_00001_1_batch_size=32,dropout_prob=0.0439,elu_alpha=1.7464,features=8_8_8_8_8_8_8,kernel_size=3,leaky_relu_al_2024-09-01_22-04-03/checkpoint_000001)\n"
|
||||
"\u001b[33m(raylet)\u001b[0m Warning: The actor ImplicitFunc is very large (26 MiB). Check that its definition is not implicitly capturing a large array or other object in scope. Tip: use ray.put() to put large objects in the Ray object store.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"config = {\n",
|
||||
" \"batch_size\": 32,\n",
|
||||
" \"batch_size\": 48,\n",
|
||||
" \"edit_count\": EPOCH_COUNT,\n",
|
||||
" \"bin_count\": 32,\n",
|
||||
" \"learning_rate\": tune.loguniform(5e-4, 5e-3),\n",
|
||||
" \"scheduler_gamma\": tune.uniform(0.8, 0.95),\n",
|
||||
" \"elu_alpha\": tune.uniform(0.5, 2),\n",
|
||||
" \"learning_rate\": tune.loguniform(1e-4, 1e-2),\n",
|
||||
" \"scheduler_gamma\": tune.uniform(0.94, 0.9999),\n",
|
||||
" # \"elu_alpha\": tune.uniform(0.5, 2),\n",
|
||||
" \"leaky_relu_slope\": tune.uniform(0, 0.03),\n",
|
||||
" \"dropout_prob\": tune.uniform(0, 0.1),\n",
|
||||
" \"chunk_count\": CHUNK_COUNT,\n",
|
||||
|
|
@ -174,25 +200,29 @@
|
|||
" [\n",
|
||||
" [16, 32, 64],\n",
|
||||
" [16, 32, 64, 128],\n",
|
||||
" [16, 32, 64, 128, 256],\n",
|
||||
" [16, 32, 32, 32, 64],\n",
|
||||
" [32, 64],\n",
|
||||
" [32, 128],\n",
|
||||
" [8, 16, 32],\n",
|
||||
" [8, 8, 8, 8, 8],\n",
|
||||
" [8, 8, 8, 8, 8, 8, 8],\n",
|
||||
" [16, 16, 16],\n",
|
||||
" [32, 64, 128],\n",
|
||||
" [32, 64, 128, 256],\n",
|
||||
" [16, 16, 16, 16, 16],\n",
|
||||
" [16, 16, 16, 16, 16, 16, 16, 16],\n",
|
||||
" [16, 16, 16, 16, 16, 16, 16, 16, 16, 16],\n",
|
||||
" [32, 32, 32],\n",
|
||||
" [32, 32, 32, 32],\n",
|
||||
" [64, 64],\n",
|
||||
" [64, 64, 64],\n",
|
||||
" [64, 64, 64, 64],\n",
|
||||
" [64, 64, 64, 64, 64],\n",
|
||||
" [256, 64, 256],\n",
|
||||
" ]\n",
|
||||
" ),\n",
|
||||
" \"use_residual\": tune.choice([True, False]),\n",
|
||||
" \"use_residual\": True,\n",
|
||||
" \"kernel_size\": tune.choice([3, 5]),\n",
|
||||
" \"model_type\": tune.choice([\"HistogramNet\"]),\n",
|
||||
" \"model_type\": \"HistogramNet\",\n",
|
||||
" \"use_instance_norm\": True,\n",
|
||||
" \"use_elu\": tune.choice([True, False]),\n",
|
||||
" \"leaky_relu_alpha\": tune.uniform(0, 0.05),\n",
|
||||
" \"use_elu\": False,\n",
|
||||
" \"leaky_relu_alpha\": tune.uniform(0, 0.07),\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"tuner = tune.Tuner(\n",
|
||||
|
|
@ -223,7 +253,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
|
@ -248,7 +278,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ import logging
|
|||
from more_itertools import divide
|
||||
|
||||
EPSILON = 1e-5
|
||||
EXAMPLE_COUNT = 5
|
||||
|
||||
|
||||
def train_with_ray_factory(
|
||||
|
|
@ -30,6 +31,12 @@ def train_with_ray_factory(
|
|||
chunk_count: int,
|
||||
**_,
|
||||
) -> torch.nn.Module:
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use(
|
||||
"agg"
|
||||
) # avoid "UserWarning: Starting a Matplotlib GUI outside of the main thread will likely fail" warnings
|
||||
|
||||
model, optimizer, scheduler, start_chunk_id, run_name = (
|
||||
load_or_create_state(
|
||||
device=device,
|
||||
|
|
@ -48,7 +55,7 @@ def train_with_ray_factory(
|
|||
test_data_loader = get_data_loader(
|
||||
test_data_paths, **{**hyperparameters, "edit_count": 1}
|
||||
)
|
||||
examples = next(iter(test_data_loader))
|
||||
examples = next(iter(test_data_loader))[:EXAMPLE_COUNT]
|
||||
|
||||
with SummaryWriter(log_dir=log_dir / run_name) as writer:
|
||||
writer.add_graph(model, examples[0].to(device))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue