Remove editor module

This commit is contained in:
Andras Schmelczer 2024-06-22 18:37:24 +01:00
parent e5959268c1
commit c966866abc
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
37 changed files with 7752 additions and 7345 deletions

View file

@ -0,0 +1,38 @@
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
class ProgressivePoolingLoss(nn.Module):
def __init__(self, target_sizes: List[int], damping: float):
super(ProgressivePoolingLoss, self).__init__()
self._target_sizes = target_sizes
self._damping = damping
def forward(self, tensor_a, tensor_b):
assert (
tensor_a.size() == tensor_b.size()
), f"Input tensors must have the same size, got {tensor_a.size()} and {tensor_b.size()}"
assert (
len(tensor_a.size()) == 5
), f"Input tensors must have 5 dimensions, got {tensor_a.size()}"
_minibatch_size, _channels, depth, height, width = tensor_a.size()
assert depth == height == width, "Input tensors must be cubes."
loss = 0.0
weight = 1
for target_size in self._target_sizes:
pool_size = depth // target_size
pooled_a = F.avg_pool3d(tensor_a, pool_size) * (pool_size**3)
pooled_b = F.avg_pool3d(tensor_b, pool_size) * (pool_size**3)
diff = torch.abs(pooled_a - pooled_b)
loss += diff.mean() * weight
weight *= self._damping
return loss