diff --git a/src/utils/get_next_run_name.py b/src/training/get_next_run_name.py similarity index 100% rename from src/utils/get_next_run_name.py rename to src/training/get_next_run_name.py diff --git a/src/training/random_hparam_search.py b/src/training/random_hparam_search.py new file mode 100644 index 0000000..907a27a --- /dev/null +++ b/src/training/random_hparam_search.py @@ -0,0 +1,57 @@ +from datetime import timedelta +import logging +from pathlib import Path +from random import choice +from itertools import count +import json +from typing import Any, Dict, List +from .train import train +from .get_next_run_name import get_next_run_name +from models import save_model +from torch.utils.data import DataLoader + + +def random_hparam_search( + hyperparameters: List[Dict[str, Any]], + training_data_path: DataLoader, + test_data_path: DataLoader, + models_path: Path, + tensorboard_path: Path, + timeout_hours: int, +) -> None: + for _ in count(): + current_hyperparameters = { + k: v.rvs() if hasattr(v, "rvs") else choice(v) + for k, v in choice(hyperparameters).items() + } + serialized_hparams = json.dumps( + current_hyperparameters, indent=2, sort_keys=True + ) + logging.info( + f"Starting {get_next_run_name(tensorboard_path)} with hparams {serialized_hparams}" + ) + + log_dir = tensorboard_path / get_next_run_name(tensorboard_path) + + try: + model = train( + hyperparameters=current_hyperparameters, + training_data_path=training_data_path, + test_data_path=test_data_path, + max_duration=timedelta(hours=timeout_hours), + log_dir=log_dir, + use_tqdm=False, + **current_hyperparameters, + ) + model_path = models_path / get_next_run_name(models_path) + save_model(model, hyperparameters, model_path) + del model + except KeyboardInterrupt as e: + logging.info("Interrupted, stopping") + break + except TimeoutError as e: + logging.warning(f"Timeout, aborting experiment") + except Exception as e: + logging.error( + f"Error with hparams {current_hyperparameters}:\n\t{e}", stack_info=True + ) diff --git a/src/training/train.py b/src/training/train.py new file mode 100644 index 0000000..4d4f49b --- /dev/null +++ b/src/training/train.py @@ -0,0 +1,121 @@ +import logging +from typing import Any, Dict, Optional +from torch.utils.tensorboard import SummaryWriter +from pathlib import Path +from torch.optim import Adam +from tqdm.notebook import tqdm +from .get_next_run_name import get_next_run_name +from visualisation import plot_histograms_in_2d +from models import create_model, save_model +from datetime import timedelta, datetime +from torch.utils.data import DataLoader +import torch +from utils import serialise_hparams + + +def train( + hyperparameters: Dict[str, Any], + training_data_path: DataLoader, + test_data_path: DataLoader, + log_dir: Path, + max_duration: Optional[timedelta], + use_tqdm: bool, + device: torch.device, + model_type: str, + bin_count: int, + learning_rate: float, + scheduler_gamma: float, + num_epochs: int, + **_, +) -> torch.nn.Module: + start_time = datetime.now() + + with SummaryWriter(log_dir) as writer: + train_data_loader = training_data_path + test_data_loader = test_data_path + + model = create_model( + type=model_type, + bin_count=bin_count, + device=device, + ).train() + writer.add_graph(model, next(iter(train_data_loader))[0].to(device)) + + optimizer = Adam(model.parameters(), lr=learning_rate) + scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, step_size=1, gamma=scheduler_gamma + ) + loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device) + + for epoch in range(num_epochs): + epoch_loss = 0 + writer.add_scalar("Actual learning rate", scheduler.get_last_lr()[0], epoch) + for batch_id, (edited_histogram, original_histogram) in enumerate( + tqdm(train_data_loader, desc=f"Epoch {epoch}", unit="batch") + if use_tqdm + else train_data_loader + ): + current_time = datetime.now() + if ( + max_duration is not None + and current_time - start_time > max_duration + ): + raise TimeoutError(f"Time limit {max_duration} exceeded") + + optimizer.zero_grad() + predicted_original = model(edited_histogram.to(device)) + loss = loss_function( + torch.log(torch.clamp(predicted_original, 1e-5, 1)), + original_histogram.to(device), + ) + + epoch_loss += loss.item() + writer.add_scalar( + "Loss/train/batch", + loss, + global_step=epoch * len(train_data_loader) + batch_id, + ) + loss.backward() + optimizer.step() + + logging.info(f"Epoch {epoch} train loss: {epoch_loss}") + with torch.no_grad(): + model.eval() + loader = iter(test_data_loader) + edited_histogram, original_histogram = next(loader) + predicted_original = model(edited_histogram.to(device)) + writer.add_figure( + "histogram", + plot_histograms_in_2d( + { + "original": original_histogram[0].numpy().squeeze(), + "edited": edited_histogram.cpu()[0].numpy().squeeze(), + "predicted": predicted_original.cpu()[0].numpy().squeeze(), + } + ), + epoch, + ) + + epoch_test_loss = 0 + for batch_id, (edited_histogram, original_histogram) in enumerate( + test_data_loader + ): + predicted_original = model(edited_histogram.to(device)) + epoch_test_loss += loss_function( + torch.log(torch.clamp(predicted_original, 1e-10, 1)), + original_histogram.to(device), + ).item() + writer.add_hparams( + serialise_hparams(hyperparameters), + { + "Loss/test/epoch": epoch_test_loss, + "Loss/train/epoch": epoch_loss, + }, + global_step=epoch, + run_name=log_dir.absolute(), + ) + logging.info(f"Epoch {epoch} test loss: {epoch_test_loss}") + + model.train() + scheduler.step() + return model diff --git a/src/utils/serialise_hparams.py b/src/utils/serialise_hparams.py new file mode 100644 index 0000000..55748ca --- /dev/null +++ b/src/utils/serialise_hparams.py @@ -0,0 +1,5 @@ +from typing import Any, Dict + + +def serialise_hparams(hyperparameters: Dict[str, Any]) -> Dict[str, Any]: + return {k: str(v) if isinstance(v, list) else v for k, v in hyperparameters.items()}