Fix nans
This commit is contained in:
parent
7b95d7a2bd
commit
28b8b026a9
2 changed files with 14 additions and 5 deletions
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
|
@ -15,7 +16,7 @@ class Residual3(nn.Module):
|
|||
use_instance_norm: bool = True,
|
||||
use_elu: bool = True,
|
||||
leaky_relu_alpha: float = 0.01,
|
||||
**_
|
||||
**_,
|
||||
):
|
||||
super(Residual3, self).__init__()
|
||||
self._elu_alpha = elu_alpha
|
||||
|
|
@ -25,6 +26,7 @@ class Residual3(nn.Module):
|
|||
self._use_instance_norm = use_instance_norm
|
||||
self._use_elu = use_elu
|
||||
self._leaky_relu_alpha = leaky_relu_alpha
|
||||
self.print_og_result = False
|
||||
|
||||
self.conv1 = self._make_conv_layer(1, features[0], kernel_sizes[0])
|
||||
self.res1 = self._make_resblock(features[0], kernel_sizes[0])
|
||||
|
|
@ -127,11 +129,17 @@ class Residual3(nn.Module):
|
|||
out = self.deconv2(out)
|
||||
out = self.deconv3(out)
|
||||
|
||||
return out
|
||||
if self.print_og_result:
|
||||
logging.info(f"Original result {torch.sum(out)}")
|
||||
self.print_og_result = False
|
||||
|
||||
# def _normalize(self, x):
|
||||
# x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
|
||||
# return x / (x_sum + EPSILON)
|
||||
return self._normalize(out)
|
||||
|
||||
@staticmethod
|
||||
def _normalize(x):
|
||||
x = torch.clamp(x, min=0)
|
||||
x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
|
||||
return x / (x_sum + EPSILON)
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ def train(
|
|||
loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
model.print_og_result = True
|
||||
epoch_loss = 0
|
||||
writer.add_scalar("Actual learning rate", scheduler.get_last_lr()[0], epoch)
|
||||
for batch_id, (edited_histogram, original_histogram) in enumerate(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue