Add models
This commit is contained in:
parent
09aceae9d4
commit
bd7033c3eb
9 changed files with 490 additions and 76 deletions
|
|
@ -1 +1,19 @@
|
|||
from .create_model import create_model
|
||||
from .v1 import HistogramRestorationNet as v1
|
||||
from .simple_cnn import SimpleCNN
|
||||
from .residual import Residual
|
||||
from .normalised_cnn import NormalisedCNN
|
||||
from .smart_res import SmartRes
|
||||
from .attention_net import AttentionNet
|
||||
from .res2 import Res2
|
||||
|
||||
|
||||
def create_model(type: str, bin_count: int):
|
||||
return {
|
||||
# "v1": v1,
|
||||
"SimpleCNN": SimpleCNN,
|
||||
"Residual": Residual,
|
||||
"NormalisedCNN": NormalisedCNN,
|
||||
"SmartRes": SmartRes,
|
||||
"AttentionNet": AttentionNet,
|
||||
"Res2": Res2,
|
||||
}[type](bin_count)
|
||||
|
|
|
|||
64
src/editor/models/attention_net.py
Normal file
64
src/editor/models/attention_net.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# Define the self-attention module
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super(SelfAttention, self).__init__()
|
||||
self.query_conv = nn.Conv3d(channels, channels // 8, kernel_size=1)
|
||||
self.key_conv = nn.Conv3d(channels, channels // 8, kernel_size=1)
|
||||
self.value_conv = nn.Conv3d(channels, channels, kernel_size=1)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, channels, depth, height, width = x.size()
|
||||
query = (
|
||||
self.query_conv(x)
|
||||
.view(batch_size, -1, depth * height * width)
|
||||
.permute(0, 2, 1)
|
||||
)
|
||||
key = self.key_conv(x).view(batch_size, -1, depth * height * width)
|
||||
value = self.value_conv(x).view(batch_size, -1, depth * height * width)
|
||||
|
||||
attention = self.softmax(torch.bmm(query, key)) # Batch matrix multiplication
|
||||
out = torch.bmm(value, attention.permute(0, 2, 1))
|
||||
out = out.view(batch_size, channels, depth, height, width)
|
||||
|
||||
return x + out
|
||||
|
||||
|
||||
# Define the residual block
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm3d(channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm3d(channels),
|
||||
)
|
||||
self.attention = SelfAttention(channels)
|
||||
|
||||
def forward(self, x):
|
||||
return self.attention(self.conv(x)) + x
|
||||
|
||||
|
||||
# Define the network
|
||||
class AttentionNet(nn.Module):
|
||||
def __init__(self, bin_count):
|
||||
super(AttentionNet, self).__init__()
|
||||
self.input_layer = nn.Sequential(
|
||||
nn.Conv3d(1, 16, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm3d(16),
|
||||
)
|
||||
self.res_blocks = nn.Sequential(ResidualBlock(16), ResidualBlock(16))
|
||||
self.output_layer = nn.Conv3d(16, 1, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.input_layer(x)
|
||||
x = self.res_blocks(x)
|
||||
x = self.output_layer(x)
|
||||
return x
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
from .v1 import HistogramRestorationNet as v1
|
||||
|
||||
|
||||
def create_model(type: str, bin_count: int):
|
||||
return {"v1": v1}[type](bin_count)
|
||||
34
src/editor/models/normalised_cnn.py
Normal file
34
src/editor/models/normalised_cnn.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class NormalisedCNN(nn.Module):
|
||||
def __init__(self, bin_count):
|
||||
super(NormalisedCNN, self).__init__()
|
||||
self.bin_count = bin_count
|
||||
|
||||
# Define the layers of the neural network
|
||||
self.conv1 = nn.Conv3d(1, 16, kernel_size=3, padding=1)
|
||||
self.bn1 = nn.BatchNorm3d(16)
|
||||
self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
|
||||
self.bn2 = nn.BatchNorm3d(32)
|
||||
self.conv3 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
|
||||
self.bn3 = nn.BatchNorm3d(64)
|
||||
self.conv4 = nn.Conv3d(64, 32, kernel_size=3, padding=1)
|
||||
self.bn4 = nn.BatchNorm3d(32)
|
||||
self.conv5 = nn.Conv3d(32, 16, kernel_size=3, padding=1)
|
||||
self.bn5 = nn.BatchNorm3d(16)
|
||||
self.conv6 = nn.Conv3d(16, 1, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(
|
||||
-1, 1, self.bin_count, self.bin_count, self.bin_count
|
||||
) # Reshape input to (N, C, D, H, W)
|
||||
x = F.relu(self.bn1(self.conv1(x)))
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = F.relu(self.bn3(self.conv3(x)))
|
||||
x = F.relu(self.bn4(self.conv4(x)))
|
||||
x = F.relu(self.bn5(self.conv5(x)))
|
||||
x = self.conv6(x)
|
||||
|
||||
return x
|
||||
37
src/editor/models/res2.py
Normal file
37
src/editor/models/res2.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm3d(channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm3d(channels),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x) + x
|
||||
|
||||
|
||||
# Define the network
|
||||
class Res2(nn.Module):
|
||||
def __init__(self, bin_count):
|
||||
super(Res2, self).__init__()
|
||||
self.input_layer = nn.Sequential(
|
||||
nn.Conv3d(1, 16, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm3d(16),
|
||||
)
|
||||
self.res_blocks = nn.Sequential(
|
||||
ResidualBlock(16), ResidualBlock(16), ResidualBlock(16), ResidualBlock(16)
|
||||
)
|
||||
self.output_layer = nn.Conv3d(16, 1, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.input_layer(x)
|
||||
x = self.res_blocks(x)
|
||||
x = self.output_layer(x)
|
||||
return x
|
||||
52
src/editor/models/residual.py
Normal file
52
src/editor/models/residual.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, bin_count: int):
|
||||
super(Residual, self).__init__()
|
||||
self.bin_count = bin_count
|
||||
|
||||
# Assuming the input histograms are 3D tensors of shape (bin_count, bin_count, bin_count)
|
||||
# Convolutional layers to extract features from the histograms
|
||||
self.conv1 = nn.Conv3d(1, 16, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
|
||||
self.conv3 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
|
||||
|
||||
# Batch normalization layers for better convergence
|
||||
self.bn1 = nn.BatchNorm3d(16)
|
||||
self.bn2 = nn.BatchNorm3d(32)
|
||||
self.bn3 = nn.BatchNorm3d(64)
|
||||
|
||||
# Residual block to help the network learn identity functions effectively
|
||||
self.resblock1 = nn.Sequential(
|
||||
nn.Conv3d(64, 64, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm3d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv3d(64, 64, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm3d(64),
|
||||
)
|
||||
|
||||
# ReLU activation
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.deconv1 = nn.ConvTranspose3d(64, 32, kernel_size=3, stride=1, padding=1)
|
||||
self.deconv2 = nn.ConvTranspose3d(32, 16, kernel_size=3, stride=1, padding=1)
|
||||
self.deconv3 = nn.ConvTranspose3d(16, 1, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(self.bn1(self.conv1(x)))
|
||||
x = self.relu(self.bn2(self.conv2(x)))
|
||||
x = self.relu(self.bn3(self.conv3(x)))
|
||||
|
||||
# Apply residual blocks
|
||||
residual = x
|
||||
out = self.resblock1(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
# Upsample to original size
|
||||
out = self.relu(self.deconv1(out))
|
||||
out = self.relu(self.deconv2(out))
|
||||
out = self.relu(self.deconv3(out))
|
||||
|
||||
return out
|
||||
37
src/editor/models/simple_cnn.py
Normal file
37
src/editor/models/simple_cnn.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SimpleCNN(nn.Module):
|
||||
def __init__(self, bin_count):
|
||||
super(SimpleCNN, self).__init__()
|
||||
self.bin_count = bin_count
|
||||
|
||||
# Define the convolutional layers
|
||||
self.conv1 = nn.Conv3d(
|
||||
1, 16, kernel_size=3, padding=1
|
||||
) # input channels = 1, output channels = 16
|
||||
self.conv2 = nn.Conv3d(
|
||||
16, 32, kernel_size=3, padding=1
|
||||
) # input channels = 16, output channels = 32
|
||||
self.conv3 = nn.Conv3d(
|
||||
32, 64, kernel_size=3, padding=1
|
||||
) # input channels = 32, output channels = 64
|
||||
self.conv4 = nn.Conv3d(
|
||||
64, 32, kernel_size=3, padding=1
|
||||
) # input channels = 64, output channels = 32
|
||||
self.conv5 = nn.Conv3d(
|
||||
32, 16, kernel_size=3, padding=1
|
||||
) # input channels = 32, output channels = 16
|
||||
self.conv6 = nn.Conv3d(
|
||||
16, 1, kernel_size=3, padding=1
|
||||
) # input channels = 16, output channels = 1
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.conv1(x))
|
||||
x = F.relu(self.conv2(x))
|
||||
x = F.relu(self.conv3(x))
|
||||
x = F.relu(self.conv4(x))
|
||||
x = F.relu(self.conv5(x))
|
||||
x = self.conv6(x)
|
||||
return x
|
||||
42
src/editor/models/smart_res.py
Normal file
42
src/editor/models/smart_res.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.conv1 = nn.Conv3d(channels, channels, kernel_size=3, padding=1)
|
||||
self.bn1 = nn.BatchNorm3d(channels)
|
||||
self.conv2 = nn.Conv3d(channels, channels, kernel_size=3, padding=1)
|
||||
self.bn2 = nn.BatchNorm3d(channels)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
out += identity
|
||||
return F.relu(out)
|
||||
|
||||
|
||||
class SmartRes(nn.Module):
|
||||
def __init__(self, bin_count):
|
||||
super(SmartRes, self).__init__()
|
||||
self.bin_count = bin_count
|
||||
self.initial_conv = nn.Conv3d(1, 16, kernel_size=3, padding=1)
|
||||
self.bn0 = nn.BatchNorm3d(16)
|
||||
self.resblock1 = ResidualBlock(16)
|
||||
self.resblock2 = ResidualBlock(16)
|
||||
self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
|
||||
self.bn2 = nn.BatchNorm3d(32)
|
||||
self.dilated_conv = nn.Conv3d(32, 32, kernel_size=3, padding=2, dilation=2)
|
||||
self.bn_dilated = nn.BatchNorm3d(32)
|
||||
self.final_conv = nn.Conv3d(32, 1, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.bn0(self.initial_conv(x)))
|
||||
x = self.resblock1(x)
|
||||
x = self.resblock2(x)
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = F.relu(self.bn_dilated(self.dilated_conv(x)))
|
||||
x = self.final_conv(x)
|
||||
return x
|
||||
275
src/train.ipynb
275
src/train.ipynb
|
|
@ -36,23 +36,32 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from scipy.stats import loguniform, uniform\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"common_hyperparameters = {\n",
|
||||
" \"batch_size\": [8, 16, 32, 64],\n",
|
||||
" \"edit_count\": [4, 8, 16, 32],\n",
|
||||
" \"bin_count\": [32],\n",
|
||||
" \"edit_count\": [8, 16, 32],\n",
|
||||
" \"bin_count\": [16, 32, 64],\n",
|
||||
" \"clip_gradients\": [True, False],\n",
|
||||
" \"learning_rate\": [0.0001],\n",
|
||||
" \"scheduler_gamma\": [0.1, 0.9],\n",
|
||||
" \"num_epochs\": [20],\n",
|
||||
" \"model_type\": [\"v1\"],\n",
|
||||
" \"learning_rate\": loguniform(0.00001, 0.005),\n",
|
||||
" \"scheduler_gamma\": uniform(0.1, 0.9),\n",
|
||||
" \"num_epochs\": [10],\n",
|
||||
" \"model_type\": [\n",
|
||||
" \"NormalisedCNN\",\n",
|
||||
" \"SimpleCNN\",\n",
|
||||
" \"Residual\",\n",
|
||||
" \"SmartRes\",\n",
|
||||
" \"Res2\",\n",
|
||||
" ],\n",
|
||||
"}\n",
|
||||
"hyperparameters = [\n",
|
||||
" {\n",
|
||||
" **common_hyperparameters,\n",
|
||||
" \"loss\": [\"progressive\"],\n",
|
||||
" \"loss_sizes\": [[4, 8, 16, 32], [8, 16, 32], [16, 32], [8, 32]],\n",
|
||||
" \"loss_damping\": [1, 2, 3, 4, 5],\n",
|
||||
" },\n",
|
||||
" # {\n",
|
||||
" # **common_hyperparameters,\n",
|
||||
" # \"loss\": [\"progressive\"],\n",
|
||||
" # \"loss_sizes\": [[4, 8, 16, 32], [8, 16, 32], [16, 32], [8, 32]],\n",
|
||||
" # \"loss_damping\": uniform(0.2, 5),\n",
|
||||
" # },\n",
|
||||
" {\n",
|
||||
" **common_hyperparameters,\n",
|
||||
" \"loss\": [\"kl\"],\n",
|
||||
|
|
@ -81,7 +90,7 @@
|
|||
"\n",
|
||||
"\n",
|
||||
"def train(hyperparameters: Dict[str, Any]) -> Path:\n",
|
||||
" model_path = (MODELS_PATH / get_next_run_name(Path(\"runs\"))).with_suffix(\"pth\")\n",
|
||||
" model_path = (MODELS_PATH / get_next_run_name(Path(\"runs\"))).with_suffix(\".pth\")\n",
|
||||
"\n",
|
||||
" log_dir = Path(\"runs\") / get_next_run_name(Path(\"runs\"))\n",
|
||||
" with SummaryWriter(log_dir) as writer:\n",
|
||||
|
|
@ -129,19 +138,21 @@
|
|||
"\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" predicted_original = model(edited_histogram)\n",
|
||||
" sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)\n",
|
||||
" predicted_original = predicted_original / sum\n",
|
||||
"\n",
|
||||
" if hyperparameters[\"loss\"] == \"kl\":\n",
|
||||
" predicted_original = torch.clamp(\n",
|
||||
" predicted_original, 0.0000000000000000000000001, 1\n",
|
||||
" predicted_original, 0.0000000000000000000001, 1\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" loss = {\n",
|
||||
" \"kl\": lambda: loss_function(\n",
|
||||
" torch.log(predicted_original.unsqueeze(1)),\n",
|
||||
" torch.log(predicted_original),\n",
|
||||
" original_histogram,\n",
|
||||
" ),\n",
|
||||
" \"progressive\": lambda: loss_function(\n",
|
||||
" predicted_original.unsqueeze(1), original_histogram\n",
|
||||
" predicted_original, original_histogram\n",
|
||||
" ),\n",
|
||||
" }[hyperparameters[\"loss\"]]()\n",
|
||||
"\n",
|
||||
|
|
@ -175,6 +186,8 @@
|
|||
" edited_histogram = edited_histogram.to(device)\n",
|
||||
" original_histogram = original_histogram.to(device)\n",
|
||||
" predicted_original = model(edited_histogram)\n",
|
||||
" sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)\n",
|
||||
" predicted_original = predicted_original / sum\n",
|
||||
" writer.add_figure(\n",
|
||||
" \"histogram\",\n",
|
||||
" plot_histograms_in_2d(\n",
|
||||
|
|
@ -192,32 +205,60 @@
|
|||
" )\n",
|
||||
" model.train()\n",
|
||||
" scheduler.step()\n",
|
||||
" except Exception as e:\n",
|
||||
" raise\n",
|
||||
" finally:\n",
|
||||
" torch.save(model.state_dict(), model_path)\n",
|
||||
" return model_path"
|
||||
" del model\n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
" return model_path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# train(\n",
|
||||
"# {\n",
|
||||
"# \"batch_size\": 64,\n",
|
||||
"# \"edit_count\": 25,\n",
|
||||
"# \"bin_count\": 32,\n",
|
||||
"# \"clip_gradients\": True,\n",
|
||||
"# \"learning_rate\": 0.005,\n",
|
||||
"# \"scheduler_gamma\": 0.7,\n",
|
||||
"# \"num_epochs\": 20,\n",
|
||||
"# \"model_type\": \"NormalisedCNN\",\n",
|
||||
"# \"loss\": \"progressive\",\n",
|
||||
"# \"loss_sizes\": [16, 32],\n",
|
||||
"# \"loss_damping\": 2,\n",
|
||||
"# }\n",
|
||||
"# )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-05-12 19:44:36,319 - INFO - Loaded 561982 training images and 62443 test images\n"
|
||||
"2024-05-12 21:54:49,789 - INFO - Starting run_71 with hparams {'batch_size': 8, 'edit_count': 8, 'bin_count': 32, 'clip_gradients': True, 'learning_rate': 0.0001, 'scheduler_gamma': 0.5, 'num_epochs': 2, 'model_type': 'Res2', 'loss': 'kl'}\n",
|
||||
"2024-05-12 21:54:49,792 - INFO - Loaded 72 training images and 8 test images\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "866f474ead21440e81dc91f2d0e55046",
|
||||
"model_id": "461659956f3944a085b5cc3a5af6ec31",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 0: 0%| | 0/17562 [00:00<?, ?batch/s]"
|
||||
"Epoch 0: 0%| | 0/9 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
|
|
@ -234,97 +275,191 @@
|
|||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "a7e35a195b784da98ff4b21e493420c5",
|
||||
"model_id": "1ea8edb16ad742d0ab6673560f231b20",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 1: 0%| | 0/17562 [00:00<?, ?batch/s]"
|
||||
"Epoch 1: 0%| | 0/9 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"train(\n",
|
||||
" {\n",
|
||||
" \"batch_size\": 32,\n",
|
||||
" \"edit_count\": 25,\n",
|
||||
" \"bin_count\": 32,\n",
|
||||
" \"clip_gradients\": True,\n",
|
||||
" \"learning_rate\": 0.0001,\n",
|
||||
" \"scheduler_gamma\": 0.5,\n",
|
||||
" \"num_epochs\": 20,\n",
|
||||
" \"model_type\": \"v1\",\n",
|
||||
" \"loss\": \"kl\",\n",
|
||||
" }\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:root:Loaded 28800 training images and 3200 test images\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Starting run 0: {'batch_size': 32, 'edit_count': 32, 'bin_count': 32, 'clip_gradients': False, 'learning_rate': 0.0001, 'scheduler_gamma': 0.9, 'num_epochs': 20, 'model_type': 'v1', 'loss': 'progressive', 'loss_sizes': [16, 32], 'loss_damping': 5}\n"
|
||||
"2024-05-12 21:55:07,098 - INFO - Starting run_72 with hparams {'batch_size': 8, 'edit_count': 8, 'bin_count': 32, 'clip_gradients': True, 'learning_rate': 0.0001, 'scheduler_gamma': 0.5, 'num_epochs': 2, 'model_type': 'Residual', 'loss': 'kl'}\n",
|
||||
"2024-05-12 21:55:07,100 - INFO - Loaded 72 training images and 8 test images\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "4f937ddaa2694b0db368c7a9bdd11330",
|
||||
"model_id": "ede7a736204f4198868a98381b8862a7",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 0: 0%| | 0/900 [00:00<?, ?batch/s]"
|
||||
"Epoch 0: 0%| | 0/9 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"ename": "ValueError",
|
||||
"evalue": "value should be one of int, float, str, bool, or torch.Tensor",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[3], line 99\u001b[0m\n\u001b[1;32m 96\u001b[0m clip_grad_norm_(model\u001b[38;5;241m.\u001b[39mparameters(), max_norm\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1.0\u001b[39m)\n\u001b[1;32m 97\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep()\n\u001b[0;32m---> 99\u001b[0m \u001b[43mwriter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd_hparams\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 100\u001b[0m \u001b[43m \u001b[49m\u001b[43mcurrent_hyperparameters\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 101\u001b[0m \u001b[43m \u001b[49m\u001b[43m{\u001b[49m\n\u001b[1;32m 102\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mLoss/train/epoch\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch_loss\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 103\u001b[0m \u001b[43m \u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 104\u001b[0m \u001b[43m \u001b[49m\u001b[43mglobal_step\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 105\u001b[0m \u001b[43m \u001b[49m\u001b[43mrun_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlog_dir\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mabsolute\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 106\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 107\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m 108\u001b[0m model\u001b[38;5;241m.\u001b[39meval()\n",
|
||||
"File \u001b[0;32m~/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/torch/utils/tensorboard/writer.py:341\u001b[0m, in \u001b[0;36mSummaryWriter.add_hparams\u001b[0;34m(self, hparam_dict, metric_dict, hparam_domain_discrete, run_name, global_step)\u001b[0m\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(hparam_dict) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mdict\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(metric_dict) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mdict\u001b[39m:\n\u001b[1;32m 340\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhparam_dict and metric_dict should be dictionary.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 341\u001b[0m exp, ssi, sei \u001b[38;5;241m=\u001b[39m \u001b[43mhparams\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhparam_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetric_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhparam_domain_discrete\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 343\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m run_name:\n\u001b[1;32m 344\u001b[0m run_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mstr\u001b[39m(time\u001b[38;5;241m.\u001b[39mtime())\n",
|
||||
"File \u001b[0;32m~/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/torch/utils/tensorboard/summary.py:316\u001b[0m, in \u001b[0;36mhparams\u001b[0;34m(hparam_dict, metric_dict, hparam_domain_discrete)\u001b[0m\n\u001b[1;32m 314\u001b[0m hps\u001b[38;5;241m.\u001b[39mappend(HParamInfo(name\u001b[38;5;241m=\u001b[39mk, \u001b[38;5;28mtype\u001b[39m\u001b[38;5;241m=\u001b[39mDataType\u001b[38;5;241m.\u001b[39mValue(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDATA_TYPE_FLOAT64\u001b[39m\u001b[38;5;124m\"\u001b[39m)))\n\u001b[1;32m 315\u001b[0m \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[0;32m--> 316\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 317\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalue should be one of int, float, str, bool, or torch.Tensor\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 318\u001b[0m )\n\u001b[1;32m 320\u001b[0m content \u001b[38;5;241m=\u001b[39m HParamsPluginData(session_start_info\u001b[38;5;241m=\u001b[39mssi, version\u001b[38;5;241m=\u001b[39mPLUGIN_DATA_VERSION)\n\u001b[1;32m 321\u001b[0m smd \u001b[38;5;241m=\u001b[39m SummaryMetadata(\n\u001b[1;32m 322\u001b[0m plugin_data\u001b[38;5;241m=\u001b[39mSummaryMetadata\u001b[38;5;241m.\u001b[39mPluginData(\n\u001b[1;32m 323\u001b[0m plugin_name\u001b[38;5;241m=\u001b[39mPLUGIN_NAME, content\u001b[38;5;241m=\u001b[39mcontent\u001b[38;5;241m.\u001b[39mSerializeToString()\n\u001b[1;32m 324\u001b[0m )\n\u001b[1;32m 325\u001b[0m )\n",
|
||||
"\u001b[0;31mValueError\u001b[0m: value should be one of int, float, str, bool, or torch.Tensor"
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "707b28e45e3448fc9e449624c9a8467e",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 1: 0%| | 0/9 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-05-12 21:55:21,888 - INFO - Starting run_73 with hparams {'batch_size': 8, 'edit_count': 8, 'bin_count': 32, 'clip_gradients': True, 'learning_rate': 0.0001, 'scheduler_gamma': 0.5, 'num_epochs': 2, 'model_type': 'SimpleCNN', 'loss': 'kl'}\n",
|
||||
"2024-05-12 21:55:21,890 - INFO - Loaded 72 training images and 8 test images\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "8f94f38fd9c245bd9cc99ffe9cb0c058",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 0: 0%| | 0/9 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "9bdbfb83837640099d2c1b5e8ac0cfe2",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 1: 0%| | 0/9 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-05-12 21:55:36,871 - INFO - Starting run_74 with hparams {'batch_size': 8, 'edit_count': 8, 'bin_count': 32, 'clip_gradients': True, 'learning_rate': 0.0001, 'scheduler_gamma': 0.5, 'num_epochs': 2, 'model_type': 'NormalisedCNN', 'loss': 'kl'}\n",
|
||||
"2024-05-12 21:55:36,872 - INFO - Loaded 72 training images and 8 test images\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "3b415e0ac9334aaf89b0fd8e234aa7c4",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 0: 0%| | 0/9 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "7e632fae2e4b4929b6e70dffcad6a341",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 1: 0%| | 0/9 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-05-12 21:55:52,083 - INFO - Starting run_75 with hparams {'batch_size': 8, 'edit_count': 8, 'bin_count': 32, 'clip_gradients': True, 'learning_rate': 0.0001, 'scheduler_gamma': 0.5, 'num_epochs': 2, 'model_type': 'SmartRes', 'loss': 'kl'}\n",
|
||||
"2024-05-12 21:55:52,084 - INFO - Loaded 72 training images and 8 test images\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "46909535e98543ccb700caeb718dde48",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 0: 0%| | 0/9 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "36f757acc512484e9ac0f34efcd4c1c4",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Epoch 1: 0%| | 0/9 [00:00<?, ?batch/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from random import choice\n",
|
||||
"from itertools import count\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for i in count():\n",
|
||||
" current_hyperparameters = {k: choice(v) for k, v in choice(hyperparameters).items()}\n",
|
||||
" logging.info(f\"Starting run {i} with hparams {current_hyperparameters}\")\n",
|
||||
"tried = set()\n",
|
||||
"\n",
|
||||
"for _ in count():\n",
|
||||
" current_hyperparameters = {\n",
|
||||
" k: v.rvs() if hasattr(v, \"rvs\") else choice(v)\n",
|
||||
" for k, v in choice(hyperparameters).items()\n",
|
||||
" }\n",
|
||||
" key = json.dumps(current_hyperparameters)\n",
|
||||
" if key in tried:\n",
|
||||
" continue\n",
|
||||
" tried.add(key)\n",
|
||||
" logging.info(\n",
|
||||
" f\"Starting {get_next_run_name(Path(\"runs\"))} with hparams {current_hyperparameters}\"\n",
|
||||
" )\n",
|
||||
" try:\n",
|
||||
" train(current_hyperparameters)\n",
|
||||
" except KeyboardInterrupt as e:\n",
|
||||
" logging.info(\"Interrupted, stopping\")\n",
|
||||
" break\n",
|
||||
" except Exception as e:\n",
|
||||
" logging.error(f\"Error with hparams {current_hyperparameters}:\\n\\t{e}\")\n",
|
||||
" continue"
|
||||
" logging.error(\n",
|
||||
" f\"Error with hparams {current_hyperparameters}:\\n\\t{e}\", stack_info=True\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue