Extract data loading
This commit is contained in:
parent
de17b3c91b
commit
b31fa39ca4
3 changed files with 289 additions and 111440 deletions
|
|
@ -1,3 +1,4 @@
|
|||
from .histogram_dataset import HistogramDataset
|
||||
from .random_edit import random_edit
|
||||
from .progressive_pooling_loss import ProgressivePoolingLoss
|
||||
from .create_data_loaders import create_data_loaders
|
||||
|
|
|
|||
46
src/editor/training/create_data_loaders.py
Normal file
46
src/editor/training/create_data_loaders.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from editor.training import HistogramDataset
|
||||
import logging
|
||||
import torch
|
||||
from config import CACHE_PATH
|
||||
import os
|
||||
|
||||
|
||||
def create_data_loaders(
|
||||
data: List[Path],
|
||||
edit_count: int,
|
||||
bin_count: int,
|
||||
training_batch_size: int,
|
||||
train_size=0.9,
|
||||
delete_corrupt_images: bool = False,
|
||||
) -> Tuple[DataLoader, DataLoader]:
|
||||
dataset = HistogramDataset(
|
||||
data,
|
||||
edit_count=edit_count,
|
||||
bin_count=bin_count,
|
||||
delete_corrupt_images=delete_corrupt_images,
|
||||
cache_path=CACHE_PATH,
|
||||
)
|
||||
total_size = len(dataset)
|
||||
train_size = int(train_size * total_size)
|
||||
test_size = total_size - train_size
|
||||
train_dataset, test_dataset = random_split(
|
||||
dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42)
|
||||
)
|
||||
|
||||
train_data_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=training_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=os.cpu_count(),
|
||||
)
|
||||
test_data_loader = DataLoader(
|
||||
test_dataset, batch_size=1, shuffle=False, num_workers=os.cpu_count()
|
||||
)
|
||||
logging.info(
|
||||
f"Loaded {len(train_dataset)} training images and {len(test_dataset)} test images"
|
||||
)
|
||||
|
||||
return train_data_loader, test_data_loader
|
||||
111682
src/train.ipynb
111682
src/train.ipynb
File diff suppressed because one or more lines are too long
Loading…
Add table
Add a link
Reference in a new issue