Update networks
This commit is contained in:
parent
0e4fe7ab63
commit
b7755bfea8
9 changed files with 173 additions and 144 deletions
|
|
@ -2,13 +2,9 @@ from typing import Any, Dict, Tuple
|
|||
from .v1 import HistogramRestorationNet as v1
|
||||
from .simple_cnn import SimpleCNN
|
||||
from .residual import Residual
|
||||
from .normalised_cnn import NormalisedCNN
|
||||
from .smart_res import SmartRes
|
||||
from .attention_net import AttentionNet
|
||||
from .res2 import Res2
|
||||
from .attention2 import EnhancedAestheticHistogramNet
|
||||
from .attention import PhotoEnhanceNetAdvanced
|
||||
from .advanced_attention import PhotoEnhanceNetAdvanced as advanced_attention
|
||||
from .residual2 import Residual2
|
||||
from .residual3 import Residual3
|
||||
from .dummy import Dummy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
|
|
@ -17,16 +13,12 @@ import json
|
|||
|
||||
|
||||
MODELS = {
|
||||
# "v1": v1,
|
||||
# "SimpleCNN": SimpleCNN,
|
||||
"v1": v1,
|
||||
"Dummy": Dummy,
|
||||
"SimpleCNN": SimpleCNN,
|
||||
"Residual": Residual,
|
||||
# "NormalisedCNN": NormalisedCNN,
|
||||
# "SmartRes": SmartRes,
|
||||
# "AttentionNet": AttentionNet,
|
||||
# "attention2": EnhancedAestheticHistogramNet,
|
||||
# "advanced_attention": advanced_attention,
|
||||
# "Res2": Res2,
|
||||
# "attention1": PhotoEnhanceNetAdvanced,
|
||||
# "Residual2": Residual2,
|
||||
"Residual3": Residual3,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -39,6 +31,7 @@ def save_model(model: nn.Module, hyperparameters: Dict[str, Any], path: Path):
|
|||
params_path = path.with_suffix(".json")
|
||||
|
||||
logging.info(f"Saving model to {model_path}")
|
||||
logging.info(f"Parameter count: {sum(p.numel() for p in model.parameters())}")
|
||||
with open(model_path, "wb") as f:
|
||||
torch.save(model.state_dict(), f)
|
||||
with open(params_path, "w") as f:
|
||||
|
|
@ -60,11 +53,12 @@ def load_model(path: Path, device: torch.device) -> Tuple[nn.Module, Dict[str, A
|
|||
).to(device)
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
model.eval()
|
||||
logging.info(f"Parameter count: {sum(p.numel() for p in model.parameters())}")
|
||||
|
||||
return model, hyperparameters
|
||||
|
||||
|
||||
def _test_models():
|
||||
def test_models():
|
||||
for model_name, model_constructor in MODELS.items():
|
||||
logging.info(f"Testing model {model_name}")
|
||||
_test_network_dimensions(model_constructor)
|
||||
|
|
@ -72,18 +66,13 @@ def _test_models():
|
|||
|
||||
def _test_network_dimensions(constructor):
|
||||
for bin_count in [16, 32, 64]:
|
||||
model = constructor(bin_count=bin_count)
|
||||
model = constructor()
|
||||
|
||||
# Create a dummy input tensor of the correct shape
|
||||
# Create a dummy input tensor of the correct shape, the mini-batch size is 4
|
||||
input_tensor = torch.rand(4, 1, bin_count, bin_count, bin_count)
|
||||
|
||||
# Test the model output
|
||||
output = model(input_tensor)
|
||||
assert (
|
||||
input_tensor.shape == output.shape
|
||||
), f"Expected output shape {input_tensor.shape}, but got {output.shape}"
|
||||
print("Test passed! Output shape matches input shape.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_models()
|
||||
logging.info("Test passed! Output shape matches input shape.")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue