More models
This commit is contained in:
parent
bd7033c3eb
commit
a6a15ec650
4 changed files with 213 additions and 9 deletions
|
|
@ -5,15 +5,46 @@ from .normalised_cnn import NormalisedCNN
|
|||
from .smart_res import SmartRes
|
||||
from .attention_net import AttentionNet
|
||||
from .res2 import Res2
|
||||
from .attention2 import EnhancedAestheticHistogramNet
|
||||
from .attention import PhotoEnhanceNetAdvanced
|
||||
from .advanced_attention import PhotoEnhanceNetAdvanced as advanced_attention
|
||||
import torch
|
||||
|
||||
|
||||
MODELS = {
|
||||
# "v1": v1,
|
||||
"SimpleCNN": SimpleCNN,
|
||||
"Residual": Residual,
|
||||
"NormalisedCNN": NormalisedCNN,
|
||||
"SmartRes": SmartRes,
|
||||
# "AttentionNet": AttentionNet,
|
||||
"attention2": EnhancedAestheticHistogramNet,
|
||||
"advanced_attention": advanced_attention,
|
||||
"Res2": Res2,
|
||||
"attention1": PhotoEnhanceNetAdvanced,
|
||||
}
|
||||
|
||||
|
||||
def create_model(type: str, bin_count: int):
|
||||
return {
|
||||
# "v1": v1,
|
||||
"SimpleCNN": SimpleCNN,
|
||||
"Residual": Residual,
|
||||
"NormalisedCNN": NormalisedCNN,
|
||||
"SmartRes": SmartRes,
|
||||
"AttentionNet": AttentionNet,
|
||||
"Res2": Res2,
|
||||
}[type](bin_count)
|
||||
return MODELS[type](bin_count)
|
||||
|
||||
|
||||
def test_models():
|
||||
for model_name, model_constructor in MODELS.items():
|
||||
print(f"Testing model {model_name}")
|
||||
_test_network_dimensions(model_constructor)
|
||||
|
||||
|
||||
def _test_network_dimensions(constructor):
|
||||
for bin_count in [16, 32, 64]:
|
||||
model = constructor(bin_count=bin_count)
|
||||
|
||||
# Create a dummy input tensor of the correct shape
|
||||
input_tensor = torch.rand(4, 1, bin_count, bin_count, bin_count)
|
||||
|
||||
# Test the model output
|
||||
output = model(input_tensor)
|
||||
assert (
|
||||
input_tensor.shape == output.shape
|
||||
), f"Expected output shape {input_tensor.shape}, but got {output.shape}"
|
||||
print("Test passed! Output shape matches input shape.")
|
||||
|
|
|
|||
48
src/editor/models/advanced_attention.py
Normal file
48
src/editor/models/advanced_attention.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding)
|
||||
self.bn = nn.BatchNorm3d(out_channels)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(self.bn(self.conv(x)))
|
||||
|
||||
|
||||
class ChannelAttention(nn.Module):
|
||||
def __init__(self, channels, reduction=16):
|
||||
super(ChannelAttention, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool3d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channels, channels // reduction, bias=False),
|
||||
nn.ReLU(),
|
||||
nn.Linear(channels // reduction, channels, bias=False),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _, _ = x.size()
|
||||
y = self.avg_pool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1, 1)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
class PhotoEnhanceNetAdvanced(nn.Module):
|
||||
def __init__(self, bin_count):
|
||||
super(PhotoEnhanceNetAdvanced, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
ConvBlock(1, 16), ConvBlock(16, 32), ConvBlock(32, 64), ConvBlock(64, 128)
|
||||
)
|
||||
self.channel_attention = ChannelAttention(128)
|
||||
self.final_conv = nn.Conv3d(128, 1, kernel_size=1) # Reduce channel size to 1
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.channel_attention(x)
|
||||
x = self.final_conv(x) # Final reduction to match the input channel dimensions
|
||||
return x
|
||||
56
src/editor/models/attention.py
Normal file
56
src/editor/models/attention.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class PhotoEnhanceNetAdvanced(nn.Module):
|
||||
def __init__(self, bin_count):
|
||||
super(PhotoEnhanceNetAdvanced, self).__init__()
|
||||
self.bin_count = bin_count
|
||||
|
||||
# Enhance complexity of the network
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv3d(1, 16, kernel_size=3, padding=1),
|
||||
nn.BatchNorm3d(16),
|
||||
nn.ReLU(),
|
||||
nn.Conv3d(16, 32, kernel_size=3, padding=1),
|
||||
nn.BatchNorm3d(32),
|
||||
nn.ReLU(),
|
||||
nn.Conv3d(32, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm3d(64),
|
||||
nn.ReLU(),
|
||||
nn.Conv3d(64, 128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm3d(128),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# Adjusted attention layer to match features channel size
|
||||
self.attention = nn.Sequential(
|
||||
nn.Conv3d(
|
||||
128, 128, 1
|
||||
), # Ensure the attention map has the same number of channels
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
# Using dense connections
|
||||
self.dense = nn.Sequential(
|
||||
nn.Conv3d(
|
||||
256, 192, kernel_size=3, padding=1
|
||||
), # Adjust input channels to account for concatenated layers
|
||||
nn.BatchNorm3d(192),
|
||||
nn.ReLU(),
|
||||
nn.Conv3d(192, 1, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
features = self.features(x)
|
||||
|
||||
# Apply attention
|
||||
attention = self.attention(features)
|
||||
x = features * attention # Element-wise multiplication
|
||||
|
||||
# Concatenate for dense connection (skip connection)
|
||||
x = torch.cat((features, attention), dim=1) # Combining feature maps
|
||||
|
||||
x = self.dense(x)
|
||||
return x
|
||||
69
src/editor/models/attention2.py
Normal file
69
src/editor/models/attention2.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class EnhancedAestheticHistogramNet(nn.Module):
|
||||
def __init__(self, bin_count):
|
||||
super(EnhancedAestheticHistogramNet, self).__init__()
|
||||
self.bin_count = bin_count
|
||||
|
||||
# Initial convolution layer
|
||||
self.initial_conv = nn.Conv3d(1, 32, kernel_size=3, stride=1, padding=1)
|
||||
self.initial_bn = nn.BatchNorm3d(32)
|
||||
self.initial_relu = nn.ReLU()
|
||||
|
||||
# Deeper convolutional layers with increasing channels and dilation for expanded receptive field
|
||||
self.conv1 = self._make_layer(32, 64, dilation=1)
|
||||
self.conv2 = self._make_layer(64, 128, dilation=2)
|
||||
self.conv3 = self._make_layer(128, 256, dilation=4)
|
||||
|
||||
# Attention module
|
||||
self.attention = nn.Sequential(
|
||||
nn.Conv3d(256, 256, kernel_size=1), # Pointwise convolution
|
||||
nn.BatchNorm3d(256),
|
||||
nn.ReLU(),
|
||||
nn.Conv3d(256, 256, kernel_size=1), # Pointwise convolution
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
# Correctly adjusted residual connections
|
||||
self.res1 = nn.Conv3d(
|
||||
1, 256, kernel_size=1
|
||||
) # Match initial input channels to later layers
|
||||
self.res2 = nn.Conv3d(128, 256, kernel_size=1) # Match output of conv2 to conv3
|
||||
|
||||
# Final convolution to bring channels back to 1
|
||||
self.final_conv = nn.Conv3d(256, 1, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def _make_layer(self, in_channels, out_channels, dilation):
|
||||
layer = nn.Sequential(
|
||||
nn.Conv3d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
),
|
||||
nn.BatchNorm3d(out_channels),
|
||||
nn.ReLU(),
|
||||
)
|
||||
return layer
|
||||
|
||||
def forward(self, x):
|
||||
identity1 = self.res1(x) # First skip connection
|
||||
|
||||
out = self.initial_relu(self.initial_bn(self.initial_conv(x)))
|
||||
out = self.conv1(out)
|
||||
out = self.conv2(out)
|
||||
|
||||
identity2 = self.res2(out) # Second skip connection
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.attention(out) * out # Apply attention
|
||||
|
||||
out += identity2 # Add from second skip connection
|
||||
out += identity1 # Add from first skip connection
|
||||
|
||||
out = self.final_conv(out)
|
||||
return out
|
||||
Loading…
Add table
Add a link
Reference in a new issue