Fix train/test split

This commit is contained in:
Andras Schmelczer 2024-06-04 22:48:07 +01:00
parent af56ec3fec
commit edeac12e37
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
7 changed files with 9738 additions and 26131 deletions

View file

@ -1,4 +1,3 @@
from .histogram_dataset import HistogramDataset
from .random_edit import random_edit
from .progressive_pooling_loss import ProgressivePoolingLoss
from .create_data_loaders import create_data_loaders

View file

@ -1,46 +0,0 @@
from pathlib import Path
from typing import List, Tuple
from torch.utils.data import DataLoader, random_split
from editor.training import HistogramDataset
import logging
import torch
from config import CACHE_PATH
import os
def create_data_loaders(
data: List[Path],
edit_count: int,
bin_count: int,
training_batch_size: int,
train_size=0.9,
delete_corrupt_images: bool = False,
) -> Tuple[DataLoader, DataLoader]:
dataset = HistogramDataset(
data,
edit_count=edit_count,
bin_count=bin_count,
delete_corrupt_images=delete_corrupt_images,
cache_path=CACHE_PATH,
)
total_size = len(dataset)
train_size = int(train_size * total_size)
test_size = total_size - train_size
train_dataset, test_dataset = random_split(
dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42)
)
train_data_loader = DataLoader(
train_dataset,
batch_size=training_batch_size,
shuffle=True,
num_workers=os.cpu_count(),
)
test_data_loader = DataLoader(
test_dataset, batch_size=1, shuffle=False, num_workers=os.cpu_count()
)
logging.info(
f"Loaded {len(train_dataset)} training images and {len(test_dataset)} test images"
)
return train_data_loader, test_data_loader

View file

@ -24,6 +24,8 @@ class HistogramDataset(Dataset):
cache_path: Optional[Path] = None,
):
self._paths = sorted(paths)
logging.info(f"Loaded {len(self._paths)} original images")
self._edit_count = edit_count
self._bin_count = bin_count
self._target_size = target_size