Fix errors

This commit is contained in:
Andras Schmelczer 2024-06-27 23:13:37 +01:00
parent 7863611f86
commit 35eb747abf
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
3 changed files with 26 additions and 73 deletions

View file

@ -47,7 +47,7 @@ def load_model(path: Path, device: torch.device) -> Tuple[nn.Module, Dict[str, A
model_path = path.with_suffix(".pth")
model = create_model(
type=hyperparameters["model_type"],
**hyperparameters,
hyperparameters=hyperparameters,
device=device,
)
model.load_state_dict(torch.load(model_path))

View file

@ -33,5 +33,10 @@ class SimpleCNN(nn.Module):
x = F.relu(self.conv4(x))
x = F.relu(self.conv5(x))
x = self.conv6(x)
sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
return x / sum
return self._normalize(x)
@staticmethod
def _normalize(x):
x = torch.clamp(x, min=0)
x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
return x / (x_sum + 1e-5)

View file

@ -2,27 +2,17 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"metadata": {}
},
"outputs": [
{
"data": {
"text/plain": [
"'Using device cuda:0'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"import torch\n",
"import os\n",
"from utils import set_up_logging\n",
"from config import LOGS_PATH\n",
"from training import train, random_hparam_search\n",
"from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA, MODELS_PATH\n",
"\n",
"set_up_logging(LOGS_PATH)\n",
"\n",
@ -33,14 +23,10 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# from training import train\n",
"from config import RUNS_PATH, TRAIN_DATA, TEST_DATA, MODELS_PATH\n",
"\n",
"\n",
"# hparams = {\n",
"# \"batch_size\": 64,\n",
"# \"edit_count\": 12,\n",
@ -65,55 +51,11 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-06-27 22:27:47,571 - INFO - Testing model Dummy\n",
"2024-06-27 22:27:47,576 - INFO - Test passed! Output shape matches input shape.\n",
"2024-06-27 22:27:47,576 - INFO - Testing model SimpleCNN\n",
"2024-06-27 22:27:48,012 - INFO - Test passed! Output shape matches input shape.\n",
"2024-06-27 22:27:48,014 - INFO - Testing model Residual\n",
"2024-06-27 22:27:49,044 - INFO - Test passed! Output shape matches input shape.\n",
"2024-06-27 22:27:49,048 - INFO - Testing model HistogramNet\n",
"2024-06-27 22:27:50,526 - INFO - Test passed! Output shape matches input shape.\n",
"2024-06-27 22:27:50,536 - INFO - Starting run_0 with hparams {\n",
" \"batch_size\": 64,\n",
" \"bin_count\": 16,\n",
" \"dropout_prob\": 0.07931334001160453,\n",
" \"edit_count\": 12,\n",
" \"elu_alpha\": 1.4632917556519958,\n",
" \"features\": [\n",
" 8,\n",
" 16,\n",
" 32\n",
" ],\n",
" \"kernel_size\": 5,\n",
" \"leaky_relu_alpha\": 0.04009731835331415,\n",
" \"leaky_relu_slope\": 0.005283501654875031,\n",
" \"learning_rate\": 0.000798531032420656,\n",
" \"model_type\": \"HistogramNet\",\n",
" \"num_epochs\": 12,\n",
" \"scheduler_gamma\": 0.8420623161905719,\n",
" \"use_elu\": true,\n",
" \"use_instance_norm\": false,\n",
" \"use_residual\": false\n",
"}\n",
"2024-06-27 22:27:50,584 - INFO - Loaded 22479 original images\n",
"2024-06-27 22:27:50,588 - INFO - Loaded 2498 original images\n",
"2024-06-27 22:27:55,003 - INFO - Original result 322461.46875\n",
"2024-06-27 22:28:40,001 - INFO - Epoch 0 train loss: 5343.839758455753\n",
"2024-06-27 22:28:44,685 - INFO - Epoch 0 test loss: 531.2615858912468\n",
"2024-06-27 22:28:45,500 - INFO - Original result 404787.25\n"
]
}
],
"outputs": [],
"source": [
"from scipy.stats import loguniform, uniform, randint\n",
"from training import random_hparam_search, get_next_run_name\n",
"from models import MODELS, test_models\n",
"\n",
"\n",
@ -129,16 +71,22 @@
" \"leaky_relu_slope\": uniform(0, 0.03),\n",
" \"dropout_prob\": uniform(0, 0.1),\n",
" \"features\": [\n",
" [16, 32, 64],\n",
" [32, 64],\n",
" [16, 32],\n",
" [16, 32, 64],\n",
" [16, 32, 64, 128],\n",
" [32, 64],\n",
" [32, 64, 128],\n",
" [8, 16, 32],\n",
" [8, 8, 8],\n",
" [8, 8, 8, 8, 8],\n",
" [8, 8, 8, 8, 8, 8, 8],\n",
" [16, 16, 16, 16, 16],\n",
" [16, 16, 16],\n",
" [32, 32],\n",
" [32, 32], \n",
" [32, 32, 32],\n",
" [32, 32, 32, 32],\n",
" [64, 64],\n",
" [64, 64, 64]\n",
" ],\n",
" \"use_residual\": [True, False],\n",
" \"kernel_size\": [3, 5],\n",
@ -156,7 +104,7 @@
" train_data_paths=TRAIN_DATA,\n",
" test_data_paths=TEST_DATA,\n",
" models_path=MODELS_PATH,\n",
" tensorboard_path=RUNS_PATH / get_next_run_name(RUNS_PATH),\n",
" tensorboard_path=RUNS_PATH,\n",
" timeout_hours=4,\n",
" device=device,\n",
")"
@ -179,7 +127,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.1.-1"
}
},
"nbformat": 4,