This commit is contained in:
Andras Schmelczer 2024-04-28 12:19:19 +01:00
parent eec9ee0275
commit 294f2fab12
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
9 changed files with 62140 additions and 11540 deletions

View file

@ -0,0 +1,3 @@
from .display_images import display_images
from .plot_histograms_in_3d import plot_histograms_in_3d
from .plot_histograms_in_2d import plot_histograms_in_2d

View file

@ -0,0 +1,25 @@
import matplotlib.pyplot as plt
from typing import Dict
from PIL.Image import Image
from math import ceil
def display_images(images: Dict[str, Image], images_per_row: int = 3):
fig, axes = plt.subplots(
nrows=ceil(len(images) / images_per_row),
ncols=min(images_per_row, len(images)),
figsize=(12, 8),
)
axes = axes.flatten()
for i, (title, image) in enumerate(images.items()):
axes[i].imshow(image)
axes[i].axis("off")
axes[i].set_title(title)
for i in range(len(images), len(axes)):
axes[i].axis("off")
plt.tight_layout()
plt.show()

View file

@ -0,0 +1,32 @@
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict
def plot_histograms_in_2d(histograms: Dict[str, np.ndarray]):
fig = plt.figure(figsize=(15, 5))
for i, (title, histogram) in enumerate(histograms.items(), 1):
ax = fig.add_subplot(1, 3, i, projection="3d")
size = histogram.shape[0]
x, y, z = np.indices(histogram.shape)
x = x.flatten()
y = y.flatten()
z = z.flatten()
values = histogram.flatten()
sizes = values * 5000
colors = np.vstack((x, y, z)).T / (size - 1)
sc = ax.scatter(x, y, z, c=colors, s=sizes, marker="o", alpha=0.5)
ax.set_xlim([0, (size - 1)])
ax.set_ylim([0, (size - 1)])
ax.set_zlim([0, (size - 1)])
ax.set_title(title)
return fig

View file

@ -0,0 +1,54 @@
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from math import ceil
from typing import Dict
import numpy as np
def plot_histograms_in_3d(
histograms: Dict[str, np.ndarray], histogram_per_row: int = 3
):
cols = min(histogram_per_row, len(histograms))
rows = ceil(len(histograms) / histogram_per_row)
fig = make_subplots(
rows=rows,
cols=cols,
specs=[[{"type": "scatter3d"} for _ in range(cols)] for _ in range(rows)],
)
for i, (title, histogram) in enumerate(histograms.items()):
fig.add_trace(
_get_3d_scatter_plot_from_histogram(title, histogram),
row=(i // (histogram_per_row + 1)) + 1,
col=(i % histogram_per_row) + 1,
)
fig.show()
def _get_3d_scatter_plot_from_histogram(title, histogram):
x, y, z, marker_size = [], [], [], []
bins = len(histogram)
for i, row in enumerate(histogram):
for j, col in enumerate(row):
for k, value in enumerate(col):
if value > 0:
x.append(i)
y.append(j)
z.append(k)
marker_size.append(value)
return go.Scatter3d(
x=x,
y=y,
z=z,
mode="markers",
marker=dict(
size=[min(20, ms * 10000) for ms in marker_size],
color=[
f"rgb({xi*256/bins},{yi*256/bins},{zi*256/bins})"
for xi, yi, zi in zip(x, y, z)
],
opacity=0.8,
),
name=title,
)