Update training

This commit is contained in:
Andras Schmelczer 2024-06-27 22:30:27 +01:00
parent c7c0f292c6
commit 7863611f86
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
4 changed files with 97 additions and 137 deletions

View file

@ -36,16 +36,10 @@ def random_hparam_search(
log_dir = tensorboard_path / get_next_run_name(tensorboard_path)
try:
train_data_loader = get_data_loader(
train_data_paths, **current_hyperparameters
)
test_data_loader = get_data_loader(
test_data_paths, **current_hyperparameters
)
model = train(
hyperparameters=current_hyperparameters,
train_data_loader=train_data_loader,
test_data_loader=test_data_loader,
train_data_paths=train_data_paths,
test_data_paths=test_data_paths,
max_duration=timedelta(hours=timeout_hours),
log_dir=log_dir,
use_tqdm=False,

View file

@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from torch.optim import Adam
@ -10,32 +10,33 @@ from datetime import timedelta, datetime
from torch.utils.data import DataLoader
import torch
from utils import serialise_hparams
from .get_data_loader import get_data_loader
EPSILON = 1e-5
def train(
hyperparameters: Dict[str, Any],
train_data_loader: DataLoader,
test_data_loader: DataLoader,
train_data_paths: List[Path],
test_data_paths: List[Path],
log_dir: Path,
max_duration: Optional[timedelta],
use_tqdm: bool,
device: torch.device,
model_type: str,
bin_count: int,
learning_rate: float,
scheduler_gamma: float,
num_epochs: int,
**_,
) -> torch.nn.Module:
train_data_loader = get_data_loader(train_data_paths, **hyperparameters)
test_data_loader = get_data_loader(test_data_paths, **hyperparameters)
start_time = datetime.now()
with SummaryWriter(log_dir) as writer:
model = create_model(
type=model_type,
bin_count=bin_count,
hyperparameters=hyperparameters,
device=device,
).train()
writer.add_graph(model, next(iter(train_data_loader))[0].to(device))