Add norm
This commit is contained in:
parent
b7755bfea8
commit
3583cd559a
3 changed files with 8 additions and 4 deletions
|
|
@ -13,11 +13,11 @@ import json
|
|||
|
||||
|
||||
MODELS = {
|
||||
"v1": v1,
|
||||
# "v1": v1,
|
||||
"Dummy": Dummy,
|
||||
"SimpleCNN": SimpleCNN,
|
||||
"Residual": Residual,
|
||||
# "Residual2": Residual2,
|
||||
"Residual2": Residual2,
|
||||
"Residual3": Residual3,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
|
|
@ -48,4 +49,5 @@ class Residual(nn.Module):
|
|||
out = self.relu(self.deconv2(out))
|
||||
out = self.relu(self.deconv3(out))
|
||||
|
||||
return out
|
||||
sum = torch.sum(out, dim=(2, 3, 4), keepdim=True)
|
||||
return out / sum
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
|
@ -32,4 +33,5 @@ class SimpleCNN(nn.Module):
|
|||
x = F.relu(self.conv4(x))
|
||||
x = F.relu(self.conv5(x))
|
||||
x = self.conv6(x)
|
||||
return x
|
||||
sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
|
||||
return x / sum
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue