From 3583cd559afa7a059f4f13d8e90becd486da2439 Mon Sep 17 00:00:00 2001 From: Andras Schmelczer Date: Sat, 22 Jun 2024 15:09:27 +0100 Subject: [PATCH] Add norm --- src/editor/models/__init__.py | 4 ++-- src/editor/models/residual.py | 4 +++- src/editor/models/simple_cnn.py | 4 +++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/editor/models/__init__.py b/src/editor/models/__init__.py index d4252d8..f9b026d 100644 --- a/src/editor/models/__init__.py +++ b/src/editor/models/__init__.py @@ -13,11 +13,11 @@ import json MODELS = { - "v1": v1, + # "v1": v1, "Dummy": Dummy, "SimpleCNN": SimpleCNN, "Residual": Residual, - # "Residual2": Residual2, + "Residual2": Residual2, "Residual3": Residual3, } diff --git a/src/editor/models/residual.py b/src/editor/models/residual.py index c17cd2e..6703af7 100644 --- a/src/editor/models/residual.py +++ b/src/editor/models/residual.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn @@ -48,4 +49,5 @@ class Residual(nn.Module): out = self.relu(self.deconv2(out)) out = self.relu(self.deconv3(out)) - return out + sum = torch.sum(out, dim=(2, 3, 4), keepdim=True) + return out / sum diff --git a/src/editor/models/simple_cnn.py b/src/editor/models/simple_cnn.py index f5bd7f9..dc356bf 100644 --- a/src/editor/models/simple_cnn.py +++ b/src/editor/models/simple_cnn.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn import torch.nn.functional as F @@ -32,4 +33,5 @@ class SimpleCNN(nn.Module): x = F.relu(self.conv4(x)) x = F.relu(self.conv5(x)) x = self.conv6(x) - return x + sum = torch.sum(x, dim=(2, 3, 4), keepdim=True) + return x / sum