More models

This commit is contained in:
Andras Schmelczer 2024-06-03 07:48:50 +01:00
parent bd7033c3eb
commit a6a15ec650
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
4 changed files with 213 additions and 9 deletions

View file

@ -5,15 +5,46 @@ from .normalised_cnn import NormalisedCNN
from .smart_res import SmartRes
from .attention_net import AttentionNet
from .res2 import Res2
from .attention2 import EnhancedAestheticHistogramNet
from .attention import PhotoEnhanceNetAdvanced
from .advanced_attention import PhotoEnhanceNetAdvanced as advanced_attention
import torch
MODELS = {
# "v1": v1,
"SimpleCNN": SimpleCNN,
"Residual": Residual,
"NormalisedCNN": NormalisedCNN,
"SmartRes": SmartRes,
# "AttentionNet": AttentionNet,
"attention2": EnhancedAestheticHistogramNet,
"advanced_attention": advanced_attention,
"Res2": Res2,
"attention1": PhotoEnhanceNetAdvanced,
}
def create_model(type: str, bin_count: int):
return {
# "v1": v1,
"SimpleCNN": SimpleCNN,
"Residual": Residual,
"NormalisedCNN": NormalisedCNN,
"SmartRes": SmartRes,
"AttentionNet": AttentionNet,
"Res2": Res2,
}[type](bin_count)
return MODELS[type](bin_count)
def test_models():
for model_name, model_constructor in MODELS.items():
print(f"Testing model {model_name}")
_test_network_dimensions(model_constructor)
def _test_network_dimensions(constructor):
for bin_count in [16, 32, 64]:
model = constructor(bin_count=bin_count)
# Create a dummy input tensor of the correct shape
input_tensor = torch.rand(4, 1, bin_count, bin_count, bin_count)
# Test the model output
output = model(input_tensor)
assert (
input_tensor.shape == output.shape
), f"Expected output shape {input_tensor.shape}, but got {output.shape}"
print("Test passed! Output shape matches input shape.")

View file

@ -0,0 +1,48 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
super(ConvBlock, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding)
self.bn = nn.BatchNorm3d(out_channels)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
class ChannelAttention(nn.Module):
def __init__(self, channels, reduction=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool3d(1)
self.fc = nn.Sequential(
nn.Linear(channels, channels // reduction, bias=False),
nn.ReLU(),
nn.Linear(channels // reduction, channels, bias=False),
nn.Sigmoid(),
)
def forward(self, x):
b, c, _, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1, 1)
return x * y.expand_as(x)
class PhotoEnhanceNetAdvanced(nn.Module):
def __init__(self, bin_count):
super(PhotoEnhanceNetAdvanced, self).__init__()
self.features = nn.Sequential(
ConvBlock(1, 16), ConvBlock(16, 32), ConvBlock(32, 64), ConvBlock(64, 128)
)
self.channel_attention = ChannelAttention(128)
self.final_conv = nn.Conv3d(128, 1, kernel_size=1) # Reduce channel size to 1
def forward(self, x):
x = self.features(x)
x = self.channel_attention(x)
x = self.final_conv(x) # Final reduction to match the input channel dimensions
return x

View file

@ -0,0 +1,56 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class PhotoEnhanceNetAdvanced(nn.Module):
def __init__(self, bin_count):
super(PhotoEnhanceNetAdvanced, self).__init__()
self.bin_count = bin_count
# Enhance complexity of the network
self.features = nn.Sequential(
nn.Conv3d(1, 16, kernel_size=3, padding=1),
nn.BatchNorm3d(16),
nn.ReLU(),
nn.Conv3d(16, 32, kernel_size=3, padding=1),
nn.BatchNorm3d(32),
nn.ReLU(),
nn.Conv3d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm3d(64),
nn.ReLU(),
nn.Conv3d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm3d(128),
nn.ReLU(),
)
# Adjusted attention layer to match features channel size
self.attention = nn.Sequential(
nn.Conv3d(
128, 128, 1
), # Ensure the attention map has the same number of channels
nn.Sigmoid(),
)
# Using dense connections
self.dense = nn.Sequential(
nn.Conv3d(
256, 192, kernel_size=3, padding=1
), # Adjust input channels to account for concatenated layers
nn.BatchNorm3d(192),
nn.ReLU(),
nn.Conv3d(192, 1, kernel_size=1),
)
def forward(self, x):
features = self.features(x)
# Apply attention
attention = self.attention(features)
x = features * attention # Element-wise multiplication
# Concatenate for dense connection (skip connection)
x = torch.cat((features, attention), dim=1) # Combining feature maps
x = self.dense(x)
return x

View file

@ -0,0 +1,69 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class EnhancedAestheticHistogramNet(nn.Module):
def __init__(self, bin_count):
super(EnhancedAestheticHistogramNet, self).__init__()
self.bin_count = bin_count
# Initial convolution layer
self.initial_conv = nn.Conv3d(1, 32, kernel_size=3, stride=1, padding=1)
self.initial_bn = nn.BatchNorm3d(32)
self.initial_relu = nn.ReLU()
# Deeper convolutional layers with increasing channels and dilation for expanded receptive field
self.conv1 = self._make_layer(32, 64, dilation=1)
self.conv2 = self._make_layer(64, 128, dilation=2)
self.conv3 = self._make_layer(128, 256, dilation=4)
# Attention module
self.attention = nn.Sequential(
nn.Conv3d(256, 256, kernel_size=1), # Pointwise convolution
nn.BatchNorm3d(256),
nn.ReLU(),
nn.Conv3d(256, 256, kernel_size=1), # Pointwise convolution
nn.Sigmoid(),
)
# Correctly adjusted residual connections
self.res1 = nn.Conv3d(
1, 256, kernel_size=1
) # Match initial input channels to later layers
self.res2 = nn.Conv3d(128, 256, kernel_size=1) # Match output of conv2 to conv3
# Final convolution to bring channels back to 1
self.final_conv = nn.Conv3d(256, 1, kernel_size=3, stride=1, padding=1)
def _make_layer(self, in_channels, out_channels, dilation):
layer = nn.Sequential(
nn.Conv3d(
in_channels,
out_channels,
kernel_size=3,
padding=dilation,
dilation=dilation,
),
nn.BatchNorm3d(out_channels),
nn.ReLU(),
)
return layer
def forward(self, x):
identity1 = self.res1(x) # First skip connection
out = self.initial_relu(self.initial_bn(self.initial_conv(x)))
out = self.conv1(out)
out = self.conv2(out)
identity2 = self.res2(out) # Second skip connection
out = self.conv3(out)
out = self.attention(out) * out # Apply attention
out += identity2 # Add from second skip connection
out += identity1 # Add from first skip connection
out = self.final_conv(out)
return out