Update networks
This commit is contained in:
parent
0e4fe7ab63
commit
b7755bfea8
9 changed files with 173 additions and 144 deletions
|
|
@ -2,13 +2,9 @@ from typing import Any, Dict, Tuple
|
|||
from .v1 import HistogramRestorationNet as v1
|
||||
from .simple_cnn import SimpleCNN
|
||||
from .residual import Residual
|
||||
from .normalised_cnn import NormalisedCNN
|
||||
from .smart_res import SmartRes
|
||||
from .attention_net import AttentionNet
|
||||
from .res2 import Res2
|
||||
from .attention2 import EnhancedAestheticHistogramNet
|
||||
from .attention import PhotoEnhanceNetAdvanced
|
||||
from .advanced_attention import PhotoEnhanceNetAdvanced as advanced_attention
|
||||
from .residual2 import Residual2
|
||||
from .residual3 import Residual3
|
||||
from .dummy import Dummy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
|
|
@ -17,16 +13,12 @@ import json
|
|||
|
||||
|
||||
MODELS = {
|
||||
# "v1": v1,
|
||||
# "SimpleCNN": SimpleCNN,
|
||||
"v1": v1,
|
||||
"Dummy": Dummy,
|
||||
"SimpleCNN": SimpleCNN,
|
||||
"Residual": Residual,
|
||||
# "NormalisedCNN": NormalisedCNN,
|
||||
# "SmartRes": SmartRes,
|
||||
# "AttentionNet": AttentionNet,
|
||||
# "attention2": EnhancedAestheticHistogramNet,
|
||||
# "advanced_attention": advanced_attention,
|
||||
# "Res2": Res2,
|
||||
# "attention1": PhotoEnhanceNetAdvanced,
|
||||
# "Residual2": Residual2,
|
||||
"Residual3": Residual3,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -39,6 +31,7 @@ def save_model(model: nn.Module, hyperparameters: Dict[str, Any], path: Path):
|
|||
params_path = path.with_suffix(".json")
|
||||
|
||||
logging.info(f"Saving model to {model_path}")
|
||||
logging.info(f"Parameter count: {sum(p.numel() for p in model.parameters())}")
|
||||
with open(model_path, "wb") as f:
|
||||
torch.save(model.state_dict(), f)
|
||||
with open(params_path, "w") as f:
|
||||
|
|
@ -60,11 +53,12 @@ def load_model(path: Path, device: torch.device) -> Tuple[nn.Module, Dict[str, A
|
|||
).to(device)
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
model.eval()
|
||||
logging.info(f"Parameter count: {sum(p.numel() for p in model.parameters())}")
|
||||
|
||||
return model, hyperparameters
|
||||
|
||||
|
||||
def _test_models():
|
||||
def test_models():
|
||||
for model_name, model_constructor in MODELS.items():
|
||||
logging.info(f"Testing model {model_name}")
|
||||
_test_network_dimensions(model_constructor)
|
||||
|
|
@ -72,18 +66,13 @@ def _test_models():
|
|||
|
||||
def _test_network_dimensions(constructor):
|
||||
for bin_count in [16, 32, 64]:
|
||||
model = constructor(bin_count=bin_count)
|
||||
model = constructor()
|
||||
|
||||
# Create a dummy input tensor of the correct shape
|
||||
# Create a dummy input tensor of the correct shape, the mini-batch size is 4
|
||||
input_tensor = torch.rand(4, 1, bin_count, bin_count, bin_count)
|
||||
|
||||
# Test the model output
|
||||
output = model(input_tensor)
|
||||
assert (
|
||||
input_tensor.shape == output.shape
|
||||
), f"Expected output shape {input_tensor.shape}, but got {output.shape}"
|
||||
print("Test passed! Output shape matches input shape.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_models()
|
||||
logging.info("Test passed! Output shape matches input shape.")
|
||||
|
|
|
|||
10
src/editor/models/dummy.py
Normal file
10
src/editor/models/dummy.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Dummy(nn.Module):
|
||||
def __init__(self, **_):
|
||||
super(Dummy, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class NormalisedCNN(nn.Module):
|
||||
def __init__(self, bin_count):
|
||||
super(NormalisedCNN, self).__init__()
|
||||
self.bin_count = bin_count
|
||||
|
||||
# Define the layers of the neural network
|
||||
self.conv1 = nn.Conv3d(1, 16, kernel_size=3, padding=1)
|
||||
self.bn1 = nn.BatchNorm3d(16)
|
||||
self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
|
||||
self.bn2 = nn.BatchNorm3d(32)
|
||||
self.conv3 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
|
||||
self.bn3 = nn.BatchNorm3d(64)
|
||||
self.conv4 = nn.Conv3d(64, 32, kernel_size=3, padding=1)
|
||||
self.bn4 = nn.BatchNorm3d(32)
|
||||
self.conv5 = nn.Conv3d(32, 16, kernel_size=3, padding=1)
|
||||
self.bn5 = nn.BatchNorm3d(16)
|
||||
self.conv6 = nn.Conv3d(16, 1, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(
|
||||
-1, 1, self.bin_count, self.bin_count, self.bin_count
|
||||
) # Reshape input to (N, C, D, H, W)
|
||||
x = F.relu(self.bn1(self.conv1(x)))
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = F.relu(self.bn3(self.conv3(x)))
|
||||
x = F.relu(self.bn4(self.conv4(x)))
|
||||
x = F.relu(self.bn5(self.conv5(x)))
|
||||
x = self.conv6(x)
|
||||
|
||||
return x
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm3d(channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm3d(channels),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x) + x
|
||||
|
||||
|
||||
# Define the network
|
||||
class Res2(nn.Module):
|
||||
def __init__(self, bin_count):
|
||||
super(Res2, self).__init__()
|
||||
self.input_layer = nn.Sequential(
|
||||
nn.Conv3d(1, 16, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm3d(16),
|
||||
)
|
||||
self.res_blocks = nn.Sequential(
|
||||
ResidualBlock(16), ResidualBlock(16), ResidualBlock(16), ResidualBlock(16)
|
||||
)
|
||||
self.output_layer = nn.Conv3d(16, 1, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.input_layer(x)
|
||||
x = self.res_blocks(x)
|
||||
x = self.output_layer(x)
|
||||
return x
|
||||
|
|
@ -2,9 +2,8 @@ import torch.nn as nn
|
|||
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, bin_count: int):
|
||||
def __init__(self, **_):
|
||||
super(Residual, self).__init__()
|
||||
self.bin_count = bin_count
|
||||
|
||||
# Assuming the input histograms are 3D tensors of shape (bin_count, bin_count, bin_count)
|
||||
# Convolutional layers to extract features from the histograms
|
||||
|
|
|
|||
146
src/editor/models/residual3.py
Normal file
146
src/editor/models/residual3.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class DepthwiseSeparableConv3d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, padding):
|
||||
super(DepthwiseSeparableConv3d, self).__init__()
|
||||
self.depthwise = nn.Conv3d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
groups=in_channels,
|
||||
)
|
||||
self.pointwise = nn.Conv3d(in_channels, out_channels, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.depthwise(x)
|
||||
x = self.pointwise(x)
|
||||
return x
|
||||
|
||||
|
||||
class Residual3(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
elu_alpha: float = 1,
|
||||
dropout_prob: float = 0.1,
|
||||
use_depthwise_separable_conv: bool = False,
|
||||
feature_map_sizes: list[int] = [16, 32, 64],
|
||||
kernel_sizes: list[int] = [3, 3, 3],
|
||||
):
|
||||
super(Residual3, self).__init__()
|
||||
|
||||
conv = DepthwiseSeparableConv3d if use_depthwise_separable_conv else nn.Conv3d
|
||||
|
||||
# Assuming the input histograms are 3D tensors of shape (bin_count, bin_count, bin_count)
|
||||
# Convolutional layers to extract features from the histograms
|
||||
self.conv1 = conv(
|
||||
1, feature_map_sizes[0], kernel_size=kernel_sizes[0], padding=1
|
||||
)
|
||||
self.conv2 = conv(
|
||||
feature_map_sizes[0],
|
||||
feature_map_sizes[1],
|
||||
kernel_size=kernel_sizes[1],
|
||||
padding=1,
|
||||
)
|
||||
self.conv3 = conv(
|
||||
feature_map_sizes[1],
|
||||
feature_map_sizes[2],
|
||||
kernel_size=kernel_sizes[2],
|
||||
padding=1,
|
||||
)
|
||||
|
||||
self.activation = nn.ELU(elu_alpha, inplace=True)
|
||||
|
||||
self.bn1 = nn.BatchNorm3d(feature_map_sizes[0])
|
||||
self.bn2 = nn.BatchNorm3d(feature_map_sizes[1])
|
||||
self.bn3 = nn.BatchNorm3d(feature_map_sizes[2])
|
||||
|
||||
self.resblock1 = nn.Sequential(
|
||||
conv(
|
||||
feature_map_sizes[2],
|
||||
feature_map_sizes[2],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
nn.ELU(elu_alpha, inplace=True),
|
||||
nn.BatchNorm3d(feature_map_sizes[2]),
|
||||
conv(
|
||||
feature_map_sizes[2],
|
||||
feature_map_sizes[2],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
nn.ELU(elu_alpha, inplace=True),
|
||||
nn.BatchNorm3d(feature_map_sizes[2]),
|
||||
)
|
||||
|
||||
# Deconvolutional layers
|
||||
self.deconv1 = nn.ConvTranspose3d(
|
||||
feature_map_sizes[2],
|
||||
feature_map_sizes[1],
|
||||
kernel_size=feature_map_sizes[2],
|
||||
stride=1,
|
||||
padding=1,
|
||||
)
|
||||
self.deconv2 = nn.ConvTranspose3d(
|
||||
feature_map_sizes[1],
|
||||
feature_map_sizes[0],
|
||||
kernel_size=feature_map_sizes[1],
|
||||
stride=1,
|
||||
padding=1,
|
||||
)
|
||||
self.deconv3 = nn.ConvTranspose3d(
|
||||
feature_map_sizes[0], 1, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout3d(p=dropout_prob)
|
||||
self._initialize_weights()
|
||||
|
||||
def forward(self, x):
|
||||
out = self.dropout(self.bn1(self.activation(self.conv1(x))))
|
||||
out = self.dropout(self.bn2(self.activation(self.conv2(out))))
|
||||
out = self.dropout(self.bn3(self.activation(self.conv2(out))))
|
||||
|
||||
out = out + self.resblock1(out)
|
||||
|
||||
out = self.activation(self.deconv1(out))
|
||||
out = self.activation(self.deconv2(out))
|
||||
out = self.activation(self.deconv3(out))
|
||||
|
||||
return self._normalize(out)
|
||||
|
||||
def _normalize(self, x):
|
||||
x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
|
||||
return x / torch.where(x_sum == 0, torch.ones_like(x_sum), x_sum)
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
|
||||
nn.init.xavier_normal_(
|
||||
m.weight,
|
||||
)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
def _test_network_dimensions(constructor):
|
||||
for bin_count in [16, 32, 64]:
|
||||
model = constructor()
|
||||
|
||||
# Create a dummy input tensor of the correct shape
|
||||
input_tensor = torch.rand(4, 1, bin_count, bin_count, bin_count)
|
||||
|
||||
# Test the model output
|
||||
output = model(input_tensor)
|
||||
assert (
|
||||
input_tensor.shape == output.shape
|
||||
), f"Expected output shape {input_tensor.shape}, but got {output.shape}"
|
||||
|
||||
|
||||
_test_network_dimensions(Residual3)
|
||||
|
|
@ -3,10 +3,8 @@ import torch.nn.functional as F
|
|||
|
||||
|
||||
class SimpleCNN(nn.Module):
|
||||
def __init__(self, bin_count):
|
||||
def __init__(self, **_):
|
||||
super(SimpleCNN, self).__init__()
|
||||
self.bin_count = bin_count
|
||||
|
||||
# Define the convolutional layers
|
||||
self.conv1 = nn.Conv3d(
|
||||
1, 16, kernel_size=3, padding=1
|
||||
|
|
|
|||
|
|
@ -1,42 +0,0 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.conv1 = nn.Conv3d(channels, channels, kernel_size=3, padding=1)
|
||||
self.bn1 = nn.BatchNorm3d(channels)
|
||||
self.conv2 = nn.Conv3d(channels, channels, kernel_size=3, padding=1)
|
||||
self.bn2 = nn.BatchNorm3d(channels)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
out += identity
|
||||
return F.relu(out)
|
||||
|
||||
|
||||
class SmartRes(nn.Module):
|
||||
def __init__(self, bin_count):
|
||||
super(SmartRes, self).__init__()
|
||||
self.bin_count = bin_count
|
||||
self.initial_conv = nn.Conv3d(1, 16, kernel_size=3, padding=1)
|
||||
self.bn0 = nn.BatchNorm3d(16)
|
||||
self.resblock1 = ResidualBlock(16)
|
||||
self.resblock2 = ResidualBlock(16)
|
||||
self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
|
||||
self.bn2 = nn.BatchNorm3d(32)
|
||||
self.dilated_conv = nn.Conv3d(32, 32, kernel_size=3, padding=2, dilation=2)
|
||||
self.bn_dilated = nn.BatchNorm3d(32)
|
||||
self.final_conv = nn.Conv3d(32, 1, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.bn0(self.initial_conv(x)))
|
||||
x = self.resblock1(x)
|
||||
x = self.resblock2(x)
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = F.relu(self.bn_dilated(self.dilated_conv(x)))
|
||||
x = self.final_conv(x)
|
||||
return x
|
||||
|
|
@ -4,7 +4,7 @@ import torch
|
|||
|
||||
|
||||
class HistogramRestorationNet(nn.Module):
|
||||
def __init__(self, bin_count: int):
|
||||
def __init__(self, **_):
|
||||
super(HistogramRestorationNet, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv3d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue