No timeout

This commit is contained in:
Andras Schmelczer 2024-06-29 11:05:35 +01:00
parent 137ba1c475
commit 2f28356ab9
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
2 changed files with 1 additions and 15 deletions

View file

@ -51,8 +51,6 @@ def random_hparam_search(
except KeyboardInterrupt as e:
logging.info("Interrupted, stopping")
break
except TimeoutError as e:
logging.warning(f"Timeout, aborting experiment")
except Exception as e:
logging.error(
f"Error with hparams {current_hyperparameters}:\n\t{e}", stack_info=True

View file

@ -1,13 +1,11 @@
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from torch.optim import Adam
from tqdm.notebook import tqdm
from visualisation import plot_histograms_in_2d
from models import create_model
from datetime import timedelta, datetime
from torch.utils.data import DataLoader
import torch
from utils import serialise_hparams
from .get_data_loader import get_data_loader
@ -20,7 +18,6 @@ def train(
train_data_paths: List[Path],
test_data_paths: List[Path],
log_dir: Path,
max_duration: Optional[timedelta],
use_tqdm: bool,
device: torch.device,
model_type: str,
@ -31,7 +28,6 @@ def train(
) -> torch.nn.Module:
train_data_loader = get_data_loader(train_data_paths, **hyperparameters)
test_data_loader = get_data_loader(test_data_paths, **hyperparameters)
start_time = datetime.now()
with SummaryWriter(log_dir) as writer:
model = create_model(
@ -48,7 +44,6 @@ def train(
loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device)
for epoch in range(num_epochs):
model.print_og_result = True
epoch_loss = 0
writer.add_scalar("Actual learning rate", scheduler.get_last_lr()[0], epoch)
for batch_id, (edited_histogram, original_histogram) in enumerate(
@ -56,13 +51,6 @@ def train(
if use_tqdm
else train_data_loader
):
current_time = datetime.now()
if (
max_duration is not None
and current_time - start_time > max_duration
):
raise TimeoutError(f"Time limit {max_duration} exceeded")
optimizer.zero_grad()
predicted_original = model(edited_histogram.to(device))
loss = loss_function(