This commit is contained in:
Andras Schmelczer 2024-04-12 20:49:09 +01:00
parent 38b21135e2
commit 07d926161e
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C

View file

@ -0,0 +1,31 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class ProgressivePoolingLoss(nn.Module):
def __init__(self, initial_pool_size: int = 2, damping=1.8):
super(ProgressivePoolingLoss, self).__init__()
self._initial_pool_size = initial_pool_size
self._damping = damping
def forward(self, tensor_a, tensor_b):
assert (
tensor_a.size() == tensor_b.size()
), "Input tensors must have the same size."
max_pool_size = min(tensor_a.size(1), tensor_a.size(2), tensor_a.size(3))
loss = 0.0
damping = 1
for pool_size in range(self._initial_pool_size, max_pool_size):
pooled_a = F.avg_pool3d(tensor_a, pool_size) * (pool_size**3)
pooled_b = F.avg_pool3d(tensor_b, pool_size) * (pool_size**3)
diff = torch.square(pooled_a - pooled_b)
loss += diff.mean() / damping
damping *= self._damping
return loss