Source code for plenoptic.plot.metamer

"""Plots for understanding Metamer objects."""  # numpydoc ignore=EX01

from typing import Any, Literal

import matplotlib as mpl
import matplotlib.pyplot as plt
from torch import Tensor

from .._synthesize import Metamer
from . import display

__all__ = [
    "metamer_representation_error",
]


def __dir__() -> list[str]:
    return __all__


def _representation_error(
    metamer: Metamer,
    iteration: int | None = None,
    iteration_selection: Literal["floor", "ceiling", "round"] = "round",
    **kwargs: Any,
) -> Tensor:
    r"""
    Get the representation error.

    This is ``metamer.model(metamer) - target_representation)``. If
    ``iteration`` is not ``None``, we use
    ``metamer.model(saved_metamer[iteration])`` instead.

    Parameters
    ----------
    metamer
        Metamer object whose representation error we want to compute.
    iteration
        Which iteration to display. If ``None``, we show the most recent one.
        Negative values are also allowed. If ``iteration!=None`` and
        ``metamer.store_progress>1`` (that is, the metamer was not cached on every
        iteration), then we show the cached metamer from the nearest iteration.
    iteration_selection

        How to select the relevant iteration from :attr:`saved_metamer`
        when the request iteration wasn't stored.

        When synthesis was run with ``store_progress=n`` (where ``n>1``),
        metamers are only saved every ``n`` iterations. If you request an
        iteration where a metamer wasn't saved, this determines which available
        iteration is used instead:

        * ``"floor"``: use the closest saved iteration **before** the
          requested one.

        * ``"ceiling"``: use the closest saved iteration **after** the
          requested one.

        * ``"round"``: use the closest saved iteration.

    **kwargs
        Passed to ``metamer.model.forward``.

    Returns
    -------
    representation_error
        The representation error at the specified iteration, for displaying.

    Raises
    ------
    IndexError
        If ``iteration`` takes an illegal value.

    Warns
    -----
    UserWarning
        If the iteration for the used metamer is not the same as the argument
        ``iteration`` (because e.g., you set ``iteration=3`` but
        ``metamer.store_progress=2``).
    """  # numpydoc ignore=EX01
    if iteration is not None:
        progress = metamer.get_progress(iteration)
        image = progress["saved_metamer"].to(metamer.target_representation.device)
        metamer_rep = metamer.model(image, **kwargs)
    else:
        metamer_rep = metamer.model(metamer.metamer, **kwargs)
    return metamer_rep - metamer.target_representation


[docs] def metamer_representation_error( metamer: Metamer, batch_idx: int = 0, iteration: int | None = None, ylim: tuple[float, float] | None | Literal[False] = None, ax: mpl.axes.Axes | None = None, as_rgb: bool = False, **kwargs: Any, ) -> list[mpl.axes.Axes]: r""" Plot representation error showing how close we are to convergence. We plot ``_representation_error(metamer, iteration)``. For more details, see :func:`plenoptic.plot.plot_representation`. Parameters ---------- metamer Metamer object whose synthesized metamer we want to display. batch_idx Which index to take from the batch dimension. iteration Which iteration to display. If ``None``, we show the most recent one. Negative values are also allowed. If ``iteration!=None`` and ``metamer.store_progress>1`` (that is, the metamer was not cached on every iteration), then we show the cached metamer from the nearest iteration. ylim If ``ylim`` is ``None``, we sets the axes' y-limits to be ``(-y_max, y_max)``, where ``y_max=np.abs(data).max()``. If it's ``False``, we do nothing. If a tuple, we use that range. ax Pre-existing axes for plot. If ``None``, we call :func:`matplotlib.pyplot.gca`. as_rgb The representation can be image-like with multiple channels, and we have no way to determine whether it should be represented as an RGB image or not, so the user must set this flag to tell us. It will be ignored if the response doesn't look image-like or if the model has its own ``plot_representation_error()`` method. Else, it will be passed to :func:`~plenoptic.plot.imshow`, see that methods docstring for details. **kwargs Passed to ``metamer.model.forward``. Returns ------- axes : List of created axes. Raises ------ IndexError If ``iteration`` takes an illegal value. Warns ----- UserWarning If the iteration for the metamer used to compute the error is not the same as the argument ``iteration`` (because e.g., you set ``iteration=3`` but ``metamer.store_progress=2``). See Also -------- :func:`~plenoptic.plot.plot_representation` Function used by this one to plot representation. :func:`~plenoptic.plot.synthesis_status` Create a figure combining this with other axis-level plots to summarize synthesis status at a given iteration. :func:`~plenoptic.plot.synthesis_animate` Create a video animating this and other axis-level plots changing over the course of synthesis. Examples -------- .. plot:: :context: reset >>> import plenoptic as po >>> import torch >>> img = po.data.einstein() >>> model = po.models.Gaussian(30).eval() >>> po.remove_grad(model) >>> met = po.Metamer(img, model) >>> met.to(torch.float64) >>> met.load(po.data.fetch_data("example_metamer_gaussian.pt")) >>> po.plot.metamer_representation_error(met) [<Axes: title=...Representation error...>] Plot on an existing axis: .. plot:: :context: close-figs >>> import matplotlib.pyplot >>> fig, axes = plt.subplots(1, 2, figsize=(8, 4)) >>> po.plot.metamer_representation_error(met, ax=axes[1]) [<Axes: title=...Representation error...>] The function uses :func:`~plenoptic.plot.plot_representation`, which switches between :func:`~plenoptic.plot.imshow` and :func:`~plenoptic.plot.stem_plot` based on the shape of the model's output: .. plot:: :context: close-figs >>> # Flatten the last two dimensions of the output, so it looks like a vector. >>> class TestModel(po.models.Gaussian): ... def __init__(self, *args, **kwargs): ... super().__init__(*args, **kwargs) ... ... def forward(self, x): ... return super().forward(x).flatten(-2) >>> model = TestModel(30).eval() >>> po.remove_grad(model) >>> met = po.Metamer(img, model) >>> met.to(torch.float64) >>> met.synthesize(5) >>> po.plot.metamer_representation_error(met) [<Axes: title=...Representation error...>] If model has its own ``plot_representation`` method, this function will use it, potentially creating multiple axes (see :func:`~plenoptic.models.PortillaSimoncelli.plot_representation` ): .. plot:: :context: close-figs >>> img = po.data.reptile_skin() >>> model = po.models.PortillaSimoncelli(img.shape[-2:]) >>> met = po.MetamerCTF(img, model, po.loss.l2_norm) >>> met.to(torch.float64) >>> met.load(po.data.fetch_data("example_metamerCTF_ps.pt")) >>> po.plot.metamer_representation_error(met) [<Axes: ...>, ..., <Axes: ...>] If plotting on an existing axis, this function will sub-divide that axis as needed: .. plot:: :context: close-figs >>> fig, axes = plt.subplots(1, 2, figsize=(8, 4)) >>> po.plot.synthesis_imshow(met, ax=axes[0]) <Axes: title=...Metamer[0] [iteration=150]...> >>> po.plot.metamer_representation_error(met, ax=axes[1]) [<Axes: ...>, ..., <Axes: ...>] """ representation_error = _representation_error( metamer=metamer, iteration=iteration, **kwargs ) if ax is None: ax = plt.gca() return display.plot_representation( metamer.model, representation_error, ax, title="Representation error", ylim=ylim, batch_idx=batch_idx, as_rgb=as_rgb, )