diff --git a/src/visualisation/display_images.py b/src/visualisation/display_images.py index a4cd2ee..9ce31f2 100644 --- a/src/visualisation/display_images.py +++ b/src/visualisation/display_images.py @@ -4,11 +4,29 @@ from PIL.Image import Image from math import ceil -def display_images(images: Dict[str, Image], images_per_row: int = 3, figsize=(24, 16)): +def display_images( + images: Dict[str, Image], images_per_row: int = 3, img_size_inches: int = 2 +) -> plt.Figure: + row_count = ceil(len(images) / images_per_row) + + an_image = next(iter(images.values())) + aspect_ratio = an_image.size[0] / an_image.size[1] + + unit_height = ( + img_size_inches + if an_image.size[1] > an_image.size[0] + else img_size_inches / aspect_ratio + ) + unit_width = ( + img_size_inches + if an_image.size[0] > an_image.size[1] + else img_size_inches * aspect_ratio + ) + fig, axes = plt.subplots( - nrows=ceil(len(images) / images_per_row), - ncols=min(images_per_row, len(images)), - figsize=figsize, + nrows=row_count, + ncols=images_per_row, + figsize=(unit_width * images_per_row, unit_height * row_count), ) axes = axes.flatten() @@ -21,5 +39,8 @@ def display_images(images: Dict[str, Image], images_per_row: int = 3, figsize=(2 for i in range(len(images), len(axes)): axes[i].axis("off") - plt.tight_layout() - plt.show() + fig.subplots_adjust( + hspace=0.25, wspace=0.15, top=0.95, bottom=0.05, left=0.05, right=0.95 + ) + + return fig diff --git a/src/visualisation/plot_histograms_in_2d.py b/src/visualisation/plot_histograms_in_2d.py index 8598271..f598719 100644 --- a/src/visualisation/plot_histograms_in_2d.py +++ b/src/visualisation/plot_histograms_in_2d.py @@ -1,14 +1,21 @@ import numpy as np import matplotlib.pyplot as plt -import numpy as np from typing import Dict -def plot_histograms_in_2d(histograms: Dict[str, np.ndarray], figsize=(15, 5)): - fig = plt.figure(figsize=figsize) +def plot_histograms_in_2d( + histograms: Dict[str, np.ndarray], histograms_per_row=3, histograms_size_inches=4 +): + row_count = max(1, len(histograms) // histograms_per_row) + fig = plt.figure( + figsize=( + histograms_per_row * histograms_size_inches, + row_count * histograms_size_inches, + ) + ) for i, (title, histogram) in enumerate(histograms.items(), 1): - ax = fig.add_subplot(1, 3, i, projection="3d") + ax = fig.add_subplot(row_count, histograms_per_row, i, projection="3d") size = histogram.shape[0] @@ -18,7 +25,7 @@ def plot_histograms_in_2d(histograms: Dict[str, np.ndarray], figsize=(15, 5)): z = z.flatten() values = histogram.flatten() - sizes = values * 5000 + sizes = values * 5000 # this is just an arbitrary scaling factor colors = np.vstack((x, y, z)).T / (size - 1) @@ -27,6 +34,17 @@ def plot_histograms_in_2d(histograms: Dict[str, np.ndarray], figsize=(15, 5)): ax.set_xlim([0, (size - 1)]) ax.set_ylim([0, (size - 1)]) ax.set_zlim([0, (size - 1)]) + ax.set_xlabel("Red", labelpad=-10) + ax.set_ylabel("Green", labelpad=-10) + ax.set_zlabel("Blue", labelpad=-12, rotation=90) + + ax.set_xticklabels([]) + ax.set_yticklabels([]) + ax.set_zticklabels([]) + ax.set_title(title) + fig.subplots_adjust( + hspace=0.25, wspace=0.15, top=0.95, bottom=0.05, left=0.05, right=0.95 + ) return fig diff --git a/src/visualisation/plot_histograms_in_3d.py b/src/visualisation/plot_histograms_in_3d.py index 1ea4877..d3834d2 100644 --- a/src/visualisation/plot_histograms_in_3d.py +++ b/src/visualisation/plot_histograms_in_3d.py @@ -6,21 +6,24 @@ import numpy as np def plot_histograms_in_3d( - histograms: Dict[str, np.ndarray], histogram_per_row: int = 3 + histograms: Dict[str, np.ndarray], histograms_per_row: int = 3 ): - cols = min(histogram_per_row, len(histograms)) - rows = ceil(len(histograms) / histogram_per_row) + column_count = min(histograms_per_row, len(histograms)) + row_count = ceil(len(histograms) / histograms_per_row) fig = make_subplots( - rows=rows, - cols=cols, - specs=[[{"type": "scatter3d"} for _ in range(cols)] for _ in range(rows)], + rows=row_count, + cols=column_count, + specs=[ + [{"type": "scatter3d"} for _ in range(column_count)] + for _ in range(row_count) + ], ) for i, (title, histogram) in enumerate(histograms.items()): fig.add_trace( _get_3d_scatter_plot_from_histogram(title, histogram), - row=(i // (histogram_per_row + 1)) + 1, - col=(i % histogram_per_row) + 1, + row=(i // histograms_per_row) + 1, + col=(i % histograms_per_row) + 1, ) scenes = { @@ -32,7 +35,9 @@ def plot_histograms_in_3d( ) for i in range(1, len(histograms) + 1) } - fig.update_layout(**scenes) + fig.update_layout(**scenes, height=300 * column_count) + fig.update_layout() # You can adjust the height as needed + fig.show()