From 07d926161e1d4fa6b962cc55550c44ac787a1eaf Mon Sep 17 00:00:00 2001 From: Andras Schmelczer Date: Fri, 12 Apr 2024 20:49:09 +0100 Subject: [PATCH] Add loss --- editor/training/progressive_pooling_loss.py | 31 +++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 editor/training/progressive_pooling_loss.py diff --git a/editor/training/progressive_pooling_loss.py b/editor/training/progressive_pooling_loss.py new file mode 100644 index 0000000..3e4cdc5 --- /dev/null +++ b/editor/training/progressive_pooling_loss.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ProgressivePoolingLoss(nn.Module): + def __init__(self, initial_pool_size: int = 2, damping=1.8): + super(ProgressivePoolingLoss, self).__init__() + self._initial_pool_size = initial_pool_size + self._damping = damping + + def forward(self, tensor_a, tensor_b): + assert ( + tensor_a.size() == tensor_b.size() + ), "Input tensors must have the same size." + + max_pool_size = min(tensor_a.size(1), tensor_a.size(2), tensor_a.size(3)) + + loss = 0.0 + damping = 1 + + for pool_size in range(self._initial_pool_size, max_pool_size): + pooled_a = F.avg_pool3d(tensor_a, pool_size) * (pool_size**3) + pooled_b = F.avg_pool3d(tensor_b, pool_size) * (pool_size**3) + + diff = torch.square(pooled_a - pooled_b) + + loss += diff.mean() / damping + damping *= self._damping + + return loss