diff --git a/src/models/__init__.py b/src/models/__init__.py index 658116b..a98da3c 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -1,8 +1,7 @@ from typing import Any, Dict, Tuple -from .v1 import HistogramRestorationNet as v1 from .simple_cnn import SimpleCNN from .residual import Residual -from .residual3 import Residual3 +from .histogram_net import HistogramNet from .dummy import Dummy import torch import torch.nn as nn @@ -12,16 +11,17 @@ import json MODELS = { - # "v1": v1, "Dummy": Dummy, "SimpleCNN": SimpleCNN, "Residual": Residual, - "Residual3": Residual3, + "HistogramNet": HistogramNet, } -def create_model(type: str, bin_count: int, device: torch.device) -> nn.Module: - return MODELS[type](bin_count).to(device) +def create_model( + type: str, hyperparameters: Dict[str, Any], device: torch.device +) -> nn.Module: + return MODELS[type](**hyperparameters).to(device) def save_model(model: nn.Module, hyperparameters: Dict[str, Any], path: Path): @@ -47,7 +47,7 @@ def load_model(path: Path, device: torch.device) -> Tuple[nn.Module, Dict[str, A model_path = path.with_suffix(".pth") model = create_model( type=hyperparameters["model_type"], - bin_count=hyperparameters["bin_count"], + **hyperparameters, device=device, ) model.load_state_dict(torch.load(model_path)) diff --git a/src/models/residual3.py b/src/models/histogram_net.py similarity index 64% rename from src/models/residual3.py rename to src/models/histogram_net.py index 92a63bb..38b973a 100644 --- a/src/models/residual3.py +++ b/src/models/histogram_net.py @@ -6,53 +6,55 @@ import torch.nn as nn EPSILON = 1e-5 -class Residual3(nn.Module): +class HistogramNet(nn.Module): def __init__( self, elu_alpha: float = 1, dropout_prob: float = 0.05, features: list[int] = [16, 32, 64], - kernel_sizes: list[int] = [3, 3, 3], + kernel_size: int = 3, use_instance_norm: bool = True, use_elu: bool = True, leaky_relu_alpha: float = 0.01, + use_residual: bool = True, **_, ): - super(Residual3, self).__init__() + super(HistogramNet, self).__init__() self._elu_alpha = elu_alpha self._dropout_prob = dropout_prob self._features = features - self._kernel_sizes = kernel_sizes + self._kernel_size = kernel_size self._use_instance_norm = use_instance_norm self._use_elu = use_elu self._leaky_relu_alpha = leaky_relu_alpha + self._use_residual = use_residual self.print_og_result = False - self.conv1 = self._make_conv_layer(1, features[0], kernel_sizes[0]) - self.res1 = self._make_resblock(features[0], kernel_sizes[0]) - self.conv2 = self._make_conv_layer(features[0], features[1], kernel_sizes[1]) - self.res2 = self._make_resblock(features[1], kernel_sizes[1]) - self.conv3 = self._make_conv_layer(features[1], features[2], kernel_sizes[2]) - self.res3 = self._make_resblock(features[2], kernel_sizes[2]) + self._convolutions = nn.ModuleList( + self._make_conv_layer(in_channels=in_channels, out_channels=out_channels) + for in_channels, out_channels in zip([1] + features[:-1], features) + ) - self.deconv1 = self._make_deconv_layer( - features[2], features[1], kernel_sizes[2] + if self._use_residual: + self._residual_blocks = nn.ModuleList( + self._make_resblock(channels) for channels in features + ) + + self._deconvolutions = nn.ModuleList( + self._make_deconv_layer(in_channels=in_channels, out_channels=out_channels) + for in_channels, out_channels in zip( + features[::-1], features[::-1][1:] + [1] + ) ) - self.deconv2 = self._make_deconv_layer( - features[1], features[0], kernel_sizes[1] - ) - self.deconv3 = self._make_deconv_layer(features[0], 1, kernel_sizes[0]) self._initialize_weights() - def _make_conv_layer( - self, in_channels: int, out_channels: int, kernel_size: int - ) -> nn.Sequential: + def _make_conv_layer(self, in_channels: int, out_channels: int) -> nn.Sequential: return nn.Sequential( nn.Conv3d( in_channels=in_channels, out_channels=out_channels, - kernel_size=kernel_size, + kernel_size=self._kernel_size, padding=1, bias=False, ), @@ -67,12 +69,12 @@ class Residual3(nn.Module): nn.Dropout(p=self._dropout_prob), ) - def _make_resblock(self, channels: int, kernel_size: int) -> nn.Sequential: + def _make_resblock(self, channels: int) -> nn.Sequential: return nn.Sequential( nn.Conv3d( in_channels=channels, out_channels=channels, - kernel_size=kernel_size, + kernel_size=self._kernel_size, padding=1, bias=False, ), @@ -86,7 +88,7 @@ class Residual3(nn.Module): nn.Conv3d( in_channels=channels, out_channels=channels, - kernel_size=kernel_size, + kernel_size=self._kernel_size, padding=1, bias=False, ), @@ -100,14 +102,12 @@ class Residual3(nn.Module): ), ) - def _make_deconv_layer( - self, in_channels: int, out_channels: int, kernel_size: int - ) -> nn.Sequential: + def _make_deconv_layer(self, in_channels: int, out_channels: int) -> nn.Sequential: return nn.Sequential( nn.ConvTranspose3d( in_channels=in_channels, out_channels=out_channels, - kernel_size=kernel_size, + kernel_size=self._kernel_size, padding=1, ), ( @@ -118,22 +118,22 @@ class Residual3(nn.Module): ) def forward(self, x): - out = self.conv1(x) - out = out + self.res1(out) - out = self.conv2(out) - out = out + self.res2(out) - out = self.conv3(out) - out = out + self.res3(out) + if self._use_residual: + for conv, res in zip(self._convolutions, self._residual_blocks): + x = conv(x) + x = x + res(x) + else: + for conv in self._convolutions: + x = conv(x) - out = self.deconv1(out) - out = self.deconv2(out) - out = self.deconv3(out) + for deconv in self._deconvolutions: + x = deconv(x) if self.print_og_result: - logging.info(f"Original result {torch.sum(out)}") + logging.info(f"Original result {torch.sum(x)}") self.print_og_result = False - return self._normalize(out) + return self._normalize(x) @staticmethod def _normalize(x): diff --git a/src/models/v1.py b/src/models/v1.py deleted file mode 100644 index 87f0bac..0000000 --- a/src/models/v1.py +++ /dev/null @@ -1,69 +0,0 @@ -import torch.nn as nn -import torch.nn.functional as F -import torch - - -class HistogramRestorationNet(nn.Module): - def __init__(self, **_): - super(HistogramRestorationNet, self).__init__() - - self.conv1 = nn.Conv3d(in_channels=1, out_channels=16, kernel_size=3, padding=1) - self.bn1 = nn.BatchNorm3d(16) - self.conv2 = nn.Conv3d(16, 32, 3, padding=1) - self.bn2 = nn.BatchNorm3d(32) - self.conv3 = nn.Conv3d(32, 64, 3, padding=1) - self.bn3 = nn.BatchNorm3d(64) - - # Adjusted residual connections with proper downsampling and channel matching - self.res1 = nn.Sequential( - nn.Conv3d(16, 32, 1, stride=1, padding=0), # Match channels - nn.BatchNorm3d(32), - nn.MaxPool3d(2), # Downsample to match size - ) - self.res2 = nn.Sequential( - nn.Conv3d(32, 64, 1, stride=1, padding=0), # Match channels - nn.BatchNorm3d(64), - nn.MaxPool3d(2), # Downsample to match size - ) - - self.fc1 = nn.Linear(64 * 4 * 4 * 4, 512) - self.fc_bn1 = nn.BatchNorm1d(512) - self.fc2 = nn.Linear(512, 32 * 32 * 32) - self.apply(HistogramRestorationNet._init_weights_he) - - @staticmethod - def _init_weights_he(m): - if isinstance(m, nn.Linear) or isinstance(m, nn.Conv3d): - torch.nn.init.kaiming_uniform_(m.weight, nonlinearity="relu") - if m.bias is not None: - torch.nn.init.zeros_(m.bias) - - def forward(self, x): - # Input dimensions: (batch_size, channels(1), 32, 32, 32) - - x = F.relu(self.bn1(self.conv1(x))) - x = F.max_pool3d(x, 2) - - # Apply first adjusted residual connection - res = self.res1(x) - x = F.relu(self.bn2(self.conv2(x))) - x = F.max_pool3d(x, 2) - x += res # Add adjusted residual - - # Apply second adjusted residual connection - res = self.res2(x) - x = F.relu(self.bn3(self.conv3(x))) - x = F.max_pool3d(x, 2) - x += res # Add adjusted residual - - # Flatten for fully connected layers - x = x.view(x.size(0), -1) - - x = F.relu(self.fc_bn1(self.fc1(x))) - x = self.fc2(x) - - # Reshape back to the histogram shape - x = x.view(-1, 32, 32, 32) - x /= torch.sum(x, (1, 2, 3)).view(x.size()[0], 1, 1, 1) - - return x