Display and animate functions
plenoptic
contains a variety of code for visualizing the outputs and the process of synthesis. This notebook details how to make use of that code, which has largely been written with the following goals: 1. If you follow the model API (and that of Synthesis
, if creating a new synthesis method), display code should plot something reasonably useful automatically. 2. The code is flexible enough to allow for customization for more useful visualizations. 3. If the plotting code works, the
animate code should also.
[1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import plenoptic as po
# so that relativfe sizes of axes created by po.imshow and others look right
plt.rcParams["figure.dpi"] = 72
# Animation-related settings
plt.rcParams["animation.html"] = "html5"
# use single-threaded ffmpeg for animation writer
plt.rcParams["animation.writer"] = "ffmpeg"
plt.rcParams["animation.ffmpeg_args"] = ["-threads", "1"]
%load_ext autoreload
%autoreload 2
%matplotlib inline
[2]:
plt.rcParams["figure.dpi"] = 72
General
We include two wrappers of display code from pyrtools
, adapting them for use with tensors. These imshow
and animshow
, which accept tensors of real- or complex-valued images or videos (respectively) and properly convert them to arrays for display purposes. These are not the most flexible functions (for example, imshow
requires that real-valued tensors be 4d) but, assuming you follow our API, should work relatively painlessly. The main reason for using them (over the image-display
code from matplotlib
) is that we guarantee fidelity to image size: a value in the tensor corresponds to a pixel or an integer number of pixels in the image (if upsampling); if downsampling, we can only down-sample by factors of two. This way, you can be sure that any strange apperance of the image is not due to aliasing in the plotting.
For imshow
, we require that real-valued tensors be 4d: (batch, channel, height, width)
. If you’re showing images, they’re likely to be grayscale (in which case there’s only 1 channel) or RGB(A) (in which case there’s 3 or 4, depending on whether it includes the alpha channel). We plot grayscale images without a problem:
[3]:
img = torch.cat([po.data.einstein(), po.data.curie()], axis=0)
print(img.shape)
fig = po.imshow(img)
torch.Size([2, 1, 256, 256])
We need to tell imshow
that the image(s) are RGB in order for it to be plot correctly.
[4]:
rgb = torch.rand(2, 3, 256, 256)
print(rgb.shape)
fig = po.imshow(rgb, as_rgb=True)
torch.Size([2, 3, 256, 256])
This is because we don’t want to assume that a tensor with 3 or 4 channels is always RGB. To pick a somewhat-contrived example, imagine the following steerable pyramid:
[5]:
pyr = po.simul.SteerablePyramidFreq(img.shape[-2:], downsample=False, height=1, order=2)
[6]:
coeffs, _ = pyr.convert_pyr_to_tensor(pyr(img), split_complex=False)
print(coeffs.shape)
torch.Size([2, 5, 256, 256])
The first and last channels are residuals, so if we only wanted to look at the coefficients, we’d do the following:
[7]:
po.imshow(coeffs[:, 1:-1], batch_idx=0)
po.imshow(coeffs[:, 1:-1], batch_idx=1)
We really don’t want to interpret those values as RGB.
Note that in the above imshow
calls, we had to specify the batch_idx
. This function expects a 4d tensor, but if it has more than one channel and more than one batch (and it’s not RGB), we can’t display everything. The user must therefore specify either batch_idx
or channel_idx
.
[8]:
po.imshow(coeffs[:, 1:-1], channel_idx=0)
animshow
works analogously to imshow
, wrapping around the pyrtools
version but expecting a 5d tensor: (batch, channel, time, height, width)
. It returns a matplotlib.animation.FuncAnimation
object, which can be saved as an mp4 or converted to an html object for display in a Jupyter notebook (because of the matplotlib configuration options set in the first cell of this notebook, and others in our documentation that make use of them, this happens automatically).
[9]:
pyr = po.simul.SteerablePyramidFreq(
img.shape[-2:],
downsample=False,
height="auto",
order=3,
is_complex=True,
tight_frame=False,
)
coeffs, _ = pyr.convert_pyr_to_tensor(pyr(img), split_complex=False)
print(coeffs.shape)
# because coeffs is 4d, we add a dummy dimension for the channel in order to make
# animshow happy
po.animshow(coeffs.unsqueeze(1), batch_idx=0, vrange="indep1")
torch.Size([2, 26, 256, 256])
/home/billbrod/micromamba/envs/plenoptic/lib/python3.10/site-packages/pyrtools/tools/display.py:119: UserWarning: Ignoring dpi argument: with PyrFigure, we do not use the dpi argument for saving, use dpi_multiple instead (this is done to prevent aliasing)
warnings.warn("Ignoring dpi argument: with PyrFigure, we do not use the dpi argument"
[9]:
Synthesis-specific
Each synthesis method has a variety of display code to visualize the state and progress of synthesis, as well as to ease understanding of the process and look for ways to improve. For example, in metamer synthesis, it can be useful to determine what component of the model has the largest error.
[10]:
img = po.data.einstein()
model = po.simul.OnOff((7, 7))
rep = model(img)
/home/billbrod/micromamba/envs/plenoptic/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3526.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
As long as your model returns a 3d or 4d vector (first two dimensions corresponding to batch
and channel
), then our plotting code should work automatically. If it returns a 3d representation, we plot a stem plot; if it’s 4d, an image.
[11]:
po.tools.display.plot_representation(data=rep, figsize=(11, 5))
This also gets used in the plotting code built into our synthesis methods.
[12]:
po.tools.remove_grad(model)
met = po.synth.Metamer(img, model)
met.synthesize(
max_iter=100,
store_progress=True,
)
/home/billbrod/Documents/plenoptic/src/plenoptic/tools/validate.py:178: UserWarning: model is in training mode, you probably want to call eval() to switch to evaluation mode
warnings.warn(
/home/billbrod/Documents/plenoptic/src/plenoptic/synthesize/metamer.py:195: UserWarning: Loss has converged, stopping synthesis
warnings.warn("Loss has converged, stopping synthesis")
After we’ve run synthesis for a while, we want to investigate how close we are. We can examine the numbers printed out above, but it’s probably useful to plot something. We provide the plot_synthesis_status()
function for doing this. By default, it includes the synthesized image, the loss, and the representation error. That lost plot is the same as the one above, except it plots data = base_representation - synthesized_representation
.
[13]:
# we have two image plots for representation error, so that bit should be 2x wider
fig = po.synth.metamer.plot_synthesis_status(
met, width_ratios={"plot_representation_error": 2.1}
)
/home/billbrod/Documents/plenoptic/src/plenoptic/tools/display.py:930: UserWarning: ax is not None, so we're ignoring figsize...
warnings.warn("ax is not None, so we're ignoring figsize...")
You can also create this plot at different iterations, in order to try and better understand what’s happening
[14]:
fig = po.synth.metamer.plot_synthesis_status(
met, iteration=10, width_ratios={"plot_representation_error": 2.1}
)
The appearance of this figure is very customizable. There are several additional plots that can be included, and all plots are optional. The additional plot below is two histograms comparing the pixel values of the synthesized and base signal.
[15]:
fig = po.synth.metamer.plot_synthesis_status(
met,
included_plots=[
"display_metamer",
"plot_loss",
"plot_representation_error",
"plot_pixel_values",
],
width_ratios={"plot_representation_error": 2.1},
)
In addition to being able to customize which plots to include, you can also pre-create the figure (with axes, if you’d like) and pass it in. By default, we try and create an appropriate-looking figure, with appropriately-sized plots, but this allows for more flexibility:
[16]:
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
fig = po.synth.metamer.plot_synthesis_status(
met,
included_plots=["display_metamer", "plot_loss", "plot_pixel_values"],
fig=fig,
)
For even more flexibility, you can specify which plot should go in which axes, by creating an axes_idx
dictionary. Keys for each plot can be created, as well as a subset (in which case each plot gets added to the next available axes, like above when axes_idx
is unset; see docstring for key names):
[17]:
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
axes_idx = {"display_metamer": 3, "plot_pixel_values": 0}
fig = po.synth.metamer.plot_synthesis_status(
met,
included_plots=["display_metamer", "plot_loss", "plot_pixel_values"],
fig=fig,
axes_idx=axes_idx,
)
This allows enables you to create more complicated figures, with axes containing other plots, arrows and other annotations, etc.
[18]:
fig, axes = plt.subplots(2, 3, figsize=(17, 12))
# to tell plot_synthesis_status to ignore plots, add them to the misc keys
axes_idx = {"display_metamer": 5, "misc": [0, 4]}
axes[0, 0].text(0.5, 0.5, "SUPER COOL TEXT", color="r")
axes[1, 0].arrow(
0,
0,
0.25,
0.25,
)
axes[0, 0].plot(np.linspace(0, 1), np.random.rand(50))
fig = po.synth.metamer.plot_synthesis_status(
met,
included_plots=["display_metamer", "plot_loss", "plot_pixel_values"],
fig=fig,
axes_idx=axes_idx,
)
We similarly have an animate
function, which animates the above plots over time, and everything that I said above also holds for them. Note that animate
will take a fair amount of time to run and requires ffmpeg on your system for most file formats (see matplotlib docs for more details).
[19]:
fig, axes = plt.subplots(2, 3, figsize=(17, 12))
# to tell plot_synthesis_status to ignore plots, add them to the misc keys
axes_idx = {"display_metamer": 5, "misc": [0, 4]}
axes[0, 0].text(0.5, 0.5, "SUPER COOL TEXT", color="r")
axes[1, 0].arrow(
0,
0,
0.25,
0.25,
)
axes[0, 0].plot(np.linspace(0, 1), np.random.rand(50))
po.synth.metamer.animate(
met,
included_plots=["display_metamer", "plot_loss", "plot_pixel_values"],
fig=fig,
axes_idx=axes_idx,
)
[19]:
More complicated model representation plots
While this provides a starting point, it’s not always super useful. In the example above, the LinearNonlinear
model returns the output of several convolutional kernels across the image, and so plotting as a series of images is pretty decent. The representation of the PortillaSimoncelli
model below, however, has several distinct components at multiple spatial scales and orientations. That structure is lost in a single stem plot:
[20]:
img = po.data.reptile_skin()
ps = po.simul.PortillaSimoncelli(img.shape[-2:])
rep = ps(img)
po.tools.display.plot_representation(data=rep);
Trying to guess this advanced structure would be impossible for our generic plotting functions. However, if your model has a plot_representation()
method, we can make use of it:
[21]:
ps.plot_representation(data=rep, ylim=False);
Our display.plot_representation
function can make use of this method if you pass it the model; note how the plot below is identical to the one above. This might not seem very useful, but we make use of this in the different plotting methods used by our synthesis classes explained above.
[22]:
po.tools.display.plot_representation(ps, rep, figsize=(15, 15));
[23]:
met = po.synth.MetamerCTF(
img, ps, loss_function=po.tools.optim.l2_norm, coarse_to_fine="together"
)
met.synthesize(
max_iter=400,
store_progress=10,
change_scale_criterion=None,
ctf_iters_to_check=10,
);
/home/billbrod/Documents/plenoptic/src/plenoptic/tools/validate.py:178: UserWarning: model is in training mode, you probably want to call eval() to switch to evaluation mode
warnings.warn(
/home/billbrod/Documents/plenoptic/src/plenoptic/tools/validate.py:211: UserWarning: Validating whether model can work with coarse-to-fine synthesis -- this can take a while!
warnings.warn("Validating whether model can work with coarse-to-fine synthesis -- this can take a while!")
[24]:
fig, _ = po.synth.metamer.plot_synthesis_status(met)
/home/billbrod/Documents/plenoptic/src/plenoptic/tools/display.py:931: UserWarning: ax is not None, so we're ignoring figsize...
fig = ax.figure
And again, we can animate this over time:
[25]:
po.synth.metamer.animate(met)
[25]: