House-keeping
This commit is contained in:
parent
d336ec3be6
commit
ab06d979e3
7 changed files with 7 additions and 42 deletions
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
|
|
@ -48,7 +48,8 @@ def load_model(path: Path, device: torch.device) -> Tuple[nn.Module, Dict[str, A
|
||||||
model = create_model(
|
model = create_model(
|
||||||
type=hyperparameters["model_type"],
|
type=hyperparameters["model_type"],
|
||||||
bin_count=hyperparameters["bin_count"],
|
bin_count=hyperparameters["bin_count"],
|
||||||
).to(device)
|
device=device,
|
||||||
|
)
|
||||||
model.load_state_dict(torch.load(model_path))
|
model.load_state_dict(torch.load(model_path))
|
||||||
model.eval()
|
model.eval()
|
||||||
logging.info(f"Parameter count: {sum(p.numel() for p in model.parameters())}")
|
logging.info(f"Parameter count: {sum(p.numel() for p in model.parameters())}")
|
||||||
|
|
|
||||||
|
|
@ -3,3 +3,4 @@ from .add_random_colour_spill import add_random_colour_spill
|
||||||
from .gamma import adjust_gamma, get_random_gamma
|
from .gamma import adjust_gamma, get_random_gamma
|
||||||
from .get_colour_lut import get_random_saturation_per_hue_lut, get_random_brightness_lut
|
from .get_colour_lut import get_random_saturation_per_hue_lut, get_random_brightness_lut
|
||||||
from .apply_pixel_shader import apply_pixel_shader
|
from .apply_pixel_shader import apply_pixel_shader
|
||||||
|
from .random_edit import random_edit
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
from .histogram_dataset import HistogramDataset
|
from .histogram_dataset import HistogramDataset
|
||||||
from .random_edit import random_edit
|
from .get_next_run_name import get_next_run_name
|
||||||
from .progressive_pooling_loss import ProgressivePoolingLoss
|
from .random_hparam_search import random_hparam_search
|
||||||
|
from .train import train
|
||||||
|
|
|
||||||
|
|
@ -1,38 +0,0 @@
|
||||||
from typing import List
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class ProgressivePoolingLoss(nn.Module):
|
|
||||||
def __init__(self, target_sizes: List[int], damping: float):
|
|
||||||
super(ProgressivePoolingLoss, self).__init__()
|
|
||||||
self._target_sizes = target_sizes
|
|
||||||
self._damping = damping
|
|
||||||
|
|
||||||
def forward(self, tensor_a, tensor_b):
|
|
||||||
assert (
|
|
||||||
tensor_a.size() == tensor_b.size()
|
|
||||||
), f"Input tensors must have the same size, got {tensor_a.size()} and {tensor_b.size()}"
|
|
||||||
|
|
||||||
assert (
|
|
||||||
len(tensor_a.size()) == 5
|
|
||||||
), f"Input tensors must have 5 dimensions, got {tensor_a.size()}"
|
|
||||||
|
|
||||||
_minibatch_size, _channels, depth, height, width = tensor_a.size()
|
|
||||||
assert depth == height == width, "Input tensors must be cubes."
|
|
||||||
|
|
||||||
loss = 0.0
|
|
||||||
weight = 1
|
|
||||||
|
|
||||||
for target_size in self._target_sizes:
|
|
||||||
pool_size = depth // target_size
|
|
||||||
pooled_a = F.avg_pool3d(tensor_a, pool_size) * (pool_size**3)
|
|
||||||
pooled_b = F.avg_pool3d(tensor_b, pool_size) * (pool_size**3)
|
|
||||||
|
|
||||||
diff = torch.abs(pooled_a - pooled_b)
|
|
||||||
|
|
||||||
loss += diff.mean() * weight
|
|
||||||
weight *= self._damping
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from .random import random
|
from .random import random
|
||||||
from .compute_histogram import compute_histogram
|
from .compute_histogram import compute_histogram
|
||||||
from .generate_rotation_matrices import generate_rotation_matrices
|
from .generate_rotation_matrices import generate_rotation_matrices
|
||||||
from .get_next_run_name import get_next_run_name
|
|
||||||
from .kldiv import kldiv
|
from .kldiv import kldiv
|
||||||
from .set_up_logging import set_up_logging
|
from .set_up_logging import set_up_logging
|
||||||
|
from .serialise_hparams import serialise_hparams
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue