Improve rendering

This commit is contained in:
Andras Schmelczer 2024-06-30 16:15:05 +01:00
parent 015a77a99e
commit d2ce23fa33
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
3 changed files with 64 additions and 20 deletions

View file

@ -4,11 +4,29 @@ from PIL.Image import Image
from math import ceil
def display_images(images: Dict[str, Image], images_per_row: int = 3, figsize=(24, 16)):
def display_images(
images: Dict[str, Image], images_per_row: int = 3, img_size_inches: int = 2
) -> plt.Figure:
row_count = ceil(len(images) / images_per_row)
an_image = next(iter(images.values()))
aspect_ratio = an_image.size[0] / an_image.size[1]
unit_height = (
img_size_inches
if an_image.size[1] > an_image.size[0]
else img_size_inches / aspect_ratio
)
unit_width = (
img_size_inches
if an_image.size[0] > an_image.size[1]
else img_size_inches * aspect_ratio
)
fig, axes = plt.subplots(
nrows=ceil(len(images) / images_per_row),
ncols=min(images_per_row, len(images)),
figsize=figsize,
nrows=row_count,
ncols=images_per_row,
figsize=(unit_width * images_per_row, unit_height * row_count),
)
axes = axes.flatten()
@ -21,5 +39,8 @@ def display_images(images: Dict[str, Image], images_per_row: int = 3, figsize=(2
for i in range(len(images), len(axes)):
axes[i].axis("off")
plt.tight_layout()
plt.show()
fig.subplots_adjust(
hspace=0.25, wspace=0.15, top=0.95, bottom=0.05, left=0.05, right=0.95
)
return fig

View file

@ -1,14 +1,21 @@
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict
def plot_histograms_in_2d(histograms: Dict[str, np.ndarray], figsize=(15, 5)):
fig = plt.figure(figsize=figsize)
def plot_histograms_in_2d(
histograms: Dict[str, np.ndarray], histograms_per_row=3, histograms_size_inches=4
):
row_count = max(1, len(histograms) // histograms_per_row)
fig = plt.figure(
figsize=(
histograms_per_row * histograms_size_inches,
row_count * histograms_size_inches,
)
)
for i, (title, histogram) in enumerate(histograms.items(), 1):
ax = fig.add_subplot(1, 3, i, projection="3d")
ax = fig.add_subplot(row_count, histograms_per_row, i, projection="3d")
size = histogram.shape[0]
@ -18,7 +25,7 @@ def plot_histograms_in_2d(histograms: Dict[str, np.ndarray], figsize=(15, 5)):
z = z.flatten()
values = histogram.flatten()
sizes = values * 5000
sizes = values * 5000 # this is just an arbitrary scaling factor
colors = np.vstack((x, y, z)).T / (size - 1)
@ -27,6 +34,17 @@ def plot_histograms_in_2d(histograms: Dict[str, np.ndarray], figsize=(15, 5)):
ax.set_xlim([0, (size - 1)])
ax.set_ylim([0, (size - 1)])
ax.set_zlim([0, (size - 1)])
ax.set_xlabel("Red", labelpad=-10)
ax.set_ylabel("Green", labelpad=-10)
ax.set_zlabel("Blue", labelpad=-12, rotation=90)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])
ax.set_title(title)
fig.subplots_adjust(
hspace=0.25, wspace=0.15, top=0.95, bottom=0.05, left=0.05, right=0.95
)
return fig

View file

@ -6,21 +6,24 @@ import numpy as np
def plot_histograms_in_3d(
histograms: Dict[str, np.ndarray], histogram_per_row: int = 3
histograms: Dict[str, np.ndarray], histograms_per_row: int = 3
):
cols = min(histogram_per_row, len(histograms))
rows = ceil(len(histograms) / histogram_per_row)
column_count = min(histograms_per_row, len(histograms))
row_count = ceil(len(histograms) / histograms_per_row)
fig = make_subplots(
rows=rows,
cols=cols,
specs=[[{"type": "scatter3d"} for _ in range(cols)] for _ in range(rows)],
rows=row_count,
cols=column_count,
specs=[
[{"type": "scatter3d"} for _ in range(column_count)]
for _ in range(row_count)
],
)
for i, (title, histogram) in enumerate(histograms.items()):
fig.add_trace(
_get_3d_scatter_plot_from_histogram(title, histogram),
row=(i // (histogram_per_row + 1)) + 1,
col=(i % histogram_per_row) + 1,
row=(i // histograms_per_row) + 1,
col=(i % histograms_per_row) + 1,
)
scenes = {
@ -32,7 +35,9 @@ def plot_histograms_in_3d(
)
for i in range(1, len(histograms) + 1)
}
fig.update_layout(**scenes)
fig.update_layout(**scenes, height=300 * column_count)
fig.update_layout() # You can adjust the height as needed
fig.show()