Fix nans
This commit is contained in:
parent
7b95d7a2bd
commit
28b8b026a9
2 changed files with 14 additions and 5 deletions
|
|
@ -47,6 +47,7 @@ def train(
|
|||
loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
model.print_og_result = True
|
||||
epoch_loss = 0
|
||||
writer.add_scalar("Actual learning rate", scheduler.get_last_lr()[0], epoch)
|
||||
for batch_id, (edited_histogram, original_histogram) in enumerate(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue