This commit is contained in:
Andras Schmelczer 2024-06-22 15:09:27 +01:00
parent b7755bfea8
commit 3583cd559a
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
3 changed files with 8 additions and 4 deletions

View file

@ -13,11 +13,11 @@ import json
MODELS = {
"v1": v1,
# "v1": v1,
"Dummy": Dummy,
"SimpleCNN": SimpleCNN,
"Residual": Residual,
# "Residual2": Residual2,
"Residual2": Residual2,
"Residual3": Residual3,
}

View file

@ -1,3 +1,4 @@
import torch
import torch.nn as nn
@ -48,4 +49,5 @@ class Residual(nn.Module):
out = self.relu(self.deconv2(out))
out = self.relu(self.deconv3(out))
return out
sum = torch.sum(out, dim=(2, 3, 4), keepdim=True)
return out / sum

View file

@ -1,3 +1,4 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -32,4 +33,5 @@ class SimpleCNN(nn.Module):
x = F.relu(self.conv4(x))
x = F.relu(self.conv5(x))
x = self.conv6(x)
return x
sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
return x / sum