"""Various helpful utilities for plotting or displaying information."""
# numpydoc ignore=ES01
import warnings
from typing import Any, Literal
import einops
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pyrtools as pt
import torch
from ..tensors import to_numpy
__all__ = [
"animshow",
"stem_plot",
"imshow",
"plot_representation",
"pyrshow",
"update_plot",
"histogram",
]
def __dir__() -> list[str]:
return __all__
def _find_zoom(
image_heights: list[int], image_widths: list[int], ax: mpl.axes.Axes
) -> float:
"""
Find best-fitting zoom based on image and axes sizes.
If images are bigger than ``ax``, then we figure out the largest float of form
``1/d``, where ``d`` is an integer. If ``ax`` is bigger than images, figure out the
largest integer we can use.
Parameters
----------
image_heights, image_widths
The last two dimensions of all images to plot.
ax
The existing axis we will use for imshow.
Returns
-------
zoom
Our best guess at zoom.
"""
def find_zoom_helper(x: float, limit: float) -> float:
"""
Find zoom that works. This is only for limit < x.
Parameters
----------
x
The sizes to consider.
limit
The max possible size.
Returns
-------
zoom
The valid zoom level.
""" # numpydoc ignore=ES01
# find all non-trivial divisors of x
divisors = [i for i in range(2, x) if not x % i]
# find the largest zoom (equivalently, smallest divisor) such that the
# zoomed in image is smaller than the limit
return 1 / min([i for i in divisors if x / i <= limit])
if ax.bbox.height > max(image_heights):
zoom = ax.bbox.height // max(image_heights)
else:
zoom = find_zoom_helper(max(image_heights), ax.bbox.height)
if ax.bbox.width > max(image_widths):
zoom = min(zoom, ax.bbox.width // max(image_widths))
else:
zoom = find_zoom_helper(max(image_widths), ax.bbox.width)
return zoom
[docs]
def imshow(
image: torch.Tensor | list[torch.Tensor],
vrange: tuple[float, float] | str = "indep1",
zoom: float | None = None,
title: str | list[str] | None = "",
col_wrap: int | None = None,
ax: mpl.axes.Axes | None = None,
cmap: mpl.colors.Colormap | None = None,
plot_complex: Literal["rectangular", "polar", "logpolar"] = "rectangular",
batch_idx: int | None = None,
channel_idx: int | None = None,
as_rgb: bool = False,
**kwargs: Any,
) -> pt.tools.display.PyrFigure:
"""
Show image(s), avoiding interpolation.
This function shows images carefully, avoiding interpolation: each element in the
input ``image`` will correspond to a pixel or an integer number of pixels. When
``zoom<1``, an integer number of input elements will be averaged into a single
pixel.
Parameters
----------
image
The images to display. Tensors should be 4d (batch, channel, height,
width). List of tensors should be used for tensors of different height
and width: all images will automatically be rescaled so they're
displayed at the same height and width, thus, their heights and widths
must be scalar multiples of each other.
vrange
If a 2-tuple, specifies the image values vmin/vmax that are mapped to
the minimum and maximum value of the colormap, respectively. If a
string:
* ``"auto0"``: All images have same vmin/vmax, which have the same
absolute value, and come from the minimum or maximum across all
images, whichever has the larger absolute value.
* ``"auto1"``: All images have same vmin/vmax, which are the
minimum/maximum values across all images.
* ``"auto2"``: All images have same vmin/vmax, which are the mean (across
all images) minus/ plus 2 std dev (across all images).
* ``"auto3"``: All images have same vmin/vmax, chosen so as to map the 10th/90th
percentile values to the 10th/90th percentile of the display intensity range.
For example: vmin is the 10th percentile image value minus 1/8 times the
difference between the 90th and 10th percentile.
* ``"indep0"``: Each image has an independent vmin/vmax, which have the same
absolute value, which comes from either their minimum or maximum value,
whichever has the larger absolute value.
* ``"indep1"``: Each image has an independent vmin/vmax, which are their
minimum/maximum values.
* ``"indep2"``: Each image has an independent vmin/vmax, which is their
mean minus/plus 2 std dev.
* ``"indep3"``: Each image has an independent vmin/vmax, chosen so that the
10th/90th percentile values map to the 10th/90th percentile intensities.
zoom
Ratio of display pixels to image pixels. If greater than 1, must be an
integer. If less than 1, must be ``1/d`` where ``d`` is a a divisor of the
size of the largest image. If ``None``, we try to determine the best zoom.
title
Title for the plot. In addition to the specified title, we add a
subtitle giving the plotted range and dimensionality (with zoom).
* If ``str``, will put the same title on every plot.
* If ``list``, all values must be ``str``, must be the same length as img,
and each title will be assigned to corresponding plot.
* If ``None``, no title will be printed and subtitle will be removed.
col_wrap
Number of axes to have in each row. If ``None``, will fit all axes in a
single row.
ax
If ``None``, we make the appropriate figure. Otherwise, we shrink the axes
so that it's the appropriate number of pixels.
cmap
Colormap to use when showing these images. If ``None``, then behavior is
determined by ``vrange``: if ``vmap in ["auto0", "indep0"]``, we use
``"RdBu_r"``, else we use ``"gray"`` (see
:external+matplotlib:ref:`matplotlib documentation <colormaps>`).
plot_complex
Specifies handling of complex values.
* ``"rectangular"``: plot real and imaginary components as separate images.
* ``"polar"``: plot amplitude and phase as separate images.
* ``"logpolar"``: plot log (base 2) amplitude and phase as separate images.
batch_idx
Which element from the batch dimension to plot. If ``None``, we plot all.
channel_idx
Which element from the channel dimension to plot. If ``None``, we plot all.
Note if this is not ``None``, then ``as_rgb=True`` will fail, because we
restrict the channels.
as_rgb
Whether to consider the channels as encoding RGB(A) values. If ``True``, we
attempt to plot the image in color, so your tensor must have 3 (or 4 if
you want the alpha channel) elements in the channel dimension. If ``False``,
we plot each channel as a separate grayscale image.
**kwargs
Passed to :func:`matplotlib.pyplot.imshow`.
Returns
-------
fig
Figure containing the plotted images.
Raises
------
ValueError
If ``images`` is not a 4d tensor or list of 4d tensors.
TypeError
If ``batch_idx`` or ``channel_idx`` are not an int or ``None``.
IndexError
If ``batch_idx`` or ``channel_idx`` are out of bounds.
ValueError
If ``zoom`` takes an illegal value.
ValueError
If ``as_rgb=True`` and the input ``image`` does not have 3 or 4 channels.
ValueError
If ``as_rgb=False``, ``image`` has more than one channel and one more than
one batch and neither ``batch_idx`` nor ``channel_idx`` is set.
Exception
If ``plot_complex`` takes an illegal value.
See Also
--------
:func:`~plenoptic.plot.synthesis_imshow`
Show the image synthesized by a synthesis object.
animshow
Animate a video.
pyrshow
Display steerable pyramid coefficients.
Notes
-----
This interpolation avoidance is only guaranteed for the saved image; it should
generally hold in notebooks as well, but will fail if, e.g., you plot an image
that's 2000 pixels wide on an monitor 1000 pixels wide; the browser handles the
rescaling in a way we can't control.
Examples
--------
Plot a single grayscale image:
.. plot::
:context: reset
>>> import plenoptic as po
>>> einstein = po.data.einstein()
>>> einstein.shape
torch.Size([1, 1, 256, 256])
>>> po.plot.imshow(einstein)
<PyrFigure size ... with 1 Axes>
For an image tensor with multiple elements along the batch dimension and a
single channel element, this function will plot each batch independently as
grayscale images:
.. plot::
:context: close-figs
>>> import torch
>>> curie = po.data.curie()
>>> imgs = torch.cat([einstein, curie])
>>> print(imgs.shape)
torch.Size([2, 1, 256, 256])
>>> po.plot.imshow(imgs)
<PyrFigure size ... with 2 Axes>
A list of 4d tensors will be concatenated along the batch dimension before plotting.
Thus, the following example is the same as above:
.. plot::
:context: close-figs
>>> po.plot.imshow([einstein, curie])
<PyrFigure size ... with 2 Axes>
You may use the ``title`` argument for any number of images, either as a string
applied to all images or as a list the length of images. Additionally, ``col_wrap``
specifies the number of images per row:
.. plot::
:context: close-figs
>>> po.plot.imshow(imgs, title=["einstein", "curie"], col_wrap=1)
<PyrFigure size ... with 2 Axes>
Specifying ``batch_idx`` will plot the corresponding element in
the batch dimension (i.e., ``imgs[batch_idx]``):
.. plot::
:context: close-figs
>>> print(imgs.shape)
torch.Size([2, 1, 256, 256])
>>> po.plot.imshow(imgs, batch_idx=1)
<PyrFigure size ... with 1 Axes>
The vrange argument allows control over the min and max values of the color range.
In addition to a 2-tuple of floats, this functions accepts several special strings
(see docstring for details). For example, ``"auto1"`` sets all images to have the
same range:
.. plot::
:context: close-figs
>>> einsteins_scaled = torch.cat([einstein, einstein * 2])
>>> po.plot.imshow(einsteins_scaled, vrange="auto1")
<PyrFigure size ... with 2 Axes>
Meanwhile, ``"indep1"`` sets each image's range independently. Note the different
ranges in the titles!
.. plot::
:context: close-figs
>>> po.plot.imshow(einsteins_scaled, vrange="indep1")
<PyrFigure size ... with 2 Axes>
The ``zoom`` argument allows users to set the ratio of display to image pixels,
increasing or decreasing the size of the resulting plot:
.. plot::
:context: close-figs
>>> po.plot.imshow(einstein, zoom=0.5)
<PyrFigure size ... with 1 Axes>
Note that if ``zoom<1`` and the value is not a divisor of the largest image size,
this function will raise an error:
>>> print(einstein.shape)
torch.Size([1, 1, 256, 256])
>>> po.plot.imshow(einstein, zoom=0.7)
Traceback (most recent call last):
Exception: zoom * signal.shape must result in integers!
You can use the ``plot_complex`` argument to control how complex tensors are
plotted:
.. plot::
:context: close-figs
>>> einstein_fft = torch.fft.fft2(einstein)
>>> po.plot.imshow([einstein, einstein_fft], plot_complex="logpolar")
<PyrFigure size ... with 3 Axes>
To plot a RGB(A) image in color, you must set ``as_rgb=True``:
.. plot::
:context: close-figs
>>> color_wheel = po.data.color_wheel()
>>> print(color_wheel.shape)
torch.Size([1, 3, 600, 600])
>>> po.plot.imshow(color_wheel, as_rgb=True, zoom=0.5)
<PyrFigure size ... with 1 Axes>
Otherwise, images with multiple channels will have each channel plotted as a
separate grayscale image:
.. plot::
:context: close-figs
>>> po.plot.imshow(color_wheel, zoom=0.5, title=["R", "G", "B"])
<PyrFigure size ... with 3 Axes>
This function will raise a ``ValueError`` if ``as_rgb=True`` and the input image
doesn't have the required number of channels:
>>> print(einstein.shape)
torch.Size([1, 1, 256, 256])
>>> po.plot.imshow(einstein, as_rgb=True)
Traceback (most recent call last):
ValueError: If as_rgb is True, then channel must have 3 or 4 elements!
Images will be automatically rescaled to be displayed at the same heights and widths
if sizes are scalar multiples of each other:
.. plot::
:context: close-figs
>>> einstein_cropped = po.process.center_crop(einstein, 32)
>>> po.plot.imshow([einstein, einstein_cropped])
<PyrFigure size ... with 2 Axes>
"""
if not isinstance(image, list):
image = [image]
if any([im.ndim != 4 for im in image]):
raise ValueError("imshow only accepts images as 4d tensors!")
images_to_plot = []
heights, widths = [], []
for im in image:
im = to_numpy(im)
orig_shape = im.shape
if batch_idx is not None:
try:
# this preserves the number of dimensions
im = im[batch_idx : batch_idx + 1]
except TypeError:
raise TypeError(f"batch_idx must be an int or None but got {batch_idx}")
if im.shape[0] == 0:
raise IndexError(
f"{batch_idx=} is out of bounds for dimension 0 with size "
f"{orig_shape[0]}"
)
if channel_idx is not None:
try:
# this preserves the number of dimensions
im = im[:, channel_idx : channel_idx + 1]
except TypeError:
raise TypeError(
f"channel_idx must be an int or None but got {channel_idx}"
)
if im.shape[1] == 0:
raise IndexError(
f"{channel_idx=} is out of bounds for dimension 1 with size "
f"{orig_shape[1]}"
)
# allow RGB and RGBA
if as_rgb:
if im.shape[1] not in [3, 4]:
raise ValueError(
"If as_rgb is True, then channel must have 3 or 4 elements!"
)
im = im.transpose(0, 2, 3, 1)
# want to insert a fake "channel" dimension here, so our putting it
# into a list below works as expected
im = im.reshape((im.shape[0], 1, *im.shape[1:]))
elif im.shape[1] > 1 and im.shape[0] > 1:
raise ValueError(
"Don't know how to plot non-rgb images with more than one channel"
" and batch! Use batch_idx / channel_idx to choose a subset for"
" plotting."
)
# by iterating through it twice, we make sure to peel apart the batch
# and channel dimensions so that they each show up as a separate image.
# because of how we've handled everything above, we know that im will
# be (b,c,h,w) or (b,c,h,w,r) where r is the RGB(A) values
for i in im:
# at this point, i_ are all shape (h,w) or (h,w,r) and so we don't
# squeeze, which could accidentally drop a dimension if h or w is a
# singleton dimension
images_to_plot.extend([i_ for i_ in i])
heights.extend([i_.shape[0] for i_ in i])
widths.extend([i_.shape[1] for i_ in i])
if zoom is None and ax is not None:
zoom = _find_zoom(heights, widths, ax)
elif zoom is None:
zoom = 1
elif zoom <= 0:
raise ValueError("zoom must be positive!")
return pt.imshow(
images_to_plot,
vrange=vrange,
zoom=zoom,
title=title,
col_wrap=col_wrap,
ax=ax,
cmap=cmap,
plot_complex=plot_complex,
**kwargs,
)
[docs]
def animshow(
video: torch.Tensor | list[torch.Tensor],
framerate: float = 2.0,
repeat: bool = False,
vrange: tuple[float, float] | str = "indep1",
zoom: float | None = None,
title: str | list[str] | None = "",
col_wrap: int | None = None,
ax: mpl.axes.Axes | None = None,
cmap: mpl.colors.Colormap | None = None,
plot_complex: Literal["rectangular", "polar", "logpolar"] = "rectangular",
batch_idx: int | None = None,
channel_idx: int | None = None,
as_rgb: bool = False,
**kwargs: Any,
) -> mpl.animation.FuncAnimation:
"""
Animate video(s), avoiding interpolation.
This function shows images carefully, avoiding interpolation: each element in the
input ``image`` will correspond to a pixel or an integer number of pixels. When
``zoom<1``, an integer number of input elements will be averaged into a single
pixel.
Parameters
----------
video
The video(s) to display. Tensors should be 5d (batch, channel, time,
height, width). List of tensors should be used for tensors of different
height and width: all videos will automatically be rescaled so they're
displayed at the same height and width, thus, their heights and widths
must be scalar multiples of each other. Videos must all have the same
number of frames.
framerate
Temporal resolution of the video, in Hz (frames per second).
repeat
Whether to loop the animation or just play it once.
vrange
If a 2-tuple, specifies the image values vmin/vmax that are mapped to
the minimum and maximum value of the colormap, respectively. If a
string:
* ``"auto0"``: All images have same vmin/vmax, which have the same
absolute value, and come from the minimum or maximum across all
images, whichever has the larger absolute value.
* ``"auto1"``: All images have same vmin/vmax, which are the
minimum/maximum values across all images.
* ``"auto2"``: All images have same vmin/vmax, which are the mean (across
all images) minus/ plus 2 std dev (across all images).
* ``"auto3"``: All images have same vmin/vmax, chosen so as to map the 10th/90th
percentile values to the 10th/90th percentile of the display intensity range.
For example: vmin is the 10th percentile image value minus 1/8 times the
difference between the 90th and 10th percentile.
* ``"indep0"``: Each image has an independent vmin/vmax, which have the same
absolute value, which comes from either their minimum or maximum value,
whichever has the larger absolute value.
* ``"indep1"``: Each image has an independent vmin/vmax, which are their
minimum/maximum values.
* ``"indep2"``: Each image has an independent vmin/vmax, which is their
mean minus/plus 2 std dev.
* ``"indep3"``: Each image has an independent vmin/vmax, chosen so that the
10th/90th percentile values map to the 10th/90th percentile intensities.
zoom
Ratio of display pixels to image pixels. If greater than 1, must be an
integer. If less than 1, must be ``1/d`` where ``d`` is a a divisor of the
size of the largest image. If ``None``, we try to determine the best zoom.
title
Title for the plot. In addition to the specified title, we add a
subtitle giving the plotted range and dimensionality (with zoom).
* If ``str``, will put the same title on every plot.
* If ``list``, all values must be ``str``, must be the same length as img,
and each title will be assigned to corresponding plot.
* If ``None``, no title will be printed and subtitle will be removed.
col_wrap
Number of axes to have in each row. If ``None``, will fit all axes in a
single row.
ax
If ``None``, we make the appropriate figure. Otherwise, we shrink the axes
so that it's the appropriate number of pixels.
cmap
Colormap to use when showing these images. If ``None``, then behavior is
determined by ``vrange``: if ``vmap in ["auto0", "indep0"]``, we use
``"RdBu_r"``, else we use ``"gray"`` (see
:external+matplotlib:ref:`matplotlib documentation <colormaps>`).
plot_complex
Specifies handling of complex values.
* ``"rectangular"``: plot real and imaginary components as separate images.
* ``"polar"``: plot amplitude and phase as separate images.
* ``"logpolar"``: plot log (base 2) amplitude and phase as separate images.
batch_idx
Which element from the batch dimension to plot. If ``None``, we plot all.
channel_idx
Which element from the channel dimension to plot. If ``None``, we plot all.
Note if this is not ``None``, then ``as_rgb=True`` will fail, because we
restrict the channels.
as_rgb
Whether to consider the channels as encoding RGB(A) values. If ``True``, we
attempt to plot the image in color, so your tensor must have 3 (or 4 if
you want the alpha channel) elements in the channel dimension. If ``False``,
we plot each channel as a separate grayscale image.
**kwargs
Passed to :func:`matplotlib.pyplot.imshow`.
Returns
-------
anim
The animation object. In order to view, must convert to HTML
or save.
Raises
------
ValueError
If ``videos`` is not a 5d tensor or list of 4d tensors.
TypeError
If ``batch_idx`` or ``channel_idx`` takes an illegal value.
ValueError
If ``as_rgb=True`` and the input ``image`` does not have 3 or 4 channels.
ValueError
If ``as_rgb=False``, ``image`` has more than one channel and one more than
one batch and neither ``batch_idx`` nor ``channel_idx`` is set.
Exception
If ``plot_complex`` takes an illegal value.
See Also
--------
imshow
Display an image.
:func:`~plenoptic.plot.synthesis_animate`
Animate synthesis process for a :class:`~plenoptic.Metamer` or a
:class:`~plenoptic.MADCompetition` object.
Notes
-----
- This functions returns a matplotlib FuncAnimation object. See below for how
to view to view it in a Jupyter notebook. See Examples section for how to save to
disk. In either case, this can take a while and you'll need the appropriate writer
installed and on your path, e.g., ffmpeg, imagemagick, etc). See :doc:`matplotlib
documentation <matplotlib:api/animation_api>` for more details.
- Unless specified, we use the ffmpeg backend, which requires that you have
ffmpeg installed and on your path (https://ffmpeg.org/download.html). To
use a different, use the matplotlib rcParams:
``matplotlib.rcParams['animation.writer'] = writer``, see `matplotlib
documentation
<https://matplotlib.org/stable/api/animation_api.html#writer-classes>`_
for more details.
- To view in a Jupyter notebook, we recommend adding the following to the first cell
of your notebook (requires ffmpeg):
.. code:: python
import matplotlib.pyplot as plt
plt.rcParams["animation.html"] = "html5"
# use single-threaded ffmpeg for animation writer
plt.rcParams["animation.writer"] = "ffmpeg"
plt.rcParams["animation.ffmpeg_args"] = ["-threads", "1"]
- This interpolation avoidance is only guaranteed for the saved image; it should
generally hold in notebooks as well, but will fail if, e.g., you plot an image
that's 2000 pixels wide on an monitor 1000 pixels wide; the browser handles the
rescaling in a way we can't control.
"""
if not isinstance(video, list):
video = [video]
if any([vid.ndim != 5 for vid in video]):
raise ValueError("animshow only accepts videos as 5d tensors!")
videos_to_show = []
heights, widths = [], []
for vid in video:
vid = to_numpy(vid)
if vid.shape[0] > 1 and batch_idx is not None:
# this preserves the number of dimensions
try:
vid = vid[batch_idx : batch_idx + 1]
except TypeError:
raise TypeError(f"batch_idx must be an int or None but got {batch_idx}")
if channel_idx is not None:
try:
# this preserves the number of dimensions
vid = vid[:, channel_idx : channel_idx + 1]
except TypeError:
raise TypeError(
f"channel_idx must be an int or None but got {channel_idx}"
)
# allow RGB and RGBA
if as_rgb:
if vid.shape[1] not in [3, 4]:
raise ValueError(
"If as_rgb is True, then channel must have 3 or 4 elements!"
)
vid = vid.transpose(0, 2, 3, 4, 1)
# want to insert a fake "channel" dimension here, so our putting it
# into a list below works as expected
vid = vid.reshape((vid.shape[0], 1, *vid.shape[1:]))
elif vid.shape[1] > 1 and vid.shape[0] > 1:
raise ValueError(
"Don't know how to plot non-rgb images with more than one channel and"
" batch! Use batch_idx / channel_idx to choose a subset for"
" plotting"
)
# by iterating through it twice, we make sure to peel apart the batch
# and channel dimensions so that they each show up as a separate video.
# because of how we've handled everything above, we know that vid will
# be (b,c,t,h,w) or (b,c,t,h,w,r) where r is the RGB(A) values
for v in vid:
videos_to_show.extend([v_.squeeze() for v_ in v])
heights.extend([v_.shape[1] for v_ in v])
widths.extend([v_.shape[2] for v_ in v])
if zoom is None and ax is not None:
zoom = _find_zoom(heights, widths, ax)
elif zoom is None:
zoom = 1
elif zoom <= 0:
raise ValueError("zoom must be positive!")
return pt.animshow(
videos_to_show,
framerate=framerate,
as_html5=False,
repeat=repeat,
vrange=vrange,
zoom=zoom,
title=title,
col_wrap=col_wrap,
ax=ax,
cmap=cmap,
plot_complex=plot_complex,
**kwargs,
)
[docs]
def pyrshow(
pyr_coeffs: dict,
vrange: tuple[float, float] | str = "indep1",
zoom: float = 1,
show_residuals: bool = True,
cmap: mpl.colors.Colormap | None = None,
plot_complex: Literal["rectangular", "polar", "logpolar"] = "rectangular",
batch_idx: int = 0,
channel_idx: int = 0,
**kwargs: Any,
) -> pt.tools.display.PyrFigure:
r"""
Display steerable pyramid coefficients in orderly fashion.
This function uses :func:`~plenoptic.plot.imshow` to show the coefficients
of the steeable pyramid (or any dictionary in the standard format),
such that each scale shows up on a single row, with each scale in a given column.
Note that unlike :func:`~plenoptic.plot.imshow`, we can only show one batch
or channel at a time.
Parameters
----------
pyr_coeffs
Pyramid coefficients in the standard dictionary format as returned by the
steerable pyramid's
:func:`~plenoptic.process.SteerablePyramidFreq.forward`
method.
vrange
If a 2-tuple, specifies the image values vmin/vmax that are mapped to
the minimum and maximum value of the colormap, respectively. If a
string:
* ``"auto0"``: All images have same vmin/vmax, which have the same
absolute value, and come from the minimum or maximum across all
images, whichever has the larger absolute value.
* ``"auto1"``: All images have same vmin/vmax, which are the
minimum/maximum values across all images.
* ``"auto2"``: All images have same vmin/vmax, which are the mean (across
all images) minus/ plus 2 std dev (across all images).
* ``"auto3"``: All images have same vmin/vmax, chosen so as to map the 10th/90th
percentile values to the 10th/90th percentile of the display intensity range.
For example: vmin is the 10th percentile image value minus 1/8 times the
difference between the 90th and 10th percentile.
* ``"indep0"``: Each image has an independent vmin/vmax, which have the same
absolute value, which comes from either their minimum or maximum value,
whichever has the larger absolute value.
* ``"indep1"``: Each image has an independent vmin/vmax, which are their
minimum/maximum values.
* ``"indep2"``: Each image has an independent vmin/vmax, which is their
mean minus/plus 2 std dev.
* ``"indep3"``: Each image has an independent vmin/vmax, chosen so that the
10th/90th percentile values map to the 10th/90th percentile intensities.
zoom
Ratio of display pixels to image pixels. if greater than 1, must be an
integer. If less than 1, must be ``1/d`` where ``d`` is a a divisor of the
size of the largest image.
show_residuals
Whether to display the residual bands.
cmap
Colormap to use when showing these images.
plot_complex
Specifies handling of complex values.
* ``"rectangular"``: plot real and imaginary components as separate images.
* ``"polar"``: plot amplitude and phase as separate images.
* ``"logpolar"``: plot log (base 2) amplitude and phase as separate images.
batch_idx
Which element from the batch dimension to plot.
channel_idx
Which element from the channel dimension to plot.
**kwargs
Passed on to :func:`pyrtools.tools.display.pyrshow`.
Returns
-------
fig
The figure displaying the coefficients.
Raises
------
TypeError
If ``batch_idx`` or ``channel_idx`` takes an illegal value.
"""
pyr_coeffvis = {}
is_complex = False
for k, v in pyr_coeffs.items():
if isinstance(k, str):
ims = [v]
keys = [k]
else:
ims = einops.rearrange(v, "b c o h w -> o b c h w")
keys = [(k, i) for i in range(len(ims))]
for key, im in zip(keys, ims):
im = to_numpy(im)
if np.iscomplex(im).any():
is_complex = True
try:
# this removes only the first (batch) dimension
im = im[batch_idx : batch_idx + 1].squeeze(0)
except TypeError:
raise TypeError(f"batch_idx must be an int but got {batch_idx}")
try:
# this removes only the first (now channel) dimension
im = im[channel_idx : channel_idx + 1].squeeze(0)
except TypeError:
raise TypeError(f"channel_idx must be an int but got {channel_idx}")
# because of how we've handled everything above, we know that im will
# be (h,w).
pyr_coeffvis[key] = im
return pt.pyrshow(
pyr_coeffvis,
is_complex=is_complex,
vrange=vrange,
zoom=zoom,
cmap=cmap,
plot_complex=plot_complex,
show_residuals=show_residuals,
**kwargs,
)
def _clean_up_axes(
ax: mpl.axes.Axes,
ylim: tuple[float, float] | None | Literal[False] = None,
spines_to_remove: list[Literal["top", "right", "bottom", "left"]] = [
"top",
"right",
"bottom",
],
axes_to_remove: list[Literal["x", "y"]] = ["x"],
) -> mpl.axes.Axes:
r"""
Clean up an axis, as desired when making a stem plot of the representation.
This function can:
- Remove the spines from axis (the axis lines and tick marks).
- Set axis objects themselves invisible (includes not just spines but also tick
labels and axis label).
- Set ylim.
Parameters
----------
ax
The axis to clean up.
ylim
If a tuple, the y-limits to use for this plot. If ``None``, we use
the original limits, slightly adjusted so that the minimum is 0. If
``False``, we do not change y-limits.
spines_to_remove
The spines to remove from the axis.
axes_to_remove
The axes to set as invisible.
Returns
-------
ax
The cleaned-up axis.
"""
if spines_to_remove is None:
spines_to_remove = ["top", "right", "bottom"]
if axes_to_remove is None:
axes_to_remove = ["x"]
if ylim is not None:
if ylim:
ax.set_ylim(ylim)
else:
ax.set_ylim((0, ax.get_ylim()[1]))
if "x" in axes_to_remove:
ax.xaxis.set_visible(False)
if "y" in axes_to_remove:
ax.yaxis.set_visible(False)
for s in spines_to_remove:
ax.spines[s].set_visible(False)
return ax
def _update_stem(
stem_container: mpl.container.StemContainer, ydata: np.ndarray | torch.Tensor
) -> mpl.container.StemContainer:
r"""
Update the information in a stem plot (for an animation).
We update the information in a single stem plot to match that given
by ``ydata``. We update the position of the markers and and the
lines connecting them to the baseline, but we don't change the
baseline at all and assume that the xdata shouldn't change at all.
Parameters
----------
stem_container
Single container for the artists created in a :func:`matplotlib.pyplot.stem`
plot. It can be treated like a namedtuple ``(markerline, stemlines, baseline)``.
In order to get this from an axis ``ax``, try ``ax.containers[0]`` (if you have
more than one container in that axis, it may not be the first one).
ydata
The new y-data to show on the plot. Importantly, must be the same length as
the existing y-data.
Returns
-------
stem_container
The StemContainer containing the updated artists.
"""
stem_container.markerline.set_ydata(ydata)
segments = stem_container.stemlines.get_segments().copy()
for s, y in zip(segments, ydata):
try:
s[1, 1] = y
except IndexError:
# this happens when our segment array is 1x2 instead of 2x2,
# which is the case when the data there is nan
continue
stem_container.stemlines.set_segments(segments)
return stem_container
def _rescale_ylim(axes: list[mpl.axes.Axes], data: np.ndarray | torch.Tensor):
r"""
Rescale y-limits nicely.
We take the axes and set their limits to be ``(-y_max, y_max)``,
where ``y_max=np.abs(data).max()``.
Parameters
----------
axes
A list of matplotlib axes to rescale.
data
The data to use when rescaling (or a dictionary of such values).
"""
data = data.cpu()
def find_ymax(data: np.ndarray | torch.Tensor) -> float:
"""
Find appropriate ymax.
Parameters
----------
data
The tensor whose ymax we should grab.
Returns
-------
ymax
The appropriate ymax.
""" # numpydoc ignore=ES01
try:
return np.abs(data).max()
except RuntimeError:
# then we need to call to_numpy on it because it needs to be
# detached and converted to an array
return np.abs(to_numpy(data)).max()
try:
y_max = find_ymax(data)
except TypeError:
# then this is a dictionary
y_max = np.max([find_ymax(d) for d in data.values()])
for ax in axes:
ax.set_ylim((-y_max, y_max))
[docs]
def stem_plot(
data: torch.Tensor,
ax: mpl.axes.Axes | None = None,
title: str | None = "",
ylim: tuple | None | Literal[False] = None,
xvals: tuple[list[float], list[float]] | None = None,
**kwargs: Any,
) -> mpl.axes.Axes:
r"""
Create a simple stem plot.
This plots the data, baseline, cleans up the axis, and sets the
title.
Helper function for :func:`~plenoptic.plot.plot_representation()`.
If ``xvals=None``, stem plot will have a baseline that covers the entire range
of the data. In order to break that up visually (so there's a line from 0 to 9,
from 10 to 19, etc) pass ``xvals`` separately.
Parameters
----------
data
The data to plot (as a stem plot).
ax
The axis to plot the data on. If ``None``, we plot on the current axis
(grabbed with :func:`matplotlib.pyplot.gca`).
title
The title to put on the axis. If ``None``, we don't call ``ax.set_title``
(useful if you want to avoid changing the title on an existing plot).
ylim
If a tuple, the y-limits to use for this plot. If ``None``, we use
the original limits, slightly adjusted so that the minimum is 0. If
``False``, we do not change y-limits.
xvals
A 2-tuple of lists, containing the start (``xvals[0]``) and stop
(``xvals[1]``) x values for plotting. If ``None``, baseline will cover
full range.
**kwargs
Passed to :func:`matplotlib.pyplot.stem`.
Returns
-------
ax
The axis with the plot.
Examples
--------
We allow for breaks in the baseline value if we want to visually
break up the plot, as we see below.
.. plot::
:context: reset
>>> import plenoptic as po
>>> import numpy as np
>>> # if ylim=None, as in this example, the minimum y-valuewill get
>>> # set to 0, so we want to make sure our values are all positive
>>> y = np.abs(np.random.randn(55))
>>> y[15:20] = np.nan
>>> y[35:40] = np.nan
>>> # we want to draw the baseline from 0 to 14, 20 to 34, and 40 to
>>> # 54, everywhere that we have non-NaN values for y
>>> xvals = ([0, 20, 40], [14, 34, 54])
>>> po.plot.stem_plot(y, xvals=xvals)
<Axes: >
If we don't care about breaking up the x-axis, you can set ``xvals=None``.
In this case, this function will just clean up the plot a little bit.
.. plot::
:context: close-figs
>>> # if ylim=None, as in this example, the minimum y-value will get
>>> # set to 0, so we want to make sure our values are all positive
>>> y = np.abs(np.random.randn(55))
>>> po.plot.stem_plot(y)
<Axes: >
"""
if ax is None:
ax = plt.gca()
if xvals is not None:
basefmt = " "
ax.hlines(len(xvals[0]) * [0], xvals[0], xvals[1], colors="C3", zorder=10)
else:
# this is the default basefmt value
basefmt = None
ax.stem(data, basefmt=basefmt, **kwargs)
ax = _clean_up_axes(ax, ylim, ["top", "right", "bottom"])
if title is not None:
ax.set_title(title)
return ax
def _get_artists_from_axes(
axes: mpl.axes.Axes | list[mpl.axes.Axes],
data: torch.Tensor | dict,
) -> dict:
"""
Grab artists from axes.
For now, we only grab containers (stem plots), images, or lines
See the docstring of :meth:`~plenoptic.plot.update_plot()` for details on
how ``axes`` and ``data`` should be structured.
Parameters
----------
axes
The axis/axes to update.
data
The new data to plot.
Returns
-------
artists
Dictionary of artists for updating plots. Values are the artists to
use, keys are the corresponding keys from data.
Raises
------
ValueError
If the number of artists in ``axes`` is different from the size of the
dimension-to-plot of ``data``.
"""
if not hasattr(axes, "__iter__"):
# then we only have one axis, so we may be able to update more than one
# data element.
if len(axes.containers) > 0:
data_check = 1
artists = axes.containers
elif len(axes.images) > 0:
# images are weird, so don't check them like this
data_check = None
artists = axes.images
elif len(axes.lines) > 0:
data_check = 1
artists = axes.lines
elif len(axes.collections) > 0:
data_check = 2
artists = axes.collections
if isinstance(data, dict):
artists = {ax.get_label(): ax for ax in artists}
else:
if data_check == 1 and data.shape[1] != len(artists):
raise ValueError(
f"data has {data.shape[1]} things to plot, but "
f"your axis contains {len(artists)} plotting artists, "
"so unsure how to continue! Pass data as a dictionary"
" with keys corresponding to the labels of the artists"
" to update to resolve this."
)
elif data_check == 2 and data.ndim > 2 and data.shape[-3] != len(artists):
raise ValueError(
f"data has {data.shape[-3]} things to plot, but "
f"your axis contains {len(artists)} plotting artists, "
"so unsure how to continue! Pass data as a dictionary"
" with keys corresponding to the labels of the artists"
" to update to resolve this."
)
else:
# then we have multiple axes, so we are only updating one data element
# per plot
artists = []
for ax in axes:
if len(ax.containers) == 1:
data_check = 1
artists.extend(ax.containers)
elif len(ax.images) == 1:
# images are weird, so don't check them like this
data_check = None
artists.extend(ax.images)
elif len(ax.lines) == 1:
artists.extend(ax.lines)
data_check = 1
elif len(ax.collections) == 1:
artists.extend(ax.collections)
data_check = 2
if isinstance(data, dict):
if len(data.keys()) != len(artists):
raise ValueError(
f"data has {len(data.keys())} things to plot, but "
f"you passed {len(axes)} axes , so unsure how "
"to continue!"
)
artists = {k: a for k, a in zip(data.keys(), artists)}
else:
if data_check == 1 and data.shape[1] != len(artists):
raise ValueError(
f"data has {data.shape[1]} things to plot, but "
f"you passed {len(axes)} axes , so unsure how "
"to continue!"
)
if data_check == 2 and data.ndim > 2 and data.shape[-3] != len(artists):
raise ValueError(
f"data has {data.shape[-3]} things to plot, but "
f"you passed {len(axes)} axes , so unsure how "
"to continue!"
)
if not isinstance(artists, dict):
artists = {f"{i:02d}": a for i, a in enumerate(artists)}
return artists
[docs]
def update_plot(
axes: mpl.axes.Axes | list[mpl.axes.Axes],
data: torch.Tensor | dict,
model: torch.nn.Module | None = None,
batch_idx: int = 0,
) -> list:
r"""
Update the information in some axes (for an animation).
This is used for creating an animation over time. In order to create the animation,
we need to know how to update the matplotlib Artists, and this provides a simple way
of doing that. It assumes the plot has been created by something like
:func:`~plenoptic.plot.plot_representation`, which initializes all the
artists.
We can update stem plots, lines (as returned by :func:`matplotlib.pyplot.plot`),
scatter plots, or images (RGB, RGBA, or grayscale).
There are two modes for this:
- Single axis: ``axes`` is a single axis, which may contain multiple artists (all of
the same type) to update. ``data`` should be a :class:`torch.Tensor` with multiple
channels (one per artist in the same order) or be a dictionary whose keys give the
label(s) of the corresponding artist(s) and whose values are
:class:`torch.Tensor`.
- Multiple axes: ``axes`` is a list of axes, each of which contains a single artist
to update (artists can be different types). ``data`` should be a
:class:`torch.Tensor` with multiple channels (one per axis in the same order) or a
dictionary with the same number of keys as ``axes``, which we can iterate through
in order, and whose values are :class:`torch.Tensor`.
In all cases, ``data`` Tensors should be 3d (if the plot we're updating is a
line or stem plot) or 4d (if it's an image or scatter plot).
RGB(A) images are special, since we store that info along the channel
dimension, so they only work with single-axis mode (which will only have a
single artist, because that's how imshow works).
If you have multiple axes, each with multiple artists you want to update, that's too
complicated for us, and so you should write a ``model.update_plot()`` function which
handles that (see
:func:`plenoptic.models.PortillaSimoncelli.update_plot`
for an example).
If ``model`` is set, we try to call ``model.update_plot()`` (which must also return
artists). If ``model`` doesn't have an ``update_plot`` method, then we try to figure
out how to update the axes ourselves, based on the shape of the data.
Parameters
----------
axes
The axis or list of axes to update. We assume that these are the axes created by
:func:`~plenoptic.plot.plot_representation` and so contain artists
in the correct order.
data
The new data to plot.
model
A differentiable model that tells us how to plot ``data``. See above for
behavior if ``None``.
batch_idx
Which index to take from the batch dimension.
Returns
-------
artists
A list of the artists used to update the information on the
plots.
Raises
------
ValueError
If ``data`` (or its values, if it's a ``dict``) are not 3 or 4 dimensional.
"""
if isinstance(data, dict):
for v in data.values():
if v.ndim not in [3, 4]:
raise ValueError(
"update_plot expects 3 or 4 dimensional data"
"; unexpected behavior will result otherwise!"
f" Got data of shape {v.shape}"
)
else:
if data.ndim not in [3, 4]:
raise ValueError(
"update_plot expects 3 or 4 dimensional data"
"; unexpected behavior will result otherwise!"
f" Got data of shape {data.shape}"
)
try:
artists = model.update_plot(axes=axes, batch_idx=batch_idx, data=data)
except AttributeError:
ax_artists = _get_artists_from_axes(axes, data)
artists = []
if not isinstance(data, dict):
data_dict = {}
# check for RGBA images
if len(ax_artists) == 1 and data.shape[1] > 1:
# can't index into dict.values(), so use this work around
# instead, as suggested
# https://stackoverflow.com/questions/43629270/how-to-get-single-value-from-dict-with-single-entry
try:
if next(iter(ax_artists.values())).get_array().data.ndim > 1:
# then this is an RGBA image
data_dict = {"00": data}
except Exception as e:
raise ValueError(
"Thought this was an RGB(A) image based on the number"
" of artists and data shape, but something is off!"
f" Original exception: {e}"
)
else:
for i, d in enumerate(data.unbind(1)):
# need to keep the shape the same because of how we
# check for shape below (unbinding removes a dimension,
# so we add it back)
data_dict[f"{i:02d}"] = d.unsqueeze(1)
data = data_dict
for k, d in data.items():
try:
art = ax_artists[k]
except KeyError:
# If the we're grabbing these labels from the line labels and
# they were originally ints, they will get converted to
# strings. this catches that
art = ax_artists[str(k)]
d = to_numpy(d[batch_idx]).squeeze()
if d.ndim == 1:
try:
# then it's a line
x, _ = art.get_data()
art.set_data(x, d)
artists.append(art)
except AttributeError:
# then it's a stemplot
sc = _update_stem(art, d)
artists.extend([sc.markerline, sc.stemlines])
elif d.ndim == 2:
try:
# then it's a grayscale image
art.set_data(d)
artists.append(art)
except AttributeError:
# then it's a scatterplot
art.set_offsets(d)
artists.append(art)
else:
# then it's an RGB(A) image. for tensors, we put that dimension
# in channel, but for images, it should be at the end
art.set_data(np.moveaxis(d, 0, -1))
artists.append(art)
# make sure to always return a list
if not isinstance(artists, list):
artists = [artists]
return artists
[docs]
def plot_representation(
model: torch.nn.Module | None = None,
data: np.ndarray | dict | None = None,
ax: mpl.axes.Axes | None = None,
figsize: tuple[float, float] | None = None,
ylim: tuple[float, float] | None | Literal[False] = False,
batch_idx: int = 0,
title: str = "",
as_rgb: bool = False,
) -> list[mpl.axes.Axes]:
r"""
Plot model representation.
We try to plot ``data`` on ``ax``, using the ``model.plot_representation`` method,
if it has it, and otherwise use a function that makes sense based on the
shape of ``data``.
All of these arguments are optional, but at least some of them need
to be set:
- If ``model`` is ``None``, we fall-back to a type of plot based on the shape of
``data``. If it looks image-like, we'll use
:func:`~plenoptic.plot.imshow` and if it looks vector-like, we'll use
:func:`~plenoptic.plot.stem_plot`. If it's a dictionary, we'll
assume each key, value pair gives the title and data to plot on a separate
sub-plot.
- If ``data`` is ``None``, we can only do something if
``model.plot_representation`` has some behavior when
``data=None``; this is probably to plot its own ``representation``
attribute. Thus, this will raise an Exception if both ``model`` and
``data`` are ``None``, because we have no idea what to plot then.
- If ``ax`` is ``None``, we create a one-subplot figure using ``figsize``.
If ``ax`` is not ``None``, we therefore ignore ``figsize``.
- If ``ylim`` is ``None``, we set the axes' y-limits to be ``(-y_max, y_max)``,
where ``y_max=np.abs(data).max()``. If it's ``False``, we do nothing.
Parameters
----------
model
A differentiable model that tells us how to plot ``data``. See
above for behavior if ``None``.
data
The data to plot. See above for behavior if ``None``.
ax
The axis to plot on. See above for behavior if ``None``.
figsize
The size of the figure to create. Must be ``None`` if ax is not ``None``. If
both figsize and ax are ``None``, then we use ``figsize=(5, 5)``.
ylim
The y-limits to use for this plot. See above for behavior if ``None``.
If ``False``, we do nothing. Ignored if ``data`` looks image-like.
batch_idx
Which index to take from the batch dimension.
title
The title to put above this axis. If you want no title, pass
the empty string (``""``).
as_rgb
Whether to consider the channels as encoding RGB(A) values. It will be
ignored if the representation doesn't look image-like or if the
model has its own plot_representation_error() method. Else, it will
be passed to :func:`~plenoptic.plot.imshow`, see that method's
docstring for details.
Returns
-------
axes
List of created axes.
Raises
------
ValueError
If both ``figsize`` and ``ax`` are not ``None``.
ValueError
If ``data`` (or its values, if it's a ``dict``) are not 3 or 4 dimensional.
See Also
--------
:func:`~plenoptic.plot.metamer_representation_error`
Plot representation error for a :class:`~plenoptic.Metamer` object at a
specified iteration.
stem_plot
If ``model`` does not have a ``plot_representation`` method and its output
is 3d, the function used to visualize its output.
imshow
If ``model`` does not have a ``plot_representation`` method and its output
is 4d, the function used to visualize its output.
"""
if ax is None:
if figsize is None:
figsize = (5, 5)
fig, ax = plt.subplots(1, 1, figsize=figsize)
else:
if figsize is not None:
raise ValueError("figsize can't be set if ax is not None")
fig = ax.figure
try:
# no point in passing figsize, because we've already created
# and are passing an axis or are passing the user-specified one
fig, axes = model.plot_representation(
ylim=ylim, ax=ax, title=title, batch_idx=batch_idx, data=data
)
except AttributeError:
if data is None:
data = model.representation
if not isinstance(data, dict):
if title is None:
title = "Representation"
data_dict = {}
if not as_rgb:
# then we peel apart the channels
for i, d in enumerate(data.unbind(1)):
# need to keep the shape the same because of how we
# check for shape below (unbinding removes a dimension,
# so we add it back)
data_dict[title + f"_{i:02d}"] = d.unsqueeze(1)
else:
data_dict[title] = data
data = data_dict
else:
warnings.warn("data has keys, so we're ignoring title!")
# want to make sure the axis we're taking over is basically invisible.
ax = _clean_up_axes(ax, False, ["top", "right", "bottom", "left"], ["x", "y"])
axes = []
if len(list(data.values())[0].shape) == 3:
# then this is 'vector-like'
gs = ax.get_subplotspec().subgridspec(
min(4, len(data)), int(np.ceil(len(data) / 4))
)
for i, (k, v) in enumerate(data.items()):
ax = fig.add_subplot(gs[i % 4, i // 4])
# only plot the specified batch, but plot each channel
# in a separate call. there should probably only be one,
# and if there's not you probably want to do things
# differently
for d in v[batch_idx]:
ax = stem_plot(to_numpy(d), ax, k, ylim)
axes.append(ax)
elif len(list(data.values())[0].shape) == 4:
# then this is 'image-like'
gs = ax.get_subplotspec().subgridspec(
int(np.ceil(len(data) / 4)), min(4, len(data))
)
for i, (k, v) in enumerate(data.items()):
ax = fig.add_subplot(gs[i // 4, i % 4])
ax = _clean_up_axes(
ax, False, ["top", "right", "bottom", "left"], ["x", "y"]
)
# only plot the specified batch
imshow(
v,
batch_idx=batch_idx,
title=k,
ax=ax,
vrange="indep0",
as_rgb=as_rgb,
)
axes.append(ax)
# because we're plotting image data, don't want to change
# ylim at all
ylim = False
else:
raise ValueError(f"Don't know what to do with data of shape {data.shape}")
if ylim is None:
if isinstance(data, dict):
data = torch.cat(list(data.values()), dim=2)
_rescale_ylim(axes, data)
return axes
[docs]
def histogram(
data: torch.Tensor | list[torch.Tensor],
labels: str | list[str] | None = None,
batch_idx: int | None = 0,
channel_idx: int | None = None,
ylim: tuple[float, float] | Literal[False] = False,
xlim: tuple[float, float] | Literal[False, "range"] = "range",
xlabel: str | Literal[False] = "Values",
ax: mpl.axes.Axes | None = None,
title: str = "Histogram of tensor values",
alpha: float = 0.4,
**kwargs: Any,
) -> mpl.axes.Axes:
r"""
Plot histogram of values from tensor.
Intended use for this is to plot distributions of pixel values.
Parameters
----------
data
The data to plot. Must either be a single tensor or a list of tensors.
labels
Labels to use for legend. Must match ``data``: if ``data`` is a single
tensor, must be a single string; if ``data`` is a list of tensors, must
be a list of the same length. If ``None``, no legend is created.
batch_idx
Which index to take from the batch (first) dimension. If ``None``, we
use all batches.
channel_idx
Which index to take from the channel (second) dimension. If ``None``,
we use all channels.
ylim
If tuple, the ylimit to set for this axis. If ``False``, we leave
it untouched.
xlim
If ``"range"``, set the xlimits to the range across plotted data.
If tuple, the xlimit to set for this axis. If ``False``, we leave
it untouched.
xlabel
Label to put on the x-axis.
ax
Pre-existing axes for plot. If ``None``, we call
:func:`matplotlib.pyplot.gca()`.
title
Title for the axis.
alpha
Alpha value for the histogram bars.
**kwargs
Passed to :func:`matplotlib.pyplot.hist`.
Returns
-------
ax
Created axes.
Raises
------
ValueError
If ``labels`` and ``data`` are both lists but have different lengths
See Also
--------
:func:`~plenoptic.plot.synthesis_histogram`
Use this function to plot histogram of values from a synthesis object.
Examples
--------
.. plot::
:context: reset
>>> import plenoptic as po
>>> import torch
>>> img = po.data.einstein()
>>> po.plot.histogram(img)
<Axes: ... 'Histogram of tensor values'...>
Plot on an existing axis:
.. plot::
:context: close-figs
>>> fig, axes = plt.subplots(1, 2, figsize=(8, 4))
>>> po.plot.histogram(img, ax=axes[1])
<Axes: ... 'Histogram of tensor values'...>
"""
def _freedman_diaconis_bins(a: np.ndarray) -> int:
"""
Calculate number of hist bins using Freedman-Diaconis rule.
Copied from seaborn.
Parameters
----------
a
The array to histogram.
Returns
-------
n_bins
Number of bins to use for histogram.
""" # numpydoc ignore=EX01
# From https://stats.stackexchange.com/questions/798/
a = np.asarray(a)
iqr = np.diff(np.percentile(a, [0.25, 0.75]))[0]
if len(a) < 2:
return 1
h = 2 * iqr / (len(a) ** (1 / 3))
# fall back to sqrt(a) bins if iqr is 0
if h == 0:
return int(np.sqrt(a.size))
else:
return int(np.ceil((a.max() - a.min()) / h))
if not isinstance(data, list):
data = [data]
if labels is None:
create_legend = False
# so we can iterate through labels along with data
labels = len(data) * [""]
else:
create_legend = True
if not isinstance(labels, list):
labels = [labels]
if len(labels) != len(data):
raise ValueError("labels must have the same length as data!")
if batch_idx is not None:
data = [d[batch_idx] for d in data]
if channel_idx is not None:
data = [d[channel_idx] for d in data]
data = [to_numpy(d).flatten() for d in data]
if xlim == "range":
tmp_data = np.concatenate(data)
xlim = (tmp_data.min(), tmp_data.max())
if ax is None:
ax = plt.gca()
for d, lab in zip(data, labels):
ax.hist(
d,
bins=min(_freedman_diaconis_bins(d), 50),
label=lab,
alpha=alpha,
**kwargs,
)
if create_legend:
ax.legend()
if ylim:
ax.set_ylim(ylim)
if xlim:
ax.set_xlim(xlim)
if xlabel:
ax.set_xlabel(xlabel)
ax.set_title(title)
return ax