23 lines
614 B
Python
23 lines
614 B
Python
from pathlib import Path
|
|
from typing import List
|
|
from torch.utils.data import DataLoader
|
|
from config import CACHE_PATH
|
|
from training import HistogramDataset
|
|
import os
|
|
|
|
|
|
def get_data_loader(
|
|
data: List[Path], edit_count: int, bin_count: int, batch_size: int, **_
|
|
) -> DataLoader:
|
|
return DataLoader(
|
|
dataset=HistogramDataset(
|
|
paths=data,
|
|
edit_count=edit_count,
|
|
bin_count=bin_count,
|
|
delete_corrupt_images=False,
|
|
cache_path=CACHE_PATH,
|
|
),
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
num_workers=os.cpu_count(),
|
|
)
|