120 lines
4.2 KiB
Python
120 lines
4.2 KiB
Python
import logging
|
|
from typing import Any, Dict, Optional
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from pathlib import Path
|
|
from torch.optim import Adam
|
|
from tqdm.notebook import tqdm
|
|
from visualisation import plot_histograms_in_2d
|
|
from models import create_model
|
|
from datetime import timedelta, datetime
|
|
from torch.utils.data import DataLoader
|
|
import torch
|
|
from utils import serialise_hparams
|
|
|
|
|
|
EPSILON = 1e-5
|
|
|
|
|
|
def train(
|
|
hyperparameters: Dict[str, Any],
|
|
train_data_loader: DataLoader,
|
|
test_data_loader: DataLoader,
|
|
log_dir: Path,
|
|
max_duration: Optional[timedelta],
|
|
use_tqdm: bool,
|
|
device: torch.device,
|
|
model_type: str,
|
|
bin_count: int,
|
|
learning_rate: float,
|
|
scheduler_gamma: float,
|
|
num_epochs: int,
|
|
**_,
|
|
) -> torch.nn.Module:
|
|
start_time = datetime.now()
|
|
|
|
with SummaryWriter(log_dir) as writer:
|
|
model = create_model(
|
|
type=model_type,
|
|
bin_count=bin_count,
|
|
device=device,
|
|
).train()
|
|
writer.add_graph(model, next(iter(train_data_loader))[0].to(device))
|
|
|
|
optimizer = Adam(model.parameters(), lr=learning_rate)
|
|
scheduler = torch.optim.lr_scheduler.StepLR(
|
|
optimizer, step_size=1, gamma=scheduler_gamma
|
|
)
|
|
loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device)
|
|
|
|
for epoch in range(num_epochs):
|
|
epoch_loss = 0
|
|
writer.add_scalar("Actual learning rate", scheduler.get_last_lr()[0], epoch)
|
|
for batch_id, (edited_histogram, original_histogram) in enumerate(
|
|
tqdm(train_data_loader, desc=f"Epoch {epoch}", unit="batch")
|
|
if use_tqdm
|
|
else train_data_loader
|
|
):
|
|
current_time = datetime.now()
|
|
if (
|
|
max_duration is not None
|
|
and current_time - start_time > max_duration
|
|
):
|
|
raise TimeoutError(f"Time limit {max_duration} exceeded")
|
|
|
|
optimizer.zero_grad()
|
|
predicted_original = model(edited_histogram.to(device))
|
|
loss = loss_function(
|
|
torch.log(torch.clamp(predicted_original, EPSILON, 1)),
|
|
original_histogram.to(device),
|
|
)
|
|
|
|
epoch_loss += loss.item()
|
|
writer.add_scalar(
|
|
"Loss/train/batch",
|
|
loss,
|
|
global_step=epoch * len(train_data_loader) + batch_id,
|
|
)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
logging.info(f"Epoch {epoch} train loss: {epoch_loss}")
|
|
with torch.no_grad():
|
|
model.eval()
|
|
loader = iter(test_data_loader)
|
|
edited_histogram, original_histogram = next(loader)
|
|
predicted_original = model(edited_histogram.to(device))
|
|
writer.add_figure(
|
|
"histogram",
|
|
plot_histograms_in_2d(
|
|
{
|
|
"original": original_histogram[0].numpy().squeeze(),
|
|
"edited": edited_histogram.cpu()[0].numpy().squeeze(),
|
|
"predicted": predicted_original.cpu()[0].numpy().squeeze(),
|
|
}
|
|
),
|
|
epoch,
|
|
)
|
|
|
|
epoch_test_loss = 0
|
|
for batch_id, (edited_histogram, original_histogram) in enumerate(
|
|
test_data_loader
|
|
):
|
|
predicted_original = model(edited_histogram.to(device))
|
|
epoch_test_loss += loss_function(
|
|
torch.log(torch.clamp(predicted_original, EPSILON, 1)),
|
|
original_histogram.to(device),
|
|
).item()
|
|
writer.add_hparams(
|
|
serialise_hparams(hyperparameters),
|
|
{
|
|
"Loss/test/epoch": epoch_test_loss,
|
|
"Loss/train/epoch": epoch_loss,
|
|
},
|
|
global_step=epoch,
|
|
run_name=log_dir.absolute(),
|
|
)
|
|
logging.info(f"Epoch {epoch} test loss: {epoch_test_loss}")
|
|
|
|
model.train()
|
|
scheduler.step()
|
|
return model
|