diff --git a/src/editor/visualisation/display_images.py b/src/editor/visualisation/display_images.py index b3e2457..a4cd2ee 100644 --- a/src/editor/visualisation/display_images.py +++ b/src/editor/visualisation/display_images.py @@ -4,11 +4,11 @@ from PIL.Image import Image from math import ceil -def display_images(images: Dict[str, Image], images_per_row: int = 3): +def display_images(images: Dict[str, Image], images_per_row: int = 3, figsize=(24, 16)): fig, axes = plt.subplots( nrows=ceil(len(images) / images_per_row), ncols=min(images_per_row, len(images)), - figsize=(24, 16), + figsize=figsize, ) axes = axes.flatten() diff --git a/src/editor/visualisation/plot_histograms_in_2d.py b/src/editor/visualisation/plot_histograms_in_2d.py index 0891e1f..8598271 100644 --- a/src/editor/visualisation/plot_histograms_in_2d.py +++ b/src/editor/visualisation/plot_histograms_in_2d.py @@ -4,8 +4,8 @@ import numpy as np from typing import Dict -def plot_histograms_in_2d(histograms: Dict[str, np.ndarray]): - fig = plt.figure(figsize=(15, 5)) +def plot_histograms_in_2d(histograms: Dict[str, np.ndarray], figsize=(15, 5)): + fig = plt.figure(figsize=figsize) for i, (title, histogram) in enumerate(histograms.items(), 1): ax = fig.add_subplot(1, 3, i, projection="3d")