This commit is contained in:
Andras Schmelczer 2024-06-27 21:35:49 +01:00
parent 7b95d7a2bd
commit 28b8b026a9
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
2 changed files with 14 additions and 5 deletions

View file

@ -47,6 +47,7 @@ def train(
loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device)
for epoch in range(num_epochs):
model.print_og_result = True
epoch_loss = 0
writer.add_scalar("Actual learning rate", scheduler.get_last_lr()[0], epoch)
for batch_id, (edited_histogram, original_histogram) in enumerate(