Add models
This commit is contained in:
parent
09aceae9d4
commit
bd7033c3eb
9 changed files with 490 additions and 76 deletions
|
|
@ -1 +1,19 @@
|
|||
from .create_model import create_model
|
||||
from .v1 import HistogramRestorationNet as v1
|
||||
from .simple_cnn import SimpleCNN
|
||||
from .residual import Residual
|
||||
from .normalised_cnn import NormalisedCNN
|
||||
from .smart_res import SmartRes
|
||||
from .attention_net import AttentionNet
|
||||
from .res2 import Res2
|
||||
|
||||
|
||||
def create_model(type: str, bin_count: int):
|
||||
return {
|
||||
# "v1": v1,
|
||||
"SimpleCNN": SimpleCNN,
|
||||
"Residual": Residual,
|
||||
"NormalisedCNN": NormalisedCNN,
|
||||
"SmartRes": SmartRes,
|
||||
"AttentionNet": AttentionNet,
|
||||
"Res2": Res2,
|
||||
}[type](bin_count)
|
||||
|
|
|
|||
64
src/editor/models/attention_net.py
Normal file
64
src/editor/models/attention_net.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# Define the self-attention module
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super(SelfAttention, self).__init__()
|
||||
self.query_conv = nn.Conv3d(channels, channels // 8, kernel_size=1)
|
||||
self.key_conv = nn.Conv3d(channels, channels // 8, kernel_size=1)
|
||||
self.value_conv = nn.Conv3d(channels, channels, kernel_size=1)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, channels, depth, height, width = x.size()
|
||||
query = (
|
||||
self.query_conv(x)
|
||||
.view(batch_size, -1, depth * height * width)
|
||||
.permute(0, 2, 1)
|
||||
)
|
||||
key = self.key_conv(x).view(batch_size, -1, depth * height * width)
|
||||
value = self.value_conv(x).view(batch_size, -1, depth * height * width)
|
||||
|
||||
attention = self.softmax(torch.bmm(query, key)) # Batch matrix multiplication
|
||||
out = torch.bmm(value, attention.permute(0, 2, 1))
|
||||
out = out.view(batch_size, channels, depth, height, width)
|
||||
|
||||
return x + out
|
||||
|
||||
|
||||
# Define the residual block
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm3d(channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm3d(channels),
|
||||
)
|
||||
self.attention = SelfAttention(channels)
|
||||
|
||||
def forward(self, x):
|
||||
return self.attention(self.conv(x)) + x
|
||||
|
||||
|
||||
# Define the network
|
||||
class AttentionNet(nn.Module):
|
||||
def __init__(self, bin_count):
|
||||
super(AttentionNet, self).__init__()
|
||||
self.input_layer = nn.Sequential(
|
||||
nn.Conv3d(1, 16, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm3d(16),
|
||||
)
|
||||
self.res_blocks = nn.Sequential(ResidualBlock(16), ResidualBlock(16))
|
||||
self.output_layer = nn.Conv3d(16, 1, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.input_layer(x)
|
||||
x = self.res_blocks(x)
|
||||
x = self.output_layer(x)
|
||||
return x
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
from .v1 import HistogramRestorationNet as v1
|
||||
|
||||
|
||||
def create_model(type: str, bin_count: int):
|
||||
return {"v1": v1}[type](bin_count)
|
||||
34
src/editor/models/normalised_cnn.py
Normal file
34
src/editor/models/normalised_cnn.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class NormalisedCNN(nn.Module):
|
||||
def __init__(self, bin_count):
|
||||
super(NormalisedCNN, self).__init__()
|
||||
self.bin_count = bin_count
|
||||
|
||||
# Define the layers of the neural network
|
||||
self.conv1 = nn.Conv3d(1, 16, kernel_size=3, padding=1)
|
||||
self.bn1 = nn.BatchNorm3d(16)
|
||||
self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
|
||||
self.bn2 = nn.BatchNorm3d(32)
|
||||
self.conv3 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
|
||||
self.bn3 = nn.BatchNorm3d(64)
|
||||
self.conv4 = nn.Conv3d(64, 32, kernel_size=3, padding=1)
|
||||
self.bn4 = nn.BatchNorm3d(32)
|
||||
self.conv5 = nn.Conv3d(32, 16, kernel_size=3, padding=1)
|
||||
self.bn5 = nn.BatchNorm3d(16)
|
||||
self.conv6 = nn.Conv3d(16, 1, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(
|
||||
-1, 1, self.bin_count, self.bin_count, self.bin_count
|
||||
) # Reshape input to (N, C, D, H, W)
|
||||
x = F.relu(self.bn1(self.conv1(x)))
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = F.relu(self.bn3(self.conv3(x)))
|
||||
x = F.relu(self.bn4(self.conv4(x)))
|
||||
x = F.relu(self.bn5(self.conv5(x)))
|
||||
x = self.conv6(x)
|
||||
|
||||
return x
|
||||
37
src/editor/models/res2.py
Normal file
37
src/editor/models/res2.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm3d(channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm3d(channels),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x) + x
|
||||
|
||||
|
||||
# Define the network
|
||||
class Res2(nn.Module):
|
||||
def __init__(self, bin_count):
|
||||
super(Res2, self).__init__()
|
||||
self.input_layer = nn.Sequential(
|
||||
nn.Conv3d(1, 16, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm3d(16),
|
||||
)
|
||||
self.res_blocks = nn.Sequential(
|
||||
ResidualBlock(16), ResidualBlock(16), ResidualBlock(16), ResidualBlock(16)
|
||||
)
|
||||
self.output_layer = nn.Conv3d(16, 1, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.input_layer(x)
|
||||
x = self.res_blocks(x)
|
||||
x = self.output_layer(x)
|
||||
return x
|
||||
52
src/editor/models/residual.py
Normal file
52
src/editor/models/residual.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, bin_count: int):
|
||||
super(Residual, self).__init__()
|
||||
self.bin_count = bin_count
|
||||
|
||||
# Assuming the input histograms are 3D tensors of shape (bin_count, bin_count, bin_count)
|
||||
# Convolutional layers to extract features from the histograms
|
||||
self.conv1 = nn.Conv3d(1, 16, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
|
||||
self.conv3 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
|
||||
|
||||
# Batch normalization layers for better convergence
|
||||
self.bn1 = nn.BatchNorm3d(16)
|
||||
self.bn2 = nn.BatchNorm3d(32)
|
||||
self.bn3 = nn.BatchNorm3d(64)
|
||||
|
||||
# Residual block to help the network learn identity functions effectively
|
||||
self.resblock1 = nn.Sequential(
|
||||
nn.Conv3d(64, 64, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm3d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv3d(64, 64, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm3d(64),
|
||||
)
|
||||
|
||||
# ReLU activation
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.deconv1 = nn.ConvTranspose3d(64, 32, kernel_size=3, stride=1, padding=1)
|
||||
self.deconv2 = nn.ConvTranspose3d(32, 16, kernel_size=3, stride=1, padding=1)
|
||||
self.deconv3 = nn.ConvTranspose3d(16, 1, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(self.bn1(self.conv1(x)))
|
||||
x = self.relu(self.bn2(self.conv2(x)))
|
||||
x = self.relu(self.bn3(self.conv3(x)))
|
||||
|
||||
# Apply residual blocks
|
||||
residual = x
|
||||
out = self.resblock1(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
# Upsample to original size
|
||||
out = self.relu(self.deconv1(out))
|
||||
out = self.relu(self.deconv2(out))
|
||||
out = self.relu(self.deconv3(out))
|
||||
|
||||
return out
|
||||
37
src/editor/models/simple_cnn.py
Normal file
37
src/editor/models/simple_cnn.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SimpleCNN(nn.Module):
|
||||
def __init__(self, bin_count):
|
||||
super(SimpleCNN, self).__init__()
|
||||
self.bin_count = bin_count
|
||||
|
||||
# Define the convolutional layers
|
||||
self.conv1 = nn.Conv3d(
|
||||
1, 16, kernel_size=3, padding=1
|
||||
) # input channels = 1, output channels = 16
|
||||
self.conv2 = nn.Conv3d(
|
||||
16, 32, kernel_size=3, padding=1
|
||||
) # input channels = 16, output channels = 32
|
||||
self.conv3 = nn.Conv3d(
|
||||
32, 64, kernel_size=3, padding=1
|
||||
) # input channels = 32, output channels = 64
|
||||
self.conv4 = nn.Conv3d(
|
||||
64, 32, kernel_size=3, padding=1
|
||||
) # input channels = 64, output channels = 32
|
||||
self.conv5 = nn.Conv3d(
|
||||
32, 16, kernel_size=3, padding=1
|
||||
) # input channels = 32, output channels = 16
|
||||
self.conv6 = nn.Conv3d(
|
||||
16, 1, kernel_size=3, padding=1
|
||||
) # input channels = 16, output channels = 1
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.conv1(x))
|
||||
x = F.relu(self.conv2(x))
|
||||
x = F.relu(self.conv3(x))
|
||||
x = F.relu(self.conv4(x))
|
||||
x = F.relu(self.conv5(x))
|
||||
x = self.conv6(x)
|
||||
return x
|
||||
42
src/editor/models/smart_res.py
Normal file
42
src/editor/models/smart_res.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.conv1 = nn.Conv3d(channels, channels, kernel_size=3, padding=1)
|
||||
self.bn1 = nn.BatchNorm3d(channels)
|
||||
self.conv2 = nn.Conv3d(channels, channels, kernel_size=3, padding=1)
|
||||
self.bn2 = nn.BatchNorm3d(channels)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
out += identity
|
||||
return F.relu(out)
|
||||
|
||||
|
||||
class SmartRes(nn.Module):
|
||||
def __init__(self, bin_count):
|
||||
super(SmartRes, self).__init__()
|
||||
self.bin_count = bin_count
|
||||
self.initial_conv = nn.Conv3d(1, 16, kernel_size=3, padding=1)
|
||||
self.bn0 = nn.BatchNorm3d(16)
|
||||
self.resblock1 = ResidualBlock(16)
|
||||
self.resblock2 = ResidualBlock(16)
|
||||
self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
|
||||
self.bn2 = nn.BatchNorm3d(32)
|
||||
self.dilated_conv = nn.Conv3d(32, 32, kernel_size=3, padding=2, dilation=2)
|
||||
self.bn_dilated = nn.BatchNorm3d(32)
|
||||
self.final_conv = nn.Conv3d(32, 1, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.bn0(self.initial_conv(x)))
|
||||
x = self.resblock1(x)
|
||||
x = self.resblock2(x)
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = F.relu(self.bn_dilated(self.dilated_conv(x)))
|
||||
x = self.final_conv(x)
|
||||
return x
|
||||
Loading…
Add table
Add a link
Reference in a new issue