Remove clutter
This commit is contained in:
parent
95829e6b1c
commit
0e4fe7ab63
4 changed files with 0 additions and 237 deletions
|
|
@ -1,48 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding)
|
||||
self.bn = nn.BatchNorm3d(out_channels)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(self.bn(self.conv(x)))
|
||||
|
||||
|
||||
class ChannelAttention(nn.Module):
|
||||
def __init__(self, channels, reduction=16):
|
||||
super(ChannelAttention, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool3d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channels, channels // reduction, bias=False),
|
||||
nn.ReLU(),
|
||||
nn.Linear(channels // reduction, channels, bias=False),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _, _ = x.size()
|
||||
y = self.avg_pool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1, 1)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
class PhotoEnhanceNetAdvanced(nn.Module):
|
||||
def __init__(self, bin_count):
|
||||
super(PhotoEnhanceNetAdvanced, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
ConvBlock(1, 16), ConvBlock(16, 32), ConvBlock(32, 64), ConvBlock(64, 128)
|
||||
)
|
||||
self.channel_attention = ChannelAttention(128)
|
||||
self.final_conv = nn.Conv3d(128, 1, kernel_size=1) # Reduce channel size to 1
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.channel_attention(x)
|
||||
x = self.final_conv(x) # Final reduction to match the input channel dimensions
|
||||
return x
|
||||
|
|
@ -1,56 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class PhotoEnhanceNetAdvanced(nn.Module):
|
||||
def __init__(self, bin_count):
|
||||
super(PhotoEnhanceNetAdvanced, self).__init__()
|
||||
self.bin_count = bin_count
|
||||
|
||||
# Enhance complexity of the network
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv3d(1, 16, kernel_size=3, padding=1),
|
||||
nn.BatchNorm3d(16),
|
||||
nn.ReLU(),
|
||||
nn.Conv3d(16, 32, kernel_size=3, padding=1),
|
||||
nn.BatchNorm3d(32),
|
||||
nn.ReLU(),
|
||||
nn.Conv3d(32, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm3d(64),
|
||||
nn.ReLU(),
|
||||
nn.Conv3d(64, 128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm3d(128),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# Adjusted attention layer to match features channel size
|
||||
self.attention = nn.Sequential(
|
||||
nn.Conv3d(
|
||||
128, 128, 1
|
||||
), # Ensure the attention map has the same number of channels
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
# Using dense connections
|
||||
self.dense = nn.Sequential(
|
||||
nn.Conv3d(
|
||||
256, 192, kernel_size=3, padding=1
|
||||
), # Adjust input channels to account for concatenated layers
|
||||
nn.BatchNorm3d(192),
|
||||
nn.ReLU(),
|
||||
nn.Conv3d(192, 1, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
features = self.features(x)
|
||||
|
||||
# Apply attention
|
||||
attention = self.attention(features)
|
||||
x = features * attention # Element-wise multiplication
|
||||
|
||||
# Concatenate for dense connection (skip connection)
|
||||
x = torch.cat((features, attention), dim=1) # Combining feature maps
|
||||
|
||||
x = self.dense(x)
|
||||
return x
|
||||
|
|
@ -1,69 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class EnhancedAestheticHistogramNet(nn.Module):
|
||||
def __init__(self, bin_count):
|
||||
super(EnhancedAestheticHistogramNet, self).__init__()
|
||||
self.bin_count = bin_count
|
||||
|
||||
# Initial convolution layer
|
||||
self.initial_conv = nn.Conv3d(1, 32, kernel_size=3, stride=1, padding=1)
|
||||
self.initial_bn = nn.BatchNorm3d(32)
|
||||
self.initial_relu = nn.ReLU()
|
||||
|
||||
# Deeper convolutional layers with increasing channels and dilation for expanded receptive field
|
||||
self.conv1 = self._make_layer(32, 64, dilation=1)
|
||||
self.conv2 = self._make_layer(64, 128, dilation=2)
|
||||
self.conv3 = self._make_layer(128, 256, dilation=4)
|
||||
|
||||
# Attention module
|
||||
self.attention = nn.Sequential(
|
||||
nn.Conv3d(256, 256, kernel_size=1), # Pointwise convolution
|
||||
nn.BatchNorm3d(256),
|
||||
nn.ReLU(),
|
||||
nn.Conv3d(256, 256, kernel_size=1), # Pointwise convolution
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
# Correctly adjusted residual connections
|
||||
self.res1 = nn.Conv3d(
|
||||
1, 256, kernel_size=1
|
||||
) # Match initial input channels to later layers
|
||||
self.res2 = nn.Conv3d(128, 256, kernel_size=1) # Match output of conv2 to conv3
|
||||
|
||||
# Final convolution to bring channels back to 1
|
||||
self.final_conv = nn.Conv3d(256, 1, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def _make_layer(self, in_channels, out_channels, dilation):
|
||||
layer = nn.Sequential(
|
||||
nn.Conv3d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
),
|
||||
nn.BatchNorm3d(out_channels),
|
||||
nn.ReLU(),
|
||||
)
|
||||
return layer
|
||||
|
||||
def forward(self, x):
|
||||
identity1 = self.res1(x) # First skip connection
|
||||
|
||||
out = self.initial_relu(self.initial_bn(self.initial_conv(x)))
|
||||
out = self.conv1(out)
|
||||
out = self.conv2(out)
|
||||
|
||||
identity2 = self.res2(out) # Second skip connection
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.attention(out) * out # Apply attention
|
||||
|
||||
out += identity2 # Add from second skip connection
|
||||
out += identity1 # Add from first skip connection
|
||||
|
||||
out = self.final_conv(out)
|
||||
return out
|
||||
|
|
@ -1,64 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# Define the self-attention module
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super(SelfAttention, self).__init__()
|
||||
self.query_conv = nn.Conv3d(channels, channels // 8, kernel_size=1)
|
||||
self.key_conv = nn.Conv3d(channels, channels // 8, kernel_size=1)
|
||||
self.value_conv = nn.Conv3d(channels, channels, kernel_size=1)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, channels, depth, height, width = x.size()
|
||||
query = (
|
||||
self.query_conv(x)
|
||||
.view(batch_size, -1, depth * height * width)
|
||||
.permute(0, 2, 1)
|
||||
)
|
||||
key = self.key_conv(x).view(batch_size, -1, depth * height * width)
|
||||
value = self.value_conv(x).view(batch_size, -1, depth * height * width)
|
||||
|
||||
attention = self.softmax(torch.bmm(query, key)) # Batch matrix multiplication
|
||||
out = torch.bmm(value, attention.permute(0, 2, 1))
|
||||
out = out.view(batch_size, channels, depth, height, width)
|
||||
|
||||
return x + out
|
||||
|
||||
|
||||
# Define the residual block
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm3d(channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm3d(channels),
|
||||
)
|
||||
self.attention = SelfAttention(channels)
|
||||
|
||||
def forward(self, x):
|
||||
return self.attention(self.conv(x)) + x
|
||||
|
||||
|
||||
# Define the network
|
||||
class AttentionNet(nn.Module):
|
||||
def __init__(self, bin_count):
|
||||
super(AttentionNet, self).__init__()
|
||||
self.input_layer = nn.Sequential(
|
||||
nn.Conv3d(1, 16, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm3d(16),
|
||||
)
|
||||
self.res_blocks = nn.Sequential(ResidualBlock(16), ResidualBlock(16))
|
||||
self.output_layer = nn.Conv3d(16, 1, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.input_layer(x)
|
||||
x = self.res_blocks(x)
|
||||
x = self.output_layer(x)
|
||||
return x
|
||||
Loading…
Add table
Add a link
Reference in a new issue