Display and animate functions

plenoptic contains a variety of code for visualizing the outputs and the process of synthesis. This notebook details how to make use of that code, which has largely been written with the following goals: 1. If you follow the model API (and that of Synthesis, if creating a new synthesis method), display code should plot something reasonably useful automatically. 2. The code is flexible enough to allow for customization for more useful visualizations. 3. If the plotting code works, the animate code should also.

[1]:
import matplotlib.pyplot as plt
import numpy as np
import torch

import plenoptic as po

# so that relativfe sizes of axes created by po.imshow and others look right
plt.rcParams["figure.dpi"] = 72
# Animation-related settings
plt.rcParams["animation.html"] = "html5"
# use single-threaded ffmpeg for animation writer
plt.rcParams["animation.writer"] = "ffmpeg"
plt.rcParams["animation.ffmpeg_args"] = ["-threads", "1"]


%load_ext autoreload
%autoreload 2
%matplotlib inline
[2]:
plt.rcParams["figure.dpi"] = 72

General

We include two wrappers of display code from pyrtools, adapting them for use with tensors. These imshow and animshow, which accept tensors of real- or complex-valued images or videos (respectively) and properly convert them to arrays for display purposes. These are not the most flexible functions (for example, imshow requires that real-valued tensors be 4d) but, assuming you follow our API, should work relatively painlessly. The main reason for using them (over the image-display code from matplotlib) is that we guarantee fidelity to image size: a value in the tensor corresponds to a pixel or an integer number of pixels in the image (if upsampling); if downsampling, we can only down-sample by factors of two. This way, you can be sure that any strange apperance of the image is not due to aliasing in the plotting.

For imshow, we require that real-valued tensors be 4d: (batch, channel, height, width). If you’re showing images, they’re likely to be grayscale (in which case there’s only 1 channel) or RGB(A) (in which case there’s 3 or 4, depending on whether it includes the alpha channel). We plot grayscale images without a problem:

[3]:
img = torch.cat([po.data.einstein(), po.data.curie()], axis=0)
print(img.shape)
fig = po.imshow(img)
torch.Size([2, 1, 256, 256])
../../_images/tutorials_advanced_Display_4_1.png

We need to tell imshow that the image(s) are RGB in order for it to be plot correctly.

[4]:
rgb = torch.rand(2, 3, 256, 256)
print(rgb.shape)
fig = po.imshow(rgb, as_rgb=True)
torch.Size([2, 3, 256, 256])
../../_images/tutorials_advanced_Display_6_1.png

This is because we don’t want to assume that a tensor with 3 or 4 channels is always RGB. To pick a somewhat-contrived example, imagine the following steerable pyramid:

[5]:
pyr = po.simul.SteerablePyramidFreq(img.shape[-2:], downsample=False, height=1, order=2)
[6]:
coeffs, _ = pyr.convert_pyr_to_tensor(pyr(img), split_complex=False)

print(coeffs.shape)
torch.Size([2, 5, 256, 256])

The first and last channels are residuals, so if we only wanted to look at the coefficients, we’d do the following:

[7]:
po.imshow(coeffs[:, 1:-1], batch_idx=0)
po.imshow(coeffs[:, 1:-1], batch_idx=1)
../../_images/tutorials_advanced_Display_11_0.png
../../_images/tutorials_advanced_Display_11_1.png

We really don’t want to interpret those values as RGB.

Note that in the above imshow calls, we had to specify the batch_idx. This function expects a 4d tensor, but if it has more than one channel and more than one batch (and it’s not RGB), we can’t display everything. The user must therefore specify either batch_idx or channel_idx.

[8]:
po.imshow(coeffs[:, 1:-1], channel_idx=0)
../../_images/tutorials_advanced_Display_13_0.png

animshow works analogously to imshow, wrapping around the pyrtools version but expecting a 5d tensor: (batch, channel, time, height, width). It returns a matplotlib.animation.FuncAnimation object, which can be saved as an mp4 or converted to an html object for display in a Jupyter notebook (because of the matplotlib configuration options set in the first cell of this notebook, and others in our documentation that make use of them, this happens automatically).

[9]:
pyr = po.simul.SteerablePyramidFreq(
    img.shape[-2:],
    downsample=False,
    height="auto",
    order=3,
    is_complex=True,
    tight_frame=False,
)
coeffs, _ = pyr.convert_pyr_to_tensor(pyr(img), split_complex=False)
print(coeffs.shape)
# because coeffs is 4d, we add a dummy dimension for the channel in order to make
# animshow happy
po.animshow(coeffs.unsqueeze(1), batch_idx=0, vrange="indep1")
torch.Size([2, 26, 256, 256])
/home/billbrod/micromamba/envs/plenoptic/lib/python3.10/site-packages/pyrtools/tools/display.py:119: UserWarning: Ignoring dpi argument: with PyrFigure, we do not use the dpi argument for saving, use dpi_multiple instead (this is done to prevent aliasing)
  warnings.warn("Ignoring dpi argument: with PyrFigure, we do not use the dpi argument"
[9]:

Synthesis-specific

Each synthesis method has a variety of display code to visualize the state and progress of synthesis, as well as to ease understanding of the process and look for ways to improve. For example, in metamer synthesis, it can be useful to determine what component of the model has the largest error.

[10]:
img = po.data.einstein()
model = po.simul.OnOff((7, 7))
rep = model(img)
/home/billbrod/micromamba/envs/plenoptic/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3526.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]

As long as your model returns a 3d or 4d vector (first two dimensions corresponding to batch and channel), then our plotting code should work automatically. If it returns a 3d representation, we plot a stem plot; if it’s 4d, an image.

[11]:
po.tools.display.plot_representation(data=rep, figsize=(11, 5))
../../_images/tutorials_advanced_Display_19_0.png

This also gets used in the plotting code built into our synthesis methods.

[12]:
po.tools.remove_grad(model)
met = po.synth.Metamer(img, model)
met.synthesize(
    max_iter=100,
    store_progress=True,
)
/home/billbrod/Documents/plenoptic/src/plenoptic/tools/validate.py:178: UserWarning: model is in training mode, you probably want to call eval() to switch to evaluation mode
  warnings.warn(
/home/billbrod/Documents/plenoptic/src/plenoptic/synthesize/metamer.py:195: UserWarning: Loss has converged, stopping synthesis
  warnings.warn("Loss has converged, stopping synthesis")

After we’ve run synthesis for a while, we want to investigate how close we are. We can examine the numbers printed out above, but it’s probably useful to plot something. We provide the plot_synthesis_status() function for doing this. By default, it includes the synthesized image, the loss, and the representation error. That lost plot is the same as the one above, except it plots data = base_representation - synthesized_representation.

[13]:
# we have two image plots for representation error, so that bit should be 2x wider
fig = po.synth.metamer.plot_synthesis_status(
    met, width_ratios={"plot_representation_error": 2.1}
)
/home/billbrod/Documents/plenoptic/src/plenoptic/tools/display.py:930: UserWarning: ax is not None, so we're ignoring figsize...
  warnings.warn("ax is not None, so we're ignoring figsize...")
../../_images/tutorials_advanced_Display_23_1.png

You can also create this plot at different iterations, in order to try and better understand what’s happening

[14]:
fig = po.synth.metamer.plot_synthesis_status(
    met, iteration=10, width_ratios={"plot_representation_error": 2.1}
)
../../_images/tutorials_advanced_Display_25_0.png

The appearance of this figure is very customizable. There are several additional plots that can be included, and all plots are optional. The additional plot below is two histograms comparing the pixel values of the synthesized and base signal.

[15]:
fig = po.synth.metamer.plot_synthesis_status(
    met,
    included_plots=[
        "display_metamer",
        "plot_loss",
        "plot_representation_error",
        "plot_pixel_values",
    ],
    width_ratios={"plot_representation_error": 2.1},
)
../../_images/tutorials_advanced_Display_27_0.png

In addition to being able to customize which plots to include, you can also pre-create the figure (with axes, if you’d like) and pass it in. By default, we try and create an appropriate-looking figure, with appropriately-sized plots, but this allows for more flexibility:

[16]:
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
fig = po.synth.metamer.plot_synthesis_status(
    met,
    included_plots=["display_metamer", "plot_loss", "plot_pixel_values"],
    fig=fig,
)
../../_images/tutorials_advanced_Display_29_0.png

For even more flexibility, you can specify which plot should go in which axes, by creating an axes_idx dictionary. Keys for each plot can be created, as well as a subset (in which case each plot gets added to the next available axes, like above when axes_idx is unset; see docstring for key names):

[17]:
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
axes_idx = {"display_metamer": 3, "plot_pixel_values": 0}
fig = po.synth.metamer.plot_synthesis_status(
    met,
    included_plots=["display_metamer", "plot_loss", "plot_pixel_values"],
    fig=fig,
    axes_idx=axes_idx,
)
../../_images/tutorials_advanced_Display_31_0.png

This allows enables you to create more complicated figures, with axes containing other plots, arrows and other annotations, etc.

[18]:
fig, axes = plt.subplots(2, 3, figsize=(17, 12))
# to tell plot_synthesis_status to ignore plots, add them to the misc keys
axes_idx = {"display_metamer": 5, "misc": [0, 4]}
axes[0, 0].text(0.5, 0.5, "SUPER COOL TEXT", color="r")
axes[1, 0].arrow(
    0,
    0,
    0.25,
    0.25,
)
axes[0, 0].plot(np.linspace(0, 1), np.random.rand(50))
fig = po.synth.metamer.plot_synthesis_status(
    met,
    included_plots=["display_metamer", "plot_loss", "plot_pixel_values"],
    fig=fig,
    axes_idx=axes_idx,
)
../../_images/tutorials_advanced_Display_33_0.png

We similarly have an animate function, which animates the above plots over time, and everything that I said above also holds for them. Note that animate will take a fair amount of time to run and requires ffmpeg on your system for most file formats (see matplotlib docs for more details).

[19]:
fig, axes = plt.subplots(2, 3, figsize=(17, 12))
# to tell plot_synthesis_status to ignore plots, add them to the misc keys
axes_idx = {"display_metamer": 5, "misc": [0, 4]}
axes[0, 0].text(0.5, 0.5, "SUPER COOL TEXT", color="r")
axes[1, 0].arrow(
    0,
    0,
    0.25,
    0.25,
)
axes[0, 0].plot(np.linspace(0, 1), np.random.rand(50))
po.synth.metamer.animate(
    met,
    included_plots=["display_metamer", "plot_loss", "plot_pixel_values"],
    fig=fig,
    axes_idx=axes_idx,
)
[19]:

More complicated model representation plots

While this provides a starting point, it’s not always super useful. In the example above, the LinearNonlinear model returns the output of several convolutional kernels across the image, and so plotting as a series of images is pretty decent. The representation of the PortillaSimoncelli model below, however, has several distinct components at multiple spatial scales and orientations. That structure is lost in a single stem plot:

[20]:
img = po.data.reptile_skin()
ps = po.simul.PortillaSimoncelli(img.shape[-2:])
rep = ps(img)
po.tools.display.plot_representation(data=rep);
../../_images/tutorials_advanced_Display_37_0.png

Trying to guess this advanced structure would be impossible for our generic plotting functions. However, if your model has a plot_representation() method, we can make use of it:

[21]:
ps.plot_representation(data=rep, ylim=False);
../../_images/tutorials_advanced_Display_39_0.png

Our display.plot_representation function can make use of this method if you pass it the model; note how the plot below is identical to the one above. This might not seem very useful, but we make use of this in the different plotting methods used by our synthesis classes explained above.

[22]:
po.tools.display.plot_representation(ps, rep, figsize=(15, 15));
../../_images/tutorials_advanced_Display_41_0.png
[23]:
met = po.synth.MetamerCTF(
    img, ps, loss_function=po.tools.optim.l2_norm, coarse_to_fine="together"
)
met.synthesize(
    max_iter=400,
    store_progress=10,
    change_scale_criterion=None,
    ctf_iters_to_check=10,
);
/home/billbrod/Documents/plenoptic/src/plenoptic/tools/validate.py:178: UserWarning: model is in training mode, you probably want to call eval() to switch to evaluation mode
  warnings.warn(
/home/billbrod/Documents/plenoptic/src/plenoptic/tools/validate.py:211: UserWarning: Validating whether model can work with coarse-to-fine synthesis -- this can take a while!
  warnings.warn("Validating whether model can work with coarse-to-fine synthesis -- this can take a while!")
[24]:
fig, _ = po.synth.metamer.plot_synthesis_status(met)
/home/billbrod/Documents/plenoptic/src/plenoptic/tools/display.py:931: UserWarning: ax is not None, so we're ignoring figsize...
  fig = ax.figure
../../_images/tutorials_advanced_Display_43_1.png

And again, we can animate this over time:

[25]:
po.synth.metamer.animate(met)
[25]: