Fix loss explosion?
This commit is contained in:
parent
d1b6e52f31
commit
856dc83c77
2 changed files with 15 additions and 13 deletions
|
|
@ -2,6 +2,9 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
EPSILON = 1e-5
|
||||
|
||||
|
||||
class Residual3(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -124,11 +127,11 @@ class Residual3(nn.Module):
|
|||
out = self.deconv2(out)
|
||||
out = self.deconv3(out)
|
||||
|
||||
return self._normalize(out)
|
||||
return out
|
||||
|
||||
def _normalize(self, x):
|
||||
x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
|
||||
return x / (x_sum + 1e-6)
|
||||
# def _normalize(self, x):
|
||||
# x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
|
||||
# return x / (x_sum + EPSILON)
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue