Add better model

This commit is contained in:
Andras Schmelczer 2024-06-22 17:47:45 +01:00
parent 3583cd559a
commit 129a315228
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
2 changed files with 111 additions and 120 deletions

View file

@ -2,7 +2,6 @@ from typing import Any, Dict, Tuple
from .v1 import HistogramRestorationNet as v1
from .simple_cnn import SimpleCNN
from .residual import Residual
from .residual2 import Residual2
from .residual3 import Residual3
from .dummy import Dummy
import torch
@ -17,13 +16,12 @@ MODELS = {
"Dummy": Dummy,
"SimpleCNN": SimpleCNN,
"Residual": Residual,
"Residual2": Residual2,
"Residual3": Residual3,
}
def create_model(type: str, bin_count: int):
return MODELS[type](bin_count)
def create_model(type: str, bin_count: int, device: torch.device) -> nn.Module:
return MODELS[type](bin_count).to(device)
def save_model(model: nn.Module, hyperparameters: Dict[str, Any], path: Path):