diff --git a/src/models/residual3.py b/src/models/residual3.py index 349c52b..92a63bb 100644 --- a/src/models/residual3.py +++ b/src/models/residual3.py @@ -1,3 +1,4 @@ +import logging import torch import torch.nn as nn @@ -15,7 +16,7 @@ class Residual3(nn.Module): use_instance_norm: bool = True, use_elu: bool = True, leaky_relu_alpha: float = 0.01, - **_ + **_, ): super(Residual3, self).__init__() self._elu_alpha = elu_alpha @@ -25,6 +26,7 @@ class Residual3(nn.Module): self._use_instance_norm = use_instance_norm self._use_elu = use_elu self._leaky_relu_alpha = leaky_relu_alpha + self.print_og_result = False self.conv1 = self._make_conv_layer(1, features[0], kernel_sizes[0]) self.res1 = self._make_resblock(features[0], kernel_sizes[0]) @@ -127,11 +129,17 @@ class Residual3(nn.Module): out = self.deconv2(out) out = self.deconv3(out) - return out + if self.print_og_result: + logging.info(f"Original result {torch.sum(out)}") + self.print_og_result = False - # def _normalize(self, x): - # x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True) - # return x / (x_sum + EPSILON) + return self._normalize(out) + + @staticmethod + def _normalize(x): + x = torch.clamp(x, min=0) + x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True) + return x / (x_sum + EPSILON) def _initialize_weights(self): for m in self.modules(): diff --git a/src/training/train.py b/src/training/train.py index 7de5b33..f2a90a0 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -47,6 +47,7 @@ def train( loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device) for epoch in range(num_epochs): + model.print_og_result = True epoch_loss = 0 writer.add_scalar("Actual learning rate", scheduler.get_last_lr()[0], epoch) for batch_id, (edited_histogram, original_histogram) in enumerate(