diff --git a/src/editor/models/__init__.py b/src/editor/models/__init__.py index ae72fb6..d519f7f 100644 --- a/src/editor/models/__init__.py +++ b/src/editor/models/__init__.py @@ -5,15 +5,46 @@ from .normalised_cnn import NormalisedCNN from .smart_res import SmartRes from .attention_net import AttentionNet from .res2 import Res2 +from .attention2 import EnhancedAestheticHistogramNet +from .attention import PhotoEnhanceNetAdvanced +from .advanced_attention import PhotoEnhanceNetAdvanced as advanced_attention +import torch + + +MODELS = { + # "v1": v1, + "SimpleCNN": SimpleCNN, + "Residual": Residual, + "NormalisedCNN": NormalisedCNN, + "SmartRes": SmartRes, + # "AttentionNet": AttentionNet, + "attention2": EnhancedAestheticHistogramNet, + "advanced_attention": advanced_attention, + "Res2": Res2, + "attention1": PhotoEnhanceNetAdvanced, +} def create_model(type: str, bin_count: int): - return { - # "v1": v1, - "SimpleCNN": SimpleCNN, - "Residual": Residual, - "NormalisedCNN": NormalisedCNN, - "SmartRes": SmartRes, - "AttentionNet": AttentionNet, - "Res2": Res2, - }[type](bin_count) + return MODELS[type](bin_count) + + +def test_models(): + for model_name, model_constructor in MODELS.items(): + print(f"Testing model {model_name}") + _test_network_dimensions(model_constructor) + + +def _test_network_dimensions(constructor): + for bin_count in [16, 32, 64]: + model = constructor(bin_count=bin_count) + + # Create a dummy input tensor of the correct shape + input_tensor = torch.rand(4, 1, bin_count, bin_count, bin_count) + + # Test the model output + output = model(input_tensor) + assert ( + input_tensor.shape == output.shape + ), f"Expected output shape {input_tensor.shape}, but got {output.shape}" + print("Test passed! Output shape matches input shape.") diff --git a/src/editor/models/advanced_attention.py b/src/editor/models/advanced_attention.py new file mode 100644 index 0000000..0192325 --- /dev/null +++ b/src/editor/models/advanced_attention.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, padding=1): + super(ConvBlock, self).__init__() + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding) + self.bn = nn.BatchNorm3d(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + return self.relu(self.bn(self.conv(x))) + + +class ChannelAttention(nn.Module): + def __init__(self, channels, reduction=16): + super(ChannelAttention, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool3d(1) + self.fc = nn.Sequential( + nn.Linear(channels, channels // reduction, bias=False), + nn.ReLU(), + nn.Linear(channels // reduction, channels, bias=False), + nn.Sigmoid(), + ) + + def forward(self, x): + b, c, _, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1, 1) + return x * y.expand_as(x) + + +class PhotoEnhanceNetAdvanced(nn.Module): + def __init__(self, bin_count): + super(PhotoEnhanceNetAdvanced, self).__init__() + self.features = nn.Sequential( + ConvBlock(1, 16), ConvBlock(16, 32), ConvBlock(32, 64), ConvBlock(64, 128) + ) + self.channel_attention = ChannelAttention(128) + self.final_conv = nn.Conv3d(128, 1, kernel_size=1) # Reduce channel size to 1 + + def forward(self, x): + x = self.features(x) + x = self.channel_attention(x) + x = self.final_conv(x) # Final reduction to match the input channel dimensions + return x diff --git a/src/editor/models/attention.py b/src/editor/models/attention.py new file mode 100644 index 0000000..0f805f4 --- /dev/null +++ b/src/editor/models/attention.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class PhotoEnhanceNetAdvanced(nn.Module): + def __init__(self, bin_count): + super(PhotoEnhanceNetAdvanced, self).__init__() + self.bin_count = bin_count + + # Enhance complexity of the network + self.features = nn.Sequential( + nn.Conv3d(1, 16, kernel_size=3, padding=1), + nn.BatchNorm3d(16), + nn.ReLU(), + nn.Conv3d(16, 32, kernel_size=3, padding=1), + nn.BatchNorm3d(32), + nn.ReLU(), + nn.Conv3d(32, 64, kernel_size=3, padding=1), + nn.BatchNorm3d(64), + nn.ReLU(), + nn.Conv3d(64, 128, kernel_size=3, padding=1), + nn.BatchNorm3d(128), + nn.ReLU(), + ) + + # Adjusted attention layer to match features channel size + self.attention = nn.Sequential( + nn.Conv3d( + 128, 128, 1 + ), # Ensure the attention map has the same number of channels + nn.Sigmoid(), + ) + + # Using dense connections + self.dense = nn.Sequential( + nn.Conv3d( + 256, 192, kernel_size=3, padding=1 + ), # Adjust input channels to account for concatenated layers + nn.BatchNorm3d(192), + nn.ReLU(), + nn.Conv3d(192, 1, kernel_size=1), + ) + + def forward(self, x): + features = self.features(x) + + # Apply attention + attention = self.attention(features) + x = features * attention # Element-wise multiplication + + # Concatenate for dense connection (skip connection) + x = torch.cat((features, attention), dim=1) # Combining feature maps + + x = self.dense(x) + return x diff --git a/src/editor/models/attention2.py b/src/editor/models/attention2.py new file mode 100644 index 0000000..fa6dfaa --- /dev/null +++ b/src/editor/models/attention2.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class EnhancedAestheticHistogramNet(nn.Module): + def __init__(self, bin_count): + super(EnhancedAestheticHistogramNet, self).__init__() + self.bin_count = bin_count + + # Initial convolution layer + self.initial_conv = nn.Conv3d(1, 32, kernel_size=3, stride=1, padding=1) + self.initial_bn = nn.BatchNorm3d(32) + self.initial_relu = nn.ReLU() + + # Deeper convolutional layers with increasing channels and dilation for expanded receptive field + self.conv1 = self._make_layer(32, 64, dilation=1) + self.conv2 = self._make_layer(64, 128, dilation=2) + self.conv3 = self._make_layer(128, 256, dilation=4) + + # Attention module + self.attention = nn.Sequential( + nn.Conv3d(256, 256, kernel_size=1), # Pointwise convolution + nn.BatchNorm3d(256), + nn.ReLU(), + nn.Conv3d(256, 256, kernel_size=1), # Pointwise convolution + nn.Sigmoid(), + ) + + # Correctly adjusted residual connections + self.res1 = nn.Conv3d( + 1, 256, kernel_size=1 + ) # Match initial input channels to later layers + self.res2 = nn.Conv3d(128, 256, kernel_size=1) # Match output of conv2 to conv3 + + # Final convolution to bring channels back to 1 + self.final_conv = nn.Conv3d(256, 1, kernel_size=3, stride=1, padding=1) + + def _make_layer(self, in_channels, out_channels, dilation): + layer = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size=3, + padding=dilation, + dilation=dilation, + ), + nn.BatchNorm3d(out_channels), + nn.ReLU(), + ) + return layer + + def forward(self, x): + identity1 = self.res1(x) # First skip connection + + out = self.initial_relu(self.initial_bn(self.initial_conv(x))) + out = self.conv1(out) + out = self.conv2(out) + + identity2 = self.res2(out) # Second skip connection + + out = self.conv3(out) + out = self.attention(out) * out # Apply attention + + out += identity2 # Add from second skip connection + out += identity1 # Add from first skip connection + + out = self.final_conv(out) + return out