diff --git a/src/editor/models/__init__.py b/src/editor/models/__init__.py index 42279df..ae72fb6 100644 --- a/src/editor/models/__init__.py +++ b/src/editor/models/__init__.py @@ -1 +1,19 @@ -from .create_model import create_model +from .v1 import HistogramRestorationNet as v1 +from .simple_cnn import SimpleCNN +from .residual import Residual +from .normalised_cnn import NormalisedCNN +from .smart_res import SmartRes +from .attention_net import AttentionNet +from .res2 import Res2 + + +def create_model(type: str, bin_count: int): + return { + # "v1": v1, + "SimpleCNN": SimpleCNN, + "Residual": Residual, + "NormalisedCNN": NormalisedCNN, + "SmartRes": SmartRes, + "AttentionNet": AttentionNet, + "Res2": Res2, + }[type](bin_count) diff --git a/src/editor/models/attention_net.py b/src/editor/models/attention_net.py new file mode 100644 index 0000000..66e2ac0 --- /dev/null +++ b/src/editor/models/attention_net.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn + + +# Define the self-attention module +class SelfAttention(nn.Module): + def __init__(self, channels): + super(SelfAttention, self).__init__() + self.query_conv = nn.Conv3d(channels, channels // 8, kernel_size=1) + self.key_conv = nn.Conv3d(channels, channels // 8, kernel_size=1) + self.value_conv = nn.Conv3d(channels, channels, kernel_size=1) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x): + batch_size, channels, depth, height, width = x.size() + query = ( + self.query_conv(x) + .view(batch_size, -1, depth * height * width) + .permute(0, 2, 1) + ) + key = self.key_conv(x).view(batch_size, -1, depth * height * width) + value = self.value_conv(x).view(batch_size, -1, depth * height * width) + + attention = self.softmax(torch.bmm(query, key)) # Batch matrix multiplication + out = torch.bmm(value, attention.permute(0, 2, 1)) + out = out.view(batch_size, channels, depth, height, width) + + return x + out + + +# Define the residual block +class ResidualBlock(nn.Module): + def __init__(self, channels): + super(ResidualBlock, self).__init__() + self.conv = nn.Sequential( + nn.Conv3d(channels, channels, kernel_size=3, padding=1), + nn.BatchNorm3d(channels), + nn.ReLU(inplace=True), + nn.Conv3d(channels, channels, kernel_size=3, padding=1), + nn.BatchNorm3d(channels), + ) + self.attention = SelfAttention(channels) + + def forward(self, x): + return self.attention(self.conv(x)) + x + + +# Define the network +class AttentionNet(nn.Module): + def __init__(self, bin_count): + super(AttentionNet, self).__init__() + self.input_layer = nn.Sequential( + nn.Conv3d(1, 16, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.BatchNorm3d(16), + ) + self.res_blocks = nn.Sequential(ResidualBlock(16), ResidualBlock(16)) + self.output_layer = nn.Conv3d(16, 1, kernel_size=3, padding=1) + + def forward(self, x): + x = self.input_layer(x) + x = self.res_blocks(x) + x = self.output_layer(x) + return x diff --git a/src/editor/models/create_model.py b/src/editor/models/create_model.py deleted file mode 100644 index 43ba8f4..0000000 --- a/src/editor/models/create_model.py +++ /dev/null @@ -1,5 +0,0 @@ -from .v1 import HistogramRestorationNet as v1 - - -def create_model(type: str, bin_count: int): - return {"v1": v1}[type](bin_count) diff --git a/src/editor/models/normalised_cnn.py b/src/editor/models/normalised_cnn.py new file mode 100644 index 0000000..e38b5b9 --- /dev/null +++ b/src/editor/models/normalised_cnn.py @@ -0,0 +1,34 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class NormalisedCNN(nn.Module): + def __init__(self, bin_count): + super(NormalisedCNN, self).__init__() + self.bin_count = bin_count + + # Define the layers of the neural network + self.conv1 = nn.Conv3d(1, 16, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm3d(16) + self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm3d(32) + self.conv3 = nn.Conv3d(32, 64, kernel_size=3, padding=1) + self.bn3 = nn.BatchNorm3d(64) + self.conv4 = nn.Conv3d(64, 32, kernel_size=3, padding=1) + self.bn4 = nn.BatchNorm3d(32) + self.conv5 = nn.Conv3d(32, 16, kernel_size=3, padding=1) + self.bn5 = nn.BatchNorm3d(16) + self.conv6 = nn.Conv3d(16, 1, kernel_size=3, padding=1) + + def forward(self, x): + x = x.view( + -1, 1, self.bin_count, self.bin_count, self.bin_count + ) # Reshape input to (N, C, D, H, W) + x = F.relu(self.bn1(self.conv1(x))) + x = F.relu(self.bn2(self.conv2(x))) + x = F.relu(self.bn3(self.conv3(x))) + x = F.relu(self.bn4(self.conv4(x))) + x = F.relu(self.bn5(self.conv5(x))) + x = self.conv6(x) + + return x diff --git a/src/editor/models/res2.py b/src/editor/models/res2.py new file mode 100644 index 0000000..6d19bc6 --- /dev/null +++ b/src/editor/models/res2.py @@ -0,0 +1,37 @@ +import torch.nn as nn + + +class ResidualBlock(nn.Module): + def __init__(self, channels): + super(ResidualBlock, self).__init__() + self.conv = nn.Sequential( + nn.Conv3d(channels, channels, kernel_size=3, padding=1), + nn.BatchNorm3d(channels), + nn.ReLU(inplace=True), + nn.Conv3d(channels, channels, kernel_size=3, padding=1), + nn.BatchNorm3d(channels), + ) + + def forward(self, x): + return self.conv(x) + x + + +# Define the network +class Res2(nn.Module): + def __init__(self, bin_count): + super(Res2, self).__init__() + self.input_layer = nn.Sequential( + nn.Conv3d(1, 16, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.BatchNorm3d(16), + ) + self.res_blocks = nn.Sequential( + ResidualBlock(16), ResidualBlock(16), ResidualBlock(16), ResidualBlock(16) + ) + self.output_layer = nn.Conv3d(16, 1, kernel_size=3, padding=1) + + def forward(self, x): + x = self.input_layer(x) + x = self.res_blocks(x) + x = self.output_layer(x) + return x diff --git a/src/editor/models/residual.py b/src/editor/models/residual.py new file mode 100644 index 0000000..ca5be9a --- /dev/null +++ b/src/editor/models/residual.py @@ -0,0 +1,52 @@ +import torch.nn as nn + + +class Residual(nn.Module): + def __init__(self, bin_count: int): + super(Residual, self).__init__() + self.bin_count = bin_count + + # Assuming the input histograms are 3D tensors of shape (bin_count, bin_count, bin_count) + # Convolutional layers to extract features from the histograms + self.conv1 = nn.Conv3d(1, 16, kernel_size=3, padding=1) + self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1) + self.conv3 = nn.Conv3d(32, 64, kernel_size=3, padding=1) + + # Batch normalization layers for better convergence + self.bn1 = nn.BatchNorm3d(16) + self.bn2 = nn.BatchNorm3d(32) + self.bn3 = nn.BatchNorm3d(64) + + # Residual block to help the network learn identity functions effectively + self.resblock1 = nn.Sequential( + nn.Conv3d(64, 64, kernel_size=3, padding=1, bias=False), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True), + nn.Conv3d(64, 64, kernel_size=3, padding=1, bias=False), + nn.BatchNorm3d(64), + ) + + # ReLU activation + self.relu = nn.ReLU(inplace=True) + + self.deconv1 = nn.ConvTranspose3d(64, 32, kernel_size=3, stride=1, padding=1) + self.deconv2 = nn.ConvTranspose3d(32, 16, kernel_size=3, stride=1, padding=1) + self.deconv3 = nn.ConvTranspose3d(16, 1, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = self.relu(self.bn1(self.conv1(x))) + x = self.relu(self.bn2(self.conv2(x))) + x = self.relu(self.bn3(self.conv3(x))) + + # Apply residual blocks + residual = x + out = self.resblock1(x) + out += residual + out = self.relu(out) + + # Upsample to original size + out = self.relu(self.deconv1(out)) + out = self.relu(self.deconv2(out)) + out = self.relu(self.deconv3(out)) + + return out diff --git a/src/editor/models/simple_cnn.py b/src/editor/models/simple_cnn.py new file mode 100644 index 0000000..cfd85cb --- /dev/null +++ b/src/editor/models/simple_cnn.py @@ -0,0 +1,37 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class SimpleCNN(nn.Module): + def __init__(self, bin_count): + super(SimpleCNN, self).__init__() + self.bin_count = bin_count + + # Define the convolutional layers + self.conv1 = nn.Conv3d( + 1, 16, kernel_size=3, padding=1 + ) # input channels = 1, output channels = 16 + self.conv2 = nn.Conv3d( + 16, 32, kernel_size=3, padding=1 + ) # input channels = 16, output channels = 32 + self.conv3 = nn.Conv3d( + 32, 64, kernel_size=3, padding=1 + ) # input channels = 32, output channels = 64 + self.conv4 = nn.Conv3d( + 64, 32, kernel_size=3, padding=1 + ) # input channels = 64, output channels = 32 + self.conv5 = nn.Conv3d( + 32, 16, kernel_size=3, padding=1 + ) # input channels = 32, output channels = 16 + self.conv6 = nn.Conv3d( + 16, 1, kernel_size=3, padding=1 + ) # input channels = 16, output channels = 1 + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + x = F.relu(self.conv3(x)) + x = F.relu(self.conv4(x)) + x = F.relu(self.conv5(x)) + x = self.conv6(x) + return x diff --git a/src/editor/models/smart_res.py b/src/editor/models/smart_res.py new file mode 100644 index 0000000..209791d --- /dev/null +++ b/src/editor/models/smart_res.py @@ -0,0 +1,42 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + def __init__(self, channels): + super(ResidualBlock, self).__init__() + self.conv1 = nn.Conv3d(channels, channels, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm3d(channels) + self.conv2 = nn.Conv3d(channels, channels, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm3d(channels) + + def forward(self, x): + identity = x + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += identity + return F.relu(out) + + +class SmartRes(nn.Module): + def __init__(self, bin_count): + super(SmartRes, self).__init__() + self.bin_count = bin_count + self.initial_conv = nn.Conv3d(1, 16, kernel_size=3, padding=1) + self.bn0 = nn.BatchNorm3d(16) + self.resblock1 = ResidualBlock(16) + self.resblock2 = ResidualBlock(16) + self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm3d(32) + self.dilated_conv = nn.Conv3d(32, 32, kernel_size=3, padding=2, dilation=2) + self.bn_dilated = nn.BatchNorm3d(32) + self.final_conv = nn.Conv3d(32, 1, kernel_size=3, padding=1) + + def forward(self, x): + x = F.relu(self.bn0(self.initial_conv(x))) + x = self.resblock1(x) + x = self.resblock2(x) + x = F.relu(self.bn2(self.conv2(x))) + x = F.relu(self.bn_dilated(self.dilated_conv(x))) + x = self.final_conv(x) + return x diff --git a/src/train.ipynb b/src/train.ipynb index 8806e88..924b80c 100644 --- a/src/train.ipynb +++ b/src/train.ipynb @@ -36,23 +36,32 @@ "metadata": {}, "outputs": [], "source": [ + "from scipy.stats import loguniform, uniform\n", + "\n", + "\n", "common_hyperparameters = {\n", " \"batch_size\": [8, 16, 32, 64],\n", - " \"edit_count\": [4, 8, 16, 32],\n", - " \"bin_count\": [32],\n", + " \"edit_count\": [8, 16, 32],\n", + " \"bin_count\": [16, 32, 64],\n", " \"clip_gradients\": [True, False],\n", - " \"learning_rate\": [0.0001],\n", - " \"scheduler_gamma\": [0.1, 0.9],\n", - " \"num_epochs\": [20],\n", - " \"model_type\": [\"v1\"],\n", + " \"learning_rate\": loguniform(0.00001, 0.005),\n", + " \"scheduler_gamma\": uniform(0.1, 0.9),\n", + " \"num_epochs\": [10],\n", + " \"model_type\": [\n", + " \"NormalisedCNN\",\n", + " \"SimpleCNN\",\n", + " \"Residual\",\n", + " \"SmartRes\",\n", + " \"Res2\",\n", + " ],\n", "}\n", "hyperparameters = [\n", - " {\n", - " **common_hyperparameters,\n", - " \"loss\": [\"progressive\"],\n", - " \"loss_sizes\": [[4, 8, 16, 32], [8, 16, 32], [16, 32], [8, 32]],\n", - " \"loss_damping\": [1, 2, 3, 4, 5],\n", - " },\n", + " # {\n", + " # **common_hyperparameters,\n", + " # \"loss\": [\"progressive\"],\n", + " # \"loss_sizes\": [[4, 8, 16, 32], [8, 16, 32], [16, 32], [8, 32]],\n", + " # \"loss_damping\": uniform(0.2, 5),\n", + " # },\n", " {\n", " **common_hyperparameters,\n", " \"loss\": [\"kl\"],\n", @@ -81,7 +90,7 @@ "\n", "\n", "def train(hyperparameters: Dict[str, Any]) -> Path:\n", - " model_path = (MODELS_PATH / get_next_run_name(Path(\"runs\"))).with_suffix(\"pth\")\n", + " model_path = (MODELS_PATH / get_next_run_name(Path(\"runs\"))).with_suffix(\".pth\")\n", "\n", " log_dir = Path(\"runs\") / get_next_run_name(Path(\"runs\"))\n", " with SummaryWriter(log_dir) as writer:\n", @@ -129,19 +138,21 @@ "\n", " optimizer.zero_grad()\n", " predicted_original = model(edited_histogram)\n", + " sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)\n", + " predicted_original = predicted_original / sum\n", "\n", " if hyperparameters[\"loss\"] == \"kl\":\n", " predicted_original = torch.clamp(\n", - " predicted_original, 0.0000000000000000000000001, 1\n", + " predicted_original, 0.0000000000000000000001, 1\n", " )\n", "\n", " loss = {\n", " \"kl\": lambda: loss_function(\n", - " torch.log(predicted_original.unsqueeze(1)),\n", + " torch.log(predicted_original),\n", " original_histogram,\n", " ),\n", " \"progressive\": lambda: loss_function(\n", - " predicted_original.unsqueeze(1), original_histogram\n", + " predicted_original, original_histogram\n", " ),\n", " }[hyperparameters[\"loss\"]]()\n", "\n", @@ -175,6 +186,8 @@ " edited_histogram = edited_histogram.to(device)\n", " original_histogram = original_histogram.to(device)\n", " predicted_original = model(edited_histogram)\n", + " sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)\n", + " predicted_original = predicted_original / sum\n", " writer.add_figure(\n", " \"histogram\",\n", " plot_histograms_in_2d(\n", @@ -192,32 +205,60 @@ " )\n", " model.train()\n", " scheduler.step()\n", + " except Exception as e:\n", + " raise\n", " finally:\n", " torch.save(model.state_dict(), model_path)\n", - " return model_path" + " del model\n", + " torch.cuda.empty_cache()\n", + " return model_path" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, + "outputs": [], + "source": [ + "# train(\n", + "# {\n", + "# \"batch_size\": 64,\n", + "# \"edit_count\": 25,\n", + "# \"bin_count\": 32,\n", + "# \"clip_gradients\": True,\n", + "# \"learning_rate\": 0.005,\n", + "# \"scheduler_gamma\": 0.7,\n", + "# \"num_epochs\": 20,\n", + "# \"model_type\": \"NormalisedCNN\",\n", + "# \"loss\": \"progressive\",\n", + "# \"loss_sizes\": [16, 32],\n", + "# \"loss_damping\": 2,\n", + "# }\n", + "# )" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-12 19:44:36,319 - INFO - Loaded 561982 training images and 62443 test images\n" + "2024-05-12 21:54:49,789 - INFO - Starting run_71 with hparams {'batch_size': 8, 'edit_count': 8, 'bin_count': 32, 'clip_gradients': True, 'learning_rate': 0.0001, 'scheduler_gamma': 0.5, 'num_epochs': 2, 'model_type': 'Res2', 'loss': 'kl'}\n", + "2024-05-12 21:54:49,792 - INFO - Loaded 72 training images and 8 test images\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "866f474ead21440e81dc91f2d0e55046", + "model_id": "461659956f3944a085b5cc3a5af6ec31", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Epoch 0: 0%| | 0/17562 [00:00 99\u001b[0m \u001b[43mwriter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd_hparams\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 100\u001b[0m \u001b[43m \u001b[49m\u001b[43mcurrent_hyperparameters\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 101\u001b[0m \u001b[43m \u001b[49m\u001b[43m{\u001b[49m\n\u001b[1;32m 102\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mLoss/train/epoch\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch_loss\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 103\u001b[0m \u001b[43m \u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 104\u001b[0m \u001b[43m \u001b[49m\u001b[43mglobal_step\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 105\u001b[0m \u001b[43m \u001b[49m\u001b[43mrun_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlog_dir\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mabsolute\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 106\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 107\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m 108\u001b[0m model\u001b[38;5;241m.\u001b[39meval()\n", - "File \u001b[0;32m~/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/torch/utils/tensorboard/writer.py:341\u001b[0m, in \u001b[0;36mSummaryWriter.add_hparams\u001b[0;34m(self, hparam_dict, metric_dict, hparam_domain_discrete, run_name, global_step)\u001b[0m\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(hparam_dict) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mdict\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(metric_dict) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mdict\u001b[39m:\n\u001b[1;32m 340\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhparam_dict and metric_dict should be dictionary.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 341\u001b[0m exp, ssi, sei \u001b[38;5;241m=\u001b[39m \u001b[43mhparams\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhparam_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetric_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhparam_domain_discrete\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 343\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m run_name:\n\u001b[1;32m 344\u001b[0m run_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mstr\u001b[39m(time\u001b[38;5;241m.\u001b[39mtime())\n", - "File \u001b[0;32m~/miniconda3/envs/bipolaroid/lib/python3.12/site-packages/torch/utils/tensorboard/summary.py:316\u001b[0m, in \u001b[0;36mhparams\u001b[0;34m(hparam_dict, metric_dict, hparam_domain_discrete)\u001b[0m\n\u001b[1;32m 314\u001b[0m hps\u001b[38;5;241m.\u001b[39mappend(HParamInfo(name\u001b[38;5;241m=\u001b[39mk, \u001b[38;5;28mtype\u001b[39m\u001b[38;5;241m=\u001b[39mDataType\u001b[38;5;241m.\u001b[39mValue(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDATA_TYPE_FLOAT64\u001b[39m\u001b[38;5;124m\"\u001b[39m)))\n\u001b[1;32m 315\u001b[0m \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[0;32m--> 316\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 317\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalue should be one of int, float, str, bool, or torch.Tensor\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 318\u001b[0m )\n\u001b[1;32m 320\u001b[0m content \u001b[38;5;241m=\u001b[39m HParamsPluginData(session_start_info\u001b[38;5;241m=\u001b[39mssi, version\u001b[38;5;241m=\u001b[39mPLUGIN_DATA_VERSION)\n\u001b[1;32m 321\u001b[0m smd \u001b[38;5;241m=\u001b[39m SummaryMetadata(\n\u001b[1;32m 322\u001b[0m plugin_data\u001b[38;5;241m=\u001b[39mSummaryMetadata\u001b[38;5;241m.\u001b[39mPluginData(\n\u001b[1;32m 323\u001b[0m plugin_name\u001b[38;5;241m=\u001b[39mPLUGIN_NAME, content\u001b[38;5;241m=\u001b[39mcontent\u001b[38;5;241m.\u001b[39mSerializeToString()\n\u001b[1;32m 324\u001b[0m )\n\u001b[1;32m 325\u001b[0m )\n", - "\u001b[0;31mValueError\u001b[0m: value should be one of int, float, str, bool, or torch.Tensor" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "707b28e45e3448fc9e449624c9a8467e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Epoch 1: 0%| | 0/9 [00:00