Extract training functions

This commit is contained in:
Andras Schmelczer 2024-06-25 08:23:59 +01:00
parent c966866abc
commit d336ec3be6
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
4 changed files with 183 additions and 0 deletions

View file

@ -0,0 +1,57 @@
from datetime import timedelta
import logging
from pathlib import Path
from random import choice
from itertools import count
import json
from typing import Any, Dict, List
from .train import train
from .get_next_run_name import get_next_run_name
from models import save_model
from torch.utils.data import DataLoader
def random_hparam_search(
hyperparameters: List[Dict[str, Any]],
training_data_path: DataLoader,
test_data_path: DataLoader,
models_path: Path,
tensorboard_path: Path,
timeout_hours: int,
) -> None:
for _ in count():
current_hyperparameters = {
k: v.rvs() if hasattr(v, "rvs") else choice(v)
for k, v in choice(hyperparameters).items()
}
serialized_hparams = json.dumps(
current_hyperparameters, indent=2, sort_keys=True
)
logging.info(
f"Starting {get_next_run_name(tensorboard_path)} with hparams {serialized_hparams}"
)
log_dir = tensorboard_path / get_next_run_name(tensorboard_path)
try:
model = train(
hyperparameters=current_hyperparameters,
training_data_path=training_data_path,
test_data_path=test_data_path,
max_duration=timedelta(hours=timeout_hours),
log_dir=log_dir,
use_tqdm=False,
**current_hyperparameters,
)
model_path = models_path / get_next_run_name(models_path)
save_model(model, hyperparameters, model_path)
del model
except KeyboardInterrupt as e:
logging.info("Interrupted, stopping")
break
except TimeoutError as e:
logging.warning(f"Timeout, aborting experiment")
except Exception as e:
logging.error(
f"Error with hparams {current_hyperparameters}:\n\t{e}", stack_info=True
)