"""Plots for understanding synthesis objects.""" # numpydoc ignore=EX01
import re
import warnings
from collections.abc import Callable
from typing import Any, Literal
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch
from .. import tensors
from .._synthesize import Eigendistortion, MADCompetition, Metamer
from . import display
from .metamer import _representation_error, metamer_representation_error
__all__ = [
"synthesis_loss",
"synthesis_imshow",
"synthesis_animate",
"synthesis_histogram",
"synthesis_status",
]
def __dir__() -> list[str]:
return __all__
[docs]
def synthesis_loss(
synthesis_object: Metamer | MADCompetition,
iteration: int | None = None,
plot_penalties: bool = False,
ax: list[mpl.axes.Axes] | mpl.axes.Axes | None = None,
**kwargs: Any,
) -> dict[str, mpl.axes.Axes]:
"""
Plot synthesis loss.
.. versionadded:: 2.0
Combines previously separate loss plotting functions for
:class:`~plenoptic.Metamer` and :class:`~plenoptic.MADCompetition`, and adds
support for plotting penalties. Note that behavior for
:class:`~plenoptic.Metamer` is different: we now plot the metamer loss, not the
objective function value (see below for details).
The behavior of this function is slightly different depending on the type of
``synthesis_object``:
- :class:`~plenoptic.Metamer`: creates a single axis object whose y-axis is
log-scaled and shows the metamer loss and, if ``plot_penalties=True``,
:attr:`~plenoptic.Metamer.penalties`. Returned dictionary has key ``"loss"``.
- :class:`~plenoptic.MADCompetition`: creates multiple axes objects, one each for
:attr:`~plenoptic.MADCompetition.reference_metric_loss`,
:attr:`~plenoptic.MADCompetition.optimized_metric_loss`, and (if
``plot_penalties=True``) :attr:`~plenoptic.MADCompetition.penalties`. The y-axis
is linearly-scaled for all plots. Returned dictionary has keys
``"reference_metric_loss"``, ``"optimized_metric_loss"``, and ``"penalties"``.
In all cases, plots a red dot at ``iteration``, to highlight the loss there. If
``iteration=None``, then the dot will be at the final iteration.
.. attention::
In all cases, we plot the components of the objective function, not the
objective function itself (whose values are stored in the
:attr:`plenoptic.Metamer.losses` or :attr:`plenoptic.MADCompetition.losses`
attribute). See Examples section and :func:`plenoptic.Metamer.objective_function`
or :func:`plenoptic.MADCompetition.objective_function` for more details.
Parameters
----------
synthesis_object
Synthesis object whose loss we want to plot.
iteration
Which iteration to display. If ``None``, we show the most recent one.
Negative values are also allowed.
plot_penalties
Whether to plot the output of the penalty function as well. See above
for behavior.
ax
Pre-existing axes for plot. If ``None``, we call
:func:`matplotlib.pyplot.gca()`. If ``synthesis_object`` is
:class:`~plenoptic.MADCompetition`, then if ``ax`` is a single axis, we split it
horizontally; if ``ax`` is a list, it must contain two (or three, if
``plot_penalties=True``) axes to plot on. If ``synthesis_object`` is
:class:`~plenoptic.Metamer`, then passing a list will result in a
``ValueError``.
**kwargs
Passed to :func:`matplotlib.pyplot.plot`.
Returns
-------
axes_dict
A dictionary whose keys are strings describing the created plots and whose
values are the corresponding matplotlib axes. See above for details.
Raises
------
IndexError
If ``iteration`` takes an illegal value.
ValueError
If ``ax`` is a list and ``synthesis_object`` is a :class:`~plenoptic.Metamer`.
ValueError
If ``synthesis_object`` is a :class:`~plenoptic.MADCompetition` and ``ax`` is a
list of the wrong length.
TypeError
If ``synthesis_object`` is not :class:`~plenoptic.MADCompetition` or
:class:`~plenoptic.Metamer`
See Also
--------
synthesis_status
Create a figure combining this with other axis-level plots to summarize
synthesis status at a given iteration.
synthesis_animate
Create a video animating this and other axis-level plots changing over
the course of synthesis.
Examples
--------
Plot loss for :class:`~plenoptic.Metamer` object:
.. plot::
:context: reset
>>> import plenoptic as po
>>> import matplotlib.pyplot as plt
>>> import torch
>>> img = po.data.einstein()
>>> model = po.models.Gaussian(30).eval()
>>> po.remove_grad(model)
>>> met = po.Metamer(img, model)
>>> met.to(torch.float64)
>>> met.load(po.data.fetch_data("example_metamer_gaussian.pt"))
>>> po.plot.synthesis_loss(met)
{'loss': <Axes: ... ylabel='Loss'>}
Include the penalties:
.. plot::
:context: close-figs
>>> po.plot.synthesis_loss(met, plot_penalties=True)
{'loss': <Axes: ... ylabel='Loss'>}
Specify an iteration:
.. plot::
:context: close-figs
>>> po.plot.synthesis_loss(met, iteration=10, plot_penalties=True)
{'loss': <Axes: ... ylabel='Loss'>}
Plot on an axis in an existing figure:
.. plot::
:context: close-figs
>>> fig, axes = plt.subplots(1, 2)
>>> po.plot.synthesis_loss(met, ax=axes[1], plot_penalties=True)
{'loss': <Axes: ... ylabel='Loss'>}
Note that we are not plotting the output of
:func:`plenoptic.Metamer.objective_function`, which is stored in
:attr:`plenoptic.Metamer.losses`. Instead, we are plotting the output of
:attr:`plenoptic.Metamer.loss_function`, which is the "metamer loss" (which does not
include the penalty). The following example illustrates the difference:
.. plot::
:context: close-figs
>>> axes = po.plot.synthesis_loss(met)
>>> axes["loss"].plot(met.losses, label="objective function")
[<matplotlib.lines.Line2D ...>]
>>> # Some tweaks to the marker and size to aid visibility.
>>> axes["loss"].plot(
... met.losses - met.penalty_lambda * met.penalties,
... "k.",
... ms=2,
... label="reconstructed metamer loss",
... )
[<matplotlib.lines.Line2D ...>]
>>> axes["loss"].legend()
<matplotlib.legend.Legend ...>
Notice how the objective function line is above the one created by the this
function, and how we compute the metamer loss alone.
Plot loss for :class:`~plenoptic.MADCompetition` object:
.. plot::
:context: close-figs
>>> img = po.data.einstein().to(torch.float64)
>>> def ds_ssim(x, y):
... return 1 - po.metric.ssim(x, y, weighted=True, pad="reflect")
>>> mad = po.MADCompetition(img, ds_ssim, po.metric.mse, "max", 1e6)
>>> mad.load(po.data.fetch_data("example_mad.pt"))
>>> po.plot.synthesis_loss(mad)
{'reference_metric_loss': <Axes: ...>, 'optimized_metric_loss': <Axes: ...>}
When plotting :class:`~plenoptic.MADCompetition` loss on an existing figure, you can
either pass a single axis, in which case we sub-divide it into the necessary number
of axes, or a list with the appropriate number of axes:
.. plot::
:context: close-figs
>>> fig, axes = plt.subplots(1, 2)
>>> po.plot.synthesis_loss(mad, ax=axes[1])
{'reference_metric_loss': <Axes: ...>, 'optimized_metric_loss': <Axes: ...>}
>>> fig, axes = plt.subplots(1, 2)
>>> po.plot.synthesis_loss(mad, ax=axes)
{'reference_metric_loss': <Axes: ...>, 'optimized_metric_loss': <Axes: ...>}
Note that, as with :class:`~plenoptic.Metamer`, we are not plotting the output of
:func:`plenoptic.MADCompetition.objective_function`, which is stored in
:attr:`plenoptic.MADCompetition.losses`. Instead, we are plotting the output of the
two metrics we are comparing. If you wish to plot the objective function output, you
can do so directly:
.. plot::
:context: close-figs
>>> plt.plot(mad.losses)
[<matplotlib.lines.Line2D ...>]
"""
if not isinstance(synthesis_object, (Metamer, MADCompetition)):
raise TypeError(
"synthesis_object must be a MADCompetition or Metamer object but got"
f" {type(synthesis_object)}"
)
# this warning is not relevant for this plotting function
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="loss iteration and iteration for")
progress = synthesis_object.get_progress(iteration)
if ax is None:
ax = plt.gca()
if isinstance(synthesis_object, Metamer):
if hasattr(ax, "__iter__"):
raise ValueError("if synthesis_object is a Metamer, ax cannot be a list!")
met_loss = (
synthesis_object.losses
- synthesis_object.penalty_lambda * synthesis_object.penalties
)
ax.plot(met_loss, label="metamer loss", **kwargs)
ax.scatter(
progress["iteration"],
progress["losses"]
- synthesis_object.penalty_lambda * progress["penalties"],
c="r",
)
ax.set(xlabel="Synthesis iteration", ylabel="Loss", yscale="log")
if plot_penalties:
ax.plot(synthesis_object.penalties, label="penalty", **kwargs)
ax.legend()
axes_dict = {"loss": ax}
elif isinstance(synthesis_object, MADCompetition):
right_length = 3 if plot_penalties else 2
if not hasattr(ax, "__iter__"):
ax = display._clean_up_axes(
ax, False, ["top", "right", "bottom", "left"], ["x", "y"]
)
gs = ax.get_subplotspec().subgridspec(1, right_length)
fig = ax.figure
ax = [fig.add_subplot(gs[0, i]) for i in range(right_length)]
else:
if len(ax) != right_length:
raise ValueError(
f"ax is a list of the wrong length! Must contain {right_length}"
" axes."
)
losses = [
synthesis_object.reference_metric_loss,
synthesis_object.optimized_metric_loss,
]
names = ["reference_metric_loss", "optimized_metric_loss"]
axes_dict = {}
if plot_penalties:
losses.append(synthesis_object.penalties)
names.append("penalties")
for ax_, loss, name in zip(ax, losses, names):
ax_.plot(loss, **kwargs)
ax_.scatter(progress["iteration"], progress[name], c="r")
ax_.set(
xlabel="Synthesis iteration", ylabel=name.capitalize().replace("_", " ")
)
axes_dict[name] = ax_
return axes_dict
def _get_synthesis_image(
synthesis_object: Metamer | MADCompetition | Eigendistortion,
batch_idx: int | None = None,
iteration: int | None = None,
return_ref_image: bool = False,
) -> tuple[list[torch.Tensor], list[int]]:
"""
Grab images from synthesis objects to plot.
This function:
- Grabs the synthesized image tensor.
- Grabs the correct iteration, if possible, raising an error if ``iteration`` is set
for an Eigendistortion or when ``store_progress=False``.
- If ``batch_idx is None``, unpack all batches into list of 4d tensors. If not, and
``synthesis_object`` is an Eigendistortion, convert from eigenindex values to
actual indices (see :func:`plenoptic.Eigendistortion._indexer`).
- If ``return_ref_image``, return the reference image as well.
Parameters
----------
synthesis_object
Synthesis object with the images we want to plot.
batch_idx
Which index to take from the batch dimension. Note that for
:class:`~plenoptic.Eigendistortion`, this is the
:attr:`~plenoptic.Eigendistortion.eigenindex`. If ``None``, we grab all
batches.
iteration
Which iteration to display, for :class:`~plenoptic.Metamer` and
:class:`~plenoptic.MADCompetition` objects. If ``None``, we show the most recent
one. Negative values are also allowed. If ``iteration!=None`` and
``synthesis_object.store_progress>1`` (that is, the synthesized image was not
cached on every iteration), then we use the cached image from the nearest
iteration. For :class:`~plenoptic.Eigendistortion`, this must be ``None``.
return_ref_image
Whether to include the reference image (``synthesis_object.image``).
Returns
-------
images
The corresponding images. Either a single 4d image tensor or a list of such
tensors.
batch_idx
Corresponding ``batch_idx``. If input ``batch_idx`` was ``None``, these are the
explicit indices. If ``synthesis_object`` was a
:class:`~plenoptic.Eigendistortion`, these have been remapped so they're now
indices.
Raises
------
IndexError
If ``iteration`` takes an illegal value.
ValueError
If ``iteration`` is not ``None`` and ``synthesis_object`` is an
:class:`~plenoptic.Eigendistortion` object.
Warns
-----
UserWarning
If the iteration used for cached image is not the same as the argument
``iteration`` (because e.g., you set ``iteration=3`` but
``synthesis_object.store_progress=2``).
"""
if isinstance(synthesis_object, Eigendistortion):
if iteration is not None:
raise ValueError(
"When synthesis_object is an Eigendistortion, iteration must be None!"
)
image = synthesis_object.eigendistortions
if batch_idx is not None:
try:
batch_idx = [synthesis_object._indexer(i) for i in batch_idx]
except TypeError:
# we're here because we can't iterate over batch_idx (can't just check
# attributes because 0d tensors have both __iter__ and __len__)
batch_idx = synthesis_object._indexer(batch_idx)
else:
progress = synthesis_object.get_progress(iteration)
if isinstance(synthesis_object, Metamer):
name = "metamer"
elif isinstance(synthesis_object, MADCompetition):
name = "mad_image"
try:
image = progress[f"saved_{name}"]
except KeyError:
if iteration is not None:
raise IndexError(
"When synthesis_object.store_progress=False, iteration must be"
" None!"
)
image = eval(f"synthesis_object.{name}")
if batch_idx is None:
image = [im.unsqueeze(0) for im in image]
if return_ref_image:
if isinstance(image, list):
image.append(synthesis_object.image)
else:
image = [image, synthesis_object.image]
return image, batch_idx
def _get_synthesis_title(
synthesis_object: Metamer | MADCompetition | Eigendistortion,
batch_idx: int | None = None,
iteration: int | None = None,
return_ref_image: bool = False,
) -> list[str]:
"""
Grab titles for synthesis images to plot.
This should be run before :func:`_get_synthesis_image`, as its input ``batch_idx``
should be the unremapped one -- we want it to match the user input.
Parameters
----------
synthesis_object
Synthesis object with the images we want to plot.
batch_idx
Which index to take from the batch dimension. Note that for
:class:`~plenoptic.Eigendistortion`, this is the
:attr:`~plenoptic.Eigendistortion.eigenindex`. If ``None``, we grab all
batches.
iteration
Which iteration to display, for :class:`~plenoptic.Metamer` and
:class:`~plenoptic.MADCompetition` objects. If ``None``, we show the most recent
one. Negative values are also allowed. If ``iteration!=None`` and
``synthesis_object.store_progress>1`` (that is, the synthesized image was not
cached on every iteration), then we use the cached image from the nearest
iteration. For :class:`~plenoptic.Eigendistortion`, this must be ``None``.
return_ref_image
Whether to include the reference image (``synthesis_object.image``).
Returns
-------
titles
Corresponding titles. These include the ``batch_idx`` and, if relevant
``iteration`` in them.
Raises
------
IndexError
If ``iteration`` takes an illegal value.
ValueError
If ``iteration`` is not ``None`` and ``synthesis_object`` is an
:class:`~plenoptic.Eigendistortion` object.
Warns
-----
UserWarning
If the iteration used for cached image is not the same as the argument
``iteration`` (because e.g., you set ``iteration=3`` but
``synthesis_object.store_progress=2``).
"""
if isinstance(synthesis_object, Eigendistortion):
if iteration is not None:
raise ValueError(
"When synthesis_object is an Eigendistortion, iteration must be None!"
)
title_names = ["Eigendistortion[{batch_idx}]", "Reference"]
if batch_idx is None:
batch_idx = synthesis_object.eigenindex
else:
progress = synthesis_object.get_progress(iteration)
if isinstance(synthesis_object, Metamer):
title_names = ["Metamer[{batch_idx}] [iteration={iter}]", "Target"]
max_batch = synthesis_object.metamer.shape[0]
elif isinstance(synthesis_object, MADCompetition):
title_names = ["MAD[{batch_idx}] [iteration={iter}]", "Reference"]
max_batch = synthesis_object.mad_image.shape[0]
try:
iteration = progress["store_progress_iteration"]
except KeyError:
if iteration is not None:
raise IndexError(
"When synthesis_object.store_progress=False, iteration must be"
" None!"
)
# losses will always have one extra value, the current loss.
iteration = len(synthesis_object.losses) - 1
if batch_idx is None:
batch_idx = range(max_batch)
try:
titles = [title_names[0].format(batch_idx=i, iter=iteration) for i in batch_idx]
except TypeError:
# we're here because we can't iterate over batch_idx (can't just check
# attributes because 0d tensors have both __iter__ and __len__)
titles = [title_names[0].format(batch_idx=batch_idx, iter=iteration)]
if return_ref_image:
titles += [f"{title_names[1]} image"]
return titles
[docs]
def synthesis_histogram(
synthesis_object: Metamer | MADCompetition | Eigendistortion,
batch_idx: int | None = None,
channel_idx: int | None = None,
iteration: int | None = None,
ylim: tuple[float, float] | Literal[False] = False,
xlim: tuple[float, float] | Literal[False, "range"] = "range",
ax: mpl.axes.Axes | None = None,
alpha: float = 0.4,
**kwargs: Any,
) -> mpl.axes.Axes:
"""
Plot histogram of values of synthesis objects.
.. versionadded:: 2.0
Combines previously separate loss plotting functions for
:class:`~plenoptic.Metamer` and :class:`~plenoptic.MADCompetition`, and adds
support for :class:`~plenoptic.Eigendistortion`.
As a way to check whether there's any values outside the preferred range.
The behavior of this function is slightly different depending on the type of
``synthesis_object``:
- :class:`~plenoptic.Metamer` and :class:`~plenoptic.MADCompetition`: compare the
synthesized tensor against the target / reference image. ``iteration`` can be
specified.
- :class:`~plenoptic.Eigendistortion`: create histograms for eigendistortions.
``iteration`` must be ``None``.
Parameters
----------
synthesis_object
Synthesis object with the images whose values we want to plot.
batch_idx
Which index to take from the batch dimension. Note that for
:class:`~plenoptic.Eigendistortion`, this is the
:attr:`~plenoptic.Eigendistortion.eigenindex`. If ``None``, we plot all
batches as separate histograms (intended use-case is for multiple
eigendistortions).
channel_idx
Which index to take from the channel dimension. If ``None``, we use all
channels (assumed use-case is RGB(A) images).
iteration
Which iteration to display, for :class:`~plenoptic.Metamer` and
:class:`~plenoptic.MADCompetition` objects. If ``None``, we show the most recent
one. Negative values are also allowed. If ``iteration!=None`` and
``synthesis_object.store_progress>1`` (that is, the synthesized image was not
cached on every iteration), then we use the cached image from the nearest
iteration. For :class:`~plenoptic.Eigendistortion`, this must be ``None``.
ylim
If tuple, the ylimit to set for this axis. If ``False``, we leave
it untouched.
xlim
If ``"range"``, set the xlimits to the range across plotted data.
If tuple, the xlimit to set for this axis. If ``False``, we leave
it untouched.
ax
Pre-existing axes for plot. If ``None``, we call
:func:`matplotlib.pyplot.gca()`.
alpha
Alpha value for the histogram bars.
**kwargs
Passed to :func:`matplotlib.pyplot.hist`.
Returns
-------
ax :
Creates axes.
Raises
------
IndexError
If ``iteration`` takes an illegal value.
ValueError
If ``iteration`` is not ``None`` and ``synthesis_object`` is an
:class:`~plenoptic.Eigendistortion` object.
Warns
-----
UserWarning
If the iteration used for cached image is not the same as the argument
``iteration`` (because e.g., you set ``iteration=3`` but
``synthesis_object.store_progress=2``).
See Also
--------
:func:`~plenoptic.plot.histogram`
The plotting function used to created this plot.
synthesis_status
Create a figure combining this with other axis-level plots to summarize
synthesis status at a given iteration.
synthesis_animate
Create a video animating this and other axis-level plots changing over
the course of synthesis.
Examples
--------
Plot histogram for :class:`~plenoptic.Metamer` object:
.. plot::
:context: reset
>>> import plenoptic as po
>>> import torch
>>> img = po.data.einstein()
>>> model = po.models.Gaussian(30).eval()
>>> po.remove_grad(model)
>>> met = po.Metamer(img, model)
>>> met.to(torch.float64)
>>> met.load(po.data.fetch_data("example_metamer_gaussian.pt"))
>>> po.plot.synthesis_histogram(met)
<Axes: ... 'Histogram of tensor values'...>
Plot pixel values from a specified iteration (requires setting ``store_progress``
when :meth:`~plenoptic.Metamer.synthesize` was called):
.. plot::
:context: close-figs
>>> po.plot.synthesis_histogram(met, iteration=10)
<Axes: ... 'Histogram of tensor values'...>
Plot on an existing axis:
.. plot::
:context: close-figs
>>> fig, axes = plt.subplots(1, 2, figsize=(8, 4))
>>> po.plot.synthesis_histogram(met, ax=axes[1])
<Axes: ... 'Histogram of tensor values'...>
Plot histogram for :class:`~plenoptic.MADCompetition` object:
.. plot::
:context: close-figs
>>> img = po.data.einstein().to(torch.float64)
>>> def ds_ssim(x, y):
... return 1 - po.metric.ssim(x, y, weighted=True, pad="reflect")
>>> mad = po.MADCompetition(img, ds_ssim, po.metric.mse, "max", 1e6)
>>> mad.load(po.data.fetch_data("example_mad.pt"))
>>> po.plot.synthesis_histogram(mad)
<Axes: ... 'Histogram of tensor values'...>
Plot histogram for :class:`~plenoptic.Eigendistortion` object. Notice how
here we plot just the values from the synthesized eigendistortions, not the base
image.
.. plot::
:context: close-figs
>>> img = po.data.einstein().to(torch.float64)
>>> lg = po.models.LuminanceGainControl(
... (31, 31), pad_mode="circular", pretrained=True, cache_filt=True
... ).eval()
>>> lg = lg.to(torch.float64)
>>> po.remove_grad(lg)
>>> eig = po.Eigendistortion(img, lg)
>>> eig.load(
... po.data.fetch_data("example_eigendistortion.pt"),
... map_location="cpu",
... )
>>> po.plot.synthesis_histogram(eig)
<Axes: ... 'Histogram of tensor values'...>
"""
# For eigendistortion, we don't plot histogram against the reference image
return_ref_image = not isinstance(synthesis_object, Eigendistortion)
titles = _get_synthesis_title(
synthesis_object, batch_idx, iteration, return_ref_image
)
images, batch_idx = _get_synthesis_image(
synthesis_object, batch_idx, iteration, return_ref_image
)
return display.histogram(
images,
titles,
batch_idx,
channel_idx,
ylim,
xlim,
ax=ax,
alpha=alpha,
**kwargs,
)
[docs]
def synthesis_imshow(
synthesis_object: Metamer | MADCompetition | Eigendistortion,
batch_idx: int = 0,
channel_idx: int | None = None,
distortion_scale: float = 5.0,
process_image: Callable[[torch.Tensor], torch.Tensor] | None = None,
zoom: float | None = None,
iteration: int | None = None,
ax: mpl.axes.Axes | None = None,
title: str | None = None,
**kwargs: Any,
) -> mpl.axes.Axes:
"""
Display image of synthesis object.
.. versionadded:: 2.0
Combines previously separate loss plotting functions for
:class:`~plenoptic.Metamer`, :class:`~plenoptic.MADCompetition`, and
:class:`~plenoptic.Eigendistortion`.
We use :func:`~plenoptic.plot.imshow` to display the synthesized image and
attempt to automatically find the most reasonable zoom value. You can override this
value using the zoom arg, but remember that :func:`~plenoptic.plot.imshow`
is opinionated about the size of the resulting image and will throw an Exception if
the axis created is not big enough for the selected zoom.
The behavior of this function is slightly different depending on the type of
``synthesis_object``:
- :class:`~plenoptic.Metamer` and :class:`~plenoptic.MADCompetition`: process and
display the synthesized image. ``iteration`` can be specified,
``distortion_scale`` must be unchanged.
- :class:`~plenoptic.Eigendistortion`: process and display
``image + (distortion_scale * eigendistortion)``. ``iteration`` must be ``None``,
``distortion_scale`` can be set.
Parameters
----------
synthesis_object
Synthesis object with the images we wish to display.
batch_idx
Which index to take from the batch dimension. Note that for
:class:`~plenoptic.Eigendistortion`, this is the
:attr:`~plenoptic.Eigendistortion.eigenindex`.
channel_idx
Which index to take from the channel dimension. If ``None``, plot all channels;
if image has more than 1 channel, will attempt to plot as RGB(A) image.
distortion_scale
Amount by which to scale eigendistortion for
``image + (distortion_scale * eigendistortion)`` for display. If
``synthesis_object`` is not :class:`~plenoptic.Eigendistortion`, must not be
set.
process_image
A function to process the plotted image. E.g., multiplying by the stdev ImageNet
then adding the mean of ImageNet to undo image preprocessing or clamping between
0 and 1. If ``None``, then no processing is performed.
zoom
How much to zoom in / enlarge the synthesized image, the ratio of display pixels
to image pixels. If ``None``, we attempt to find the best value ourselves.
iteration
Which iteration to display, for :class:`~plenoptic.Metamer` and
:class:`~plenoptic.MADCompetition` objects. If ``None``, we show the most recent
one. Negative values are also allowed. If ``iteration!=None`` and
``synthesis_object.store_progress>1`` (that is, the synthesized image was not
cached on every iteration), then we use the cached image from the nearest
iteration.
ax
Pre-existing axes for plot. If ``None``, we call :func:`matplotlib.pyplot.gca`.
title
Title to add to axis. If ``None``, we pick appropriate title based on the type
of ``synthesis_object``.
**kwargs
Passed to :func:`~plenoptic.plot.imshow`.
Returns
-------
ax :
The matplotlib axes containing the plot.
Raises
------
ValueError
If ``batch_idx`` is not an int.
IndexError
If ``iteration`` takes an illegal value.
Warns
-----
UserWarning
If the iteration used for cached image is not the same as the argument
``iteration`` (because e.g., you set ``iteration=3`` but
``synthesis_object.store_progress=2``).
See Also
--------
:func:`~plenoptic.plot.imshow`
Function used by this one to visualize the metamer image.
synthesis_status
Create a figure combining this with other axis-level plots to summarize
synthesis status at a given iteration.
synthesis_animate
Create a video animating this and other axis-level plots changing over
the course of synthesis.
:func:`~plenoptic.plot.mad_imshow_all`
Display all MAD Competition images from a complete set of four
:class:`~plenoptic.MADCompetition` instances.
:func:`~plenoptic.plot.eigendistortion_imshow_all`
Display base image, eigendistortions alone, and eigendistortions added to image
together in a single figure.
Examples
--------
Plot for :class:`~plenoptic.Metamer` object:
.. plot::
:context: reset
>>> import plenoptic as po
>>> import matplotlib.pyplot as plt
>>> import torch
>>> img = po.data.einstein()
>>> model = po.models.Gaussian(30).eval()
>>> po.remove_grad(model)
>>> met = po.Metamer(img, model)
>>> met.to(torch.float64)
>>> met.load(po.data.fetch_data("example_metamer_gaussian.pt"))
>>> po.plot.synthesis_imshow(met)
<Axes: title=...Metamer[0] [iteration=107]...>
If a matplotlib figure exists, this function will use it (using
:func:`matplotlib.pyplot.gca`):
.. plot::
:context: close-figs
>>> fig, axes = plt.subplots(1, 2)
>>> po.plot.synthesis_imshow(met)
<Axes: title=...Metamer[0] [iteration=107]...>
Display metamer from a specified iteration (requires setting ``store_progress``
when :meth:`~plenoptic.Metamer.synthesize` was called):
.. plot::
:context: close-figs
>>> po.plot.synthesis_imshow(met, iteration=10)
<Axes: title=...Metamer[0] [iteration=10]...>
Explicitly define the axis to use:
.. plot::
:context: close-figs
>>> fig, axes = plt.subplots(1, 2, figsize=(8, 4))
>>> po.plot.synthesis_imshow(met, ax=axes[1])
<Axes: title=...Metamer[0] [iteration=107]...>
When plotting on an existing axis, if ``zoom=None``, this function will determine
the best zoom level for the axis size.
.. plot::
:context: close-figs
>>> fig, axes = plt.subplots(1, 1, figsize=(8, 8))
>>> po.plot.synthesis_imshow(met, ax=axes)
<Axes: title=...Metamer[0] [iteration=107]...dims: [256, 256] * 2.0'}>
Plot for :class:`~plenoptic.MADCompetition` object:
.. plot::
:context: close-figs
>>> img = po.data.einstein().to(torch.float64)
>>> def ds_ssim(x, y):
... return 1 - po.metric.ssim(x, y, weighted=True, pad="reflect")
>>> mad = po.MADCompetition(img, ds_ssim, po.metric.mse, "max", 1e6)
>>> mad.load(po.data.fetch_data("example_mad.pt"))
>>> po.plot.synthesis_imshow(mad)
<Axes: title=...MAD[0] [iteration=200]...>
Plot for :class:`~plenoptic.Eigendistortion` object. Note here that we plot
the distortion multiplied by ``distortion_scale`` and added to the target image.
.. plot::
:context: close-figs
>>> img = po.data.einstein().to(torch.float64)
>>> lg = po.models.LuminanceGainControl(
... (31, 31), pad_mode="circular", pretrained=True, cache_filt=True
... ).eval()
>>> lg = lg.to(torch.float64)
>>> po.remove_grad(lg)
>>> eig = po.Eigendistortion(img, lg)
>>> eig.load(
... po.data.fetch_data("example_eigendistortion.pt"),
... map_location="cpu",
... )
>>> po.plot.synthesis_imshow(eig)
<Axes: title=...5.0 * Eigendistortion[0]...range: [-1.4e-01, 1.0e+00]...>
Use the ``process_image`` argument to apply a preprocessing function to the
image before plotting it:
.. plot::
:context: close-figs
>>> po.plot.synthesis_imshow(eig, process_image=lambda x: x.clip(0, 1))
<Axes: title=...5.0 * Eigendistortion[0]...range: [0.0e+00, 1.0e+00]...>
See :func:`~plenoptic.plot.eigendistortion_imshow_all` for how to set
``process_image`` to undo ImageNet normalization.
"""
try:
batch_idx = int(batch_idx)
except (TypeError, ValueError):
raise ValueError("batch_idx must be a single integer!")
if title is None:
title = _get_synthesis_title(synthesis_object, batch_idx, iteration)
if isinstance(synthesis_object, Eigendistortion):
title = [f"{distortion_scale} * {t}" for t in title]
image, batch_idx = _get_synthesis_image(synthesis_object, batch_idx, iteration)
if isinstance(synthesis_object, Eigendistortion):
image = synthesis_object.image + distortion_scale * image
else:
# if distortion_scale is not default value
if distortion_scale != 5:
raise ValueError(
f"If synthesis_object is type {type(synthesis_object)}, "
"distortion_scale cannot be set"
)
if process_image is not None:
image = process_image(image)
# we're only plotting one image here, so if the user wants multiple
# channels, they must be RGB
as_rgb = bool(channel_idx is None and image.shape[1] > 1)
if ax is None:
ax = plt.gca()
display.imshow(
image,
ax=ax,
title=title,
zoom=zoom,
batch_idx=batch_idx,
channel_idx=channel_idx,
as_rgb=as_rgb,
**kwargs,
)
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
return ax
def _check_plot_consistency(
included_plots: list[str],
width_ratios: dict[str, float],
axes_idx: dict[str, int],
):
"""
Raise ValueError if width_ratios or axes_idx reference plots not in included_plots.
Because I'm unsure how to behave then.
""" # noqa: DOC501
# numpydoc ignore=PR01
if extra_plots := set(width_ratios) - set(included_plots):
raise ValueError(
"width_ratios contains keys referencing plots not included in "
f"included_plots! {extra_plots}"
)
extra_plots = set(axes_idx) - set(included_plots)
if extra_plots and extra_plots != {"misc"}:
raise ValueError(
"axes_idx contains keys referencing plots not included in included_plots!"
f" {extra_plots}"
)
def _check_included_plots(
to_check: list[str] | dict[str, float],
to_check_name: str,
synthesis_object: Metamer | MADCompetition | Eigendistortion,
):
"""
Check whether the user wanted us to create plots that we can't.
Helper function for :func:`synthesis_status` and :func:`synthesis_animate`.
Raises a ``ValueError`` if ``to_check`` contains any values that are not allowed.
Parameters
----------
to_check
The variable to check. We ensure that it doesn't contain any extra (not
allowed) values. If a list, we check its contents. If a dict, we check
its keys.
to_check_name
Name of the ``to_check`` variable, used in the error message.
synthesis_object
Synthesis object we're producing the figure for, so we know what the allowed
plots are.
Raises
------
ValueError
If ``to_check`` takes an illegal value.
""" # numpydoc ignore=EX01
allowed_vals = [
"synthesis_imshow",
"synthesis_histogram",
"misc",
]
if isinstance(synthesis_object, Metamer):
allowed_vals.extend(["synthesis_loss", "metamer_representation_error"])
elif isinstance(synthesis_object, MADCompetition):
allowed_vals.extend(["synthesis_loss"])
try:
vals = to_check.keys()
except AttributeError:
vals = to_check
not_allowed = [v for v in vals if v not in allowed_vals]
if not_allowed:
raise ValueError(
f"{to_check_name} contained value(s) {not_allowed}! "
f"For {type(synthesis_object)} only {allowed_vals} are permissible!"
)
def _setup_synthesis_fig(
included_plots: list[str],
synthesis_object: Metamer | MADCompetition | Eigendistortion,
fig: mpl.figure.Figure | None = None,
axes_idx: dict[str, int] = {},
figsize: tuple[float, float] | None = None,
width_ratios: dict[str, int] = {},
) -> tuple[mpl.figure.Figure, dict[str, mpl.axes.Axes | list[mpl.axes.Axes]]]:
"""
Set up figure for :func:`synthesis_status`.
Creates figure with enough axes for the all the plots you want. Will
also create index in ``axes_idx`` for them if you haven't done so already.
If ``fig=None``, all axes will be on the same row and have the same width.
If you want them to be on different rows, will need to initialize ``fig``
yourself and pass that in. For changing width, change the corresponding
value in ``width_ratios``, which gives width relative to other axes. So
if you want the axis for the ``synthesis_loss`` plot to be twice as wide
as the others, pass ``width_ratios={"synthesis_loss": 2}``.
.. attention::
This function does not raise errors if ``included_plots``,
``width_ratios``, or ``axes_idx`` contains improper values, it assumes
that validation has already been handled.
Parameters
----------
included_plots
Which plots to include.
synthesis_object
Synthesis object we're producing the figure for, so we know what widths
to use if unset.
fig
The figure to plot on or ``None``. If ``None``, we create a new figure.
axes_idx
Dictionary specifying which axes contains which type of plot, allows for more
fine-grained control of the resulting figure. Possible keys are the possible
values of ``included_plots``, plus ``"misc"``. Values should all be ints. If you
tell this function to create a plot that doesn't have a corresponding key, we
find the lowest int that is not already in the dict, so if you have axes that
you want unchanged, place their idx in ``"misc"``.
figsize
The size of the figure to create. It may take a little bit of
playing around to find a reasonable value. If ``None``, we attempt to
make our best guess, aiming to have relative width=1 correspond to 5.
width_ratios
If ``width_ratios`` is an empty dictionary, plot widths will depend on
``synthesis_object`` class: for :class:`~plenoptic.MADCompetition`,
:func:`synthesis_loss` will have double the width of the rest; for other
classes, all will be the same width. To change that, specify their relative
widths; keys should be strings (possible values same as ``included_plots``)
and values should be floats specifying their relative width.
Returns
-------
fig
The figure to plot on.
axes_dict
Dictionary mapping between plot types and axis objects.
""" # numpydoc ignore=EX01
n_subplots = 0
axes_idx = axes_idx.copy()
# start with the defaults
actual_width_ratios = {
"synthesis_imshow": 1,
"synthesis_histogram": 1,
"metamer_representation_error": 1,
"synthesis_loss": 2 if isinstance(synthesis_object, MADCompetition) else 1,
}
# overwrite with any user-specified values
actual_width_ratios.update(width_ratios)
all_possible_plots = [
"synthesis_imshow",
"synthesis_loss",
"metamer_representation_error",
"synthesis_histogram",
]
# make sure that we skip any axes user told us to.
misc_axes = axes_idx.get("misc", [])
if not hasattr(misc_axes, "__iter__"):
misc_axes = [misc_axes]
n_subplots += len(misc_axes)
figure_width_ratios = [1] * len(misc_axes)
for plot in all_possible_plots:
if plot in included_plots:
n_subplots += 1
figure_width_ratios.append(actual_width_ratios[plot])
if plot not in axes_idx:
axes_idx[plot] = tensors._find_min_int(axes_idx.values())
if fig is None:
figure_width_ratios = np.array(figure_width_ratios)
if figsize is None:
# we want (5, 5) for each subplot, with a bit of room between
# each subplot
figsize = (
(figure_width_ratios * 5).sum() + figure_width_ratios.sum() - 1,
5,
)
figure_width_ratios = figure_width_ratios / figure_width_ratios.sum()
fig, axes = plt.subplots(
1,
n_subplots,
figsize=figsize,
gridspec_kw={"width_ratios": figure_width_ratios},
)
if n_subplots == 1:
axes = [axes]
else:
axes = fig.axes
all_axes = []
# make sure misc contains all the empty axes. this will catch additional axes if
# e.g., the user created a figure with 10 axes and then passed it to this function
for i in axes_idx.values():
# so if it's a list of ints
if hasattr(i, "__iter__"):
all_axes.extend(i)
else:
all_axes.append(i)
misc_axes += [i for i, _ in enumerate(fig.axes) if i not in all_axes]
axes_idx["misc"] = misc_axes
# now remap from idx to axes objects
axes_dict = {}
for k, v in axes_idx.items():
if hasattr(v, "__iter__"):
axes_dict[k] = [axes[v_] for v_ in v]
else:
axes_dict[k] = axes[v]
return fig, axes_dict
def _get_default_included_plots(
synthesis_object: Metamer | MADCompetition | Eigendistortion,
) -> list[str]:
"""
Return value for ``included_plots``, based on ``synthesis_object`` class.
- :class:`~plenoptic.Metamer`: :func:`synthesis_imshow`,
:func:`synthesis_loss`, :func:`~plenoptic.plot.metamer_representation_error`
- :class:`~plenoptic.MADCompetition`: :func:`synthesis_imshow`,
:func:`synthesis_loss`
- :class:`~plenoptic.Eigendistortion`: :func:`synthesis_imshow`,
:func:`synthesis_histogram`
Parameters
----------
synthesis_object
Synthesis object we're producing the figure for.
Returns
-------
included_plots
Included plots.
"""
if isinstance(synthesis_object, Metamer):
return ["synthesis_imshow", "synthesis_loss", "metamer_representation_error"]
if isinstance(synthesis_object, MADCompetition):
return ["synthesis_imshow", "synthesis_loss"]
if isinstance(synthesis_object, Eigendistortion):
return ["synthesis_imshow", "synthesis_histogram"]
def _synthesis_status(
synthesis_object: Metamer | MADCompetition | Eigendistortion,
batch_idx: int = 0,
channel_idx: int | None = None,
iteration: int | None = None,
included_plots: list[str] | None = None,
fig: mpl.figure.Figure | None = None,
axes_idx: dict[str, int | list[int]] = {},
figsize: tuple[float, float] | None = None,
width_ratios: dict[str, float] = {},
**kwargs: dict[str, Any],
) -> tuple[mpl.figure.Figure, dict[str, mpl.axes.Axes | list[mpl.axes.Axes]]]:
r"""
Help create synthesis status figure, returning extra info.
This helper figure is used by :func:`synthesis_status` and :func:`synthesis_animate`
and returns additional shared information they need.
See :func:`synthesis_status` for more complete docstring.
Parameters
----------
synthesis_object
Synthesis object with status to plot.
batch_idx
Which index to take from the batch dimension.
channel_idx
Which index to take from the channel dimension. If ``None``, plot all channels;
if image has more than 1 channel, will attempt to plot as RGB(A) image.
iteration
Which iteration to display, for :class:`~plenoptic.Metamer` and
:class:`~plenoptic.MADCompetition` objects. If ``None``, we show the most recent
one. Negative values are also allowed. If ``iteration!=None`` and
``synthesis_object.store_progress>1`` (that is, the synthesized image was not
cached on every iteration), then we use the cached image from the nearest
iteration.
included_plots
Which plots to include. See above for behavior if ``None``, otherwise must be a
list of strings whose values are names of plotting functions that can accept
``synthesis_object``, see above for list.
fig
If ``None``, we create a new figure. Otherwise we assume this is
a figure that has the appropriate size and number of subplots.
axes_idx
Dictionary specifying which axes contains which type of plot, allows
for more fine-grained control of the resulting figure.
Keys must be strings matching the names of the included plots, see above
for possible values, or ``"misc"``. All axes in ``"misc"`` will be ignored
by this function. If you tell this function to create a plot that doesn't
have a corresponding key, we find the lowest int that is not already in
the dict, so if you have axes that you want unchanged, place their idx
in ``'misc'``.
figsize
The size of the figure to create. It may take a little bit of
playing around to find a reasonable value. If ``None``, we attempt to
make our best guess, aiming to have each axis be of size ``(5, 5)``.
width_ratios
If ``width_ratios`` is an empty dictionary, plot widths will depend on
``synthesis_object`` class: for :class:`~plenoptic.MADCompetition`,
:func:`synthesis_loss` will have double the width of the rest; for other
classes, all will be the same width. To change that, specify their relative
widths; keys should be strings (possible values same as ``included_plots``)
and values should be floats specifying their relative width.
**kwargs
Additional keyword arguments to pass to plotting functions. Keys must be the
of the form ``{plot_func}_kwargs``, where ``{plot_func}`` name of the
plotting function. See Examples for examples. Will raise a ValueError if
there are additional kwargs.
Returns
-------
fig
The figure containing this plot.
axes_dict
Dictionary mapping between plot types and axis objects.
Raises
------
ValueError
If ``kwargs`` contains additional keys.
"""
if included_plots is None:
included_plots = _get_default_included_plots(synthesis_object)
_check_included_plots(included_plots, "included_plots", synthesis_object)
_check_included_plots(width_ratios, "width_ratios", synthesis_object)
_check_included_plots(axes_idx, "axes_idx", synthesis_object)
_check_plot_consistency(included_plots, width_ratios, axes_idx)
fig, axes_dict = _setup_synthesis_fig(
included_plots, synthesis_object, fig, axes_idx, figsize, width_ratios
)
if "synthesis_imshow" in included_plots:
synthesis_imshow(
synthesis_object,
batch_idx=batch_idx,
channel_idx=channel_idx,
iteration=iteration,
ax=axes_dict["synthesis_imshow"],
**kwargs.pop("synthesis_imshow_kwargs", {}),
)
if "synthesis_loss" in included_plots:
loss_axes = synthesis_loss(
synthesis_object,
iteration=iteration,
ax=axes_dict["synthesis_loss"],
**kwargs.pop("synthesis_loss_kwargs", {}),
)
# synthesis_loss may create new axes, so make sure it's up-to-date here
axes_dict["synthesis_loss"] = list(loss_axes.values())
if "metamer_representation_error" in included_plots:
rep_axes = metamer_representation_error(
synthesis_object,
batch_idx=batch_idx,
iteration=iteration,
ax=axes_dict["metamer_representation_error"],
**kwargs.pop("metamer_representation_error_kwargs", {}),
)
# metamer_representation_error may create new axes, so make sure it's
# up-to-date here
axes_dict["metamer_representation_error"] = rep_axes
if "synthesis_histogram" in included_plots:
synthesis_histogram(
synthesis_object,
batch_idx=batch_idx,
channel_idx=channel_idx,
iteration=iteration,
ax=axes_dict["synthesis_histogram"],
**kwargs.pop("synthesis_histogram_kwargs", {}),
)
if kwargs:
raise ValueError(
f"kwargs has additional keys {list(kwargs.keys())}, don't know"
" what to do with them! Did you forget to include a plot?"
)
return fig, axes_dict
[docs]
def synthesis_status(
synthesis_object: Metamer | MADCompetition | Eigendistortion,
batch_idx: int = 0,
channel_idx: int | None = None,
iteration: int | None = None,
included_plots: list[str] | None = None,
fig: mpl.figure.Figure | None = None,
axes_idx: dict[str, int | list[int]] = {},
figsize: tuple[float, float] | None = None,
width_ratios: dict[str, float] = {},
**kwargs: dict[str, Any],
) -> mpl.figure.Figure:
r"""
Make a plot showing synthesis status.
.. versionadded:: 2.0
Combines previously separate loss plotting functions for
:class:`~plenoptic.Metamer`, :class:`~plenoptic.MADCompetition`, and adds support
for :class:`~plenoptic.Eigendistortion`.
We create several subplots to analyze this. The plots to include are
specified by including their name in the ``included_plots`` list. All plots
can be created separately using the method with the individual plot name (see See
Also section below).
This function's behavior when ``included_plots is None``, and allowed values for
that variable, depends upon the type of ``synthesis_object``:
- :class:`~plenoptic.Metamer`: :func:`synthesis_imshow`,
:func:`synthesis_loss`, :func:`~plenoptic.plot.metamer_representation_error`.
Additional allowed values: :func:`synthesis_histogram`.
- :class:`~plenoptic.MADCompetition`: :func:`synthesis_imshow`,
:func:`synthesis_loss`. Additional allowed values: :func:`synthesis_histogram`.
- :class:`~plenoptic.Eigendistortion`: :func:`synthesis_imshow`,
:func:`synthesis_histogram`.
Parameters
----------
synthesis_object
Synthesis object with status to plot.
batch_idx
Which index to take from the batch dimension.
channel_idx
Which index to take from the channel dimension. If ``None``, plot all channels.
iteration
Which iteration to display, for :class:`~plenoptic.Metamer` and
:class:`~plenoptic.MADCompetition` objects. If ``None``, we show the most recent
one. Negative values are also allowed. If ``iteration!=None`` and
``synthesis_object.store_progress>1`` (that is, the synthesized image was not
cached on every iteration), then we use the cached image from the nearest
iteration.
included_plots
Which plots to include. See above for behavior if ``None``, otherwise must be a
list of strings whose values are names of plotting functions that can accept
``synthesis_object``, see above for list.
fig
If ``None``, we create a new figure. Otherwise we assume this is
a figure that has the appropriate size and number of subplots.
axes_idx
Dictionary specifying which axes contains which type of plot, allows
for more fine-grained control of the resulting figure.
Keys must be strings matching the names of the included plots, see above
for possible values, or ``"misc"``. All axes in ``"misc"`` will be ignored
by this function. If you tell this function to create a plot that doesn't
have a corresponding key, we find the lowest int that is not already in
the dict, so if you have axes that you want unchanged, place their idx
in ``'misc'``.
figsize
The size of the figure to create. It may take a little bit of
playing around to find a reasonable value. If ``None``, we attempt to
make our best guess, aiming to have each axis be of size ``(5, 5)``.
width_ratios
If ``width_ratios`` is an empty dictionary, plot widths will depend on
``synthesis_object`` class: for :class:`~plenoptic.MADCompetition`,
:func:`synthesis_loss` will have double the width of the rest; for other
classes, all will be the same width. To change that, specify their relative
widths; keys should be strings (possible values same as ``included_plots``)
and values should be floats specifying their relative width.
**kwargs
Additional keyword arguments to pass to plotting functions. Keys must be the
of the form ``{plot_func}_kwargs``, where ``{plot_func}`` name of the
plotting function. See Examples for examples. Will raise a ValueError if
there are additional kwargs.
Returns
-------
fig
The figure containing this plot.
Raises
------
ValueError
If the ``iteration is not None`` and the given ``synthesis_object`` object is
:class:`~plenoptic.Eigendistortion` or synthesis was run with
``store_progress=False``.
ValueError
If any of ``width_ratios``, ``included_plots``, or ``axes_idx`` reference an
plot that is incompatible with ``synthesis_object``. See list at top of
docstring for compatible plots for each class.
ValueError
If ``kwargs`` contains additional keys.
Warns
-----
UserWarning
If the iteration used for cached image is not the same as the argument
``iteration`` (because e.g., you set ``iteration=3`` but
``synthesis_object.store_progress=2``).
See Also
--------
synthesis_imshow
One of this function's axis-level component functions: display synthesized
image at a given synthesis iteration.
synthesis_loss
One of this function's axis-level component functions: plot synthesis loss
over iterations.
:func:`~plenoptic.plot.metamer_representation_error`
One of this function's axis-level component functions: plot error in model
representation at a given metamer synthesis iteration.
synthesis_histogram
One of this function's axis-level component functions: plot histogram of
values from synthesized object.
synthesis_animate
Create a video that animates this figure over synthesis iteration.
Examples
--------
Plot for a :class:`~plenoptic.Metamer` object:
.. plot::
:context: reset
>>> import plenoptic as po
>>> import torch
>>> img = po.data.einstein()
>>> model = po.models.Gaussian(30).eval()
>>> po.remove_grad(model)
>>> met = po.Metamer(img, model)
>>> met.to(torch.float64)
>>> met.load(po.data.fetch_data("example_metamer_gaussian.pt"))
>>> po.plot.synthesis_status(met)
<Figure size ...>
If model has its own ``plot_representation`` method, this function will use it
for plotting the representation error (see
:func:`~plenoptic.models.PortillaSimoncelli.plot_representation`):
.. plot::
:context: close-figs
>>> img = po.data.reptile_skin()
>>> model = po.models.PortillaSimoncelli(img.shape[-2:])
>>> met = po.MetamerCTF(img, model, po.loss.l2_norm)
>>> met.to(torch.float64)
>>> met.load(po.data.fetch_data("example_metamerCTF_ps.pt"))
>>> po.plot.synthesis_status(met)
<Figure size ...>
Plot a different iteration of synthesis:
.. plot::
:context: close-figs
>>> po.plot.synthesis_status(met, iteration=10)
<Figure size ...>
Change the included plots:
.. plot::
:context: close-figs
>>> included_plots = ["synthesis_loss", "synthesis_histogram"]
>>> po.plot.synthesis_status(met, included_plots=included_plots)
<Figure size ...>
Adjust width of included plots:
.. plot::
:context: close-figs
>>> width_ratios = {"metamer_representation_error": 3}
>>> po.plot.synthesis_status(met, width_ratios=width_ratios)
<Figure size ...>
Change the arrangement of the plots, creating some empty axes:
.. plot::
:context: close-figs
>>> axes_idx = {"misc": [0, 3], "synthesis_loss": 4}
>>> po.plot.synthesis_status(met, axes_idx=axes_idx)
<Figure size ...>
Plot on an existing figure, with already existing plots:
.. plot::
:context: close-figs
>>> fig, axes = plt.subplots(1, 4, figsize=(16, 4))
>>> axes[0].plot(torch.rand(100))
[<matplotlib.lines.Line2D ...>]
>>> # specify misc: 0 so we don't plot on top of this axis.
>>> axes_idx = {"misc": 0}
>>> po.plot.synthesis_status(met, fig=fig, axes_idx=axes_idx)
<Figure size ...>
Note that if you pass a figure, it must already have axes created:
>>> fig = plt.figure()
>>> po.plot.synthesis_status(met, fig=fig)
Traceback (most recent call last):
IndexError: list index out of range
Specify additional keyword arguments to one of the underlying plots:
.. plot::
:context: close-figs
>>> po.plot.synthesis_status(
... met,
... synthesis_loss_kwargs={"plot_penalties": True},
... synthesis_imshow_kwargs={"zoom": 0.5},
... )
<Figure size ...>
Plot for :class:`~plenoptic.MADCompetition` object. Note the plots
are different:
.. plot::
:context: close-figs
>>> img = po.data.einstein().to(torch.float64)
>>> def ds_ssim(x, y):
... return 1 - po.metric.ssim(x, y, weighted=True, pad="reflect")
>>> mad = po.MADCompetition(img, ds_ssim, po.metric.mse, "max", 1e6)
>>> mad.load(po.data.fetch_data("example_mad.pt"))
>>> po.plot.synthesis_status(mad)
<Figure size ...>
Plot for :class:`~plenoptic.Eigendistortion` object. Note the plots
are different:
.. plot::
:context: close-figs
>>> img = po.data.einstein().to(torch.float64)
>>> lg = po.models.LuminanceGainControl(
... (31, 31), pad_mode="circular", pretrained=True, cache_filt=True
... ).eval()
>>> lg = lg.to(torch.float64)
>>> po.remove_grad(lg)
>>> eig = po.Eigendistortion(img, lg)
>>> eig.load(
... po.data.fetch_data("example_eigendistortion.pt"),
... map_location="cpu",
... )
>>> po.plot.synthesis_status(eig)
<Figure size ...>
"""
fig, _ = _synthesis_status(
synthesis_object,
batch_idx,
channel_idx,
iteration,
included_plots,
fig,
axes_idx,
figsize,
width_ratios,
**kwargs,
)
return fig
def _get_rescale_ylim(
metamer: Metamer,
rescale_ylim: str | Literal[False] = "rescale",
) -> int:
"""
Prepare rescale_ylim_interval for :func:`synthesis_animate`.
This only works with a :class:`~plenoptic.Metamer` object, and is intended to be
used with the :func:`~plenoptic.plot.metamer_representation_error` plot. It thus
checks ``metamer.target_representation``.
Parameters
----------
metamer
Metamer object we'll be animating.
rescale_ylim
How to rescale y-limits of plots over time. Currently only applies to
:func:`~plenoptic.plot.metamer_representation_error` plot. Must be one of:
- ``False``: never rescale y-limits.
- the string ``"rescale"``: rescale y-limits 10 times over the course of the
animation.
- a string of the form ``"rescaleN"``: rescale y-limits every N frames.
Returns
-------
rescale_ylim_interval
How often to update the ylim, in frames.
Raises
------
ValueError
If ``synthesis_object`` is a :class:`~plenoptic.Metamer` object whose
:attr:`~plenoptic.Metamer.target_representation` is 4d and ``rescale_ylim`` has
been set -- we do not know how to best rescale color ranges.
"""
# then they've changed rescale_ylim to an illegal value
if metamer.target_representation.ndimension() == 4 and rescale_ylim not in [
"rescale",
False,
]:
raise ValueError(
"Looks like representation is image-like, haven't fully"
" thought out how to best handle rescaling color ranges yet!"
)
try:
if rescale_ylim.startswith("rescale"):
try:
rescale_ylim_interval = int(rescale_ylim.replace("rescale", ""))
except ValueError:
# then there's nothing we can convert to an int there
rescale_ylim_interval = int((metamer.saved_metamer.shape[0] - 1) // 10)
if rescale_ylim_interval == 0:
rescale_ylim_interval = int(metamer.saved_metamer.shape[0] - 1)
else:
raise ValueError(f"Don't know how to handle {rescale_ylim=}!")
except AttributeError:
# check if rescale_ylim is exactly False, not False-y
if rescale_ylim is False:
# this way we'll never rescale
rescale_ylim_interval = len(metamer.saved_metamer) + 1
else:
raise ValueError(f"Don't know how to handle {rescale_ylim=}!")
return rescale_ylim_interval
[docs]
def synthesis_animate(
synthesis_object: Metamer | MADCompetition,
framerate: int = 10,
batch_idx: int = 0,
channel_idx: int | None = None,
included_plots: list[str] | None = None,
fig: mpl.figure.Figure | None = None,
axes_idx: dict[str, int] = {},
figsize: tuple[float, float] | None = None,
width_ratios: dict[str, float] = {},
rescale_ylim: str | Literal[False] = "rescale",
**kwargs: dict[str, Any],
) -> mpl.animation.FuncAnimation:
r"""
Animate synthesis progress.
.. versionadded:: 2.0
Combines previously separate loss plotting functions for
:class:`~plenoptic.Metamer` and :class:`~plenoptic.MADCompetition`.
This animates the figure produced by :func:`synthesis_status` over time, for each
stored iteration. It begins by calling that function to initialize the first frame
of the movie.
This function's behavior when ``included_plots is None``, and allowed values for
that variable, depends upon the type of ``synthesis_object``:
- :class:`~plenoptic.Metamer`: :func:`synthesis_imshow`,
:func:`synthesis_loss`, :func:`~plenoptic.plot.metamer_representation_error`.
Additional allowed values: :func:`synthesis_histogram`.
- :class:`~plenoptic.MADCompetition`: :func:`synthesis_imshow`,
:func:`synthesis_loss`. Additional allowed values: :func:`synthesis_histogram`.
Parameters
----------
synthesis_object
Synthesis object with the images we wish to display.
framerate
How many frames a second to display.
batch_idx
Which index to take from the batch dimension.
channel_idx
Which index to take from the channel dimension. If ``None``, plot all channels.
included_plots
Which plots to include. See above for behavior if ``None``, otherwise must be a
list of strings whose values are names of plotting functions that can accept
``synthesis_object``, see above for list.
fig
If ``None``, we create a new figure. Otherwise we assume this is
a figure that has the appropriate size and number of subplots.
axes_idx
Dictionary specifying which axes contains which type of plot, allows
for more fine-grained control of the resulting figure.
Keys must be strings matching the names of the included plots, see above
for possible values, or ``"misc"``. All axes in ``"misc"`` will be ignored
by this function. If you tell this function to create a plot that doesn't
have a corresponding key, we find the lowest int that is not already in
the dict, so if you have axes that you want unchanged, place their idx
in ``'misc'``.
figsize
The size of the figure to create. It may take a little bit of
playing around to find a reasonable value. If ``None``, we attempt to
make our best guess, aiming to have each axis be of size ``(5, 5)``.
width_ratios
If ``width_ratios`` is an empty dictionary, plot widths will depend on
``synthesis_object`` class: for :class:`~plenoptic.MADCompetition`,
:func:`synthesis_loss` will have double the width of the rest; for other
classes, all will be the same width. To change that, specify their relative
widths; keys should be strings (possible values same as ``included_plots``)
and values should be floats specifying their relative width.
rescale_ylim
How to rescale y-limits of plots over time. Currently only applies to
:func:`~plenoptic.plot.metamer_representation_error` plot. Must be one of:
- ``False``: never rescale y-limits.
- the string ``"rescale"``: rescale y-limits 10 times over the course of the
animation.
- a string of the form ``"rescaleN"``: rescale y-limits every N frames.
**kwargs
Additional keyword arguments to pass to plotting functions. Keys must be the
of the form ``{plot_func}_kwargs``, where ``{plot_func}`` name of the
plotting function. See Examples for examples. Will raise a ValueError if
there are additional kwargs.
Returns
-------
anim
The animation object. In order to view, must convert to HTML or save.
Raises
------
ValueError
If synthesis for this ``synthesis_object`` was run with
``store_progress=False``.
ValueError
If ``rescale_ylim`` takes an illegal value.
ValueError
If ``kwargs`` contains additional keys.
ValueError
If ``synthesis_object`` is a :class:`~plenoptic.Metamer` object whose
:attr:`~plenoptic.Metamer.target_representation` is 4d and ``rescale_ylim`` has
been set -- we do not know how to best rescale color ranges.
ValueError
If any of ``width_ratios``, ``included_plots``, or ``axes_idx`` reference an
plot that is incompatible with ``synthesis_object``. See list at top of
docstring for compatible plots for each class.
TypeError
If ``synthesis_object`` is not :class:`~plenoptic.MADCompetition` or
:class:`~plenoptic.Metamer`
See Also
--------
:func:`~plenoptic.plot.update_plot`
Function used by this one to update ``synthesis_imshow`` and
``metamer_representation_error`` plots.
synthesis_imshow
One of this function's axis-level component functions: display synthesized
image at a given synthesis iteration.
synthesis_loss
One of this function's axis-level component functions: plot synthesis loss
over iterations.
metamer_representation_error
One of this function's axis-level component functions: plot error in model
representation at a given synthesis iteration.
synthesis_histogram
One of this function's axis-level component functions: plot histogram of
values from synthesized object.
synthesis_status
Create a figure that shows a frame from this movie: the synthesis status at
a given iteration.
Notes
-----
- This functions returns a matplotlib FuncAnimation object. See below for how
to view to view it in a Jupyter notebook. See Examples section for how to save to
disk. In either case, this can take a while and you'll need the appropriate writer
installed and on your path, e.g., ffmpeg, imagemagick, etc). See :doc:`matplotlib
documentation <matplotlib:api/animation_api>` for more details.
- Unless specified, we use the ffmpeg backend, which requires that you have
ffmpeg installed and on your path (https://ffmpeg.org/download.html). To use
a different, use the matplotlib rcParams:
``matplotlib.rcParams['animation.writer'] = writer``, see `matplotlib
documentation
<https://matplotlib.org/stable/api/animation_api.html#writer-classes>`_ for
more details.
- To view in a Jupyter notebook, we recommend adding the following to the first cell
of your notebook (requires ffmpeg):
.. code:: python
import matplotlib.pyplot as plt
plt.rcParams["animation.html"] = "html5"
# use single-threaded ffmpeg for animation writer
plt.rcParams["animation.writer"] = "ffmpeg"
plt.rcParams["animation.ffmpeg_args"] = ["-threads", "1"]
Examples
--------
.. plot::
:context: reset
>>> import plenoptic as po
>>> import torch
>>> img = po.data.einstein()
>>> model = po.models.Gaussian(30).eval()
>>> po.remove_grad(model)
>>> met = po.Metamer(img, model)
>>> met.to(torch.float64)
>>> met.load(po.data.fetch_data("example_metamer_gaussian.pt"))
>>> ani = po.plot.synthesis_animate(met)
>>> # Save the video (here we're saving it as a .gif)
>>> ani.save("animate-example-1.gif")
.. image:: animate-example-1.gif
This function can only be used if
:meth:`~plenoptic.Metamer.synthesize` was called with
``store_progress``.
>>> import plenoptic as po
>>> img = po.data.einstein()
>>> model = po.models.Gaussian(30).eval()
>>> po.remove_grad(model)
>>> met = po.Metamer(img, model)
>>> met.to(torch.float64)
>>> met.synthesize(5)
>>> ani = po.plot.synthesis_animate(met)
Traceback (most recent call last):
ValueError: When synthesis_object.store_progress=False, cannot animate!
If model has its own ``plot_representation`` method, this function will use it
for plotting the representation error (see
:func:`~plenoptic.models.PortillaSimoncelli.plot_representation`
):
.. plot::
:context: close-figs
>>> img = po.data.reptile_skin()
>>> model = po.models.PortillaSimoncelli(img.shape[-2:])
>>> met = po.MetamerCTF(img, model, po.loss.l2_norm)
>>> met.to(torch.float64)
>>> met.load(po.data.fetch_data("example_metamerCTF_ps.pt"))
>>> ani = po.plot.synthesis_animate(met)
>>> # Save the video (here we're saving it as a .gif)
>>> ani.save("animate-example-2.gif")
.. image:: animate-example-2.gif
Change the included plots:
.. plot::
:context: close-figs
>>> included_plots = ["synthesis_loss", "synthesis_histogram"]
>>> ani = po.plot.synthesis_animate(met, included_plots=included_plots)
>>> # Save the video (here we're saving it as a .gif)
>>> ani.save("animate-example-3.gif")
.. image:: animate-example-3.gif
Adjust width of included plots:
.. plot::
:context: close-figs
>>> width_ratios = {"metamer_representation_error": 3}
>>> ani = po.plot.synthesis_animate(met, width_ratios=width_ratios)
>>> # Save the video (here we're saving it as a .gif)
>>> ani.save("animate-example-4.gif")
.. image:: animate-example-4.gif
Change the arrangement of the plots, creating some empty axes:
.. plot::
:context: close-figs
>>> axes_idx = {"misc": [0, 3], "synthesis_loss": 4}
>>> ani = po.plot.synthesis_animate(met, axes_idx=axes_idx)
>>> # Save the video (here we're saving it as a .gif)
>>> ani.save("animate-example-5.gif")
.. image:: animate-example-5.gif
Plot on an existing figure, with already existing plots:
.. plot::
:context: close-figs
>>> fig, axes = plt.subplots(1, 4, figsize=(16, 4))
>>> axes[0].plot(torch.rand(100))
[<matplotlib.lines.Line2D ...>]
>>> # specify misc: 0 so we don't plot on top of this axis.
>>> axes_idx = {"misc": 0}
>>> ani = po.plot.synthesis_animate(met, fig=fig, axes_idx=axes_idx)
>>> # Save the video (here we're saving it as a .gif)
>>> ani.save("animate-example-6.gif")
.. image:: animate-example-6.gif
Specify additional keyword arguments to one of the underlying plots:
.. plot::
:context: close-figs
>>> ani = po.plot.synthesis_animate(
... met,
... synthesis_loss_kwargs={"plot_penalties": True},
... synthesis_imshow_kwargs={"zoom": 0.5},
... )
>>> # Save the video (here we're saving it as a .gif)
>>> ani.save("animate-example-7.gif")
.. image:: animate-example-7.gif
Plot for :class:`~plenoptic.MADCompetition` object. Note the plots
are different:
.. plot::
:context: close-figs
>>> img = po.data.einstein().to(torch.float64)
>>> def ds_ssim(x, y):
... return 1 - po.metric.ssim(x, y, weighted=True, pad="reflect")
>>> mad = po.MADCompetition(img, ds_ssim, po.metric.mse, "max", 1e6)
>>> mad.load(po.data.fetch_data("example_mad.pt"))
>>> ani = po.plot.synthesis_animate(mad)
>>> # Save the video (here we're saving it as a .gif)
>>> ani.save("animate-example-8.gif")
.. image:: animate-example-8.gif
"""
if not isinstance(synthesis_object, (Metamer, MADCompetition)):
raise TypeError(
"synthesis_object must be a MADCompetition or Metamer object but got"
f" {type(synthesis_object)}"
)
if not synthesis_object.store_progress:
raise ValueError("When synthesis_object.store_progress=False, cannot animate!")
# rescale_ylim only relevant for metamer_representation_error plot
if isinstance(synthesis_object, Metamer):
rescale_ylim_interval = _get_rescale_ylim(synthesis_object, rescale_ylim)
fig, axes_dict = _synthesis_status(
synthesis_object,
batch_idx,
channel_idx,
0,
included_plots,
fig,
axes_idx,
figsize,
width_ratios,
**kwargs,
)
# grab the artist for the loss plot (we don't need to do this for the
# metamer or representation plot, because we use the update_plot
# function for that)
if "synthesis_loss" in axes_dict:
scat = [ax.collections[0] for ax in axes_dict["synthesis_loss"]]
if "synthesis_imshow" in axes_dict:
ax = axes_dict["synthesis_imshow"]
# replace the bit of the title that specifies the range,
# since we don't make any promises about that. we have to do
# this here because we need the figure to have been created
ax.set_title(re.sub(r"\n range: .* \n", "\n\n", ax.get_title()))
if (
"metamer_representation_error" in axes_dict
and synthesis_object.target_representation.ndimension() == 4
):
# replace the bit of the title that specifies the range,
# since we don't make any promises about that. we have to do
# this here because we need the figure to have been created
for ax in axes_dict["metamer_representation_error"]:
ax.set_title(re.sub(r"\n range: .* \n", "\n\n", ax.get_title()))
if isinstance(synthesis_object, Metamer):
saved_synth = synthesis_object.saved_metamer
# this plot shows the metamer loss, which requires subtracting the penalty off
# of the objective function
met_loss = (
synthesis_object.losses
- synthesis_object.penalty_lambda * synthesis_object.penalties
)
losses = [met_loss]
elif isinstance(synthesis_object, MADCompetition):
saved_synth = synthesis_object.saved_mad_image
losses = [
synthesis_object.reference_metric_loss,
synthesis_object.optimized_metric_loss,
]
def movie_plot(i: int) -> list[mpl.artist.Artist]:
"""
Matplotlib function for animation.
Update plots for frame ``i``.
Parameters
----------
i
The frame to plot.
Returns
-------
artists
The updated matplotlib artists.
""" # numpydoc ignore=EX01
artists = []
if "synthesis_imshow" in axes_dict:
artists.extend(
display.update_plot(
axes_dict["synthesis_imshow"],
data=saved_synth[i],
batch_idx=batch_idx,
)
)
if "metamer_representation_error" in axes_dict:
# this warning is not relevant for animate
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="loss iteration and iteration for"
)
rep_error = _representation_error(
synthesis_object,
iteration=min(
i * synthesis_object.store_progress,
len(synthesis_object.losses) - 1,
),
)
# we pass rep_error_axes to update, and we've grabbed
# the right things above
artists.extend(
display.update_plot(
axes_dict["metamer_representation_error"],
batch_idx=batch_idx,
model=synthesis_object.model,
data=rep_error,
)
)
if (
(i + 1) % rescale_ylim_interval == 0
and synthesis_object.target_representation.ndimension() == 3
):
display._rescale_ylim(
axes_dict["metamer_representation_error"], rep_error
)
if "synthesis_histogram" in axes_dict:
# this is the dumbest way to do this, but it's simple --
# clearing the axes can cause problems if the user has, for
# example, changed the tick locator or formatter. not sure how
# to handle this best right now
axes_dict["synthesis_histogram"].clear()
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="loss iteration and iteration for"
)
synthesis_histogram(
synthesis_object,
batch_idx=batch_idx,
channel_idx=channel_idx,
iteration=min(
i * synthesis_object.store_progress,
len(synthesis_object.losses) - 1,
),
ax=axes_dict["synthesis_histogram"],
)
if "synthesis_loss" in axes_dict:
# loss always contains values from every iteration, but everything
# else will be subsampled.
x_val = synthesis_object._convert_iteration(i, False)
for sc_artist, loss in zip(scat, losses):
sc_artist.set_offsets((x_val, loss[x_val]))
artists.append(sc_artist)
# as long as blitting is True, need to return a sequence of artists
return artists
# don't need an init_func, since we handle initialization ourselves
anim = mpl.animation.FuncAnimation(
fig,
movie_plot,
frames=len(saved_synth),
blit=True,
interval=1000.0 / framerate,
repeat=False,
)
plt.close(fig)
return anim