From 129a31522885ef0f2b1f03b7de92db5c772d48de Mon Sep 17 00:00:00 2001 From: Andras Schmelczer Date: Sat, 22 Jun 2024 17:47:45 +0100 Subject: [PATCH] Add better model --- src/editor/models/__init__.py | 6 +- src/editor/models/residual3.py | 225 ++++++++++++++++----------------- 2 files changed, 111 insertions(+), 120 deletions(-) diff --git a/src/editor/models/__init__.py b/src/editor/models/__init__.py index f9b026d..de6ff8d 100644 --- a/src/editor/models/__init__.py +++ b/src/editor/models/__init__.py @@ -2,7 +2,6 @@ from typing import Any, Dict, Tuple from .v1 import HistogramRestorationNet as v1 from .simple_cnn import SimpleCNN from .residual import Residual -from .residual2 import Residual2 from .residual3 import Residual3 from .dummy import Dummy import torch @@ -17,13 +16,12 @@ MODELS = { "Dummy": Dummy, "SimpleCNN": SimpleCNN, "Residual": Residual, - "Residual2": Residual2, "Residual3": Residual3, } -def create_model(type: str, bin_count: int): - return MODELS[type](bin_count) +def create_model(type: str, bin_count: int, device: torch.device) -> nn.Module: + return MODELS[type](bin_count).to(device) def save_model(model: nn.Module, hyperparameters: Dict[str, Any], path: Path): diff --git a/src/editor/models/residual3.py b/src/editor/models/residual3.py index f0dd7e1..a492796 100644 --- a/src/editor/models/residual3.py +++ b/src/editor/models/residual3.py @@ -2,145 +2,138 @@ import torch import torch.nn as nn -class DepthwiseSeparableConv3d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, padding): - super(DepthwiseSeparableConv3d, self).__init__() - self.depthwise = nn.Conv3d( - in_channels, - in_channels, - kernel_size=kernel_size, - padding=padding, - groups=in_channels, - ) - self.pointwise = nn.Conv3d(in_channels, out_channels, kernel_size=1) - - def forward(self, x): - x = self.depthwise(x) - x = self.pointwise(x) - return x - - class Residual3(nn.Module): def __init__( self, elu_alpha: float = 1, - dropout_prob: float = 0.1, - use_depthwise_separable_conv: bool = False, - feature_map_sizes: list[int] = [16, 32, 64], + dropout_prob: float = 0.05, + features: list[int] = [16, 32, 64], kernel_sizes: list[int] = [3, 3, 3], + use_instance_norm: bool = True, + use_elu: bool = True, + leaky_relu_alpha: float = 0.01, + **_ ): super(Residual3, self).__init__() + self._elu_alpha = elu_alpha + self._dropout_prob = dropout_prob + self._features = features + self._kernel_sizes = kernel_sizes + self._use_instance_norm = use_instance_norm + self._use_elu = use_elu + self._leaky_relu_alpha = leaky_relu_alpha - conv = DepthwiseSeparableConv3d if use_depthwise_separable_conv else nn.Conv3d + self.conv1 = self._make_conv_layer(1, features[0], kernel_sizes[0]) + self.res1 = self._make_resblock(features[0], kernel_sizes[0]) + self.conv2 = self._make_conv_layer(features[0], features[1], kernel_sizes[1]) + self.res2 = self._make_resblock(features[1], kernel_sizes[1]) + self.conv3 = self._make_conv_layer(features[1], features[2], kernel_sizes[2]) + self.res3 = self._make_resblock(features[2], kernel_sizes[2]) - # Assuming the input histograms are 3D tensors of shape (bin_count, bin_count, bin_count) - # Convolutional layers to extract features from the histograms - self.conv1 = conv( - 1, feature_map_sizes[0], kernel_size=kernel_sizes[0], padding=1 + self.deconv1 = self._make_deconv_layer( + features[2], features[1], kernel_sizes[2] ) - self.conv2 = conv( - feature_map_sizes[0], - feature_map_sizes[1], - kernel_size=kernel_sizes[1], - padding=1, - ) - self.conv3 = conv( - feature_map_sizes[1], - feature_map_sizes[2], - kernel_size=kernel_sizes[2], - padding=1, + self.deconv2 = self._make_deconv_layer( + features[1], features[0], kernel_sizes[1] ) + self.deconv3 = self._make_deconv_layer(features[0], 1, kernel_sizes[0]) - self.activation = nn.ELU(elu_alpha, inplace=True) - - self.bn1 = nn.BatchNorm3d(feature_map_sizes[0]) - self.bn2 = nn.BatchNorm3d(feature_map_sizes[1]) - self.bn3 = nn.BatchNorm3d(feature_map_sizes[2]) - - self.resblock1 = nn.Sequential( - conv( - feature_map_sizes[2], - feature_map_sizes[2], - kernel_size=3, - stride=1, - padding=1, - bias=False, - ), - nn.ELU(elu_alpha, inplace=True), - nn.BatchNorm3d(feature_map_sizes[2]), - conv( - feature_map_sizes[2], - feature_map_sizes[2], - kernel_size=3, - stride=1, - padding=1, - bias=False, - ), - nn.ELU(elu_alpha, inplace=True), - nn.BatchNorm3d(feature_map_sizes[2]), - ) - - # Deconvolutional layers - self.deconv1 = nn.ConvTranspose3d( - feature_map_sizes[2], - feature_map_sizes[1], - kernel_size=feature_map_sizes[2], - stride=1, - padding=1, - ) - self.deconv2 = nn.ConvTranspose3d( - feature_map_sizes[1], - feature_map_sizes[0], - kernel_size=feature_map_sizes[1], - stride=1, - padding=1, - ) - self.deconv3 = nn.ConvTranspose3d( - feature_map_sizes[0], 1, kernel_size=3, stride=1, padding=1 - ) - - self.dropout = nn.Dropout3d(p=dropout_prob) self._initialize_weights() + def _make_conv_layer( + self, in_channels: int, out_channels: int, kernel_size: int + ) -> nn.Sequential: + return nn.Sequential( + nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=1, + bias=False, + ), + ( + nn.ELU(self._elu_alpha) + if self._use_elu + else nn.LeakyReLU(self._leaky_relu_alpha) + ), + (nn.InstanceNorm3d if self._use_instance_norm else nn.BatchNorm3d)( + out_channels + ), + nn.Dropout(p=self._dropout_prob), + ) + + def _make_resblock(self, channels: int, kernel_size: int) -> nn.Sequential: + return nn.Sequential( + nn.Conv3d( + in_channels=channels, + out_channels=channels, + kernel_size=kernel_size, + padding=1, + bias=False, + ), + ( + nn.ELU(self._elu_alpha) + if self._use_elu + else nn.LeakyReLU(self._leaky_relu_alpha)( + nn.InstanceNorm3d if self._use_instance_norm else nn.BatchNorm3d + )(channels) + ), + nn.Conv3d( + in_channels=channels, + out_channels=channels, + kernel_size=kernel_size, + padding=1, + bias=False, + ), + ( + nn.ELU(self._elu_alpha) + if self._use_elu + else nn.LeakyReLU(self._leaky_relu_alpha) + ), + (nn.InstanceNorm3d if self._use_instance_norm else nn.BatchNorm3d)( + channels + ), + ) + + def _make_deconv_layer( + self, in_channels: int, out_channels: int, kernel_size: int + ) -> nn.Sequential: + return nn.Sequential( + nn.ConvTranspose3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=1, + ), + ( + nn.ELU(self._elu_alpha) + if self._use_elu + else nn.LeakyReLU(self._leaky_relu_alpha) + ), + ) + def forward(self, x): - out = self.dropout(self.bn1(self.activation(self.conv1(x)))) - out = self.dropout(self.bn2(self.activation(self.conv2(out)))) - out = self.dropout(self.bn3(self.activation(self.conv2(out)))) + out = self.conv1(x) + out = out + self.res1(out) + out = self.conv2(out) + out = out + self.res2(out) + out = self.conv3(out) + out = out + self.res3(out) - out = out + self.resblock1(out) - - out = self.activation(self.deconv1(out)) - out = self.activation(self.deconv2(out)) - out = self.activation(self.deconv3(out)) + out = self.deconv1(out) + out = self.deconv2(out) + out = self.deconv3(out) return self._normalize(out) def _normalize(self, x): x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True) - return x / torch.where(x_sum == 0, torch.ones_like(x_sum), x_sum) + return x / (x_sum + 1e-6) def _initialize_weights(self): for m in self.modules(): - if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d): - nn.init.xavier_normal_( - m.weight, - ) + if isinstance(m, (nn.Conv3d, nn.ConvTranspose3d)): + # Applying He normal initialization + nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) - - -def _test_network_dimensions(constructor): - for bin_count in [16, 32, 64]: - model = constructor() - - # Create a dummy input tensor of the correct shape - input_tensor = torch.rand(4, 1, bin_count, bin_count, bin_count) - - # Test the model output - output = model(input_tensor) - assert ( - input_tensor.shape == output.shape - ), f"Expected output shape {input_tensor.shape}, but got {output.shape}" - - -_test_network_dimensions(Residual3)