Fix up models
This commit is contained in:
parent
28b8b026a9
commit
bf524eea0b
3 changed files with 45 additions and 114 deletions
|
|
@ -1,8 +1,7 @@
|
|||
from typing import Any, Dict, Tuple
|
||||
from .v1 import HistogramRestorationNet as v1
|
||||
from .simple_cnn import SimpleCNN
|
||||
from .residual import Residual
|
||||
from .residual3 import Residual3
|
||||
from .histogram_net import HistogramNet
|
||||
from .dummy import Dummy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -12,16 +11,17 @@ import json
|
|||
|
||||
|
||||
MODELS = {
|
||||
# "v1": v1,
|
||||
"Dummy": Dummy,
|
||||
"SimpleCNN": SimpleCNN,
|
||||
"Residual": Residual,
|
||||
"Residual3": Residual3,
|
||||
"HistogramNet": HistogramNet,
|
||||
}
|
||||
|
||||
|
||||
def create_model(type: str, bin_count: int, device: torch.device) -> nn.Module:
|
||||
return MODELS[type](bin_count).to(device)
|
||||
def create_model(
|
||||
type: str, hyperparameters: Dict[str, Any], device: torch.device
|
||||
) -> nn.Module:
|
||||
return MODELS[type](**hyperparameters).to(device)
|
||||
|
||||
|
||||
def save_model(model: nn.Module, hyperparameters: Dict[str, Any], path: Path):
|
||||
|
|
@ -47,7 +47,7 @@ def load_model(path: Path, device: torch.device) -> Tuple[nn.Module, Dict[str, A
|
|||
model_path = path.with_suffix(".pth")
|
||||
model = create_model(
|
||||
type=hyperparameters["model_type"],
|
||||
bin_count=hyperparameters["bin_count"],
|
||||
**hyperparameters,
|
||||
device=device,
|
||||
)
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue