46 lines
1.3 KiB
Python
46 lines
1.3 KiB
Python
from pathlib import Path
|
|
from typing import List, Tuple
|
|
from torch.utils.data import DataLoader, random_split
|
|
from editor.training import HistogramDataset
|
|
import logging
|
|
import torch
|
|
from config import CACHE_PATH
|
|
import os
|
|
|
|
|
|
def create_data_loaders(
|
|
data: List[Path],
|
|
edit_count: int,
|
|
bin_count: int,
|
|
training_batch_size: int,
|
|
train_size=0.9,
|
|
delete_corrupt_images: bool = False,
|
|
) -> Tuple[DataLoader, DataLoader]:
|
|
dataset = HistogramDataset(
|
|
data,
|
|
edit_count=edit_count,
|
|
bin_count=bin_count,
|
|
delete_corrupt_images=delete_corrupt_images,
|
|
cache_path=CACHE_PATH,
|
|
)
|
|
total_size = len(dataset)
|
|
train_size = int(train_size * total_size)
|
|
test_size = total_size - train_size
|
|
train_dataset, test_dataset = random_split(
|
|
dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42)
|
|
)
|
|
|
|
train_data_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=training_batch_size,
|
|
shuffle=True,
|
|
num_workers=os.cpu_count(),
|
|
)
|
|
test_data_loader = DataLoader(
|
|
test_dataset, batch_size=1, shuffle=False, num_workers=os.cpu_count()
|
|
)
|
|
logging.info(
|
|
f"Loaded {len(train_dataset)} training images and {len(test_dataset)} test images"
|
|
)
|
|
|
|
return train_data_loader, test_data_loader
|