diff --git a/src/histogram_transfer/__init__.py b/src/histogram_transfer/__init__.py index bfc8e2a..dda0123 100644 --- a/src/histogram_transfer/__init__.py +++ b/src/histogram_transfer/__init__.py @@ -1,4 +1,3 @@ from .regrain import regrain -from .pdf_transfer_1d import pdf_transfer_1d from .pdf_transfer_3d import pdf_transfer_3d from .apply_histogram import apply_histogram diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 54ec8fc..63de494 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -5,3 +5,4 @@ from .kldiv import kldiv from .set_up_logging import set_up_logging from .serialise_hparams import serialise_hparams from .get_device import get_device +from .delete_corrupt_images import delete_corrupt_images diff --git a/src/visualisation/plot_histograms_in_3d.py b/src/visualisation/plot_histograms_in_3d.py index d3834d2..c5e699c 100644 --- a/src/visualisation/plot_histograms_in_3d.py +++ b/src/visualisation/plot_histograms_in_3d.py @@ -6,7 +6,7 @@ import numpy as np def plot_histograms_in_3d( - histograms: Dict[str, np.ndarray], histograms_per_row: int = 3 + histograms: Dict[str, np.ndarray], histograms_per_row: int = 3, height: int = 300 ): column_count = min(histograms_per_row, len(histograms)) row_count = ceil(len(histograms) / histograms_per_row) @@ -35,7 +35,7 @@ def plot_histograms_in_3d( ) for i in range(1, len(histograms) + 1) } - fig.update_layout(**scenes, height=300 * column_count) + fig.update_layout(**scenes, height=height * column_count) fig.update_layout() # You can adjust the height as needed fig.show()