Remove editor module
This commit is contained in:
parent
e5959268c1
commit
c966866abc
37 changed files with 7752 additions and 7345 deletions
76
src/models/__init__.py
Normal file
76
src/models/__init__.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
from typing import Any, Dict, Tuple
|
||||
from .v1 import HistogramRestorationNet as v1
|
||||
from .simple_cnn import SimpleCNN
|
||||
from .residual import Residual
|
||||
from .residual3 import Residual3
|
||||
from .dummy import Dummy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import json
|
||||
|
||||
|
||||
MODELS = {
|
||||
# "v1": v1,
|
||||
"Dummy": Dummy,
|
||||
"SimpleCNN": SimpleCNN,
|
||||
"Residual": Residual,
|
||||
"Residual3": Residual3,
|
||||
}
|
||||
|
||||
|
||||
def create_model(type: str, bin_count: int, device: torch.device) -> nn.Module:
|
||||
return MODELS[type](bin_count).to(device)
|
||||
|
||||
|
||||
def save_model(model: nn.Module, hyperparameters: Dict[str, Any], path: Path):
|
||||
model_path = path.with_suffix(".pth")
|
||||
params_path = path.with_suffix(".json")
|
||||
|
||||
logging.info(f"Saving model to {model_path}")
|
||||
logging.info(f"Parameter count: {sum(p.numel() for p in model.parameters())}")
|
||||
with open(model_path, "wb") as f:
|
||||
torch.save(model.state_dict(), f)
|
||||
with open(params_path, "w") as f:
|
||||
json.dump(hyperparameters, f, indent=2)
|
||||
|
||||
|
||||
def load_model(path: Path, device: torch.device) -> Tuple[nn.Module, Dict[str, Any]]:
|
||||
logging.info(f"Loading model from {path}")
|
||||
|
||||
params_path = path.with_suffix(".json")
|
||||
with open(params_path) as f:
|
||||
hyperparameters = json.load(f)
|
||||
logging.info(f"Hyperparameters: {hyperparameters}")
|
||||
|
||||
model_path = path.with_suffix(".pth")
|
||||
model = create_model(
|
||||
type=hyperparameters["model_type"],
|
||||
bin_count=hyperparameters["bin_count"],
|
||||
).to(device)
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
model.eval()
|
||||
logging.info(f"Parameter count: {sum(p.numel() for p in model.parameters())}")
|
||||
|
||||
return model, hyperparameters
|
||||
|
||||
|
||||
def test_models():
|
||||
for model_name, model_constructor in MODELS.items():
|
||||
logging.info(f"Testing model {model_name}")
|
||||
_test_network_dimensions(model_constructor)
|
||||
|
||||
|
||||
def _test_network_dimensions(constructor):
|
||||
for bin_count in [16, 32, 64]:
|
||||
model = constructor()
|
||||
|
||||
# Create a dummy input tensor of the correct shape, the mini-batch size is 4
|
||||
input_tensor = torch.rand(4, 1, bin_count, bin_count, bin_count)
|
||||
|
||||
output = model(input_tensor)
|
||||
assert (
|
||||
input_tensor.shape == output.shape
|
||||
), f"Expected output shape {input_tensor.shape}, but got {output.shape}"
|
||||
logging.info("Test passed! Output shape matches input shape.")
|
||||
10
src/models/dummy.py
Normal file
10
src/models/dummy.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Dummy(nn.Module):
|
||||
def __init__(self, **_):
|
||||
super(Dummy, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
53
src/models/residual.py
Normal file
53
src/models/residual.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, **_):
|
||||
super(Residual, self).__init__()
|
||||
|
||||
# Assuming the input histograms are 3D tensors of shape (bin_count, bin_count, bin_count)
|
||||
# Convolutional layers to extract features from the histograms
|
||||
self.conv1 = nn.Conv3d(1, 16, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
|
||||
self.conv3 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
|
||||
|
||||
# Batch normalization layers for better convergence
|
||||
self.bn1 = nn.BatchNorm3d(16)
|
||||
self.bn2 = nn.BatchNorm3d(32)
|
||||
self.bn3 = nn.BatchNorm3d(64)
|
||||
|
||||
# Residual block to help the network learn identity functions effectively
|
||||
self.resblock1 = nn.Sequential(
|
||||
nn.Conv3d(64, 64, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm3d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv3d(64, 64, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm3d(64),
|
||||
)
|
||||
|
||||
# ReLU activation
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.deconv1 = nn.ConvTranspose3d(64, 32, kernel_size=3, stride=1, padding=1)
|
||||
self.deconv2 = nn.ConvTranspose3d(32, 16, kernel_size=3, stride=1, padding=1)
|
||||
self.deconv3 = nn.ConvTranspose3d(16, 1, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(self.bn1(self.conv1(x)))
|
||||
x = self.relu(self.bn2(self.conv2(x)))
|
||||
x = self.relu(self.bn3(self.conv3(x)))
|
||||
|
||||
# Apply residual blocks
|
||||
residual = x
|
||||
out = self.resblock1(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
# Upsample to original size
|
||||
out = self.relu(self.deconv1(out))
|
||||
out = self.relu(self.deconv2(out))
|
||||
out = self.relu(self.deconv3(out))
|
||||
|
||||
sum = torch.sum(out, dim=(2, 3, 4), keepdim=True)
|
||||
return out / sum
|
||||
139
src/models/residual3.py
Normal file
139
src/models/residual3.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Residual3(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
elu_alpha: float = 1,
|
||||
dropout_prob: float = 0.05,
|
||||
features: list[int] = [16, 32, 64],
|
||||
kernel_sizes: list[int] = [3, 3, 3],
|
||||
use_instance_norm: bool = True,
|
||||
use_elu: bool = True,
|
||||
leaky_relu_alpha: float = 0.01,
|
||||
**_
|
||||
):
|
||||
super(Residual3, self).__init__()
|
||||
self._elu_alpha = elu_alpha
|
||||
self._dropout_prob = dropout_prob
|
||||
self._features = features
|
||||
self._kernel_sizes = kernel_sizes
|
||||
self._use_instance_norm = use_instance_norm
|
||||
self._use_elu = use_elu
|
||||
self._leaky_relu_alpha = leaky_relu_alpha
|
||||
|
||||
self.conv1 = self._make_conv_layer(1, features[0], kernel_sizes[0])
|
||||
self.res1 = self._make_resblock(features[0], kernel_sizes[0])
|
||||
self.conv2 = self._make_conv_layer(features[0], features[1], kernel_sizes[1])
|
||||
self.res2 = self._make_resblock(features[1], kernel_sizes[1])
|
||||
self.conv3 = self._make_conv_layer(features[1], features[2], kernel_sizes[2])
|
||||
self.res3 = self._make_resblock(features[2], kernel_sizes[2])
|
||||
|
||||
self.deconv1 = self._make_deconv_layer(
|
||||
features[2], features[1], kernel_sizes[2]
|
||||
)
|
||||
self.deconv2 = self._make_deconv_layer(
|
||||
features[1], features[0], kernel_sizes[1]
|
||||
)
|
||||
self.deconv3 = self._make_deconv_layer(features[0], 1, kernel_sizes[0])
|
||||
|
||||
self._initialize_weights()
|
||||
|
||||
def _make_conv_layer(
|
||||
self, in_channels: int, out_channels: int, kernel_size: int
|
||||
) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.Conv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
(
|
||||
nn.ELU(self._elu_alpha)
|
||||
if self._use_elu
|
||||
else nn.LeakyReLU(self._leaky_relu_alpha)
|
||||
),
|
||||
(nn.InstanceNorm3d if self._use_instance_norm else nn.BatchNorm3d)(
|
||||
out_channels
|
||||
),
|
||||
nn.Dropout(p=self._dropout_prob),
|
||||
)
|
||||
|
||||
def _make_resblock(self, channels: int, kernel_size: int) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.Conv3d(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
(
|
||||
nn.ELU(self._elu_alpha)
|
||||
if self._use_elu
|
||||
else nn.LeakyReLU(self._leaky_relu_alpha)(
|
||||
nn.InstanceNorm3d if self._use_instance_norm else nn.BatchNorm3d
|
||||
)(channels)
|
||||
),
|
||||
nn.Conv3d(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
(
|
||||
nn.ELU(self._elu_alpha)
|
||||
if self._use_elu
|
||||
else nn.LeakyReLU(self._leaky_relu_alpha)
|
||||
),
|
||||
(nn.InstanceNorm3d if self._use_instance_norm else nn.BatchNorm3d)(
|
||||
channels
|
||||
),
|
||||
)
|
||||
|
||||
def _make_deconv_layer(
|
||||
self, in_channels: int, out_channels: int, kernel_size: int
|
||||
) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.ConvTranspose3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=1,
|
||||
),
|
||||
(
|
||||
nn.ELU(self._elu_alpha)
|
||||
if self._use_elu
|
||||
else nn.LeakyReLU(self._leaky_relu_alpha)
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = out + self.res1(out)
|
||||
out = self.conv2(out)
|
||||
out = out + self.res2(out)
|
||||
out = self.conv3(out)
|
||||
out = out + self.res3(out)
|
||||
|
||||
out = self.deconv1(out)
|
||||
out = self.deconv2(out)
|
||||
out = self.deconv3(out)
|
||||
|
||||
return self._normalize(out)
|
||||
|
||||
def _normalize(self, x):
|
||||
x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
|
||||
return x / (x_sum + 1e-6)
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, (nn.Conv3d, nn.ConvTranspose3d)):
|
||||
# Applying He normal initialization
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu")
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
37
src/models/simple_cnn.py
Normal file
37
src/models/simple_cnn.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SimpleCNN(nn.Module):
|
||||
def __init__(self, **_):
|
||||
super(SimpleCNN, self).__init__()
|
||||
# Define the convolutional layers
|
||||
self.conv1 = nn.Conv3d(
|
||||
1, 16, kernel_size=3, padding=1
|
||||
) # input channels = 1, output channels = 16
|
||||
self.conv2 = nn.Conv3d(
|
||||
16, 32, kernel_size=3, padding=1
|
||||
) # input channels = 16, output channels = 32
|
||||
self.conv3 = nn.Conv3d(
|
||||
32, 64, kernel_size=3, padding=1
|
||||
) # input channels = 32, output channels = 64
|
||||
self.conv4 = nn.Conv3d(
|
||||
64, 32, kernel_size=3, padding=1
|
||||
) # input channels = 64, output channels = 32
|
||||
self.conv5 = nn.Conv3d(
|
||||
32, 16, kernel_size=3, padding=1
|
||||
) # input channels = 32, output channels = 16
|
||||
self.conv6 = nn.Conv3d(
|
||||
16, 1, kernel_size=3, padding=1
|
||||
) # input channels = 16, output channels = 1
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.conv1(x))
|
||||
x = F.relu(self.conv2(x))
|
||||
x = F.relu(self.conv3(x))
|
||||
x = F.relu(self.conv4(x))
|
||||
x = F.relu(self.conv5(x))
|
||||
x = self.conv6(x)
|
||||
sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
|
||||
return x / sum
|
||||
69
src/models/v1.py
Normal file
69
src/models/v1.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
|
||||
|
||||
class HistogramRestorationNet(nn.Module):
|
||||
def __init__(self, **_):
|
||||
super(HistogramRestorationNet, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv3d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
|
||||
self.bn1 = nn.BatchNorm3d(16)
|
||||
self.conv2 = nn.Conv3d(16, 32, 3, padding=1)
|
||||
self.bn2 = nn.BatchNorm3d(32)
|
||||
self.conv3 = nn.Conv3d(32, 64, 3, padding=1)
|
||||
self.bn3 = nn.BatchNorm3d(64)
|
||||
|
||||
# Adjusted residual connections with proper downsampling and channel matching
|
||||
self.res1 = nn.Sequential(
|
||||
nn.Conv3d(16, 32, 1, stride=1, padding=0), # Match channels
|
||||
nn.BatchNorm3d(32),
|
||||
nn.MaxPool3d(2), # Downsample to match size
|
||||
)
|
||||
self.res2 = nn.Sequential(
|
||||
nn.Conv3d(32, 64, 1, stride=1, padding=0), # Match channels
|
||||
nn.BatchNorm3d(64),
|
||||
nn.MaxPool3d(2), # Downsample to match size
|
||||
)
|
||||
|
||||
self.fc1 = nn.Linear(64 * 4 * 4 * 4, 512)
|
||||
self.fc_bn1 = nn.BatchNorm1d(512)
|
||||
self.fc2 = nn.Linear(512, 32 * 32 * 32)
|
||||
self.apply(HistogramRestorationNet._init_weights_he)
|
||||
|
||||
@staticmethod
|
||||
def _init_weights_he(m):
|
||||
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv3d):
|
||||
torch.nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
|
||||
if m.bias is not None:
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, x):
|
||||
# Input dimensions: (batch_size, channels(1), 32, 32, 32)
|
||||
|
||||
x = F.relu(self.bn1(self.conv1(x)))
|
||||
x = F.max_pool3d(x, 2)
|
||||
|
||||
# Apply first adjusted residual connection
|
||||
res = self.res1(x)
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = F.max_pool3d(x, 2)
|
||||
x += res # Add adjusted residual
|
||||
|
||||
# Apply second adjusted residual connection
|
||||
res = self.res2(x)
|
||||
x = F.relu(self.bn3(self.conv3(x)))
|
||||
x = F.max_pool3d(x, 2)
|
||||
x += res # Add adjusted residual
|
||||
|
||||
# Flatten for fully connected layers
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
x = F.relu(self.fc_bn1(self.fc1(x)))
|
||||
x = self.fc2(x)
|
||||
|
||||
# Reshape back to the histogram shape
|
||||
x = x.view(-1, 32, 32, 32)
|
||||
x /= torch.sum(x, (1, 2, 3)).view(x.size()[0], 1, 1, 1)
|
||||
|
||||
return x
|
||||
Loading…
Add table
Add a link
Reference in a new issue