Fix "extract training"
This commit is contained in:
parent
856dc83c77
commit
2475d7c8dd
3 changed files with 40 additions and 8 deletions
23
src/training/get_data_loader.py
Normal file
23
src/training/get_data_loader.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
from torch.utils.data import DataLoader
|
||||
from config import CACHE_PATH
|
||||
from training import HistogramDataset
|
||||
import os
|
||||
|
||||
|
||||
def get_data_loader(
|
||||
data: List[Path], edit_count: int, bin_count: int, batch_size: int, **_
|
||||
) -> DataLoader:
|
||||
return DataLoader(
|
||||
dataset=HistogramDataset(
|
||||
paths=data,
|
||||
edit_count=edit_count,
|
||||
bin_count=bin_count,
|
||||
delete_corrupt_images=False,
|
||||
cache_path=CACHE_PATH,
|
||||
),
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=os.cpu_count(),
|
||||
)
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
from torch.utils.data import Dataset
|
||||
from typing import List, Optional, Tuple
|
||||
from utils import compute_histogram
|
||||
from .random_edit import random_edit
|
||||
from operations.random_edit import random_edit
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
|
|
|
|||
|
|
@ -8,18 +8,20 @@ from typing import Any, Dict, List
|
|||
from .train import train
|
||||
from .get_next_run_name import get_next_run_name
|
||||
from models import save_model
|
||||
from torch.utils.data import DataLoader
|
||||
from .get_data_loader import get_data_loader
|
||||
import torch
|
||||
|
||||
|
||||
def random_hparam_search(
|
||||
hyperparameters: List[Dict[str, Any]],
|
||||
training_data_path: DataLoader,
|
||||
test_data_path: DataLoader,
|
||||
train_data_paths: List[Path],
|
||||
test_data_paths: List[Path],
|
||||
models_path: Path,
|
||||
tensorboard_path: Path,
|
||||
timeout_hours: int,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
for _ in count():
|
||||
for _ in range(1):
|
||||
current_hyperparameters = {
|
||||
k: v.rvs() if hasattr(v, "rvs") else choice(v)
|
||||
for k, v in choice(hyperparameters).items()
|
||||
|
|
@ -34,17 +36,24 @@ def random_hparam_search(
|
|||
log_dir = tensorboard_path / get_next_run_name(tensorboard_path)
|
||||
|
||||
try:
|
||||
train_data_loader = get_data_loader(
|
||||
train_data_paths, **current_hyperparameters
|
||||
)
|
||||
test_data_loader = get_data_loader(
|
||||
test_data_paths, **current_hyperparameters
|
||||
)
|
||||
model = train(
|
||||
hyperparameters=current_hyperparameters,
|
||||
training_data_path=training_data_path,
|
||||
test_data_path=test_data_path,
|
||||
train_data_loader=train_data_loader,
|
||||
test_data_loader=test_data_loader,
|
||||
max_duration=timedelta(hours=timeout_hours),
|
||||
log_dir=log_dir,
|
||||
use_tqdm=False,
|
||||
device=device,
|
||||
**current_hyperparameters,
|
||||
)
|
||||
model_path = models_path / get_next_run_name(models_path)
|
||||
save_model(model, hyperparameters, model_path)
|
||||
save_model(model, current_hyperparameters, model_path)
|
||||
del model
|
||||
except KeyboardInterrupt as e:
|
||||
logging.info("Interrupted, stopping")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue