Remove clutter

This commit is contained in:
Andras Schmelczer 2024-06-22 11:12:52 +01:00
parent 95829e6b1c
commit 0e4fe7ab63
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
4 changed files with 0 additions and 237 deletions

View file

@ -1,48 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
super(ConvBlock, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding)
self.bn = nn.BatchNorm3d(out_channels)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
class ChannelAttention(nn.Module):
def __init__(self, channels, reduction=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool3d(1)
self.fc = nn.Sequential(
nn.Linear(channels, channels // reduction, bias=False),
nn.ReLU(),
nn.Linear(channels // reduction, channels, bias=False),
nn.Sigmoid(),
)
def forward(self, x):
b, c, _, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1, 1)
return x * y.expand_as(x)
class PhotoEnhanceNetAdvanced(nn.Module):
def __init__(self, bin_count):
super(PhotoEnhanceNetAdvanced, self).__init__()
self.features = nn.Sequential(
ConvBlock(1, 16), ConvBlock(16, 32), ConvBlock(32, 64), ConvBlock(64, 128)
)
self.channel_attention = ChannelAttention(128)
self.final_conv = nn.Conv3d(128, 1, kernel_size=1) # Reduce channel size to 1
def forward(self, x):
x = self.features(x)
x = self.channel_attention(x)
x = self.final_conv(x) # Final reduction to match the input channel dimensions
return x

View file

@ -1,56 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class PhotoEnhanceNetAdvanced(nn.Module):
def __init__(self, bin_count):
super(PhotoEnhanceNetAdvanced, self).__init__()
self.bin_count = bin_count
# Enhance complexity of the network
self.features = nn.Sequential(
nn.Conv3d(1, 16, kernel_size=3, padding=1),
nn.BatchNorm3d(16),
nn.ReLU(),
nn.Conv3d(16, 32, kernel_size=3, padding=1),
nn.BatchNorm3d(32),
nn.ReLU(),
nn.Conv3d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm3d(64),
nn.ReLU(),
nn.Conv3d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm3d(128),
nn.ReLU(),
)
# Adjusted attention layer to match features channel size
self.attention = nn.Sequential(
nn.Conv3d(
128, 128, 1
), # Ensure the attention map has the same number of channels
nn.Sigmoid(),
)
# Using dense connections
self.dense = nn.Sequential(
nn.Conv3d(
256, 192, kernel_size=3, padding=1
), # Adjust input channels to account for concatenated layers
nn.BatchNorm3d(192),
nn.ReLU(),
nn.Conv3d(192, 1, kernel_size=1),
)
def forward(self, x):
features = self.features(x)
# Apply attention
attention = self.attention(features)
x = features * attention # Element-wise multiplication
# Concatenate for dense connection (skip connection)
x = torch.cat((features, attention), dim=1) # Combining feature maps
x = self.dense(x)
return x

View file

@ -1,69 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class EnhancedAestheticHistogramNet(nn.Module):
def __init__(self, bin_count):
super(EnhancedAestheticHistogramNet, self).__init__()
self.bin_count = bin_count
# Initial convolution layer
self.initial_conv = nn.Conv3d(1, 32, kernel_size=3, stride=1, padding=1)
self.initial_bn = nn.BatchNorm3d(32)
self.initial_relu = nn.ReLU()
# Deeper convolutional layers with increasing channels and dilation for expanded receptive field
self.conv1 = self._make_layer(32, 64, dilation=1)
self.conv2 = self._make_layer(64, 128, dilation=2)
self.conv3 = self._make_layer(128, 256, dilation=4)
# Attention module
self.attention = nn.Sequential(
nn.Conv3d(256, 256, kernel_size=1), # Pointwise convolution
nn.BatchNorm3d(256),
nn.ReLU(),
nn.Conv3d(256, 256, kernel_size=1), # Pointwise convolution
nn.Sigmoid(),
)
# Correctly adjusted residual connections
self.res1 = nn.Conv3d(
1, 256, kernel_size=1
) # Match initial input channels to later layers
self.res2 = nn.Conv3d(128, 256, kernel_size=1) # Match output of conv2 to conv3
# Final convolution to bring channels back to 1
self.final_conv = nn.Conv3d(256, 1, kernel_size=3, stride=1, padding=1)
def _make_layer(self, in_channels, out_channels, dilation):
layer = nn.Sequential(
nn.Conv3d(
in_channels,
out_channels,
kernel_size=3,
padding=dilation,
dilation=dilation,
),
nn.BatchNorm3d(out_channels),
nn.ReLU(),
)
return layer
def forward(self, x):
identity1 = self.res1(x) # First skip connection
out = self.initial_relu(self.initial_bn(self.initial_conv(x)))
out = self.conv1(out)
out = self.conv2(out)
identity2 = self.res2(out) # Second skip connection
out = self.conv3(out)
out = self.attention(out) * out # Apply attention
out += identity2 # Add from second skip connection
out += identity1 # Add from first skip connection
out = self.final_conv(out)
return out

View file

@ -1,64 +0,0 @@
import torch
import torch.nn as nn
# Define the self-attention module
class SelfAttention(nn.Module):
def __init__(self, channels):
super(SelfAttention, self).__init__()
self.query_conv = nn.Conv3d(channels, channels // 8, kernel_size=1)
self.key_conv = nn.Conv3d(channels, channels // 8, kernel_size=1)
self.value_conv = nn.Conv3d(channels, channels, kernel_size=1)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
batch_size, channels, depth, height, width = x.size()
query = (
self.query_conv(x)
.view(batch_size, -1, depth * height * width)
.permute(0, 2, 1)
)
key = self.key_conv(x).view(batch_size, -1, depth * height * width)
value = self.value_conv(x).view(batch_size, -1, depth * height * width)
attention = self.softmax(torch.bmm(query, key)) # Batch matrix multiplication
out = torch.bmm(value, attention.permute(0, 2, 1))
out = out.view(batch_size, channels, depth, height, width)
return x + out
# Define the residual block
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv = nn.Sequential(
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
nn.BatchNorm3d(channels),
nn.ReLU(inplace=True),
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
nn.BatchNorm3d(channels),
)
self.attention = SelfAttention(channels)
def forward(self, x):
return self.attention(self.conv(x)) + x
# Define the network
class AttentionNet(nn.Module):
def __init__(self, bin_count):
super(AttentionNet, self).__init__()
self.input_layer = nn.Sequential(
nn.Conv3d(1, 16, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm3d(16),
)
self.res_blocks = nn.Sequential(ResidualBlock(16), ResidualBlock(16))
self.output_layer = nn.Conv3d(16, 1, kernel_size=3, padding=1)
def forward(self, x):
x = self.input_layer(x)
x = self.res_blocks(x)
x = self.output_layer(x)
return x