Extract training functions

This commit is contained in:
Andras Schmelczer 2024-06-25 08:23:59 +01:00
parent c966866abc
commit d336ec3be6
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
4 changed files with 183 additions and 0 deletions

View file

@ -0,0 +1,57 @@
from datetime import timedelta
import logging
from pathlib import Path
from random import choice
from itertools import count
import json
from typing import Any, Dict, List
from .train import train
from .get_next_run_name import get_next_run_name
from models import save_model
from torch.utils.data import DataLoader
def random_hparam_search(
hyperparameters: List[Dict[str, Any]],
training_data_path: DataLoader,
test_data_path: DataLoader,
models_path: Path,
tensorboard_path: Path,
timeout_hours: int,
) -> None:
for _ in count():
current_hyperparameters = {
k: v.rvs() if hasattr(v, "rvs") else choice(v)
for k, v in choice(hyperparameters).items()
}
serialized_hparams = json.dumps(
current_hyperparameters, indent=2, sort_keys=True
)
logging.info(
f"Starting {get_next_run_name(tensorboard_path)} with hparams {serialized_hparams}"
)
log_dir = tensorboard_path / get_next_run_name(tensorboard_path)
try:
model = train(
hyperparameters=current_hyperparameters,
training_data_path=training_data_path,
test_data_path=test_data_path,
max_duration=timedelta(hours=timeout_hours),
log_dir=log_dir,
use_tqdm=False,
**current_hyperparameters,
)
model_path = models_path / get_next_run_name(models_path)
save_model(model, hyperparameters, model_path)
del model
except KeyboardInterrupt as e:
logging.info("Interrupted, stopping")
break
except TimeoutError as e:
logging.warning(f"Timeout, aborting experiment")
except Exception as e:
logging.error(
f"Error with hparams {current_hyperparameters}:\n\t{e}", stack_info=True
)

121
src/training/train.py Normal file
View file

@ -0,0 +1,121 @@
import logging
from typing import Any, Dict, Optional
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from torch.optim import Adam
from tqdm.notebook import tqdm
from .get_next_run_name import get_next_run_name
from visualisation import plot_histograms_in_2d
from models import create_model, save_model
from datetime import timedelta, datetime
from torch.utils.data import DataLoader
import torch
from utils import serialise_hparams
def train(
hyperparameters: Dict[str, Any],
training_data_path: DataLoader,
test_data_path: DataLoader,
log_dir: Path,
max_duration: Optional[timedelta],
use_tqdm: bool,
device: torch.device,
model_type: str,
bin_count: int,
learning_rate: float,
scheduler_gamma: float,
num_epochs: int,
**_,
) -> torch.nn.Module:
start_time = datetime.now()
with SummaryWriter(log_dir) as writer:
train_data_loader = training_data_path
test_data_loader = test_data_path
model = create_model(
type=model_type,
bin_count=bin_count,
device=device,
).train()
writer.add_graph(model, next(iter(train_data_loader))[0].to(device))
optimizer = Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=1, gamma=scheduler_gamma
)
loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device)
for epoch in range(num_epochs):
epoch_loss = 0
writer.add_scalar("Actual learning rate", scheduler.get_last_lr()[0], epoch)
for batch_id, (edited_histogram, original_histogram) in enumerate(
tqdm(train_data_loader, desc=f"Epoch {epoch}", unit="batch")
if use_tqdm
else train_data_loader
):
current_time = datetime.now()
if (
max_duration is not None
and current_time - start_time > max_duration
):
raise TimeoutError(f"Time limit {max_duration} exceeded")
optimizer.zero_grad()
predicted_original = model(edited_histogram.to(device))
loss = loss_function(
torch.log(torch.clamp(predicted_original, 1e-5, 1)),
original_histogram.to(device),
)
epoch_loss += loss.item()
writer.add_scalar(
"Loss/train/batch",
loss,
global_step=epoch * len(train_data_loader) + batch_id,
)
loss.backward()
optimizer.step()
logging.info(f"Epoch {epoch} train loss: {epoch_loss}")
with torch.no_grad():
model.eval()
loader = iter(test_data_loader)
edited_histogram, original_histogram = next(loader)
predicted_original = model(edited_histogram.to(device))
writer.add_figure(
"histogram",
plot_histograms_in_2d(
{
"original": original_histogram[0].numpy().squeeze(),
"edited": edited_histogram.cpu()[0].numpy().squeeze(),
"predicted": predicted_original.cpu()[0].numpy().squeeze(),
}
),
epoch,
)
epoch_test_loss = 0
for batch_id, (edited_histogram, original_histogram) in enumerate(
test_data_loader
):
predicted_original = model(edited_histogram.to(device))
epoch_test_loss += loss_function(
torch.log(torch.clamp(predicted_original, 1e-10, 1)),
original_histogram.to(device),
).item()
writer.add_hparams(
serialise_hparams(hyperparameters),
{
"Loss/test/epoch": epoch_test_loss,
"Loss/train/epoch": epoch_loss,
},
global_step=epoch,
run_name=log_dir.absolute(),
)
logging.info(f"Epoch {epoch} test loss: {epoch_test_loss}")
model.train()
scheduler.step()
return model

View file

@ -0,0 +1,5 @@
from typing import Any, Dict
def serialise_hparams(hyperparameters: Dict[str, Any]) -> Dict[str, Any]:
return {k: str(v) if isinstance(v, list) else v for k, v in hyperparameters.items()}