Add better model

This commit is contained in:
Andras Schmelczer 2024-06-22 17:47:45 +01:00
parent 3583cd559a
commit 129a315228
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
2 changed files with 111 additions and 120 deletions

View file

@ -2,7 +2,6 @@ from typing import Any, Dict, Tuple
from .v1 import HistogramRestorationNet as v1
from .simple_cnn import SimpleCNN
from .residual import Residual
from .residual2 import Residual2
from .residual3 import Residual3
from .dummy import Dummy
import torch
@ -17,13 +16,12 @@ MODELS = {
"Dummy": Dummy,
"SimpleCNN": SimpleCNN,
"Residual": Residual,
"Residual2": Residual2,
"Residual3": Residual3,
}
def create_model(type: str, bin_count: int):
return MODELS[type](bin_count)
def create_model(type: str, bin_count: int, device: torch.device) -> nn.Module:
return MODELS[type](bin_count).to(device)
def save_model(model: nn.Module, hyperparameters: Dict[str, Any], path: Path):

View file

@ -2,145 +2,138 @@ import torch
import torch.nn as nn
class DepthwiseSeparableConv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding):
super(DepthwiseSeparableConv3d, self).__init__()
self.depthwise = nn.Conv3d(
in_channels,
in_channels,
kernel_size=kernel_size,
padding=padding,
groups=in_channels,
)
self.pointwise = nn.Conv3d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
x = self.depthwise(x)
x = self.pointwise(x)
return x
class Residual3(nn.Module):
def __init__(
self,
elu_alpha: float = 1,
dropout_prob: float = 0.1,
use_depthwise_separable_conv: bool = False,
feature_map_sizes: list[int] = [16, 32, 64],
dropout_prob: float = 0.05,
features: list[int] = [16, 32, 64],
kernel_sizes: list[int] = [3, 3, 3],
use_instance_norm: bool = True,
use_elu: bool = True,
leaky_relu_alpha: float = 0.01,
**_
):
super(Residual3, self).__init__()
self._elu_alpha = elu_alpha
self._dropout_prob = dropout_prob
self._features = features
self._kernel_sizes = kernel_sizes
self._use_instance_norm = use_instance_norm
self._use_elu = use_elu
self._leaky_relu_alpha = leaky_relu_alpha
conv = DepthwiseSeparableConv3d if use_depthwise_separable_conv else nn.Conv3d
self.conv1 = self._make_conv_layer(1, features[0], kernel_sizes[0])
self.res1 = self._make_resblock(features[0], kernel_sizes[0])
self.conv2 = self._make_conv_layer(features[0], features[1], kernel_sizes[1])
self.res2 = self._make_resblock(features[1], kernel_sizes[1])
self.conv3 = self._make_conv_layer(features[1], features[2], kernel_sizes[2])
self.res3 = self._make_resblock(features[2], kernel_sizes[2])
# Assuming the input histograms are 3D tensors of shape (bin_count, bin_count, bin_count)
# Convolutional layers to extract features from the histograms
self.conv1 = conv(
1, feature_map_sizes[0], kernel_size=kernel_sizes[0], padding=1
self.deconv1 = self._make_deconv_layer(
features[2], features[1], kernel_sizes[2]
)
self.conv2 = conv(
feature_map_sizes[0],
feature_map_sizes[1],
kernel_size=kernel_sizes[1],
padding=1,
)
self.conv3 = conv(
feature_map_sizes[1],
feature_map_sizes[2],
kernel_size=kernel_sizes[2],
padding=1,
self.deconv2 = self._make_deconv_layer(
features[1], features[0], kernel_sizes[1]
)
self.deconv3 = self._make_deconv_layer(features[0], 1, kernel_sizes[0])
self.activation = nn.ELU(elu_alpha, inplace=True)
self.bn1 = nn.BatchNorm3d(feature_map_sizes[0])
self.bn2 = nn.BatchNorm3d(feature_map_sizes[1])
self.bn3 = nn.BatchNorm3d(feature_map_sizes[2])
self.resblock1 = nn.Sequential(
conv(
feature_map_sizes[2],
feature_map_sizes[2],
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
nn.ELU(elu_alpha, inplace=True),
nn.BatchNorm3d(feature_map_sizes[2]),
conv(
feature_map_sizes[2],
feature_map_sizes[2],
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
nn.ELU(elu_alpha, inplace=True),
nn.BatchNorm3d(feature_map_sizes[2]),
)
# Deconvolutional layers
self.deconv1 = nn.ConvTranspose3d(
feature_map_sizes[2],
feature_map_sizes[1],
kernel_size=feature_map_sizes[2],
stride=1,
padding=1,
)
self.deconv2 = nn.ConvTranspose3d(
feature_map_sizes[1],
feature_map_sizes[0],
kernel_size=feature_map_sizes[1],
stride=1,
padding=1,
)
self.deconv3 = nn.ConvTranspose3d(
feature_map_sizes[0], 1, kernel_size=3, stride=1, padding=1
)
self.dropout = nn.Dropout3d(p=dropout_prob)
self._initialize_weights()
def _make_conv_layer(
self, in_channels: int, out_channels: int, kernel_size: int
) -> nn.Sequential:
return nn.Sequential(
nn.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=1,
bias=False,
),
(
nn.ELU(self._elu_alpha)
if self._use_elu
else nn.LeakyReLU(self._leaky_relu_alpha)
),
(nn.InstanceNorm3d if self._use_instance_norm else nn.BatchNorm3d)(
out_channels
),
nn.Dropout(p=self._dropout_prob),
)
def _make_resblock(self, channels: int, kernel_size: int) -> nn.Sequential:
return nn.Sequential(
nn.Conv3d(
in_channels=channels,
out_channels=channels,
kernel_size=kernel_size,
padding=1,
bias=False,
),
(
nn.ELU(self._elu_alpha)
if self._use_elu
else nn.LeakyReLU(self._leaky_relu_alpha)(
nn.InstanceNorm3d if self._use_instance_norm else nn.BatchNorm3d
)(channels)
),
nn.Conv3d(
in_channels=channels,
out_channels=channels,
kernel_size=kernel_size,
padding=1,
bias=False,
),
(
nn.ELU(self._elu_alpha)
if self._use_elu
else nn.LeakyReLU(self._leaky_relu_alpha)
),
(nn.InstanceNorm3d if self._use_instance_norm else nn.BatchNorm3d)(
channels
),
)
def _make_deconv_layer(
self, in_channels: int, out_channels: int, kernel_size: int
) -> nn.Sequential:
return nn.Sequential(
nn.ConvTranspose3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=1,
),
(
nn.ELU(self._elu_alpha)
if self._use_elu
else nn.LeakyReLU(self._leaky_relu_alpha)
),
)
def forward(self, x):
out = self.dropout(self.bn1(self.activation(self.conv1(x))))
out = self.dropout(self.bn2(self.activation(self.conv2(out))))
out = self.dropout(self.bn3(self.activation(self.conv2(out))))
out = self.conv1(x)
out = out + self.res1(out)
out = self.conv2(out)
out = out + self.res2(out)
out = self.conv3(out)
out = out + self.res3(out)
out = out + self.resblock1(out)
out = self.activation(self.deconv1(out))
out = self.activation(self.deconv2(out))
out = self.activation(self.deconv3(out))
out = self.deconv1(out)
out = self.deconv2(out)
out = self.deconv3(out)
return self._normalize(out)
def _normalize(self, x):
x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
return x / torch.where(x_sum == 0, torch.ones_like(x_sum), x_sum)
return x / (x_sum + 1e-6)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
nn.init.xavier_normal_(
m.weight,
)
if isinstance(m, (nn.Conv3d, nn.ConvTranspose3d)):
# Applying He normal initialization
nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _test_network_dimensions(constructor):
for bin_count in [16, 32, 64]:
model = constructor()
# Create a dummy input tensor of the correct shape
input_tensor = torch.rand(4, 1, bin_count, bin_count, bin_count)
# Test the model output
output = model(input_tensor)
assert (
input_tensor.shape == output.shape
), f"Expected output shape {input_tensor.shape}, but got {output.shape}"
_test_network_dimensions(Residual3)