Fix errors

This commit is contained in:
Andras Schmelczer 2024-06-27 23:13:37 +01:00
parent 7863611f86
commit 35eb747abf
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
3 changed files with 26 additions and 73 deletions

View file

@ -47,7 +47,7 @@ def load_model(path: Path, device: torch.device) -> Tuple[nn.Module, Dict[str, A
model_path = path.with_suffix(".pth")
model = create_model(
type=hyperparameters["model_type"],
**hyperparameters,
hyperparameters=hyperparameters,
device=device,
)
model.load_state_dict(torch.load(model_path))