diff --git a/src/training/get_data_loader.py b/src/training/get_data_loader.py new file mode 100644 index 0000000..b6735a9 --- /dev/null +++ b/src/training/get_data_loader.py @@ -0,0 +1,23 @@ +from pathlib import Path +from typing import List +from torch.utils.data import DataLoader +from config import CACHE_PATH +from training import HistogramDataset +import os + + +def get_data_loader( + data: List[Path], edit_count: int, bin_count: int, batch_size: int, **_ +) -> DataLoader: + return DataLoader( + dataset=HistogramDataset( + paths=data, + edit_count=edit_count, + bin_count=bin_count, + delete_corrupt_images=False, + cache_path=CACHE_PATH, + ), + batch_size=batch_size, + shuffle=True, + num_workers=os.cpu_count(), + ) diff --git a/src/training/histogram_dataset.py b/src/training/histogram_dataset.py index 93a0513..7aa1b56 100644 --- a/src/training/histogram_dataset.py +++ b/src/training/histogram_dataset.py @@ -1,7 +1,7 @@ from torch.utils.data import Dataset from typing import List, Optional, Tuple from utils import compute_histogram -from .random_edit import random_edit +from operations.random_edit import random_edit from PIL import Image from tqdm import tqdm import logging diff --git a/src/training/random_hparam_search.py b/src/training/random_hparam_search.py index 907a27a..0e0f6e5 100644 --- a/src/training/random_hparam_search.py +++ b/src/training/random_hparam_search.py @@ -8,18 +8,20 @@ from typing import Any, Dict, List from .train import train from .get_next_run_name import get_next_run_name from models import save_model -from torch.utils.data import DataLoader +from .get_data_loader import get_data_loader +import torch def random_hparam_search( hyperparameters: List[Dict[str, Any]], - training_data_path: DataLoader, - test_data_path: DataLoader, + train_data_paths: List[Path], + test_data_paths: List[Path], models_path: Path, tensorboard_path: Path, timeout_hours: int, + device: torch.device, ) -> None: - for _ in count(): + for _ in range(1): current_hyperparameters = { k: v.rvs() if hasattr(v, "rvs") else choice(v) for k, v in choice(hyperparameters).items() @@ -34,17 +36,24 @@ def random_hparam_search( log_dir = tensorboard_path / get_next_run_name(tensorboard_path) try: + train_data_loader = get_data_loader( + train_data_paths, **current_hyperparameters + ) + test_data_loader = get_data_loader( + test_data_paths, **current_hyperparameters + ) model = train( hyperparameters=current_hyperparameters, - training_data_path=training_data_path, - test_data_path=test_data_path, + train_data_loader=train_data_loader, + test_data_loader=test_data_loader, max_duration=timedelta(hours=timeout_hours), log_dir=log_dir, use_tqdm=False, + device=device, **current_hyperparameters, ) model_path = models_path / get_next_run_name(models_path) - save_model(model, hyperparameters, model_path) + save_model(model, current_hyperparameters, model_path) del model except KeyboardInterrupt as e: logging.info("Interrupted, stopping")