This commit is contained in:
Andras Schmelczer 2024-06-27 21:35:49 +01:00
parent 7b95d7a2bd
commit 28b8b026a9
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
2 changed files with 14 additions and 5 deletions

View file

@ -1,3 +1,4 @@
import logging
import torch
import torch.nn as nn
@ -15,7 +16,7 @@ class Residual3(nn.Module):
use_instance_norm: bool = True,
use_elu: bool = True,
leaky_relu_alpha: float = 0.01,
**_
**_,
):
super(Residual3, self).__init__()
self._elu_alpha = elu_alpha
@ -25,6 +26,7 @@ class Residual3(nn.Module):
self._use_instance_norm = use_instance_norm
self._use_elu = use_elu
self._leaky_relu_alpha = leaky_relu_alpha
self.print_og_result = False
self.conv1 = self._make_conv_layer(1, features[0], kernel_sizes[0])
self.res1 = self._make_resblock(features[0], kernel_sizes[0])
@ -127,11 +129,17 @@ class Residual3(nn.Module):
out = self.deconv2(out)
out = self.deconv3(out)
return out
if self.print_og_result:
logging.info(f"Original result {torch.sum(out)}")
self.print_og_result = False
# def _normalize(self, x):
# x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
# return x / (x_sum + EPSILON)
return self._normalize(out)
@staticmethod
def _normalize(x):
x = torch.clamp(x, min=0)
x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
return x / (x_sum + EPSILON)
def _initialize_weights(self):
for m in self.modules():

View file

@ -47,6 +47,7 @@ def train(
loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device)
for epoch in range(num_epochs):
model.print_og_result = True
epoch_loss = 0
writer.add_scalar("Actual learning rate", scheduler.get_last_lr()[0], epoch)
for batch_id, (edited_histogram, original_histogram) in enumerate(