Fix "extract training"

This commit is contained in:
Andras Schmelczer 2024-06-25 09:02:45 +01:00
parent 856dc83c77
commit 2475d7c8dd
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
3 changed files with 40 additions and 8 deletions

View file

@ -0,0 +1,23 @@
from pathlib import Path
from typing import List
from torch.utils.data import DataLoader
from config import CACHE_PATH
from training import HistogramDataset
import os
def get_data_loader(
data: List[Path], edit_count: int, bin_count: int, batch_size: int, **_
) -> DataLoader:
return DataLoader(
dataset=HistogramDataset(
paths=data,
edit_count=edit_count,
bin_count=bin_count,
delete_corrupt_images=False,
cache_path=CACHE_PATH,
),
batch_size=batch_size,
shuffle=True,
num_workers=os.cpu_count(),
)

View file

@ -1,7 +1,7 @@
from torch.utils.data import Dataset
from typing import List, Optional, Tuple
from utils import compute_histogram
from .random_edit import random_edit
from operations.random_edit import random_edit
from PIL import Image
from tqdm import tqdm
import logging

View file

@ -8,18 +8,20 @@ from typing import Any, Dict, List
from .train import train
from .get_next_run_name import get_next_run_name
from models import save_model
from torch.utils.data import DataLoader
from .get_data_loader import get_data_loader
import torch
def random_hparam_search(
hyperparameters: List[Dict[str, Any]],
training_data_path: DataLoader,
test_data_path: DataLoader,
train_data_paths: List[Path],
test_data_paths: List[Path],
models_path: Path,
tensorboard_path: Path,
timeout_hours: int,
device: torch.device,
) -> None:
for _ in count():
for _ in range(1):
current_hyperparameters = {
k: v.rvs() if hasattr(v, "rvs") else choice(v)
for k, v in choice(hyperparameters).items()
@ -34,17 +36,24 @@ def random_hparam_search(
log_dir = tensorboard_path / get_next_run_name(tensorboard_path)
try:
train_data_loader = get_data_loader(
train_data_paths, **current_hyperparameters
)
test_data_loader = get_data_loader(
test_data_paths, **current_hyperparameters
)
model = train(
hyperparameters=current_hyperparameters,
training_data_path=training_data_path,
test_data_path=test_data_path,
train_data_loader=train_data_loader,
test_data_loader=test_data_loader,
max_duration=timedelta(hours=timeout_hours),
log_dir=log_dir,
use_tqdm=False,
device=device,
**current_hyperparameters,
)
model_path = models_path / get_next_run_name(models_path)
save_model(model, hyperparameters, model_path)
save_model(model, current_hyperparameters, model_path)
del model
except KeyboardInterrupt as e:
logging.info("Interrupted, stopping")