House-keeping

This commit is contained in:
Andras Schmelczer 2024-06-25 08:25:00 +01:00
parent d336ec3be6
commit ab06d979e3
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
7 changed files with 7 additions and 42 deletions

View file

@ -48,7 +48,8 @@ def load_model(path: Path, device: torch.device) -> Tuple[nn.Module, Dict[str, A
model = create_model(
type=hyperparameters["model_type"],
bin_count=hyperparameters["bin_count"],
).to(device)
device=device,
)
model.load_state_dict(torch.load(model_path))
model.eval()
logging.info(f"Parameter count: {sum(p.numel() for p in model.parameters())}")