Extract data loading

This commit is contained in:
Andras Schmelczer 2024-05-12 19:51:05 +01:00
parent de17b3c91b
commit b31fa39ca4
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
3 changed files with 289 additions and 111440 deletions

View file

@ -1,3 +1,4 @@
from .histogram_dataset import HistogramDataset
from .random_edit import random_edit
from .progressive_pooling_loss import ProgressivePoolingLoss
from .create_data_loaders import create_data_loaders

View file

@ -0,0 +1,46 @@
from pathlib import Path
from typing import List, Tuple
from torch.utils.data import DataLoader, random_split
from editor.training import HistogramDataset
import logging
import torch
from config import CACHE_PATH
import os
def create_data_loaders(
data: List[Path],
edit_count: int,
bin_count: int,
training_batch_size: int,
train_size=0.9,
delete_corrupt_images: bool = False,
) -> Tuple[DataLoader, DataLoader]:
dataset = HistogramDataset(
data,
edit_count=edit_count,
bin_count=bin_count,
delete_corrupt_images=delete_corrupt_images,
cache_path=CACHE_PATH,
)
total_size = len(dataset)
train_size = int(train_size * total_size)
test_size = total_size - train_size
train_dataset, test_dataset = random_split(
dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42)
)
train_data_loader = DataLoader(
train_dataset,
batch_size=training_batch_size,
shuffle=True,
num_workers=os.cpu_count(),
)
test_data_loader = DataLoader(
test_dataset, batch_size=1, shuffle=False, num_workers=os.cpu_count()
)
logging.info(
f"Loaded {len(train_dataset)} training images and {len(test_dataset)} test images"
)
return train_data_loader, test_data_loader

File diff suppressed because one or more lines are too long