From 95829e6b1cdb14fac3fe0ad465f0abe08f7dbe21 Mon Sep 17 00:00:00 2001 From: Andras Schmelczer Date: Fri, 21 Jun 2024 22:30:53 +0100 Subject: [PATCH] Extract model persistance --- src/editor/models/__init__.py | 43 +++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/src/editor/models/__init__.py b/src/editor/models/__init__.py index 5f2ed39..867f5b2 100644 --- a/src/editor/models/__init__.py +++ b/src/editor/models/__init__.py @@ -1,3 +1,4 @@ +from typing import Any, Dict, Tuple from .v1 import HistogramRestorationNet as v1 from .simple_cnn import SimpleCNN from .residual import Residual @@ -9,6 +10,10 @@ from .attention2 import EnhancedAestheticHistogramNet from .attention import PhotoEnhanceNetAdvanced from .advanced_attention import PhotoEnhanceNetAdvanced as advanced_attention import torch +import torch.nn as nn +from pathlib import Path +import logging +import json MODELS = { @@ -29,9 +34,39 @@ def create_model(type: str, bin_count: int): return MODELS[type](bin_count) -def test_models(): +def save_model(model: nn.Module, hyperparameters: Dict[str, Any], path: Path): + model_path = path.with_suffix(".pth") + params_path = path.with_suffix(".json") + + logging.info(f"Saving model to {model_path}") + with open(model_path, "wb") as f: + torch.save(model.state_dict(), f) + with open(params_path, "w") as f: + json.dump(hyperparameters, f, indent=2) + + +def load_model(path: Path, device: torch.device) -> Tuple[nn.Module, Dict[str, Any]]: + logging.info(f"Loading model from {path}") + + params_path = path.with_suffix(".json") + with open(params_path) as f: + hyperparameters = json.load(f) + logging.info(f"Hyperparameters: {hyperparameters}") + + model_path = path.with_suffix(".pth") + model = create_model( + type=hyperparameters["model_type"], + bin_count=hyperparameters["bin_count"], + ).to(device) + model.load_state_dict(torch.load(model_path)) + model.eval() + + return model, hyperparameters + + +def _test_models(): for model_name, model_constructor in MODELS.items(): - print(f"Testing model {model_name}") + logging.info(f"Testing model {model_name}") _test_network_dimensions(model_constructor) @@ -48,3 +83,7 @@ def _test_network_dimensions(constructor): input_tensor.shape == output.shape ), f"Expected output shape {input_tensor.shape}, but got {output.shape}" print("Test passed! Output shape matches input shape.") + + +if __name__ == "__main__": + _test_models()