Update networks

This commit is contained in:
Andras Schmelczer 2024-06-22 15:01:02 +01:00
parent 0e4fe7ab63
commit b7755bfea8
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
9 changed files with 173 additions and 144 deletions

View file

@ -2,13 +2,9 @@ from typing import Any, Dict, Tuple
from .v1 import HistogramRestorationNet as v1
from .simple_cnn import SimpleCNN
from .residual import Residual
from .normalised_cnn import NormalisedCNN
from .smart_res import SmartRes
from .attention_net import AttentionNet
from .res2 import Res2
from .attention2 import EnhancedAestheticHistogramNet
from .attention import PhotoEnhanceNetAdvanced
from .advanced_attention import PhotoEnhanceNetAdvanced as advanced_attention
from .residual2 import Residual2
from .residual3 import Residual3
from .dummy import Dummy
import torch
import torch.nn as nn
from pathlib import Path
@ -17,16 +13,12 @@ import json
MODELS = {
# "v1": v1,
# "SimpleCNN": SimpleCNN,
"v1": v1,
"Dummy": Dummy,
"SimpleCNN": SimpleCNN,
"Residual": Residual,
# "NormalisedCNN": NormalisedCNN,
# "SmartRes": SmartRes,
# "AttentionNet": AttentionNet,
# "attention2": EnhancedAestheticHistogramNet,
# "advanced_attention": advanced_attention,
# "Res2": Res2,
# "attention1": PhotoEnhanceNetAdvanced,
# "Residual2": Residual2,
"Residual3": Residual3,
}
@ -39,6 +31,7 @@ def save_model(model: nn.Module, hyperparameters: Dict[str, Any], path: Path):
params_path = path.with_suffix(".json")
logging.info(f"Saving model to {model_path}")
logging.info(f"Parameter count: {sum(p.numel() for p in model.parameters())}")
with open(model_path, "wb") as f:
torch.save(model.state_dict(), f)
with open(params_path, "w") as f:
@ -60,11 +53,12 @@ def load_model(path: Path, device: torch.device) -> Tuple[nn.Module, Dict[str, A
).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()
logging.info(f"Parameter count: {sum(p.numel() for p in model.parameters())}")
return model, hyperparameters
def _test_models():
def test_models():
for model_name, model_constructor in MODELS.items():
logging.info(f"Testing model {model_name}")
_test_network_dimensions(model_constructor)
@ -72,18 +66,13 @@ def _test_models():
def _test_network_dimensions(constructor):
for bin_count in [16, 32, 64]:
model = constructor(bin_count=bin_count)
model = constructor()
# Create a dummy input tensor of the correct shape
# Create a dummy input tensor of the correct shape, the mini-batch size is 4
input_tensor = torch.rand(4, 1, bin_count, bin_count, bin_count)
# Test the model output
output = model(input_tensor)
assert (
input_tensor.shape == output.shape
), f"Expected output shape {input_tensor.shape}, but got {output.shape}"
print("Test passed! Output shape matches input shape.")
if __name__ == "__main__":
_test_models()
logging.info("Test passed! Output shape matches input shape.")