Fix bugs
This commit is contained in:
parent
ae2995d0e9
commit
137ba1c475
2 changed files with 12 additions and 18 deletions
|
|
@ -22,6 +22,7 @@ def random_hparam_search(
|
|||
device: torch.device,
|
||||
) -> None:
|
||||
for _ in count():
|
||||
run_id = get_next_run_name(tensorboard_path)
|
||||
current_hyperparameters = {
|
||||
k: v.rvs() if hasattr(v, "rvs") else choice(v)
|
||||
for k, v in choice(hyperparameters).items()
|
||||
|
|
@ -29,11 +30,9 @@ def random_hparam_search(
|
|||
serialized_hparams = json.dumps(
|
||||
current_hyperparameters, indent=2, sort_keys=True
|
||||
)
|
||||
logging.info(
|
||||
f"Starting {get_next_run_name(tensorboard_path)} with hparams {serialized_hparams}"
|
||||
)
|
||||
logging.info(f"Starting {run_id} with hparams {serialized_hparams}")
|
||||
|
||||
log_dir = tensorboard_path / get_next_run_name(tensorboard_path)
|
||||
log_dir = tensorboard_path / run_id
|
||||
|
||||
try:
|
||||
model = train(
|
||||
|
|
@ -46,7 +45,7 @@ def random_hparam_search(
|
|||
device=device,
|
||||
**current_hyperparameters,
|
||||
)
|
||||
model_path = models_path / get_next_run_name(models_path)
|
||||
model_path = models_path / run_id
|
||||
save_model(model, current_hyperparameters, model_path)
|
||||
del model
|
||||
except KeyboardInterrupt as e:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue