Extract model persistance
This commit is contained in:
parent
d0935eeae6
commit
95829e6b1c
1 changed files with 41 additions and 2 deletions
|
|
@ -1,3 +1,4 @@
|
||||||
|
from typing import Any, Dict, Tuple
|
||||||
from .v1 import HistogramRestorationNet as v1
|
from .v1 import HistogramRestorationNet as v1
|
||||||
from .simple_cnn import SimpleCNN
|
from .simple_cnn import SimpleCNN
|
||||||
from .residual import Residual
|
from .residual import Residual
|
||||||
|
|
@ -9,6 +10,10 @@ from .attention2 import EnhancedAestheticHistogramNet
|
||||||
from .attention import PhotoEnhanceNetAdvanced
|
from .attention import PhotoEnhanceNetAdvanced
|
||||||
from .advanced_attention import PhotoEnhanceNetAdvanced as advanced_attention
|
from .advanced_attention import PhotoEnhanceNetAdvanced as advanced_attention
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from pathlib import Path
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
MODELS = {
|
MODELS = {
|
||||||
|
|
@ -29,9 +34,39 @@ def create_model(type: str, bin_count: int):
|
||||||
return MODELS[type](bin_count)
|
return MODELS[type](bin_count)
|
||||||
|
|
||||||
|
|
||||||
def test_models():
|
def save_model(model: nn.Module, hyperparameters: Dict[str, Any], path: Path):
|
||||||
|
model_path = path.with_suffix(".pth")
|
||||||
|
params_path = path.with_suffix(".json")
|
||||||
|
|
||||||
|
logging.info(f"Saving model to {model_path}")
|
||||||
|
with open(model_path, "wb") as f:
|
||||||
|
torch.save(model.state_dict(), f)
|
||||||
|
with open(params_path, "w") as f:
|
||||||
|
json.dump(hyperparameters, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(path: Path, device: torch.device) -> Tuple[nn.Module, Dict[str, Any]]:
|
||||||
|
logging.info(f"Loading model from {path}")
|
||||||
|
|
||||||
|
params_path = path.with_suffix(".json")
|
||||||
|
with open(params_path) as f:
|
||||||
|
hyperparameters = json.load(f)
|
||||||
|
logging.info(f"Hyperparameters: {hyperparameters}")
|
||||||
|
|
||||||
|
model_path = path.with_suffix(".pth")
|
||||||
|
model = create_model(
|
||||||
|
type=hyperparameters["model_type"],
|
||||||
|
bin_count=hyperparameters["bin_count"],
|
||||||
|
).to(device)
|
||||||
|
model.load_state_dict(torch.load(model_path))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
return model, hyperparameters
|
||||||
|
|
||||||
|
|
||||||
|
def _test_models():
|
||||||
for model_name, model_constructor in MODELS.items():
|
for model_name, model_constructor in MODELS.items():
|
||||||
print(f"Testing model {model_name}")
|
logging.info(f"Testing model {model_name}")
|
||||||
_test_network_dimensions(model_constructor)
|
_test_network_dimensions(model_constructor)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -48,3 +83,7 @@ def _test_network_dimensions(constructor):
|
||||||
input_tensor.shape == output.shape
|
input_tensor.shape == output.shape
|
||||||
), f"Expected output shape {input_tensor.shape}, but got {output.shape}"
|
), f"Expected output shape {input_tensor.shape}, but got {output.shape}"
|
||||||
print("Test passed! Output shape matches input shape.")
|
print("Test passed! Output shape matches input shape.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
_test_models()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue