Fix up models
This commit is contained in:
parent
28b8b026a9
commit
bf524eea0b
3 changed files with 45 additions and 114 deletions
|
|
@ -1,8 +1,7 @@
|
||||||
from typing import Any, Dict, Tuple
|
from typing import Any, Dict, Tuple
|
||||||
from .v1 import HistogramRestorationNet as v1
|
|
||||||
from .simple_cnn import SimpleCNN
|
from .simple_cnn import SimpleCNN
|
||||||
from .residual import Residual
|
from .residual import Residual
|
||||||
from .residual3 import Residual3
|
from .histogram_net import HistogramNet
|
||||||
from .dummy import Dummy
|
from .dummy import Dummy
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -12,16 +11,17 @@ import json
|
||||||
|
|
||||||
|
|
||||||
MODELS = {
|
MODELS = {
|
||||||
# "v1": v1,
|
|
||||||
"Dummy": Dummy,
|
"Dummy": Dummy,
|
||||||
"SimpleCNN": SimpleCNN,
|
"SimpleCNN": SimpleCNN,
|
||||||
"Residual": Residual,
|
"Residual": Residual,
|
||||||
"Residual3": Residual3,
|
"HistogramNet": HistogramNet,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def create_model(type: str, bin_count: int, device: torch.device) -> nn.Module:
|
def create_model(
|
||||||
return MODELS[type](bin_count).to(device)
|
type: str, hyperparameters: Dict[str, Any], device: torch.device
|
||||||
|
) -> nn.Module:
|
||||||
|
return MODELS[type](**hyperparameters).to(device)
|
||||||
|
|
||||||
|
|
||||||
def save_model(model: nn.Module, hyperparameters: Dict[str, Any], path: Path):
|
def save_model(model: nn.Module, hyperparameters: Dict[str, Any], path: Path):
|
||||||
|
|
@ -47,7 +47,7 @@ def load_model(path: Path, device: torch.device) -> Tuple[nn.Module, Dict[str, A
|
||||||
model_path = path.with_suffix(".pth")
|
model_path = path.with_suffix(".pth")
|
||||||
model = create_model(
|
model = create_model(
|
||||||
type=hyperparameters["model_type"],
|
type=hyperparameters["model_type"],
|
||||||
bin_count=hyperparameters["bin_count"],
|
**hyperparameters,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
model.load_state_dict(torch.load(model_path))
|
model.load_state_dict(torch.load(model_path))
|
||||||
|
|
|
||||||
|
|
@ -6,53 +6,55 @@ import torch.nn as nn
|
||||||
EPSILON = 1e-5
|
EPSILON = 1e-5
|
||||||
|
|
||||||
|
|
||||||
class Residual3(nn.Module):
|
class HistogramNet(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
elu_alpha: float = 1,
|
elu_alpha: float = 1,
|
||||||
dropout_prob: float = 0.05,
|
dropout_prob: float = 0.05,
|
||||||
features: list[int] = [16, 32, 64],
|
features: list[int] = [16, 32, 64],
|
||||||
kernel_sizes: list[int] = [3, 3, 3],
|
kernel_size: int = 3,
|
||||||
use_instance_norm: bool = True,
|
use_instance_norm: bool = True,
|
||||||
use_elu: bool = True,
|
use_elu: bool = True,
|
||||||
leaky_relu_alpha: float = 0.01,
|
leaky_relu_alpha: float = 0.01,
|
||||||
|
use_residual: bool = True,
|
||||||
**_,
|
**_,
|
||||||
):
|
):
|
||||||
super(Residual3, self).__init__()
|
super(HistogramNet, self).__init__()
|
||||||
self._elu_alpha = elu_alpha
|
self._elu_alpha = elu_alpha
|
||||||
self._dropout_prob = dropout_prob
|
self._dropout_prob = dropout_prob
|
||||||
self._features = features
|
self._features = features
|
||||||
self._kernel_sizes = kernel_sizes
|
self._kernel_size = kernel_size
|
||||||
self._use_instance_norm = use_instance_norm
|
self._use_instance_norm = use_instance_norm
|
||||||
self._use_elu = use_elu
|
self._use_elu = use_elu
|
||||||
self._leaky_relu_alpha = leaky_relu_alpha
|
self._leaky_relu_alpha = leaky_relu_alpha
|
||||||
|
self._use_residual = use_residual
|
||||||
self.print_og_result = False
|
self.print_og_result = False
|
||||||
|
|
||||||
self.conv1 = self._make_conv_layer(1, features[0], kernel_sizes[0])
|
self._convolutions = nn.ModuleList(
|
||||||
self.res1 = self._make_resblock(features[0], kernel_sizes[0])
|
self._make_conv_layer(in_channels=in_channels, out_channels=out_channels)
|
||||||
self.conv2 = self._make_conv_layer(features[0], features[1], kernel_sizes[1])
|
for in_channels, out_channels in zip([1] + features[:-1], features)
|
||||||
self.res2 = self._make_resblock(features[1], kernel_sizes[1])
|
)
|
||||||
self.conv3 = self._make_conv_layer(features[1], features[2], kernel_sizes[2])
|
|
||||||
self.res3 = self._make_resblock(features[2], kernel_sizes[2])
|
|
||||||
|
|
||||||
self.deconv1 = self._make_deconv_layer(
|
if self._use_residual:
|
||||||
features[2], features[1], kernel_sizes[2]
|
self._residual_blocks = nn.ModuleList(
|
||||||
|
self._make_resblock(channels) for channels in features
|
||||||
|
)
|
||||||
|
|
||||||
|
self._deconvolutions = nn.ModuleList(
|
||||||
|
self._make_deconv_layer(in_channels=in_channels, out_channels=out_channels)
|
||||||
|
for in_channels, out_channels in zip(
|
||||||
|
features[::-1], features[::-1][1:] + [1]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.deconv2 = self._make_deconv_layer(
|
|
||||||
features[1], features[0], kernel_sizes[1]
|
|
||||||
)
|
|
||||||
self.deconv3 = self._make_deconv_layer(features[0], 1, kernel_sizes[0])
|
|
||||||
|
|
||||||
self._initialize_weights()
|
self._initialize_weights()
|
||||||
|
|
||||||
def _make_conv_layer(
|
def _make_conv_layer(self, in_channels: int, out_channels: int) -> nn.Sequential:
|
||||||
self, in_channels: int, out_channels: int, kernel_size: int
|
|
||||||
) -> nn.Sequential:
|
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
nn.Conv3d(
|
nn.Conv3d(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
kernel_size=kernel_size,
|
kernel_size=self._kernel_size,
|
||||||
padding=1,
|
padding=1,
|
||||||
bias=False,
|
bias=False,
|
||||||
),
|
),
|
||||||
|
|
@ -67,12 +69,12 @@ class Residual3(nn.Module):
|
||||||
nn.Dropout(p=self._dropout_prob),
|
nn.Dropout(p=self._dropout_prob),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _make_resblock(self, channels: int, kernel_size: int) -> nn.Sequential:
|
def _make_resblock(self, channels: int) -> nn.Sequential:
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
nn.Conv3d(
|
nn.Conv3d(
|
||||||
in_channels=channels,
|
in_channels=channels,
|
||||||
out_channels=channels,
|
out_channels=channels,
|
||||||
kernel_size=kernel_size,
|
kernel_size=self._kernel_size,
|
||||||
padding=1,
|
padding=1,
|
||||||
bias=False,
|
bias=False,
|
||||||
),
|
),
|
||||||
|
|
@ -86,7 +88,7 @@ class Residual3(nn.Module):
|
||||||
nn.Conv3d(
|
nn.Conv3d(
|
||||||
in_channels=channels,
|
in_channels=channels,
|
||||||
out_channels=channels,
|
out_channels=channels,
|
||||||
kernel_size=kernel_size,
|
kernel_size=self._kernel_size,
|
||||||
padding=1,
|
padding=1,
|
||||||
bias=False,
|
bias=False,
|
||||||
),
|
),
|
||||||
|
|
@ -100,14 +102,12 @@ class Residual3(nn.Module):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _make_deconv_layer(
|
def _make_deconv_layer(self, in_channels: int, out_channels: int) -> nn.Sequential:
|
||||||
self, in_channels: int, out_channels: int, kernel_size: int
|
|
||||||
) -> nn.Sequential:
|
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
nn.ConvTranspose3d(
|
nn.ConvTranspose3d(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
kernel_size=kernel_size,
|
kernel_size=self._kernel_size,
|
||||||
padding=1,
|
padding=1,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
|
@ -118,22 +118,22 @@ class Residual3(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out = self.conv1(x)
|
if self._use_residual:
|
||||||
out = out + self.res1(out)
|
for conv, res in zip(self._convolutions, self._residual_blocks):
|
||||||
out = self.conv2(out)
|
x = conv(x)
|
||||||
out = out + self.res2(out)
|
x = x + res(x)
|
||||||
out = self.conv3(out)
|
else:
|
||||||
out = out + self.res3(out)
|
for conv in self._convolutions:
|
||||||
|
x = conv(x)
|
||||||
|
|
||||||
out = self.deconv1(out)
|
for deconv in self._deconvolutions:
|
||||||
out = self.deconv2(out)
|
x = deconv(x)
|
||||||
out = self.deconv3(out)
|
|
||||||
|
|
||||||
if self.print_og_result:
|
if self.print_og_result:
|
||||||
logging.info(f"Original result {torch.sum(out)}")
|
logging.info(f"Original result {torch.sum(x)}")
|
||||||
self.print_og_result = False
|
self.print_og_result = False
|
||||||
|
|
||||||
return self._normalize(out)
|
return self._normalize(x)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _normalize(x):
|
def _normalize(x):
|
||||||
|
|
@ -1,69 +0,0 @@
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class HistogramRestorationNet(nn.Module):
|
|
||||||
def __init__(self, **_):
|
|
||||||
super(HistogramRestorationNet, self).__init__()
|
|
||||||
|
|
||||||
self.conv1 = nn.Conv3d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
|
|
||||||
self.bn1 = nn.BatchNorm3d(16)
|
|
||||||
self.conv2 = nn.Conv3d(16, 32, 3, padding=1)
|
|
||||||
self.bn2 = nn.BatchNorm3d(32)
|
|
||||||
self.conv3 = nn.Conv3d(32, 64, 3, padding=1)
|
|
||||||
self.bn3 = nn.BatchNorm3d(64)
|
|
||||||
|
|
||||||
# Adjusted residual connections with proper downsampling and channel matching
|
|
||||||
self.res1 = nn.Sequential(
|
|
||||||
nn.Conv3d(16, 32, 1, stride=1, padding=0), # Match channels
|
|
||||||
nn.BatchNorm3d(32),
|
|
||||||
nn.MaxPool3d(2), # Downsample to match size
|
|
||||||
)
|
|
||||||
self.res2 = nn.Sequential(
|
|
||||||
nn.Conv3d(32, 64, 1, stride=1, padding=0), # Match channels
|
|
||||||
nn.BatchNorm3d(64),
|
|
||||||
nn.MaxPool3d(2), # Downsample to match size
|
|
||||||
)
|
|
||||||
|
|
||||||
self.fc1 = nn.Linear(64 * 4 * 4 * 4, 512)
|
|
||||||
self.fc_bn1 = nn.BatchNorm1d(512)
|
|
||||||
self.fc2 = nn.Linear(512, 32 * 32 * 32)
|
|
||||||
self.apply(HistogramRestorationNet._init_weights_he)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _init_weights_he(m):
|
|
||||||
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv3d):
|
|
||||||
torch.nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
|
|
||||||
if m.bias is not None:
|
|
||||||
torch.nn.init.zeros_(m.bias)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# Input dimensions: (batch_size, channels(1), 32, 32, 32)
|
|
||||||
|
|
||||||
x = F.relu(self.bn1(self.conv1(x)))
|
|
||||||
x = F.max_pool3d(x, 2)
|
|
||||||
|
|
||||||
# Apply first adjusted residual connection
|
|
||||||
res = self.res1(x)
|
|
||||||
x = F.relu(self.bn2(self.conv2(x)))
|
|
||||||
x = F.max_pool3d(x, 2)
|
|
||||||
x += res # Add adjusted residual
|
|
||||||
|
|
||||||
# Apply second adjusted residual connection
|
|
||||||
res = self.res2(x)
|
|
||||||
x = F.relu(self.bn3(self.conv3(x)))
|
|
||||||
x = F.max_pool3d(x, 2)
|
|
||||||
x += res # Add adjusted residual
|
|
||||||
|
|
||||||
# Flatten for fully connected layers
|
|
||||||
x = x.view(x.size(0), -1)
|
|
||||||
|
|
||||||
x = F.relu(self.fc_bn1(self.fc1(x)))
|
|
||||||
x = self.fc2(x)
|
|
||||||
|
|
||||||
# Reshape back to the histogram shape
|
|
||||||
x = x.view(-1, 32, 32, 32)
|
|
||||||
x /= torch.sum(x, (1, 2, 3)).view(x.size()[0], 1, 1, 1)
|
|
||||||
|
|
||||||
return x
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue