No timeout
This commit is contained in:
parent
137ba1c475
commit
2f28356ab9
2 changed files with 1 additions and 15 deletions
|
|
@ -51,8 +51,6 @@ def random_hparam_search(
|
|||
except KeyboardInterrupt as e:
|
||||
logging.info("Interrupted, stopping")
|
||||
break
|
||||
except TimeoutError as e:
|
||||
logging.warning(f"Timeout, aborting experiment")
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error with hparams {current_hyperparameters}:\n\t{e}", stack_info=True
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from pathlib import Path
|
||||
from torch.optim import Adam
|
||||
from tqdm.notebook import tqdm
|
||||
from visualisation import plot_histograms_in_2d
|
||||
from models import create_model
|
||||
from datetime import timedelta, datetime
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
from utils import serialise_hparams
|
||||
from .get_data_loader import get_data_loader
|
||||
|
|
@ -20,7 +18,6 @@ def train(
|
|||
train_data_paths: List[Path],
|
||||
test_data_paths: List[Path],
|
||||
log_dir: Path,
|
||||
max_duration: Optional[timedelta],
|
||||
use_tqdm: bool,
|
||||
device: torch.device,
|
||||
model_type: str,
|
||||
|
|
@ -31,7 +28,6 @@ def train(
|
|||
) -> torch.nn.Module:
|
||||
train_data_loader = get_data_loader(train_data_paths, **hyperparameters)
|
||||
test_data_loader = get_data_loader(test_data_paths, **hyperparameters)
|
||||
start_time = datetime.now()
|
||||
|
||||
with SummaryWriter(log_dir) as writer:
|
||||
model = create_model(
|
||||
|
|
@ -48,7 +44,6 @@ def train(
|
|||
loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
model.print_og_result = True
|
||||
epoch_loss = 0
|
||||
writer.add_scalar("Actual learning rate", scheduler.get_last_lr()[0], epoch)
|
||||
for batch_id, (edited_histogram, original_histogram) in enumerate(
|
||||
|
|
@ -56,13 +51,6 @@ def train(
|
|||
if use_tqdm
|
||||
else train_data_loader
|
||||
):
|
||||
current_time = datetime.now()
|
||||
if (
|
||||
max_duration is not None
|
||||
and current_time - start_time > max_duration
|
||||
):
|
||||
raise TimeoutError(f"Time limit {max_duration} exceeded")
|
||||
|
||||
optimizer.zero_grad()
|
||||
predicted_original = model(edited_histogram.to(device))
|
||||
loss = loss_function(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue