From 2f28356ab90bf89621fb2ed347a222453cdd1315 Mon Sep 17 00:00:00 2001 From: Andras Schmelczer Date: Sat, 29 Jun 2024 11:05:35 +0100 Subject: [PATCH] No timeout --- src/training/random_hparam_search.py | 2 -- src/training/train.py | 14 +------------- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/src/training/random_hparam_search.py b/src/training/random_hparam_search.py index cc3b08c..e6c9819 100644 --- a/src/training/random_hparam_search.py +++ b/src/training/random_hparam_search.py @@ -51,8 +51,6 @@ def random_hparam_search( except KeyboardInterrupt as e: logging.info("Interrupted, stopping") break - except TimeoutError as e: - logging.warning(f"Timeout, aborting experiment") except Exception as e: logging.error( f"Error with hparams {current_hyperparameters}:\n\t{e}", stack_info=True diff --git a/src/training/train.py b/src/training/train.py index b8337eb..6ef9f56 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -1,13 +1,11 @@ import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from torch.utils.tensorboard import SummaryWriter from pathlib import Path from torch.optim import Adam from tqdm.notebook import tqdm from visualisation import plot_histograms_in_2d from models import create_model -from datetime import timedelta, datetime -from torch.utils.data import DataLoader import torch from utils import serialise_hparams from .get_data_loader import get_data_loader @@ -20,7 +18,6 @@ def train( train_data_paths: List[Path], test_data_paths: List[Path], log_dir: Path, - max_duration: Optional[timedelta], use_tqdm: bool, device: torch.device, model_type: str, @@ -31,7 +28,6 @@ def train( ) -> torch.nn.Module: train_data_loader = get_data_loader(train_data_paths, **hyperparameters) test_data_loader = get_data_loader(test_data_paths, **hyperparameters) - start_time = datetime.now() with SummaryWriter(log_dir) as writer: model = create_model( @@ -48,7 +44,6 @@ def train( loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device) for epoch in range(num_epochs): - model.print_og_result = True epoch_loss = 0 writer.add_scalar("Actual learning rate", scheduler.get_last_lr()[0], epoch) for batch_id, (edited_histogram, original_histogram) in enumerate( @@ -56,13 +51,6 @@ def train( if use_tqdm else train_data_loader ): - current_time = datetime.now() - if ( - max_duration is not None - and current_time - start_time > max_duration - ): - raise TimeoutError(f"Time limit {max_duration} exceeded") - optimizer.zero_grad() predicted_original = model(edited_histogram.to(device)) loss = loss_function(