38 lines
1.2 KiB
Python
38 lines
1.2 KiB
Python
from typing import List
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class ProgressivePoolingLoss(nn.Module):
|
|
def __init__(self, target_sizes: List[int], damping: float):
|
|
super(ProgressivePoolingLoss, self).__init__()
|
|
self._target_sizes = target_sizes
|
|
self._damping = damping
|
|
|
|
def forward(self, tensor_a, tensor_b):
|
|
assert (
|
|
tensor_a.size() == tensor_b.size()
|
|
), f"Input tensors must have the same size, got {tensor_a.size()} and {tensor_b.size()}"
|
|
|
|
assert (
|
|
len(tensor_a.size()) == 5
|
|
), f"Input tensors must have 5 dimensions, got {tensor_a.size()}"
|
|
|
|
_minibatch_size, _channels, depth, height, width = tensor_a.size()
|
|
assert depth == height == width, "Input tensors must be cubes."
|
|
|
|
loss = 0.0
|
|
weight = 1
|
|
|
|
for target_size in self._target_sizes:
|
|
pool_size = depth // target_size
|
|
pooled_a = F.avg_pool3d(tensor_a, pool_size) * (pool_size**3)
|
|
pooled_b = F.avg_pool3d(tensor_b, pool_size) * (pool_size**3)
|
|
|
|
diff = torch.abs(pooled_a - pooled_b)
|
|
|
|
loss += diff.mean() * weight
|
|
weight *= self._damping
|
|
|
|
return loss
|