Update networks

This commit is contained in:
Andras Schmelczer 2024-06-22 15:01:02 +01:00
parent 0e4fe7ab63
commit b7755bfea8
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
9 changed files with 173 additions and 144 deletions

View file

@ -2,13 +2,9 @@ from typing import Any, Dict, Tuple
from .v1 import HistogramRestorationNet as v1
from .simple_cnn import SimpleCNN
from .residual import Residual
from .normalised_cnn import NormalisedCNN
from .smart_res import SmartRes
from .attention_net import AttentionNet
from .res2 import Res2
from .attention2 import EnhancedAestheticHistogramNet
from .attention import PhotoEnhanceNetAdvanced
from .advanced_attention import PhotoEnhanceNetAdvanced as advanced_attention
from .residual2 import Residual2
from .residual3 import Residual3
from .dummy import Dummy
import torch
import torch.nn as nn
from pathlib import Path
@ -17,16 +13,12 @@ import json
MODELS = {
# "v1": v1,
# "SimpleCNN": SimpleCNN,
"v1": v1,
"Dummy": Dummy,
"SimpleCNN": SimpleCNN,
"Residual": Residual,
# "NormalisedCNN": NormalisedCNN,
# "SmartRes": SmartRes,
# "AttentionNet": AttentionNet,
# "attention2": EnhancedAestheticHistogramNet,
# "advanced_attention": advanced_attention,
# "Res2": Res2,
# "attention1": PhotoEnhanceNetAdvanced,
# "Residual2": Residual2,
"Residual3": Residual3,
}
@ -39,6 +31,7 @@ def save_model(model: nn.Module, hyperparameters: Dict[str, Any], path: Path):
params_path = path.with_suffix(".json")
logging.info(f"Saving model to {model_path}")
logging.info(f"Parameter count: {sum(p.numel() for p in model.parameters())}")
with open(model_path, "wb") as f:
torch.save(model.state_dict(), f)
with open(params_path, "w") as f:
@ -60,11 +53,12 @@ def load_model(path: Path, device: torch.device) -> Tuple[nn.Module, Dict[str, A
).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()
logging.info(f"Parameter count: {sum(p.numel() for p in model.parameters())}")
return model, hyperparameters
def _test_models():
def test_models():
for model_name, model_constructor in MODELS.items():
logging.info(f"Testing model {model_name}")
_test_network_dimensions(model_constructor)
@ -72,18 +66,13 @@ def _test_models():
def _test_network_dimensions(constructor):
for bin_count in [16, 32, 64]:
model = constructor(bin_count=bin_count)
model = constructor()
# Create a dummy input tensor of the correct shape
# Create a dummy input tensor of the correct shape, the mini-batch size is 4
input_tensor = torch.rand(4, 1, bin_count, bin_count, bin_count)
# Test the model output
output = model(input_tensor)
assert (
input_tensor.shape == output.shape
), f"Expected output shape {input_tensor.shape}, but got {output.shape}"
print("Test passed! Output shape matches input shape.")
if __name__ == "__main__":
_test_models()
logging.info("Test passed! Output shape matches input shape.")

View file

@ -0,0 +1,10 @@
import torch.nn as nn
import torch.nn.functional as F
class Dummy(nn.Module):
def __init__(self, **_):
super(Dummy, self).__init__()
def forward(self, x):
return x

View file

@ -1,34 +0,0 @@
import torch.nn as nn
import torch.nn.functional as F
class NormalisedCNN(nn.Module):
def __init__(self, bin_count):
super(NormalisedCNN, self).__init__()
self.bin_count = bin_count
# Define the layers of the neural network
self.conv1 = nn.Conv3d(1, 16, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm3d(16)
self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm3d(32)
self.conv3 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm3d(64)
self.conv4 = nn.Conv3d(64, 32, kernel_size=3, padding=1)
self.bn4 = nn.BatchNorm3d(32)
self.conv5 = nn.Conv3d(32, 16, kernel_size=3, padding=1)
self.bn5 = nn.BatchNorm3d(16)
self.conv6 = nn.Conv3d(16, 1, kernel_size=3, padding=1)
def forward(self, x):
x = x.view(
-1, 1, self.bin_count, self.bin_count, self.bin_count
) # Reshape input to (N, C, D, H, W)
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = F.relu(self.bn4(self.conv4(x)))
x = F.relu(self.bn5(self.conv5(x)))
x = self.conv6(x)
return x

View file

@ -1,37 +0,0 @@
import torch.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv = nn.Sequential(
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
nn.BatchNorm3d(channels),
nn.ReLU(inplace=True),
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
nn.BatchNorm3d(channels),
)
def forward(self, x):
return self.conv(x) + x
# Define the network
class Res2(nn.Module):
def __init__(self, bin_count):
super(Res2, self).__init__()
self.input_layer = nn.Sequential(
nn.Conv3d(1, 16, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm3d(16),
)
self.res_blocks = nn.Sequential(
ResidualBlock(16), ResidualBlock(16), ResidualBlock(16), ResidualBlock(16)
)
self.output_layer = nn.Conv3d(16, 1, kernel_size=3, padding=1)
def forward(self, x):
x = self.input_layer(x)
x = self.res_blocks(x)
x = self.output_layer(x)
return x

View file

@ -2,9 +2,8 @@ import torch.nn as nn
class Residual(nn.Module):
def __init__(self, bin_count: int):
def __init__(self, **_):
super(Residual, self).__init__()
self.bin_count = bin_count
# Assuming the input histograms are 3D tensors of shape (bin_count, bin_count, bin_count)
# Convolutional layers to extract features from the histograms

View file

@ -0,0 +1,146 @@
import torch
import torch.nn as nn
class DepthwiseSeparableConv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding):
super(DepthwiseSeparableConv3d, self).__init__()
self.depthwise = nn.Conv3d(
in_channels,
in_channels,
kernel_size=kernel_size,
padding=padding,
groups=in_channels,
)
self.pointwise = nn.Conv3d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
x = self.depthwise(x)
x = self.pointwise(x)
return x
class Residual3(nn.Module):
def __init__(
self,
elu_alpha: float = 1,
dropout_prob: float = 0.1,
use_depthwise_separable_conv: bool = False,
feature_map_sizes: list[int] = [16, 32, 64],
kernel_sizes: list[int] = [3, 3, 3],
):
super(Residual3, self).__init__()
conv = DepthwiseSeparableConv3d if use_depthwise_separable_conv else nn.Conv3d
# Assuming the input histograms are 3D tensors of shape (bin_count, bin_count, bin_count)
# Convolutional layers to extract features from the histograms
self.conv1 = conv(
1, feature_map_sizes[0], kernel_size=kernel_sizes[0], padding=1
)
self.conv2 = conv(
feature_map_sizes[0],
feature_map_sizes[1],
kernel_size=kernel_sizes[1],
padding=1,
)
self.conv3 = conv(
feature_map_sizes[1],
feature_map_sizes[2],
kernel_size=kernel_sizes[2],
padding=1,
)
self.activation = nn.ELU(elu_alpha, inplace=True)
self.bn1 = nn.BatchNorm3d(feature_map_sizes[0])
self.bn2 = nn.BatchNorm3d(feature_map_sizes[1])
self.bn3 = nn.BatchNorm3d(feature_map_sizes[2])
self.resblock1 = nn.Sequential(
conv(
feature_map_sizes[2],
feature_map_sizes[2],
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
nn.ELU(elu_alpha, inplace=True),
nn.BatchNorm3d(feature_map_sizes[2]),
conv(
feature_map_sizes[2],
feature_map_sizes[2],
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
nn.ELU(elu_alpha, inplace=True),
nn.BatchNorm3d(feature_map_sizes[2]),
)
# Deconvolutional layers
self.deconv1 = nn.ConvTranspose3d(
feature_map_sizes[2],
feature_map_sizes[1],
kernel_size=feature_map_sizes[2],
stride=1,
padding=1,
)
self.deconv2 = nn.ConvTranspose3d(
feature_map_sizes[1],
feature_map_sizes[0],
kernel_size=feature_map_sizes[1],
stride=1,
padding=1,
)
self.deconv3 = nn.ConvTranspose3d(
feature_map_sizes[0], 1, kernel_size=3, stride=1, padding=1
)
self.dropout = nn.Dropout3d(p=dropout_prob)
self._initialize_weights()
def forward(self, x):
out = self.dropout(self.bn1(self.activation(self.conv1(x))))
out = self.dropout(self.bn2(self.activation(self.conv2(out))))
out = self.dropout(self.bn3(self.activation(self.conv2(out))))
out = out + self.resblock1(out)
out = self.activation(self.deconv1(out))
out = self.activation(self.deconv2(out))
out = self.activation(self.deconv3(out))
return self._normalize(out)
def _normalize(self, x):
x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
return x / torch.where(x_sum == 0, torch.ones_like(x_sum), x_sum)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
nn.init.xavier_normal_(
m.weight,
)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _test_network_dimensions(constructor):
for bin_count in [16, 32, 64]:
model = constructor()
# Create a dummy input tensor of the correct shape
input_tensor = torch.rand(4, 1, bin_count, bin_count, bin_count)
# Test the model output
output = model(input_tensor)
assert (
input_tensor.shape == output.shape
), f"Expected output shape {input_tensor.shape}, but got {output.shape}"
_test_network_dimensions(Residual3)

View file

@ -3,10 +3,8 @@ import torch.nn.functional as F
class SimpleCNN(nn.Module):
def __init__(self, bin_count):
def __init__(self, **_):
super(SimpleCNN, self).__init__()
self.bin_count = bin_count
# Define the convolutional layers
self.conv1 = nn.Conv3d(
1, 16, kernel_size=3, padding=1

View file

@ -1,42 +0,0 @@
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv3d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm3d(channels)
self.conv2 = nn.Conv3d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm3d(channels)
def forward(self, x):
identity = x
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += identity
return F.relu(out)
class SmartRes(nn.Module):
def __init__(self, bin_count):
super(SmartRes, self).__init__()
self.bin_count = bin_count
self.initial_conv = nn.Conv3d(1, 16, kernel_size=3, padding=1)
self.bn0 = nn.BatchNorm3d(16)
self.resblock1 = ResidualBlock(16)
self.resblock2 = ResidualBlock(16)
self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm3d(32)
self.dilated_conv = nn.Conv3d(32, 32, kernel_size=3, padding=2, dilation=2)
self.bn_dilated = nn.BatchNorm3d(32)
self.final_conv = nn.Conv3d(32, 1, kernel_size=3, padding=1)
def forward(self, x):
x = F.relu(self.bn0(self.initial_conv(x)))
x = self.resblock1(x)
x = self.resblock2(x)
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn_dilated(self.dilated_conv(x)))
x = self.final_conv(x)
return x

View file

@ -4,7 +4,7 @@ import torch
class HistogramRestorationNet(nn.Module):
def __init__(self, bin_count: int):
def __init__(self, **_):
super(HistogramRestorationNet, self).__init__()
self.conv1 = nn.Conv3d(in_channels=1, out_channels=16, kernel_size=3, padding=1)