Update training

This commit is contained in:
Andras Schmelczer 2024-06-27 22:30:27 +01:00
parent c7c0f292c6
commit 7863611f86
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
4 changed files with 97 additions and 137 deletions

View file

@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from torch.optim import Adam
@ -10,32 +10,33 @@ from datetime import timedelta, datetime
from torch.utils.data import DataLoader
import torch
from utils import serialise_hparams
from .get_data_loader import get_data_loader
EPSILON = 1e-5
def train(
hyperparameters: Dict[str, Any],
train_data_loader: DataLoader,
test_data_loader: DataLoader,
train_data_paths: List[Path],
test_data_paths: List[Path],
log_dir: Path,
max_duration: Optional[timedelta],
use_tqdm: bool,
device: torch.device,
model_type: str,
bin_count: int,
learning_rate: float,
scheduler_gamma: float,
num_epochs: int,
**_,
) -> torch.nn.Module:
train_data_loader = get_data_loader(train_data_paths, **hyperparameters)
test_data_loader = get_data_loader(test_data_paths, **hyperparameters)
start_time = datetime.now()
with SummaryWriter(log_dir) as writer:
model = create_model(
type=model_type,
bin_count=bin_count,
hyperparameters=hyperparameters,
device=device,
).train()
writer.add_graph(model, next(iter(train_data_loader))[0].to(device))