From b7755bfea8b1a5d0110d344fbaa94b755957e776 Mon Sep 17 00:00:00 2001 From: Andras Schmelczer Date: Sat, 22 Jun 2024 15:01:02 +0100 Subject: [PATCH] Update networks --- src/editor/models/__init__.py | 39 +++----- src/editor/models/dummy.py | 10 ++ src/editor/models/normalised_cnn.py | 34 ------- src/editor/models/res2.py | 37 ------- src/editor/models/residual.py | 3 +- src/editor/models/residual3.py | 146 ++++++++++++++++++++++++++++ src/editor/models/simple_cnn.py | 4 +- src/editor/models/smart_res.py | 42 -------- src/editor/models/v1.py | 2 +- 9 files changed, 173 insertions(+), 144 deletions(-) create mode 100644 src/editor/models/dummy.py delete mode 100644 src/editor/models/normalised_cnn.py delete mode 100644 src/editor/models/res2.py create mode 100644 src/editor/models/residual3.py delete mode 100644 src/editor/models/smart_res.py diff --git a/src/editor/models/__init__.py b/src/editor/models/__init__.py index 867f5b2..d4252d8 100644 --- a/src/editor/models/__init__.py +++ b/src/editor/models/__init__.py @@ -2,13 +2,9 @@ from typing import Any, Dict, Tuple from .v1 import HistogramRestorationNet as v1 from .simple_cnn import SimpleCNN from .residual import Residual -from .normalised_cnn import NormalisedCNN -from .smart_res import SmartRes -from .attention_net import AttentionNet -from .res2 import Res2 -from .attention2 import EnhancedAestheticHistogramNet -from .attention import PhotoEnhanceNetAdvanced -from .advanced_attention import PhotoEnhanceNetAdvanced as advanced_attention +from .residual2 import Residual2 +from .residual3 import Residual3 +from .dummy import Dummy import torch import torch.nn as nn from pathlib import Path @@ -17,16 +13,12 @@ import json MODELS = { - # "v1": v1, - # "SimpleCNN": SimpleCNN, + "v1": v1, + "Dummy": Dummy, + "SimpleCNN": SimpleCNN, "Residual": Residual, - # "NormalisedCNN": NormalisedCNN, - # "SmartRes": SmartRes, - # "AttentionNet": AttentionNet, - # "attention2": EnhancedAestheticHistogramNet, - # "advanced_attention": advanced_attention, - # "Res2": Res2, - # "attention1": PhotoEnhanceNetAdvanced, + # "Residual2": Residual2, + "Residual3": Residual3, } @@ -39,6 +31,7 @@ def save_model(model: nn.Module, hyperparameters: Dict[str, Any], path: Path): params_path = path.with_suffix(".json") logging.info(f"Saving model to {model_path}") + logging.info(f"Parameter count: {sum(p.numel() for p in model.parameters())}") with open(model_path, "wb") as f: torch.save(model.state_dict(), f) with open(params_path, "w") as f: @@ -60,11 +53,12 @@ def load_model(path: Path, device: torch.device) -> Tuple[nn.Module, Dict[str, A ).to(device) model.load_state_dict(torch.load(model_path)) model.eval() + logging.info(f"Parameter count: {sum(p.numel() for p in model.parameters())}") return model, hyperparameters -def _test_models(): +def test_models(): for model_name, model_constructor in MODELS.items(): logging.info(f"Testing model {model_name}") _test_network_dimensions(model_constructor) @@ -72,18 +66,13 @@ def _test_models(): def _test_network_dimensions(constructor): for bin_count in [16, 32, 64]: - model = constructor(bin_count=bin_count) + model = constructor() - # Create a dummy input tensor of the correct shape + # Create a dummy input tensor of the correct shape, the mini-batch size is 4 input_tensor = torch.rand(4, 1, bin_count, bin_count, bin_count) - # Test the model output output = model(input_tensor) assert ( input_tensor.shape == output.shape ), f"Expected output shape {input_tensor.shape}, but got {output.shape}" - print("Test passed! Output shape matches input shape.") - - -if __name__ == "__main__": - _test_models() + logging.info("Test passed! Output shape matches input shape.") diff --git a/src/editor/models/dummy.py b/src/editor/models/dummy.py new file mode 100644 index 0000000..29c1c97 --- /dev/null +++ b/src/editor/models/dummy.py @@ -0,0 +1,10 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class Dummy(nn.Module): + def __init__(self, **_): + super(Dummy, self).__init__() + + def forward(self, x): + return x diff --git a/src/editor/models/normalised_cnn.py b/src/editor/models/normalised_cnn.py deleted file mode 100644 index e38b5b9..0000000 --- a/src/editor/models/normalised_cnn.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch.nn as nn -import torch.nn.functional as F - - -class NormalisedCNN(nn.Module): - def __init__(self, bin_count): - super(NormalisedCNN, self).__init__() - self.bin_count = bin_count - - # Define the layers of the neural network - self.conv1 = nn.Conv3d(1, 16, kernel_size=3, padding=1) - self.bn1 = nn.BatchNorm3d(16) - self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1) - self.bn2 = nn.BatchNorm3d(32) - self.conv3 = nn.Conv3d(32, 64, kernel_size=3, padding=1) - self.bn3 = nn.BatchNorm3d(64) - self.conv4 = nn.Conv3d(64, 32, kernel_size=3, padding=1) - self.bn4 = nn.BatchNorm3d(32) - self.conv5 = nn.Conv3d(32, 16, kernel_size=3, padding=1) - self.bn5 = nn.BatchNorm3d(16) - self.conv6 = nn.Conv3d(16, 1, kernel_size=3, padding=1) - - def forward(self, x): - x = x.view( - -1, 1, self.bin_count, self.bin_count, self.bin_count - ) # Reshape input to (N, C, D, H, W) - x = F.relu(self.bn1(self.conv1(x))) - x = F.relu(self.bn2(self.conv2(x))) - x = F.relu(self.bn3(self.conv3(x))) - x = F.relu(self.bn4(self.conv4(x))) - x = F.relu(self.bn5(self.conv5(x))) - x = self.conv6(x) - - return x diff --git a/src/editor/models/res2.py b/src/editor/models/res2.py deleted file mode 100644 index 6d19bc6..0000000 --- a/src/editor/models/res2.py +++ /dev/null @@ -1,37 +0,0 @@ -import torch.nn as nn - - -class ResidualBlock(nn.Module): - def __init__(self, channels): - super(ResidualBlock, self).__init__() - self.conv = nn.Sequential( - nn.Conv3d(channels, channels, kernel_size=3, padding=1), - nn.BatchNorm3d(channels), - nn.ReLU(inplace=True), - nn.Conv3d(channels, channels, kernel_size=3, padding=1), - nn.BatchNorm3d(channels), - ) - - def forward(self, x): - return self.conv(x) + x - - -# Define the network -class Res2(nn.Module): - def __init__(self, bin_count): - super(Res2, self).__init__() - self.input_layer = nn.Sequential( - nn.Conv3d(1, 16, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.BatchNorm3d(16), - ) - self.res_blocks = nn.Sequential( - ResidualBlock(16), ResidualBlock(16), ResidualBlock(16), ResidualBlock(16) - ) - self.output_layer = nn.Conv3d(16, 1, kernel_size=3, padding=1) - - def forward(self, x): - x = self.input_layer(x) - x = self.res_blocks(x) - x = self.output_layer(x) - return x diff --git a/src/editor/models/residual.py b/src/editor/models/residual.py index ca5be9a..c17cd2e 100644 --- a/src/editor/models/residual.py +++ b/src/editor/models/residual.py @@ -2,9 +2,8 @@ import torch.nn as nn class Residual(nn.Module): - def __init__(self, bin_count: int): + def __init__(self, **_): super(Residual, self).__init__() - self.bin_count = bin_count # Assuming the input histograms are 3D tensors of shape (bin_count, bin_count, bin_count) # Convolutional layers to extract features from the histograms diff --git a/src/editor/models/residual3.py b/src/editor/models/residual3.py new file mode 100644 index 0000000..f0dd7e1 --- /dev/null +++ b/src/editor/models/residual3.py @@ -0,0 +1,146 @@ +import torch +import torch.nn as nn + + +class DepthwiseSeparableConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, padding): + super(DepthwiseSeparableConv3d, self).__init__() + self.depthwise = nn.Conv3d( + in_channels, + in_channels, + kernel_size=kernel_size, + padding=padding, + groups=in_channels, + ) + self.pointwise = nn.Conv3d(in_channels, out_channels, kernel_size=1) + + def forward(self, x): + x = self.depthwise(x) + x = self.pointwise(x) + return x + + +class Residual3(nn.Module): + def __init__( + self, + elu_alpha: float = 1, + dropout_prob: float = 0.1, + use_depthwise_separable_conv: bool = False, + feature_map_sizes: list[int] = [16, 32, 64], + kernel_sizes: list[int] = [3, 3, 3], + ): + super(Residual3, self).__init__() + + conv = DepthwiseSeparableConv3d if use_depthwise_separable_conv else nn.Conv3d + + # Assuming the input histograms are 3D tensors of shape (bin_count, bin_count, bin_count) + # Convolutional layers to extract features from the histograms + self.conv1 = conv( + 1, feature_map_sizes[0], kernel_size=kernel_sizes[0], padding=1 + ) + self.conv2 = conv( + feature_map_sizes[0], + feature_map_sizes[1], + kernel_size=kernel_sizes[1], + padding=1, + ) + self.conv3 = conv( + feature_map_sizes[1], + feature_map_sizes[2], + kernel_size=kernel_sizes[2], + padding=1, + ) + + self.activation = nn.ELU(elu_alpha, inplace=True) + + self.bn1 = nn.BatchNorm3d(feature_map_sizes[0]) + self.bn2 = nn.BatchNorm3d(feature_map_sizes[1]) + self.bn3 = nn.BatchNorm3d(feature_map_sizes[2]) + + self.resblock1 = nn.Sequential( + conv( + feature_map_sizes[2], + feature_map_sizes[2], + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + nn.ELU(elu_alpha, inplace=True), + nn.BatchNorm3d(feature_map_sizes[2]), + conv( + feature_map_sizes[2], + feature_map_sizes[2], + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + nn.ELU(elu_alpha, inplace=True), + nn.BatchNorm3d(feature_map_sizes[2]), + ) + + # Deconvolutional layers + self.deconv1 = nn.ConvTranspose3d( + feature_map_sizes[2], + feature_map_sizes[1], + kernel_size=feature_map_sizes[2], + stride=1, + padding=1, + ) + self.deconv2 = nn.ConvTranspose3d( + feature_map_sizes[1], + feature_map_sizes[0], + kernel_size=feature_map_sizes[1], + stride=1, + padding=1, + ) + self.deconv3 = nn.ConvTranspose3d( + feature_map_sizes[0], 1, kernel_size=3, stride=1, padding=1 + ) + + self.dropout = nn.Dropout3d(p=dropout_prob) + self._initialize_weights() + + def forward(self, x): + out = self.dropout(self.bn1(self.activation(self.conv1(x)))) + out = self.dropout(self.bn2(self.activation(self.conv2(out)))) + out = self.dropout(self.bn3(self.activation(self.conv2(out)))) + + out = out + self.resblock1(out) + + out = self.activation(self.deconv1(out)) + out = self.activation(self.deconv2(out)) + out = self.activation(self.deconv3(out)) + + return self._normalize(out) + + def _normalize(self, x): + x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True) + return x / torch.where(x_sum == 0, torch.ones_like(x_sum), x_sum) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d): + nn.init.xavier_normal_( + m.weight, + ) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +def _test_network_dimensions(constructor): + for bin_count in [16, 32, 64]: + model = constructor() + + # Create a dummy input tensor of the correct shape + input_tensor = torch.rand(4, 1, bin_count, bin_count, bin_count) + + # Test the model output + output = model(input_tensor) + assert ( + input_tensor.shape == output.shape + ), f"Expected output shape {input_tensor.shape}, but got {output.shape}" + + +_test_network_dimensions(Residual3) diff --git a/src/editor/models/simple_cnn.py b/src/editor/models/simple_cnn.py index cfd85cb..f5bd7f9 100644 --- a/src/editor/models/simple_cnn.py +++ b/src/editor/models/simple_cnn.py @@ -3,10 +3,8 @@ import torch.nn.functional as F class SimpleCNN(nn.Module): - def __init__(self, bin_count): + def __init__(self, **_): super(SimpleCNN, self).__init__() - self.bin_count = bin_count - # Define the convolutional layers self.conv1 = nn.Conv3d( 1, 16, kernel_size=3, padding=1 diff --git a/src/editor/models/smart_res.py b/src/editor/models/smart_res.py deleted file mode 100644 index 209791d..0000000 --- a/src/editor/models/smart_res.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch.nn as nn -import torch.nn.functional as F - - -class ResidualBlock(nn.Module): - def __init__(self, channels): - super(ResidualBlock, self).__init__() - self.conv1 = nn.Conv3d(channels, channels, kernel_size=3, padding=1) - self.bn1 = nn.BatchNorm3d(channels) - self.conv2 = nn.Conv3d(channels, channels, kernel_size=3, padding=1) - self.bn2 = nn.BatchNorm3d(channels) - - def forward(self, x): - identity = x - out = F.relu(self.bn1(self.conv1(x))) - out = self.bn2(self.conv2(out)) - out += identity - return F.relu(out) - - -class SmartRes(nn.Module): - def __init__(self, bin_count): - super(SmartRes, self).__init__() - self.bin_count = bin_count - self.initial_conv = nn.Conv3d(1, 16, kernel_size=3, padding=1) - self.bn0 = nn.BatchNorm3d(16) - self.resblock1 = ResidualBlock(16) - self.resblock2 = ResidualBlock(16) - self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1) - self.bn2 = nn.BatchNorm3d(32) - self.dilated_conv = nn.Conv3d(32, 32, kernel_size=3, padding=2, dilation=2) - self.bn_dilated = nn.BatchNorm3d(32) - self.final_conv = nn.Conv3d(32, 1, kernel_size=3, padding=1) - - def forward(self, x): - x = F.relu(self.bn0(self.initial_conv(x))) - x = self.resblock1(x) - x = self.resblock2(x) - x = F.relu(self.bn2(self.conv2(x))) - x = F.relu(self.bn_dilated(self.dilated_conv(x))) - x = self.final_conv(x) - return x diff --git a/src/editor/models/v1.py b/src/editor/models/v1.py index b2b1a79..87f0bac 100644 --- a/src/editor/models/v1.py +++ b/src/editor/models/v1.py @@ -4,7 +4,7 @@ import torch class HistogramRestorationNet(nn.Module): - def __init__(self, bin_count: int): + def __init__(self, **_): super(HistogramRestorationNet, self).__init__() self.conv1 = nn.Conv3d(in_channels=1, out_channels=16, kernel_size=3, padding=1)