Fix loss explosion?

This commit is contained in:
Andras Schmelczer 2024-06-25 09:00:20 +01:00
parent d1b6e52f31
commit 856dc83c77
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
2 changed files with 15 additions and 13 deletions

View file

@ -2,6 +2,9 @@ import torch
import torch.nn as nn
EPSILON = 1e-5
class Residual3(nn.Module):
def __init__(
self,
@ -124,11 +127,11 @@ class Residual3(nn.Module):
out = self.deconv2(out)
out = self.deconv3(out)
return self._normalize(out)
return out
def _normalize(self, x):
x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
return x / (x_sum + 1e-6)
# def _normalize(self, x):
# x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
# return x / (x_sum + EPSILON)
def _initialize_weights(self):
for m in self.modules():