Extract training functions
This commit is contained in:
parent
c966866abc
commit
d336ec3be6
4 changed files with 183 additions and 0 deletions
57
src/training/random_hparam_search.py
Normal file
57
src/training/random_hparam_search.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
from datetime import timedelta
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from random import choice
|
||||
from itertools import count
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
from .train import train
|
||||
from .get_next_run_name import get_next_run_name
|
||||
from models import save_model
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
def random_hparam_search(
|
||||
hyperparameters: List[Dict[str, Any]],
|
||||
training_data_path: DataLoader,
|
||||
test_data_path: DataLoader,
|
||||
models_path: Path,
|
||||
tensorboard_path: Path,
|
||||
timeout_hours: int,
|
||||
) -> None:
|
||||
for _ in count():
|
||||
current_hyperparameters = {
|
||||
k: v.rvs() if hasattr(v, "rvs") else choice(v)
|
||||
for k, v in choice(hyperparameters).items()
|
||||
}
|
||||
serialized_hparams = json.dumps(
|
||||
current_hyperparameters, indent=2, sort_keys=True
|
||||
)
|
||||
logging.info(
|
||||
f"Starting {get_next_run_name(tensorboard_path)} with hparams {serialized_hparams}"
|
||||
)
|
||||
|
||||
log_dir = tensorboard_path / get_next_run_name(tensorboard_path)
|
||||
|
||||
try:
|
||||
model = train(
|
||||
hyperparameters=current_hyperparameters,
|
||||
training_data_path=training_data_path,
|
||||
test_data_path=test_data_path,
|
||||
max_duration=timedelta(hours=timeout_hours),
|
||||
log_dir=log_dir,
|
||||
use_tqdm=False,
|
||||
**current_hyperparameters,
|
||||
)
|
||||
model_path = models_path / get_next_run_name(models_path)
|
||||
save_model(model, hyperparameters, model_path)
|
||||
del model
|
||||
except KeyboardInterrupt as e:
|
||||
logging.info("Interrupted, stopping")
|
||||
break
|
||||
except TimeoutError as e:
|
||||
logging.warning(f"Timeout, aborting experiment")
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error with hparams {current_hyperparameters}:\n\t{e}", stack_info=True
|
||||
)
|
||||
121
src/training/train.py
Normal file
121
src/training/train.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from pathlib import Path
|
||||
from torch.optim import Adam
|
||||
from tqdm.notebook import tqdm
|
||||
from .get_next_run_name import get_next_run_name
|
||||
from visualisation import plot_histograms_in_2d
|
||||
from models import create_model, save_model
|
||||
from datetime import timedelta, datetime
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
from utils import serialise_hparams
|
||||
|
||||
|
||||
def train(
|
||||
hyperparameters: Dict[str, Any],
|
||||
training_data_path: DataLoader,
|
||||
test_data_path: DataLoader,
|
||||
log_dir: Path,
|
||||
max_duration: Optional[timedelta],
|
||||
use_tqdm: bool,
|
||||
device: torch.device,
|
||||
model_type: str,
|
||||
bin_count: int,
|
||||
learning_rate: float,
|
||||
scheduler_gamma: float,
|
||||
num_epochs: int,
|
||||
**_,
|
||||
) -> torch.nn.Module:
|
||||
start_time = datetime.now()
|
||||
|
||||
with SummaryWriter(log_dir) as writer:
|
||||
train_data_loader = training_data_path
|
||||
test_data_loader = test_data_path
|
||||
|
||||
model = create_model(
|
||||
type=model_type,
|
||||
bin_count=bin_count,
|
||||
device=device,
|
||||
).train()
|
||||
writer.add_graph(model, next(iter(train_data_loader))[0].to(device))
|
||||
|
||||
optimizer = Adam(model.parameters(), lr=learning_rate)
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(
|
||||
optimizer, step_size=1, gamma=scheduler_gamma
|
||||
)
|
||||
loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
epoch_loss = 0
|
||||
writer.add_scalar("Actual learning rate", scheduler.get_last_lr()[0], epoch)
|
||||
for batch_id, (edited_histogram, original_histogram) in enumerate(
|
||||
tqdm(train_data_loader, desc=f"Epoch {epoch}", unit="batch")
|
||||
if use_tqdm
|
||||
else train_data_loader
|
||||
):
|
||||
current_time = datetime.now()
|
||||
if (
|
||||
max_duration is not None
|
||||
and current_time - start_time > max_duration
|
||||
):
|
||||
raise TimeoutError(f"Time limit {max_duration} exceeded")
|
||||
|
||||
optimizer.zero_grad()
|
||||
predicted_original = model(edited_histogram.to(device))
|
||||
loss = loss_function(
|
||||
torch.log(torch.clamp(predicted_original, 1e-5, 1)),
|
||||
original_histogram.to(device),
|
||||
)
|
||||
|
||||
epoch_loss += loss.item()
|
||||
writer.add_scalar(
|
||||
"Loss/train/batch",
|
||||
loss,
|
||||
global_step=epoch * len(train_data_loader) + batch_id,
|
||||
)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
logging.info(f"Epoch {epoch} train loss: {epoch_loss}")
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
loader = iter(test_data_loader)
|
||||
edited_histogram, original_histogram = next(loader)
|
||||
predicted_original = model(edited_histogram.to(device))
|
||||
writer.add_figure(
|
||||
"histogram",
|
||||
plot_histograms_in_2d(
|
||||
{
|
||||
"original": original_histogram[0].numpy().squeeze(),
|
||||
"edited": edited_histogram.cpu()[0].numpy().squeeze(),
|
||||
"predicted": predicted_original.cpu()[0].numpy().squeeze(),
|
||||
}
|
||||
),
|
||||
epoch,
|
||||
)
|
||||
|
||||
epoch_test_loss = 0
|
||||
for batch_id, (edited_histogram, original_histogram) in enumerate(
|
||||
test_data_loader
|
||||
):
|
||||
predicted_original = model(edited_histogram.to(device))
|
||||
epoch_test_loss += loss_function(
|
||||
torch.log(torch.clamp(predicted_original, 1e-10, 1)),
|
||||
original_histogram.to(device),
|
||||
).item()
|
||||
writer.add_hparams(
|
||||
serialise_hparams(hyperparameters),
|
||||
{
|
||||
"Loss/test/epoch": epoch_test_loss,
|
||||
"Loss/train/epoch": epoch_loss,
|
||||
},
|
||||
global_step=epoch,
|
||||
run_name=log_dir.absolute(),
|
||||
)
|
||||
logging.info(f"Epoch {epoch} test loss: {epoch_test_loss}")
|
||||
|
||||
model.train()
|
||||
scheduler.step()
|
||||
return model
|
||||
5
src/utils/serialise_hparams.py
Normal file
5
src/utils/serialise_hparams.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from typing import Any, Dict
|
||||
|
||||
|
||||
def serialise_hparams(hyperparameters: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {k: str(v) if isinstance(v, list) else v for k, v in hyperparameters.items()}
|
||||
Loading…
Add table
Add a link
Reference in a new issue