Update training
This commit is contained in:
parent
c7c0f292c6
commit
7863611f86
4 changed files with 97 additions and 137 deletions
|
|
@ -1,5 +1,5 @@
|
|||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from pathlib import Path
|
||||
from torch.optim import Adam
|
||||
|
|
@ -10,32 +10,33 @@ from datetime import timedelta, datetime
|
|||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
from utils import serialise_hparams
|
||||
|
||||
from .get_data_loader import get_data_loader
|
||||
|
||||
EPSILON = 1e-5
|
||||
|
||||
|
||||
def train(
|
||||
hyperparameters: Dict[str, Any],
|
||||
train_data_loader: DataLoader,
|
||||
test_data_loader: DataLoader,
|
||||
train_data_paths: List[Path],
|
||||
test_data_paths: List[Path],
|
||||
log_dir: Path,
|
||||
max_duration: Optional[timedelta],
|
||||
use_tqdm: bool,
|
||||
device: torch.device,
|
||||
model_type: str,
|
||||
bin_count: int,
|
||||
learning_rate: float,
|
||||
scheduler_gamma: float,
|
||||
num_epochs: int,
|
||||
**_,
|
||||
) -> torch.nn.Module:
|
||||
train_data_loader = get_data_loader(train_data_paths, **hyperparameters)
|
||||
test_data_loader = get_data_loader(test_data_paths, **hyperparameters)
|
||||
start_time = datetime.now()
|
||||
|
||||
with SummaryWriter(log_dir) as writer:
|
||||
model = create_model(
|
||||
type=model_type,
|
||||
bin_count=bin_count,
|
||||
hyperparameters=hyperparameters,
|
||||
device=device,
|
||||
).train()
|
||||
writer.add_graph(model, next(iter(train_data_loader))[0].to(device))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue