Remove clutter
This commit is contained in:
parent
68397b565a
commit
49d9ece2ec
4 changed files with 6 additions and 68 deletions
File diff suppressed because one or more lines are too long
|
|
@ -1,4 +1,4 @@
|
||||||
from .histogram_dataset import HistogramDataset
|
from .histogram_dataset import HistogramDataset
|
||||||
from .get_next_run_name import get_next_run_name
|
from .get_next_run_name import get_next_run_name
|
||||||
from .random_hparam_search import random_hparam_search
|
|
||||||
from .train import train
|
from .train import train
|
||||||
|
from .train_with_ray import train_with_ray_factory
|
||||||
|
|
|
||||||
|
|
@ -56,10 +56,7 @@ class HistogramDataset(Dataset):
|
||||||
original_idx = idx // self._edit_count
|
original_idx = idx // self._edit_count
|
||||||
edit_idx = idx % self._edit_count
|
edit_idx = idx % self._edit_count
|
||||||
|
|
||||||
edited_histogram = None
|
|
||||||
original_histogram = None
|
|
||||||
cached_data_path = None
|
cached_data_path = None
|
||||||
|
|
||||||
if self._cache_path is not None:
|
if self._cache_path is not None:
|
||||||
cached_data_path = self._cache_path / str(original_idx) / f"{edit_idx}.bin"
|
cached_data_path = self._cache_path / str(original_idx) / f"{edit_idx}.bin"
|
||||||
cached_data_path.parent.mkdir(parents=True, exist_ok=True)
|
cached_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -89,13 +86,11 @@ class HistogramDataset(Dataset):
|
||||||
cached_data_path,
|
cached_data_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = (
|
return (
|
||||||
torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0),
|
torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0),
|
||||||
torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0),
|
torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0),
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_2_histograms(tensor1: np.ndarray, tensor2: np.ndarray, path: Path):
|
def save_2_histograms(tensor1: np.ndarray, tensor2: np.ndarray, path: Path):
|
||||||
flat_array1 = tensor1.flatten().astype(np.float32)
|
flat_array1 = tensor1.flatten().astype(np.float32)
|
||||||
|
|
|
||||||
|
|
@ -1,57 +0,0 @@
|
||||||
from datetime import timedelta
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from random import choice
|
|
||||||
from itertools import count
|
|
||||||
import json
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
from .train import train
|
|
||||||
from .get_next_run_name import get_next_run_name
|
|
||||||
from models import save_model
|
|
||||||
from .get_data_loader import get_data_loader
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def random_hparam_search(
|
|
||||||
hyperparameters: List[Dict[str, Any]],
|
|
||||||
train_data_paths: List[Path],
|
|
||||||
test_data_paths: List[Path],
|
|
||||||
models_path: Path,
|
|
||||||
tensorboard_path: Path,
|
|
||||||
timeout_hours: int,
|
|
||||||
device: torch.device,
|
|
||||||
) -> None:
|
|
||||||
for _ in count():
|
|
||||||
run_id = get_next_run_name(tensorboard_path)
|
|
||||||
current_hyperparameters = {
|
|
||||||
k: v.rvs() if hasattr(v, "rvs") else choice(v)
|
|
||||||
for k, v in choice(hyperparameters).items()
|
|
||||||
}
|
|
||||||
serialized_hparams = json.dumps(
|
|
||||||
current_hyperparameters, indent=2, sort_keys=True
|
|
||||||
)
|
|
||||||
logging.info(f"Starting {run_id} with hparams {serialized_hparams}")
|
|
||||||
|
|
||||||
log_dir = tensorboard_path / run_id
|
|
||||||
|
|
||||||
try:
|
|
||||||
model = train(
|
|
||||||
hyperparameters=current_hyperparameters,
|
|
||||||
train_data_paths=train_data_paths,
|
|
||||||
test_data_paths=test_data_paths,
|
|
||||||
max_duration=timedelta(hours=timeout_hours),
|
|
||||||
log_dir=log_dir,
|
|
||||||
use_tqdm=False,
|
|
||||||
device=device,
|
|
||||||
**current_hyperparameters,
|
|
||||||
)
|
|
||||||
model_path = models_path / run_id
|
|
||||||
save_model(model, current_hyperparameters, model_path)
|
|
||||||
del model
|
|
||||||
except KeyboardInterrupt as e:
|
|
||||||
logging.info("Interrupted, stopping")
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(
|
|
||||||
f"Error with hparams {current_hyperparameters}:\n\t{e}", stack_info=True
|
|
||||||
)
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue