Fix errors
This commit is contained in:
parent
7863611f86
commit
35eb747abf
3 changed files with 26 additions and 73 deletions
|
|
@ -47,7 +47,7 @@ def load_model(path: Path, device: torch.device) -> Tuple[nn.Module, Dict[str, A
|
|||
model_path = path.with_suffix(".pth")
|
||||
model = create_model(
|
||||
type=hyperparameters["model_type"],
|
||||
**hyperparameters,
|
||||
hyperparameters=hyperparameters,
|
||||
device=device,
|
||||
)
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
|
|
|
|||
|
|
@ -33,5 +33,10 @@ class SimpleCNN(nn.Module):
|
|||
x = F.relu(self.conv4(x))
|
||||
x = F.relu(self.conv5(x))
|
||||
x = self.conv6(x)
|
||||
sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
|
||||
return x / sum
|
||||
return self._normalize(x)
|
||||
|
||||
@staticmethod
|
||||
def _normalize(x):
|
||||
x = torch.clamp(x, min=0)
|
||||
x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
|
||||
return x / (x_sum + 1e-5)
|
||||
|
|
|
|||
|
|
@ -2,27 +2,17 @@
|
|||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"metadata": {}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Using device cuda:0'"
|
||||
]
|
||||
},
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import os\n",
|
||||
"from utils import set_up_logging\n",
|
||||
"from config import LOGS_PATH\n",
|
||||
"from training import train, random_hparam_search\n",
|
||||
"from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA, MODELS_PATH\n",
|
||||
"\n",
|
||||
"set_up_logging(LOGS_PATH)\n",
|
||||
"\n",
|
||||
|
|
@ -33,14 +23,10 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# from training import train\n",
|
||||
"from config import RUNS_PATH, TRAIN_DATA, TEST_DATA, MODELS_PATH\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# hparams = {\n",
|
||||
"# \"batch_size\": 64,\n",
|
||||
"# \"edit_count\": 12,\n",
|
||||
|
|
@ -65,55 +51,11 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-06-27 22:27:47,571 - INFO - Testing model Dummy\n",
|
||||
"2024-06-27 22:27:47,576 - INFO - Test passed! Output shape matches input shape.\n",
|
||||
"2024-06-27 22:27:47,576 - INFO - Testing model SimpleCNN\n",
|
||||
"2024-06-27 22:27:48,012 - INFO - Test passed! Output shape matches input shape.\n",
|
||||
"2024-06-27 22:27:48,014 - INFO - Testing model Residual\n",
|
||||
"2024-06-27 22:27:49,044 - INFO - Test passed! Output shape matches input shape.\n",
|
||||
"2024-06-27 22:27:49,048 - INFO - Testing model HistogramNet\n",
|
||||
"2024-06-27 22:27:50,526 - INFO - Test passed! Output shape matches input shape.\n",
|
||||
"2024-06-27 22:27:50,536 - INFO - Starting run_0 with hparams {\n",
|
||||
" \"batch_size\": 64,\n",
|
||||
" \"bin_count\": 16,\n",
|
||||
" \"dropout_prob\": 0.07931334001160453,\n",
|
||||
" \"edit_count\": 12,\n",
|
||||
" \"elu_alpha\": 1.4632917556519958,\n",
|
||||
" \"features\": [\n",
|
||||
" 8,\n",
|
||||
" 16,\n",
|
||||
" 32\n",
|
||||
" ],\n",
|
||||
" \"kernel_size\": 5,\n",
|
||||
" \"leaky_relu_alpha\": 0.04009731835331415,\n",
|
||||
" \"leaky_relu_slope\": 0.005283501654875031,\n",
|
||||
" \"learning_rate\": 0.000798531032420656,\n",
|
||||
" \"model_type\": \"HistogramNet\",\n",
|
||||
" \"num_epochs\": 12,\n",
|
||||
" \"scheduler_gamma\": 0.8420623161905719,\n",
|
||||
" \"use_elu\": true,\n",
|
||||
" \"use_instance_norm\": false,\n",
|
||||
" \"use_residual\": false\n",
|
||||
"}\n",
|
||||
"2024-06-27 22:27:50,584 - INFO - Loaded 22479 original images\n",
|
||||
"2024-06-27 22:27:50,588 - INFO - Loaded 2498 original images\n",
|
||||
"2024-06-27 22:27:55,003 - INFO - Original result 322461.46875\n",
|
||||
"2024-06-27 22:28:40,001 - INFO - Epoch 0 train loss: 5343.839758455753\n",
|
||||
"2024-06-27 22:28:44,685 - INFO - Epoch 0 test loss: 531.2615858912468\n",
|
||||
"2024-06-27 22:28:45,500 - INFO - Original result 404787.25\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from scipy.stats import loguniform, uniform, randint\n",
|
||||
"from training import random_hparam_search, get_next_run_name\n",
|
||||
"from models import MODELS, test_models\n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
|
@ -129,16 +71,22 @@
|
|||
" \"leaky_relu_slope\": uniform(0, 0.03),\n",
|
||||
" \"dropout_prob\": uniform(0, 0.1),\n",
|
||||
" \"features\": [\n",
|
||||
" [16, 32, 64],\n",
|
||||
" [32, 64],\n",
|
||||
" [16, 32],\n",
|
||||
" [16, 32, 64],\n",
|
||||
" [16, 32, 64, 128],\n",
|
||||
" [32, 64],\n",
|
||||
" [32, 64, 128],\n",
|
||||
" [8, 16, 32],\n",
|
||||
" [8, 8, 8],\n",
|
||||
" [8, 8, 8, 8, 8],\n",
|
||||
" [8, 8, 8, 8, 8, 8, 8],\n",
|
||||
" [16, 16, 16, 16, 16],\n",
|
||||
" [16, 16, 16],\n",
|
||||
" [32, 32],\n",
|
||||
" [32, 32], \n",
|
||||
" [32, 32, 32],\n",
|
||||
" [32, 32, 32, 32],\n",
|
||||
" [64, 64],\n",
|
||||
" [64, 64, 64]\n",
|
||||
" ],\n",
|
||||
" \"use_residual\": [True, False],\n",
|
||||
" \"kernel_size\": [3, 5],\n",
|
||||
|
|
@ -156,7 +104,7 @@
|
|||
" train_data_paths=TRAIN_DATA,\n",
|
||||
" test_data_paths=TEST_DATA,\n",
|
||||
" models_path=MODELS_PATH,\n",
|
||||
" tensorboard_path=RUNS_PATH / get_next_run_name(RUNS_PATH),\n",
|
||||
" tensorboard_path=RUNS_PATH,\n",
|
||||
" timeout_hours=4,\n",
|
||||
" device=device,\n",
|
||||
")"
|
||||
|
|
@ -179,7 +127,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
"version": "3.1.-1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue