Source code for plenoptic.tools.display

"""various helpful utilities for plotting or displaying information"""

import warnings

import matplotlib.pyplot as plt
import numpy as np
import pyrtools as pt
import torch

from .data import to_numpy


[docs] def imshow( image, vrange="indep1", zoom=None, title="", col_wrap=None, ax=None, cmap=None, plot_complex="rectangular", batch_idx=None, channel_idx=None, as_rgb=False, **kwargs, ): """Show image(s) correctly. This function shows images correctly, making sure that each element in the tensor corresponds to a pixel or an integer number of pixels, to avoid aliasing (NOTE: this guarantee only holds for the saved image; it should generally hold in notebooks as well, but will fail if, e.g., you plot an image that's 2000 pixels wide on an monitor 1000 pixels wide; the notebook handles the rescaling in a way we can't control). Parameters ---------- image : torch.Tensor or list The images to display. Tensors should be 4d (batch, channel, height, width). List of tensors should be used for tensors of different height and width: all images will automatically be rescaled so they're displayed at the same height and width, thus, their heights and widths must be scalar multiples of each other. vrange : `tuple` or `str` If a 2-tuple, specifies the image values vmin/vmax that are mapped to the minimum and maximum value of the colormap, respectively. If a string: * `'auto0'`: all images have same vmin/vmax, which have the same absolute value, and come from the minimum or maximum across all images, whichever has the larger absolute value * `'auto/auto1'`: all images have same vmin/vmax, which are the minimum/maximum values across all images * `'auto2'`: all images have same vmin/vmax, which are the mean (across all images) minus/ plus 2 std dev (across all images) * `'auto3'`: all images have same vmin/vmax, chosen so as to map the 10th/90th percentile values to the 10th/90th percentile of the display intensity range. For example: vmin is the 10th percentile image value minus 1/8 times the difference between the 90th and 10th percentile * `'indep0'`: each image has an independent vmin/vmax, which have the same absolute value, which comes from either their minimum or maximum value, whichever has the larger absolute value. * `'indep1'`: each image has an independent vmin/vmax, which are their minimum/maximum values * `'indep2'`: each image has an independent vmin/vmax, which is their mean minus/plus 2 std dev * `'indep3'`: each image has an independent vmin/vmax, chosen so that the 10th/90th percentile values map to the 10th/90th percentile intensities. zoom : `float` or `None` ratio of display pixels to image pixels. if >1, must be an integer. If <1, must be 1/d where d is a a divisor of the size of the largest image. If None, we try to determine the best zoom. title : `str`, `list`, or None, optional Title for the plot. In addition to the specified title, we add a subtitle giving the plotted range and dimensionality (with zoom) * if `str`, will put the same title on every plot. * if `list`, all values must be `str`, must be the same length as img, assigning each title to corresponding image. * if None, no title will be printed (and subtitle will be removed). col_wrap : `int` or None, optional number of axes to have in each row. If None, will fit all axes in a single row. ax : `matplotlib.pyplot.axis` or None, optional if None, we make the appropriate figure. otherwise, we resize the axes so that it's the appropriate number of pixels (done by shrinking the bbox - if the bbox is already too small, this will throw an Exception!, so first define a large enough figure using either make_figure or plt.figure) cmap : matplotlib colormap, optional colormap to use when showing these images plot_complex : {'rectangular', 'polar', 'logpolar'} specifies handling of complex values. * `'rectangular'`: plot real and imaginary components as separate images * `'polar'`: plot amplitude and phase as separate images * `'logpolar'`: plot log_2 amplitude and phase as separate images for any other value, we raise a warning and default to rectangular. batch_idx : int or None, optional Which element from the batch dimension to plot. If None, we plot all. channel_idx : int or None, optional Which element from the channel dimension to plot. If None, we plot all. Note if this is an int, then `as_rgb=True` will fail, because we restrict the channels. as_rgb : bool, optional Whether to consider the channels as encoding RGB(A) values. If True, we attempt to plot the image in color, so your tensor must have 3 (or 4 if you want the alpha channel) elements in the channel dimension, or this will raise an Exception. If False, we plot each channel as a separate grayscale image. kwargs : Passed to `ax.imshow` Returns ------- fig : `PyrFigure` figure containing the plotted images """ if not isinstance(image, list): image = [image] images_to_plot = [] heights, widths = [], [] for im in image: im = to_numpy(im) if im.shape[0] > 1 and batch_idx is not None: # this preserves the number of dimensions im = im[batch_idx : batch_idx + 1] if channel_idx is not None: # this preserves the number of dimensions im = im[:, channel_idx : channel_idx + 1] # allow RGB and RGBA if as_rgb: if im.shape[1] not in [3, 4]: raise Exception( "If as_rgb is True, then channel must have 3 " "or 4 elements!" ) im = im.transpose(0, 2, 3, 1) # want to insert a fake "channel" dimension here, so our putting it # into a list below works as expected im = im.reshape((im.shape[0], 1, *im.shape[1:])) elif im.shape[1] > 1 and im.shape[0] > 1: raise Exception( "Don't know how to plot images with more than one channel and" " batch! Use batch_idx / channel_idx to choose a subset for" " plotting" ) # by iterating through it twice, we make sure to peel apart the batch # and channel dimensions so that they each show up as a separate image. # because of how we've handled everything above, we know that im will # be (b,c,h,w) or (b,c,h,w,r) where r is the RGB(A) values for i in im: # at this point, i_ are all shape (h,w) or (h,w,r) and so we don't # squeeze, which could accidentally drop a dimension if h or w is a # singleton dimension images_to_plot.extend([i_ for i_ in i]) heights.extend([i_.shape[0] for i_ in i]) widths.extend([i_.shape[1] for i_ in i]) def find_zoom(x, limit): """Find zoom that works. This is only for limit < x.""" # find all non-trivial divisors of x divisors = [i for i in range(2, x) if not x % i] # find the largest zoom (equivalently, smallest divisor) such that the # zoomed in image is smaller than the limit return 1 / min([i for i in divisors if x / i <= limit]) if ax is not None and zoom is None: if ax.bbox.height > max(heights): zoom = ax.bbox.height // max(heights) else: zoom = find_zoom(max(heights), ax.bbox.height) if ax.bbox.width > max(widths): zoom = min(zoom, ax.bbox.width // max(widths)) else: zoom = find_zoom(max(widths), ax.bbox.width) elif zoom is None: zoom = 1 return pt.imshow( images_to_plot, vrange=vrange, zoom=zoom, title=title, col_wrap=col_wrap, ax=ax, cmap=cmap, plot_complex=plot_complex, **kwargs, )
[docs] def animshow( video, framerate=2.0, repeat=False, vrange="indep1", zoom=1, title="", col_wrap=None, ax=None, cmap=None, plot_complex="rectangular", batch_idx=None, channel_idx=None, as_rgb=False, **kwargs, ): """Animate video(s) correctly. This function animates videos correctly, making sure that each element in the tensor corresponds to a pixel or an integer number of pixels, to avoid aliasing (NOTE: this guarantee only holds for the saved animation (assuming video compression doesn't interfere); it should generally hold in notebooks as well, but will fail if, e.g., your video is 2000 pixels wide on an monitor 1000 pixels wide; the notebook handles the rescaling in a way we can't control). This functions returns a matplotlib FuncAnimation object. See our documentation (e.g., [Quickstart](https://docs.plenoptic.org/docs/branch/main/tutorials/00_quickstart.html)) for examples on how to view it in a Jupyter notebook. In order to save, use ``anim.save(filename)``. In either case, this can take a while and you'll need the appropriate writer installed and on your path, e.g., ffmpeg, imagemagick, etc). See [matplotlib documentation](https://matplotlib.org/stable/api/animation_api.html) for more details. Parameters ---------- video : torch.Tensor or list The videos to display. Tensors should be 5d (batch, channel, time, height, width). List of tensors should be used for tensors of different height and width: all videos will automatically be rescaled so they're displayed at the same height and width, thus, their heights and widths must be scalar multiples of each other. Videos must all have the same number of frames as well. framerate : `float` Temporal resolution of the video, in Hz (frames per second). repeat : `bool` whether to loop the animation or just play it once vrange : `tuple` or `str` If a 2-tuple, specifies the image values vmin/vmax that are mapped to the minimum and maximum value of the colormap, respectively. If a string: * `'auto0'`: all images have same vmin/vmax, which have the same absolute value, and come from the minimum or maximum across all images, whichever has the larger absolute value * `'auto/auto1'`: all images have same vmin/vmax, which are the minimum/maximum values across all images * `'auto2'`: all images have same vmin/vmax, which are the mean (across all images) minus/ plus 2 std dev (across all images) * `'auto3'`: all images have same vmin/vmax, chosen so as to map the 10th/90th percentile values to the 10th/90th percentile of the display intensity range. For example: vmin is the 10th percentile image value minus 1/8 times the difference between the 90th and 10th percentile * `'indep0'`: each image has an independent vmin/vmax, which have the same absolute value, which comes from either their minimum or maximum value, whichever has the larger absolute value. * `'indep1'`: each image has an independent vmin/vmax, which are their minimum/maximum values * `'indep2'`: each image has an independent vmin/vmax, which is their mean minus/plus 2 std dev * `'indep3'`: each image has an independent vmin/vmax, chosen so that the 10th/90th percentile values map to the 10th/90th percentile intensities. zoom : `float` ratio of display pixels to image pixels. if >1, must be an integer. If <1, must be 1/d where d is a a divisor of the size of the largest image. title : `str`, `list`, or None, optional Title for the plot. In addition to the specified title, we add a subtitle giving the plotted range and dimensionality (with zoom) * if `str`, will put the same title on every plot. * if `list`, all values must be `str`, must be the same length as img, assigning each title to corresponding image. * if None, no title will be printed (and subtitle will be removed). col_wrap : `int` or None, optional number of axes to have in each row. If None, will fit all axes in a single row. ax : `matplotlib.pyplot.axis` or None, optional if None, we make the appropriate figure. otherwise, we resize the axes so that it's the appropriate number of pixels (done by shrinking the bbox - if the bbox is already too small, this will throw an Exception!, so first define a large enough figure using either pyrtools.make_figure or plt.figure) cmap : matplotlib colormap, optional colormap to use when showing these images plot_complex : {'rectangular', 'polar', 'logpolar'} specifies handling of complex values. * `'rectangular'`: plot real and imaginary components as separate images * `'polar'`: plot amplitude and phase as separate images * `'logpolar'`: plot log_2 amplitude and phase as separate images for any other value, we raise a warning and default to rectangular. batch_idx : int or None, optional Which element from the batch dimension to plot. If None, we plot all. channel_idx : int or None, optional Which element from the channel dimension to plot. If None, we plot all. Note if this is an int, then `as_rgb=True` will fail, because we restrict the channels. as_rgb : bool, optional Whether to consider the channels as encoding RGB(A) values. If True, we attempt to plot the image in color, so your tensor must have 3 (or 4 if you want the alpha channel) elements in the channel dimension, or this will raise an Exception. If False, we plot each channel as a separate grayscale image. kwargs : Passed to `ax.imshow` Returns ------- anim : matplotlib.animation.FuncAnimation The animation object. In order to view, must convert to HTML or save. Notes ----- By default, we use the ffmpeg backend, which requires that you have ffmpeg installed and on your path (https://ffmpeg.org/download.html). To use a different, use the matplotlib rcParams: `matplotlib.rcParams['animation.writer'] = writer`, see https://matplotlib.org/stable/api/animation_api.html#writer-classes for more details. """ if not isinstance(video, list): video = [video] videos_to_show = [] for vid in video: vid = to_numpy(vid) if vid.shape[0] > 1 and batch_idx is not None: # this preserves the number of dimensions vid = vid[batch_idx : batch_idx + 1] if channel_idx is not None: # this preserves the number of dimensions vid = vid[:, channel_idx : channel_idx + 1] # allow RGB and RGBA if as_rgb: if vid.shape[1] not in [3, 4]: raise Exception( "If as_rgb is True, then channel must have 3 " "or 4 elements!" ) vid = vid.transpose(0, 2, 3, 4, 1) # want to insert a fake "channel" dimension here, so our putting it # into a list below works as expected vid = vid.reshape((vid.shape[0], 1, *vid.shape[1:])) elif vid.shape[1] > 1 and vid.shape[0] > 1: raise Exception( "Don't know how to plot images with more than one channel and" " batch! Use batch_idx / channel_idx to choose a subset for" " plotting" ) # by iterating through it twice, we make sure to peel apart the batch # and channel dimensions so that they each show up as a separate video. # because of how we've handled everything above, we know that vid will # be (b,c,t,h,w) or (b,c,t,h,w,r) where r is the RGB(A) values for v in vid: videos_to_show.extend([v_.squeeze() for v_ in v]) return pt.animshow( videos_to_show, framerate=framerate, as_html5=False, repeat=repeat, vrange=vrange, zoom=zoom, title=title, col_wrap=col_wrap, ax=ax, cmap=cmap, plot_complex=plot_complex, **kwargs, )
[docs] def pyrshow( pyr_coeffs, vrange="indep1", zoom=1, show_residuals=True, cmap=None, plot_complex="rectangular", batch_idx=0, channel_idx=0, **kwargs, ): r"""Display steerable pyramid coefficients in orderly fashion. This function uses ``imshow`` to show the coefficients of the steeable pyramid, such that each scale shows up on a single row, with each scale in a given column. Note that unlike imshow, we can only show one batch or channel at a time Parameters ---------- pyr_coeffs : `dict` pyramid coefficients in the standard dictionary format as returned by ``SteerablePyramidFreq.forward()`` vrange : `tuple` or `str` If a 2-tuple, specifies the image values vmin/vmax that are mapped to the minimum and maximum value of the colormap, respectively. If a string: * `'auto0'`: all images have same vmin/vmax, which have the same absolute value, and come from the minimum or maximum across all images, whichever has the larger absolute value * `'auto/auto1'`: all images have same vmin/vmax, which are the minimum/maximum values across all images * `'auto2'`: all images have same vmin/vmax, which are the mean (across all images) minus/ plus 2 std dev (across all images) * `'auto3'`: all images have same vmin/vmax, chosen so as to map the 10th/90th percentile values to the 10th/90th percentile of the display intensity range. For example: vmin is the 10th percentile image value minus 1/8 times the difference between the 90th and 10th percentile * `'indep0'`: each image has an independent vmin/vmax, which have the same absolute value, which comes from either their minimum or maximum value, whichever has the larger absolute value. * `'indep1'`: each image has an independent vmin/vmax, which are their minimum/maximum values * `'indep2'`: each image has an independent vmin/vmax, which is their mean minus/plus 2 std dev * `'indep3'`: each image has an independent vmin/vmax, chosen so that the 10th/90th percentile values map to the 10th/90th percentile intensities. zoom : `float` ratio of display pixels to image pixels. if >1, must be an integer. If <1, must be 1/d where d is a a divisor of the size of the largest image. show_residuals : `bool` whether to display the residual bands (lowpass, highpass depending on the pyramid type) cmap : matplotlib colormap, optional colormap to use when showing these images plot_complex : {'rectangular', 'polar', 'logpolar'} specifies handling of complex values. * `'rectangular'`: plot real and imaginary components as separate images * `'polar'`: plot amplitude and phase as separate images * `'logpolar'`: plot log_2 amplitude and phase as separate images for any other value, we raise a warning and default to rectangular. batch_idx : int, optional Which element from the batch dimension to plot. channel_idx : int, optional Which element from the channel dimension to plot. kwargs : Passed on to ``pyrtools.pyrshow`` Returns ------- fig: `PyrFigure` the figure displaying the coefficients. """ pyr_coeffvis = {} is_complex = False for k, v in pyr_coeffs.items(): im = to_numpy(v) if np.iscomplex(im).any(): is_complex = True # this removes only the first (batch) dimension im = im[batch_idx : batch_idx + 1].squeeze(0) # this removes only the first (now channel) dimension im = im[channel_idx : channel_idx + 1].squeeze(0) # because of how we've handled everything above, we know that im will # be (h,w). pyr_coeffvis[k] = im return pt.pyrshow( pyr_coeffvis, is_complex=is_complex, vrange=vrange, zoom=zoom, cmap=cmap, plot_complex=plot_complex, show_residuals=show_residuals, **kwargs, )
[docs] def clean_up_axes( ax, ylim=None, spines_to_remove=["top", "right", "bottom"], axes_to_remove=["x"], ): r"""Clean up an axis, as desired when making a stem plot of the representation Parameters ---------- ax : `matplotlib.pyplot.axis` The axis to clean up. ylim : `tuple`, False, or None If a tuple, the y-limits to use for this plot. If None, we use the default, slightly adjusted so that the minimum is 0. If False, we do nothing. spines_to_remove : `list` Some combination of 'top', 'right', 'bottom', and 'left'. The spines we remove from the axis. axes_to_remove : `list` Some combination of 'x', 'y'. The axes to set as invisible. Returns ------- ax : matplotlib.pyplot.axis The cleaned-up axis """ if spines_to_remove is None: spines_to_remove = ["top", "right", "bottom"] if axes_to_remove is None: axes_to_remove = ["x"] if ylim is not None: if ylim: ax.set_ylim(ylim) else: ax.set_ylim((0, ax.get_ylim()[1])) if "x" in axes_to_remove: ax.xaxis.set_visible(False) if "y" in axes_to_remove: ax.yaxis.set_visible(False) for s in spines_to_remove: ax.spines[s].set_visible(False) return ax
[docs] def update_stem(stem_container, ydata): r"""Update the information in a stem plot We update the information in a single stem plot to match that given by ``ydata``. We update the position of the markers and and the lines connecting them to the baseline, but we don't change the baseline at all and assume that the xdata shouldn't change at all. Parameters ---------- stem_container : `matplotlib.container.StemContainer` Single container for the artists created in a ``plt.stem`` plot. It can be treated like a namedtuple ``(markerline, stemlines, baseline)``. In order to get this from an axis ``ax``, try ``ax.containers[0]`` (obviously if you have more than one container in that axis, it may not be the first one). ydata : array_like The new y-data to show on the plot. Importantly, must be the same length as the existing y-data. Returns ------- stem_container : `matplotlib.container.StemContainer` The StemContainer containing the updated artists. """ stem_container.markerline.set_ydata(ydata) segments = stem_container.stemlines.get_segments().copy() for s, y in zip(segments, ydata): try: s[1, 1] = y except IndexError: # this happens when our segment array is 1x2 instead of 2x2, # which is the case when the data there is nan continue stem_container.stemlines.set_segments(segments) return stem_container
[docs] def rescale_ylim(axes, data): r"""rescale y-limits nicely We take the axes and set their limits to be ``(-y_max, y_max)``, where ``y_max=np.abs(data).max()`` Parameters ---------- axes : `list` A list of matplotlib axes to rescale data : array_like or dict The data to use when rescaling (or a dictiontary of those values) """ data = data.cpu() def find_ymax(data): try: return np.abs(data).max() except RuntimeError: # then we need to call to_numpy on it because it needs to be # detached and converted to an array return np.abs(to_numpy(data)).max() try: y_max = find_ymax(data) except TypeError: # then this is a dictionary y_max = np.max([find_ymax(d) for d in data.values()]) for ax in axes: ax.set_ylim((-y_max, y_max))
[docs] def clean_stem_plot(data, ax=None, title="", ylim=None, xvals=None, **kwargs): r"""convenience wrapper for plotting stem plots This plots the data, baseline, cleans up the axis, and sets the title Should not be called by users directly, but is a helper function for the various plot_representation() functions By default, stem plot would have a baseline that covers the entire range of the data. We want to be able to break that up visually (so there's a line from 0 to 9, from 10 to 19, etc), and passing xvals separately allows us to do that. If you want the default stem plot behavior, leave xvals as None. Parameters ---------- data : `np.ndarray` The data to plot (as a stem plot) ax : `matplotlib.pyplot.axis` or `None`, optional The axis to plot the data on. If None, we plot on the current axis title : str or None, optional The title to put on the axis if not None. If None, we don't call ``ax.set_title`` (useful if you want to avoid changing the title on an existing plot) ylim : tuple or None, optional If not None, the y-limits to use for this plot. If None, we use the default, slightly adjusted so that the minimum is 0. If False, do not change y-limits. xvals : `tuple` or `None`, optional A 2-tuple of lists, containing the start (``xvals[0]``) and stop (``xvals[1]``) x values for plotting. If None, we use the default stem plot behavior. kwargs : passed to ax.stem Returns ------- ax : `matplotlib.pyplot.axis` The axis with the plot Examples -------- We allow for breaks in the baseline value if we want to visually break up the plot, as we see below. ..plot:: :include-source: import plenoptic as po import numpy as np import matplotlib.pyplot as plt # if ylim=None, as in this example, the minimum y-valuewill get # set to 0, so we want to make sure our values are all positive y = np.abs(np.random.randn(55)) y[15:20] = np.nan y[35:40] = np.nan # we want to draw the baseline from 0 to 14, 20 to 34, and 40 to # 54, everywhere that we have non-NaN values for y xvals = ([0, 20, 40], [14, 34, 54]) po.tools.display.clean_stem_plot(y, xvals=xvals) plt.show() If we don't care about breaking up the x-axis, you can simply use the default xvals (``None``). In this case, this function will just clean up the plot a little bit ..plot:: :include-source: import plenoptic as po import numpy as np import matplotlib.pyplot as plt # if ylim=None, as in this example, the minimum y-valuewill get # set to 0, so we want to make sure our values are all positive y = np.abs(np.random.randn(55)) po.tools.display.clean_stem_plot(y) plt.show() """ if ax is None: ax = plt.gca() if xvals is not None: basefmt = " " ax.hlines(len(xvals[0]) * [0], xvals[0], xvals[1], colors="C3", zorder=10) else: # this is the default basefmt value basefmt = None ax.stem(data, basefmt=basefmt, **kwargs) ax = clean_up_axes(ax, ylim, ["top", "right", "bottom"]) if title is not None: ax.set_title(title) return ax
def _get_artists_from_axes(axes, data): """Grab artists from axes. For now, we only grab containers (stem plots), images, or lines See the docstring of :meth:`update_plot()` for details on how `axes` and `data` should be structured Parameters ---------- axes : list or matplotlib.axes.Axes The axis/axes to update. data : torch.Tensor or dict The new data to plot. Returns ------- artists : dict dictionary of artists for updating plots. values are the artists to use, keys are the corresponding keys for data """ if not hasattr(axes, "__iter__"): # then we only have one axis, so we may be able to update more than one # data element. if len(axes.containers) > 0: data_check = 1 artists = axes.containers elif len(axes.images) > 0: # images are weird, so don't check them like this data_check = None artists = axes.images elif len(axes.lines) > 0: data_check = 1 artists = axes.lines elif len(axes.collections) > 0: data_check = 2 artists = axes.collections if isinstance(data, dict): artists = {ax.get_label(): ax for ax in artists} else: if data_check == 1 and data.shape[1] != len(artists): raise Exception( f"data has {data.shape[1]} things to plot, but " f"your axis contains {len(artists)} plotting artists, " "so unsure how to continue! Pass data as a dictionary" " with keys corresponding to the labels of the artists" " to update to resolve this." ) elif data_check == 2 and data.ndim > 2 and data.shape[-3] != len(artists): raise Exception( f"data has {data.shape[-3]} things to plot, but " f"your axis contains {len(artists)} plotting artists, " "so unsure how to continue! Pass data as a dictionary" " with keys corresponding to the labels of the artists" " to update to resolve this." ) else: # then we have multiple axes, so we are only updating one data element # per plot artists = [] for ax in axes: if len(ax.containers) == 1: data_check = 1 artists.extend(ax.containers) elif len(ax.images) == 1: # images are weird, so don't check them like this data_check = None artists.extend(ax.images) elif len(ax.lines) == 1: artists.extend(ax.lines) data_check = 1 elif len(ax.collections) == 1: artists.extend(ax.collections) data_check = 2 if isinstance(data, dict): if len(data.keys()) != len(artists): raise Exception( f"data has {len(data.keys())} things to plot, but " f"you passed {len(axes)} axes , so unsure how " "to continue!" ) artists = {k: a for k, a in zip(data.keys(), artists)} else: if data_check == 1 and data.shape[1] != len(artists): raise Exception( f"data has {data.shape[1]} things to plot, but " f"you passed {len(axes)} axes , so unsure how " "to continue!" ) if data_check == 2 and data.ndim > 2 and data.shape[-3] != len(artists): raise Exception( f"data has {data.shape[-3]} things to plot, but " f"you passed {len(axes)} axes , so unsure how " "to continue!" ) if not isinstance(artists, dict): artists = {f"{i:02d}": a for i, a in enumerate(artists)} return artists
[docs] def update_plot(axes, data, model=None, batch_idx=0): r"""Update the information in some axes. This is used for creating an animation over time. In order to create the animation, we need to know how to update the matplotlib Artists, and this provides a simple way of doing that. It assumes the plot has been created by something like ``plot_representation``, which initializes all the artists. We can update stem plots, lines (as returned by ``plt.plot``), scatter plots, or images (RGB, RGBA, or grayscale). There are two modes for this: - single axis: axes is a single axis, which may contain multiple artists (all of the same type) to update. data should be a Tensor with multiple channels (one per artist in the same order) or be a dictionary whose keys give the label(s) of the corresponding artist(s) and whose values are Tensors. - multiple axes: axes is a list of axes, each of which contains a single artist to update (artists can be different types). data should be a Tensor with multiple channels (one per axis in the same order) or a dictionary with the same number of keys as axes, which we can iterate through in order, and whose values are Tensors. In all cases, data Tensors should be 3d (if the plot we're updating is a line or stem plot) or 4d (if it's an image or scatter plot). RGB(A) images are special, since we store that info along the channel dimension, so they only work with single-axis mode (which will only have a single artist, because that's how imshow works). If you have multiple axes, each with multiple artists you want to update, that's too complicated for us, and so you should write a ``model.update_plot()`` function which handles that. If ``model`` is set, we try to call ``model.update_plot()`` (which must also return artists). If model doesn't have an ``update_plot`` method, then we try to figure out how to update the axes ourselves, based on the shape of the data. Parameters ---------- axes : `list` or `matplotlib.pyplot.axis` The axis or list of axes to update. We assume that these are the axes created by ``plot_representation`` and so contain stem plots in the correct order. data : `torch.Tensor` or `dict` The new data to plot. model : `torch.nn.Module` or `None`, optional A differentiable model that tells us how to plot ``data``. See above for behavior if ``None``. batch_idx : int, optional Which index to take from the batch dimension Returns ------- artists : `list` A list of the artists used to update the information on the plots """ if isinstance(data, dict): for v in data.values(): if v.ndim not in [3, 4]: raise ValueError( "update_plot expects 3 or 4 dimensional data" "; unexpected behavior will result otherwise!" f" Got data of shape {v.shape}" ) else: if data.ndim not in [3, 4]: raise ValueError( "update_plot expects 3 or 4 dimensional data" "; unexpected behavior will result otherwise!" f" Got data of shape {data.shape}" ) try: artists = model.update_plot(axes=axes, batch_idx=batch_idx, data=data) except AttributeError: ax_artists = _get_artists_from_axes(axes, data) artists = [] if not isinstance(data, dict): data_dict = {} # check for RGBA images if len(ax_artists) == 1 and data.shape[1] > 1: # can't index into dict.values(), so use this work around # instead, as suggested # https://stackoverflow.com/questions/43629270/how-to-get-single-value-from-dict-with-single-entry try: if next(iter(ax_artists.values())).get_array().data.ndim > 1: # then this is an RGBA image data_dict = {"00": data} except Exception as e: raise Exception( "Thought this was an RGB(A) image based on the number" " of artists and data shape, but something is off!" f" Original exception: {e}" ) else: for i, d in enumerate(data.unbind(1)): # need to keep the shape the same because of how we # check for shape below (unbinding removes a dimension, # so we add it back) data_dict[f"{i:02d}"] = d.unsqueeze(1) data = data_dict for k, d in data.items(): try: art = ax_artists[k] except KeyError: # If the we're grabbing these labels from the line labels and # they were originally ints, they will get converted to # strings. this catches that art = ax_artists[str(k)] d = to_numpy(d[batch_idx]).squeeze() if d.ndim == 1: try: # then it's a line x, _ = art.get_data() art.set_data(x, d) artists.append(art) except AttributeError: # then it's a stemplot sc = update_stem(art, d) artists.extend([sc.markerline, sc.stemlines]) elif d.ndim == 2: try: # then it's a grayscale image art.set_data(d) artists.append(art) except AttributeError: # then it's a scatterplot art.set_offsets(d) artists.append(art) else: # then it's an RGB(A) image. for tensors, we put that dimension # in channel, but for images, it should be at the end art.set_data(np.moveaxis(d, 0, -1)) artists.append(art) # make sure to always return a list if not isinstance(artists, list): artists = [artists] return artists
[docs] def plot_representation( model=None, data=None, ax=None, figsize=(5, 5), ylim=False, batch_idx=0, title="", as_rgb=False, ): r"""Helper function for plotting model representation We are trying to plot ``data`` on ``ax``, using ``model.plot_representation`` method, if it has it, and otherwise default to a function that makes sense based on the shape of ``data``. All of these arguments are optional, but at least some of them need to be set: - If ``model`` is ``None``, we fall-back to a type of plot based on the shape of ``data``. If it looks image-like, we'll use ``plenoptic.imshow`` and if it looks vector-like, we'll use ``plenoptic.clean_stem_plot``. If it's a dictionary, we'll assume each key, value pair gives the title and data to plot on a separate sub-plot. - If ``data`` is ``None``, we can only do something if ``model.plot_representation`` has some default behavior when ``data=None``; this is probably to plot its own ``representation`` attribute. Thus, this will raise an Exception if both ``model`` and ``data`` are ``None``, because we have no idea what to plot then. - If ``ax`` is ``None``, we create a one-subplot figure using ``figsize``. If ``ax`` is not ``None``, we therefore ignore ``figsize``. - If ``ylim`` is ``None``, we call ``rescale_ylim``, which sets the axes' y-limits to be ``(-y_max, y_max)``, where ``y_max=np.abs(data).max()``. If it's ``False``, we do nothing. Parameters ---------- model : `torch.nn.Module` or None, optional A differentiable model that tells us how to plot ``data``. See above for behavior if ``None``. data : `array_like`, `dict`, or `None`, optional The data to plot. See above for behavior if ``None``. ax : matplotlib.pyplot.axis or None, optional The axis to plot on. See above for behavior if ``None``. figsize : `tuple`, optional The size of the figure to create. Ignored if ``ax`` is not ``None``. ylim : `tuple`, `None`, or `False`, optional If not None, the y-limits to use for this plot. See above for behavior if ``None``. If False, we do nothing. batch_idx : `int`, optional Which index to take from the batch dimension title : `str`, optional The title to put above this axis. If you want no title, pass the empty string (``''``) as_rgb : bool, optional The representation can be image-like with multiple channels, and we have no way to determine whether it should be represented as an RGB image or not, so the user must set this flag to tell us. It will be ignored if the representation doesn't look image-like or if the model has its own plot_representation_error() method. Else, it will be passed to `po.imshow()`, see that methods docstring for details. Returns ------- axes : list List of created axes. """ if ax is None: fig, ax = plt.subplots(1, 1, figsize=figsize) else: warnings.warn("ax is not None, so we're ignoring figsize...") fig = ax.figure try: # no point in passing figsize, because we've already created # and are passing an axis or are passing the user-specified one fig, axes = model.plot_representation( ylim=ylim, ax=ax, title=title, batch_idx=batch_idx, data=data ) except AttributeError: if data is None: data = model.representation if not isinstance(data, dict): if title is None: title = "Representation" data_dict = {} if not as_rgb: # then we peel apart the channels for i, d in enumerate(data.unbind(1)): # need to keep the shape the same because of how we # check for shape below (unbinding removes a dimension, # so we add it back) data_dict[title + f"_{i:02d}"] = d.unsqueeze(1) else: data_dict[title] = data data = data_dict else: warnings.warn("data has keys, so we're ignoring title!") # want to make sure the axis we're taking over is basically invisible. ax = clean_up_axes(ax, False, ["top", "right", "bottom", "left"], ["x", "y"]) axes = [] if len(list(data.values())[0].shape) == 3: # then this is 'vector-like' gs = ax.get_subplotspec().subgridspec( min(4, len(data)), int(np.ceil(len(data) / 4)) ) for i, (k, v) in enumerate(data.items()): ax = fig.add_subplot(gs[i % 4, i // 4]) # only plot the specified batch, but plot each channel # in a separate call. there should probably only be one, # and if there's not you probably want to do things # differently for d in v[batch_idx]: ax = clean_stem_plot(to_numpy(d), ax, k, ylim) axes.append(ax) elif len(list(data.values())[0].shape) == 4: # then this is 'image-like' gs = ax.get_subplotspec().subgridspec( int(np.ceil(len(data) / 4)), min(4, len(data)) ) for i, (k, v) in enumerate(data.items()): ax = fig.add_subplot(gs[i // 4, i % 4]) ax = clean_up_axes( ax, False, ["top", "right", "bottom", "left"], ["x", "y"] ) # only plot the specified batch imshow( v, batch_idx=batch_idx, title=k, ax=ax, vrange="indep0", as_rgb=as_rgb, ) axes.append(ax) # because we're plotting image data, don't want to change # ylim at all ylim = False else: raise Exception(f"Don't know what to do with data of shape {data.shape}") if ylim is None: if isinstance(data, dict): data = torch.cat(list(data.values()), dim=2) rescale_ylim(axes, data) return axes