import logging from typing import Any, Dict, Optional from torch.utils.tensorboard import SummaryWriter from pathlib import Path from torch.optim import Adam from tqdm.notebook import tqdm from visualisation import plot_histograms_in_2d from models import create_model from datetime import timedelta, datetime from torch.utils.data import DataLoader import torch from utils import serialise_hparams EPSILON = 1e-5 def train( hyperparameters: Dict[str, Any], train_data_loader: DataLoader, test_data_loader: DataLoader, log_dir: Path, max_duration: Optional[timedelta], use_tqdm: bool, device: torch.device, model_type: str, bin_count: int, learning_rate: float, scheduler_gamma: float, num_epochs: int, **_, ) -> torch.nn.Module: start_time = datetime.now() with SummaryWriter(log_dir) as writer: model = create_model( type=model_type, bin_count=bin_count, device=device, ).train() writer.add_graph(model, next(iter(train_data_loader))[0].to(device)) optimizer = Adam(model.parameters(), lr=learning_rate) scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=1, gamma=scheduler_gamma ) loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device) for epoch in range(num_epochs): epoch_loss = 0 writer.add_scalar("Actual learning rate", scheduler.get_last_lr()[0], epoch) for batch_id, (edited_histogram, original_histogram) in enumerate( tqdm(train_data_loader, desc=f"Epoch {epoch}", unit="batch") if use_tqdm else train_data_loader ): current_time = datetime.now() if ( max_duration is not None and current_time - start_time > max_duration ): raise TimeoutError(f"Time limit {max_duration} exceeded") optimizer.zero_grad() predicted_original = model(edited_histogram.to(device)) loss = loss_function( torch.log(torch.clamp(predicted_original, EPSILON, 1)), original_histogram.to(device), ) epoch_loss += loss.item() writer.add_scalar( "Loss/train/batch", loss, global_step=epoch * len(train_data_loader) + batch_id, ) loss.backward() optimizer.step() logging.info(f"Epoch {epoch} train loss: {epoch_loss}") with torch.no_grad(): model.eval() loader = iter(test_data_loader) edited_histogram, original_histogram = next(loader) predicted_original = model(edited_histogram.to(device)) writer.add_figure( "histogram", plot_histograms_in_2d( { "original": original_histogram[0].numpy().squeeze(), "edited": edited_histogram.cpu()[0].numpy().squeeze(), "predicted": predicted_original.cpu()[0].numpy().squeeze(), } ), epoch, ) epoch_test_loss = 0 for batch_id, (edited_histogram, original_histogram) in enumerate( test_data_loader ): predicted_original = model(edited_histogram.to(device)) epoch_test_loss += loss_function( torch.log(torch.clamp(predicted_original, EPSILON, 1)), original_histogram.to(device), ).item() writer.add_hparams( serialise_hparams(hyperparameters), { "Loss/test/epoch": epoch_test_loss, "Loss/train/epoch": epoch_loss, }, global_step=epoch, run_name=log_dir.absolute(), ) logging.info(f"Epoch {epoch} test loss: {epoch_test_loss}") model.train() scheduler.step() return model