Improve rendering
This commit is contained in:
parent
015a77a99e
commit
d2ce23fa33
3 changed files with 64 additions and 20 deletions
|
|
@ -4,11 +4,29 @@ from PIL.Image import Image
|
|||
from math import ceil
|
||||
|
||||
|
||||
def display_images(images: Dict[str, Image], images_per_row: int = 3, figsize=(24, 16)):
|
||||
def display_images(
|
||||
images: Dict[str, Image], images_per_row: int = 3, img_size_inches: int = 2
|
||||
) -> plt.Figure:
|
||||
row_count = ceil(len(images) / images_per_row)
|
||||
|
||||
an_image = next(iter(images.values()))
|
||||
aspect_ratio = an_image.size[0] / an_image.size[1]
|
||||
|
||||
unit_height = (
|
||||
img_size_inches
|
||||
if an_image.size[1] > an_image.size[0]
|
||||
else img_size_inches / aspect_ratio
|
||||
)
|
||||
unit_width = (
|
||||
img_size_inches
|
||||
if an_image.size[0] > an_image.size[1]
|
||||
else img_size_inches * aspect_ratio
|
||||
)
|
||||
|
||||
fig, axes = plt.subplots(
|
||||
nrows=ceil(len(images) / images_per_row),
|
||||
ncols=min(images_per_row, len(images)),
|
||||
figsize=figsize,
|
||||
nrows=row_count,
|
||||
ncols=images_per_row,
|
||||
figsize=(unit_width * images_per_row, unit_height * row_count),
|
||||
)
|
||||
|
||||
axes = axes.flatten()
|
||||
|
|
@ -21,5 +39,8 @@ def display_images(images: Dict[str, Image], images_per_row: int = 3, figsize=(2
|
|||
for i in range(len(images), len(axes)):
|
||||
axes[i].axis("off")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
fig.subplots_adjust(
|
||||
hspace=0.25, wspace=0.15, top=0.95, bottom=0.05, left=0.05, right=0.95
|
||||
)
|
||||
|
||||
return fig
|
||||
|
|
|
|||
|
|
@ -1,14 +1,21 @@
|
|||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def plot_histograms_in_2d(histograms: Dict[str, np.ndarray], figsize=(15, 5)):
|
||||
fig = plt.figure(figsize=figsize)
|
||||
def plot_histograms_in_2d(
|
||||
histograms: Dict[str, np.ndarray], histograms_per_row=3, histograms_size_inches=4
|
||||
):
|
||||
row_count = max(1, len(histograms) // histograms_per_row)
|
||||
fig = plt.figure(
|
||||
figsize=(
|
||||
histograms_per_row * histograms_size_inches,
|
||||
row_count * histograms_size_inches,
|
||||
)
|
||||
)
|
||||
|
||||
for i, (title, histogram) in enumerate(histograms.items(), 1):
|
||||
ax = fig.add_subplot(1, 3, i, projection="3d")
|
||||
ax = fig.add_subplot(row_count, histograms_per_row, i, projection="3d")
|
||||
|
||||
size = histogram.shape[0]
|
||||
|
||||
|
|
@ -18,7 +25,7 @@ def plot_histograms_in_2d(histograms: Dict[str, np.ndarray], figsize=(15, 5)):
|
|||
z = z.flatten()
|
||||
values = histogram.flatten()
|
||||
|
||||
sizes = values * 5000
|
||||
sizes = values * 5000 # this is just an arbitrary scaling factor
|
||||
|
||||
colors = np.vstack((x, y, z)).T / (size - 1)
|
||||
|
||||
|
|
@ -27,6 +34,17 @@ def plot_histograms_in_2d(histograms: Dict[str, np.ndarray], figsize=(15, 5)):
|
|||
ax.set_xlim([0, (size - 1)])
|
||||
ax.set_ylim([0, (size - 1)])
|
||||
ax.set_zlim([0, (size - 1)])
|
||||
ax.set_xlabel("Red", labelpad=-10)
|
||||
ax.set_ylabel("Green", labelpad=-10)
|
||||
ax.set_zlabel("Blue", labelpad=-12, rotation=90)
|
||||
|
||||
ax.set_xticklabels([])
|
||||
ax.set_yticklabels([])
|
||||
ax.set_zticklabels([])
|
||||
|
||||
ax.set_title(title)
|
||||
|
||||
fig.subplots_adjust(
|
||||
hspace=0.25, wspace=0.15, top=0.95, bottom=0.05, left=0.05, right=0.95
|
||||
)
|
||||
return fig
|
||||
|
|
|
|||
|
|
@ -6,21 +6,24 @@ import numpy as np
|
|||
|
||||
|
||||
def plot_histograms_in_3d(
|
||||
histograms: Dict[str, np.ndarray], histogram_per_row: int = 3
|
||||
histograms: Dict[str, np.ndarray], histograms_per_row: int = 3
|
||||
):
|
||||
cols = min(histogram_per_row, len(histograms))
|
||||
rows = ceil(len(histograms) / histogram_per_row)
|
||||
column_count = min(histograms_per_row, len(histograms))
|
||||
row_count = ceil(len(histograms) / histograms_per_row)
|
||||
fig = make_subplots(
|
||||
rows=rows,
|
||||
cols=cols,
|
||||
specs=[[{"type": "scatter3d"} for _ in range(cols)] for _ in range(rows)],
|
||||
rows=row_count,
|
||||
cols=column_count,
|
||||
specs=[
|
||||
[{"type": "scatter3d"} for _ in range(column_count)]
|
||||
for _ in range(row_count)
|
||||
],
|
||||
)
|
||||
|
||||
for i, (title, histogram) in enumerate(histograms.items()):
|
||||
fig.add_trace(
|
||||
_get_3d_scatter_plot_from_histogram(title, histogram),
|
||||
row=(i // (histogram_per_row + 1)) + 1,
|
||||
col=(i % histogram_per_row) + 1,
|
||||
row=(i // histograms_per_row) + 1,
|
||||
col=(i % histograms_per_row) + 1,
|
||||
)
|
||||
|
||||
scenes = {
|
||||
|
|
@ -32,7 +35,9 @@ def plot_histograms_in_3d(
|
|||
)
|
||||
for i in range(1, len(histograms) + 1)
|
||||
}
|
||||
fig.update_layout(**scenes)
|
||||
fig.update_layout(**scenes, height=300 * column_count)
|
||||
fig.update_layout() # You can adjust the height as needed
|
||||
|
||||
fig.show()
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue