From ab06d979e3c899116acaf424e8c25f728530ac9f Mon Sep 17 00:00:00 2001 From: Andras Schmelczer Date: Tue, 25 Jun 2024 08:25:00 +0100 Subject: [PATCH] House-keeping --- src/__init__.py | 0 src/models/__init__.py | 3 +- src/operations/__init__.py | 1 + src/{training => operations}/random_edit.py | 0 src/training/__init__.py | 5 +-- src/training/progressive_pooling_loss.py | 38 --------------------- src/utils/__init__.py | 2 +- 7 files changed, 7 insertions(+), 42 deletions(-) create mode 100644 src/__init__.py rename src/{training => operations}/random_edit.py (100%) delete mode 100644 src/training/progressive_pooling_loss.py diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/__init__.py b/src/models/__init__.py index de6ff8d..658116b 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -48,7 +48,8 @@ def load_model(path: Path, device: torch.device) -> Tuple[nn.Module, Dict[str, A model = create_model( type=hyperparameters["model_type"], bin_count=hyperparameters["bin_count"], - ).to(device) + device=device, + ) model.load_state_dict(torch.load(model_path)) model.eval() logging.info(f"Parameter count: {sum(p.numel() for p in model.parameters())}") diff --git a/src/operations/__init__.py b/src/operations/__init__.py index 5da6215..11d048d 100644 --- a/src/operations/__init__.py +++ b/src/operations/__init__.py @@ -3,3 +3,4 @@ from .add_random_colour_spill import add_random_colour_spill from .gamma import adjust_gamma, get_random_gamma from .get_colour_lut import get_random_saturation_per_hue_lut, get_random_brightness_lut from .apply_pixel_shader import apply_pixel_shader +from .random_edit import random_edit diff --git a/src/training/random_edit.py b/src/operations/random_edit.py similarity index 100% rename from src/training/random_edit.py rename to src/operations/random_edit.py diff --git a/src/training/__init__.py b/src/training/__init__.py index 272d7a2..deb569a 100644 --- a/src/training/__init__.py +++ b/src/training/__init__.py @@ -1,3 +1,4 @@ from .histogram_dataset import HistogramDataset -from .random_edit import random_edit -from .progressive_pooling_loss import ProgressivePoolingLoss +from .get_next_run_name import get_next_run_name +from .random_hparam_search import random_hparam_search +from .train import train diff --git a/src/training/progressive_pooling_loss.py b/src/training/progressive_pooling_loss.py deleted file mode 100644 index c579216..0000000 --- a/src/training/progressive_pooling_loss.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import List -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class ProgressivePoolingLoss(nn.Module): - def __init__(self, target_sizes: List[int], damping: float): - super(ProgressivePoolingLoss, self).__init__() - self._target_sizes = target_sizes - self._damping = damping - - def forward(self, tensor_a, tensor_b): - assert ( - tensor_a.size() == tensor_b.size() - ), f"Input tensors must have the same size, got {tensor_a.size()} and {tensor_b.size()}" - - assert ( - len(tensor_a.size()) == 5 - ), f"Input tensors must have 5 dimensions, got {tensor_a.size()}" - - _minibatch_size, _channels, depth, height, width = tensor_a.size() - assert depth == height == width, "Input tensors must be cubes." - - loss = 0.0 - weight = 1 - - for target_size in self._target_sizes: - pool_size = depth // target_size - pooled_a = F.avg_pool3d(tensor_a, pool_size) * (pool_size**3) - pooled_b = F.avg_pool3d(tensor_b, pool_size) * (pool_size**3) - - diff = torch.abs(pooled_a - pooled_b) - - loss += diff.mean() * weight - weight *= self._damping - - return loss diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 12f8fb1..de54983 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,6 +1,6 @@ from .random import random from .compute_histogram import compute_histogram from .generate_rotation_matrices import generate_rotation_matrices -from .get_next_run_name import get_next_run_name from .kldiv import kldiv from .set_up_logging import set_up_logging +from .serialise_hparams import serialise_hparams