Fix train/test split

This commit is contained in:
Andras Schmelczer 2024-06-04 22:48:07 +01:00
parent af56ec3fec
commit edeac12e37
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
7 changed files with 9738 additions and 26131 deletions

14
src/data.py Normal file
View file

@ -0,0 +1,14 @@
import random
from config import DATA, TRAIN_SIZE
random.seed(42)
length = len(DATA)
indices = list(range(length))
random.shuffle(indices)
train_indices = indices[: int(length * TRAIN_SIZE)]
test_indices = indices[int(length * TRAIN_SIZE) :]
TRAIN_DATA = [DATA[i] for i in train_indices]
TEST_DATA = [DATA[i] for i in test_indices]