Don't eval after load

This commit is contained in:
Andras Schmelczer 2024-09-01 22:10:31 +01:00
parent 9e286264b8
commit a5d74c650c
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
3 changed files with 2245 additions and 2199 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -51,7 +51,6 @@ def load_model(path: Path, device: torch.device) -> Tuple[nn.Module, Dict[str, A
device=device,
)
model.load_state_dict(torch.load(model_path))
model.eval()
logging.info(f"Parameter count: {sum(p.numel() for p in model.parameters())}")
return model, hyperparameters