fix training
This commit is contained in:
parent
b87c1dd859
commit
4349f5af41
4 changed files with 86227 additions and 120070 deletions
|
|
@ -1 +1,3 @@
|
|||
from .histogram_dataset import HistogramDataset
|
||||
from .random_edit import random_edit
|
||||
from .progressive_pooling_loss import ProgressivePoolingLoss
|
||||
|
|
|
|||
|
|
@ -61,13 +61,14 @@ class HistogramDataset(Dataset):
|
|||
|
||||
edited = random_edit(original, seed=idx)
|
||||
|
||||
original_histogram = compute_histogram(
|
||||
original, bins=self._bin_count, normalize=True
|
||||
)
|
||||
edited_histogram = compute_histogram(
|
||||
edited, bins=self._bin_count, normalize=True
|
||||
)
|
||||
|
||||
original_histogram = compute_histogram(
|
||||
original, bins=self._bin_count, normalize=True
|
||||
)
|
||||
|
||||
result = (
|
||||
torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0),
|
||||
torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0),
|
||||
|
|
|
|||
|
|
@ -1,31 +1,38 @@
|
|||
from typing import List
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ProgressivePoolingLoss(nn.Module):
|
||||
def __init__(self, initial_pool_size: int = 2, damping=1.8):
|
||||
def __init__(self, target_sizes: List[int], damping: float):
|
||||
super(ProgressivePoolingLoss, self).__init__()
|
||||
self._initial_pool_size = initial_pool_size
|
||||
self._target_sizes = target_sizes
|
||||
self._damping = damping
|
||||
|
||||
def forward(self, tensor_a, tensor_b):
|
||||
assert (
|
||||
tensor_a.size() == tensor_b.size()
|
||||
), "Input tensors must have the same size."
|
||||
), f"Input tensors must have the same size, got {tensor_a.size()} and {tensor_b.size()}"
|
||||
|
||||
max_pool_size = min(tensor_a.size(1), tensor_a.size(2), tensor_a.size(3))
|
||||
assert (
|
||||
len(tensor_a.size()) == 5
|
||||
), f"Input tensors must have 5 dimensions, got {tensor_a.size()}"
|
||||
|
||||
_minibatch_size, _channels, depth, height, width = tensor_a.size()
|
||||
assert depth == height == width, "Input tensors must be cubes."
|
||||
|
||||
loss = 0.0
|
||||
damping = 1
|
||||
weight = 1
|
||||
|
||||
for pool_size in range(self._initial_pool_size, max_pool_size):
|
||||
for target_size in self._target_sizes:
|
||||
pool_size = depth // target_size
|
||||
pooled_a = F.avg_pool3d(tensor_a, pool_size) * (pool_size**3)
|
||||
pooled_b = F.avg_pool3d(tensor_b, pool_size) * (pool_size**3)
|
||||
|
||||
diff = torch.square(pooled_a - pooled_b)
|
||||
diff = torch.abs(pooled_a - pooled_b)
|
||||
|
||||
loss += diff.mean() / damping
|
||||
damping *= self._damping
|
||||
loss += diff.mean() * weight
|
||||
weight *= self._damping
|
||||
|
||||
return loss
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue