diff --git a/src/editor/models/__init__.py b/src/editor/models/__init__.py index d4252d8..f9b026d 100644 --- a/src/editor/models/__init__.py +++ b/src/editor/models/__init__.py @@ -13,11 +13,11 @@ import json MODELS = { - "v1": v1, + # "v1": v1, "Dummy": Dummy, "SimpleCNN": SimpleCNN, "Residual": Residual, - # "Residual2": Residual2, + "Residual2": Residual2, "Residual3": Residual3, } diff --git a/src/editor/models/residual.py b/src/editor/models/residual.py index c17cd2e..6703af7 100644 --- a/src/editor/models/residual.py +++ b/src/editor/models/residual.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn @@ -48,4 +49,5 @@ class Residual(nn.Module): out = self.relu(self.deconv2(out)) out = self.relu(self.deconv3(out)) - return out + sum = torch.sum(out, dim=(2, 3, 4), keepdim=True) + return out / sum diff --git a/src/editor/models/simple_cnn.py b/src/editor/models/simple_cnn.py index f5bd7f9..dc356bf 100644 --- a/src/editor/models/simple_cnn.py +++ b/src/editor/models/simple_cnn.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn import torch.nn.functional as F @@ -32,4 +33,5 @@ class SimpleCNN(nn.Module): x = F.relu(self.conv4(x)) x = F.relu(self.conv5(x)) x = self.conv6(x) - return x + sum = torch.sum(x, dim=(2, 3, 4), keepdim=True) + return x / sum