Fix train/test split
This commit is contained in:
parent
af56ec3fec
commit
edeac12e37
7 changed files with 9738 additions and 26131 deletions
|
|
@ -1,4 +1,3 @@
|
|||
from .histogram_dataset import HistogramDataset
|
||||
from .random_edit import random_edit
|
||||
from .progressive_pooling_loss import ProgressivePoolingLoss
|
||||
from .create_data_loaders import create_data_loaders
|
||||
|
|
|
|||
|
|
@ -1,46 +0,0 @@
|
|||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from editor.training import HistogramDataset
|
||||
import logging
|
||||
import torch
|
||||
from config import CACHE_PATH
|
||||
import os
|
||||
|
||||
|
||||
def create_data_loaders(
|
||||
data: List[Path],
|
||||
edit_count: int,
|
||||
bin_count: int,
|
||||
training_batch_size: int,
|
||||
train_size=0.9,
|
||||
delete_corrupt_images: bool = False,
|
||||
) -> Tuple[DataLoader, DataLoader]:
|
||||
dataset = HistogramDataset(
|
||||
data,
|
||||
edit_count=edit_count,
|
||||
bin_count=bin_count,
|
||||
delete_corrupt_images=delete_corrupt_images,
|
||||
cache_path=CACHE_PATH,
|
||||
)
|
||||
total_size = len(dataset)
|
||||
train_size = int(train_size * total_size)
|
||||
test_size = total_size - train_size
|
||||
train_dataset, test_dataset = random_split(
|
||||
dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42)
|
||||
)
|
||||
|
||||
train_data_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=training_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=os.cpu_count(),
|
||||
)
|
||||
test_data_loader = DataLoader(
|
||||
test_dataset, batch_size=1, shuffle=False, num_workers=os.cpu_count()
|
||||
)
|
||||
logging.info(
|
||||
f"Loaded {len(train_dataset)} training images and {len(test_dataset)} test images"
|
||||
)
|
||||
|
||||
return train_data_loader, test_data_loader
|
||||
|
|
@ -24,6 +24,8 @@ class HistogramDataset(Dataset):
|
|||
cache_path: Optional[Path] = None,
|
||||
):
|
||||
self._paths = sorted(paths)
|
||||
logging.info(f"Loaded {len(self._paths)} original images")
|
||||
|
||||
self._edit_count = edit_count
|
||||
self._bin_count = bin_count
|
||||
self._target_size = target_size
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue