This commit is contained in:
Andras Schmelczer 2024-06-29 10:14:12 +01:00
parent ae2995d0e9
commit 137ba1c475
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
2 changed files with 12 additions and 18 deletions

View file

@ -22,6 +22,7 @@ def random_hparam_search(
device: torch.device,
) -> None:
for _ in count():
run_id = get_next_run_name(tensorboard_path)
current_hyperparameters = {
k: v.rvs() if hasattr(v, "rvs") else choice(v)
for k, v in choice(hyperparameters).items()
@ -29,11 +30,9 @@ def random_hparam_search(
serialized_hparams = json.dumps(
current_hyperparameters, indent=2, sort_keys=True
)
logging.info(
f"Starting {get_next_run_name(tensorboard_path)} with hparams {serialized_hparams}"
)
logging.info(f"Starting {run_id} with hparams {serialized_hparams}")
log_dir = tensorboard_path / get_next_run_name(tensorboard_path)
log_dir = tensorboard_path / run_id
try:
model = train(
@ -46,7 +45,7 @@ def random_hparam_search(
device=device,
**current_hyperparameters,
)
model_path = models_path / get_next_run_name(models_path)
model_path = models_path / run_id
save_model(model, current_hyperparameters, model_path)
del model
except KeyboardInterrupt as e: