"""Plots for understanding MADCompetition objects."""
from typing import Any
import matplotlib as mpl
import matplotlib.pyplot as plt
import torch
from pyrtools.tools.display import make_figure as pt_make_figure
from .._synthesize import MADCompetition
from . import display
from .synthesis import synthesis_imshow, synthesis_loss
__all__ = [
"mad_imshow_all",
"mad_loss_all",
]
def __dir__() -> list[str]:
return __all__
[docs]
def mad_imshow_all(
mad_metric1_min: MADCompetition,
mad_metric2_min: MADCompetition,
mad_metric1_max: MADCompetition,
mad_metric2_max: MADCompetition,
metric1_name: str | None = None,
metric2_name: str | None = None,
zoom: int | float = 1,
**kwargs: Any,
) -> mpl.figure.Figure:
"""
Display all MAD Competition images.
To generate a full set of MAD Competition images, you need four instances:
one for minimizing and maximizing each metric. This helper function creates
a figure to display the full set of images.
In addition to the four MAD Competition images, this also plots the initial
image from ``mad_metric1_min``, for comparison.
Parameters
----------
mad_metric1_min
``MADCompetition`` object that minimized the first metric.
mad_metric2_min
``MADCompetition`` object that minimized the second metric.
mad_metric1_max
``MADCompetition`` object that maximized the first metric.
mad_metric2_max
``MADCompetition`` object that maximized the second metric.
metric1_name
Name of the first metric. If ``None``, we use the name of the
``optimized_metric`` function from ``mad_metric1_min``.
metric2_name
Name of the second metric. If ``None``, we use the name of the
``optimized_metric`` function from ``mad_metric2_min``.
zoom
Ratio of display pixels to image pixels. See
:func:`~plenoptic.plot.synthesis_imshow` for details.
**kwargs
Passed to :func:`~plenoptic.plot.synthesis_imshow`.
Returns
-------
fig
Figure containing the images.
Raises
------
ValueError
If the four ``MADCompetition`` instances do not have the same ``image``
attribute.
See Also
--------
:func:`~plenoptic.plot.synthesis_imshow`
Display the image from a single :class:`~plenoptic.MADCompetition` instance.
:func:`~plenoptic.plot.synthesis_status`
Create a composite plot showing synthesis status of a single
:class:`~plenoptic.MADCompetition` instance.
Examples
--------
See the :ref:`MAD Competition <mad-nb>` :ref:`tutorial notebooks <mad-concept>`
in the User Guide of documentation for examples.
"""
# this is a bit of a hack right now, because they don't all have same
# initial image
if not torch.allclose(mad_metric1_min.image, mad_metric2_min.image):
raise ValueError("All four instances of MADCompetition must have same image!")
if not torch.allclose(mad_metric1_min.image, mad_metric1_max.image):
raise ValueError("All four instances of MADCompetition must have same image!")
if not torch.allclose(mad_metric1_min.image, mad_metric2_max.image):
raise ValueError("All four instances of MADCompetition must have same image!")
if metric1_name is None:
metric1_name = mad_metric1_min.optimized_metric.__name__
if metric2_name is None:
metric2_name = mad_metric2_min.optimized_metric.__name__
fig = pt_make_figure(3, 2, [zoom * i for i in mad_metric1_min.image.shape[-2:]])
mads = [mad_metric1_min, mad_metric1_max, mad_metric2_min, mad_metric2_max]
titles = [
f"Minimize {metric1_name}",
f"Maximize {metric1_name}",
f"Minimize {metric2_name}",
f"Maximize {metric2_name}",
]
# we're only plotting one image here, so if the user wants multiple
# channels, they must be RGB
if kwargs.get("channel_idx") is None and mad_metric1_min.initial_image.shape[1] > 1:
as_rgb = True
else:
as_rgb = False
display.imshow(
mad_metric1_min.image,
ax=fig.axes[0],
title="Reference image",
zoom=zoom,
as_rgb=as_rgb,
**kwargs,
)
display.imshow(
mad_metric1_min.initial_image,
ax=fig.axes[1],
title="Initial (noisy) image",
zoom=zoom,
as_rgb=as_rgb,
**kwargs,
)
for ax, mad, title in zip(fig.axes[2:], mads, titles):
synthesis_imshow(mad, zoom=zoom, ax=ax, title=title, **kwargs)
return fig
[docs]
def mad_loss_all(
mad_metric1_min: MADCompetition,
mad_metric2_min: MADCompetition,
mad_metric1_max: MADCompetition,
mad_metric2_max: MADCompetition,
metric1_name: str | None = None,
metric2_name: str | None = None,
metric1_kwargs: dict = {"c": "C0"},
metric2_kwargs: dict = {"c": "C1"},
min_kwargs: dict = {"linestyle": "--"},
max_kwargs: dict = {"linestyle": "-"},
figsize: tuple[int, int] = (10, 5),
) -> mpl.figure.Figure:
"""
Plot loss for full set of MAD Competiton instances.
To generate a full set of MAD Competition images, you need four instances:
one for minimizing and maximizing each metric. This helper function creates
a two-axis figure to display the loss for this full set.
Parameters
----------
mad_metric1_min
``MADCompetition`` object that minimized the first metric.
mad_metric2_min
``MADCompetition`` object that minimized the second metric.
mad_metric1_max
``MADCompetition`` object that maximized the first metric.
mad_metric2_max
``MADCompetition`` object that maximized the second metric.
metric1_name
Name of the first metric. If ``None``, we use the name of the
``optimized_metric`` function from ``mad_metric1_min``.
metric2_name
Name of the second metric. If ``None``, we use the name of the
``optimized_metric`` function from ``mad_metric2_min``.
metric1_kwargs
Dictionary of arguments to pass to :func:`matplotlib.pyplot.plot` to identify
synthesis instance where the first metric was being optimized.
metric2_kwargs
Dictionary of arguments to pass to :func:`matplotlib.pyplot.plot` to identify
synthesis instance where the second metric was being optimized.
min_kwargs
Dictionary of arguments to pass to :func:`matplotlib.pyplot.plot` to identify
synthesis instance where ``optimized_metric`` was being minimized.
max_kwargs
Dictionary of arguments to pass to :func:`matplotlib.pyplot.plot` to identify
synthesis instance where ``optimized_metric`` was being maximized.
figsize
Size of the figure we create.
Returns
-------
fig
Figure containing the plot.
Raises
------
ValueError
If the four ``MADCompetition`` instances do not have the same ``image``
attribute.
See Also
--------
:func:`~plenoptic.plot.synthesis_loss`
Display the loss from a single :class:`~plenoptic.MADCompetition` instance.
:func:`~plenoptic.plot.synthesis_status`
Create a composite plot showing synthesis status of a single
:class:`~plenoptic.MADCompetition` instance.
Examples
--------
See the :ref:`MAD Competition <mad-nb>` :ref:`tutorial notebooks <mad-concept>`
in the User Guide of documentation for examples.
"""
if not torch.allclose(mad_metric1_min.image, mad_metric2_min.image):
raise ValueError("All four instances of MADCompetition must have same image!")
if not torch.allclose(mad_metric1_min.image, mad_metric1_max.image):
raise ValueError("All four instances of MADCompetition must have same image!")
if not torch.allclose(mad_metric1_min.image, mad_metric2_max.image):
raise ValueError("All four instances of MADCompetition must have same image!")
if metric1_name is None:
metric1_name = mad_metric1_min.optimized_metric.__name__
if metric2_name is None:
metric2_name = mad_metric2_min.optimized_metric.__name__
fig, axes = plt.subplots(1, 2, figsize=figsize)
synthesis_loss(
mad_metric1_min,
ax=axes,
label=f"Minimize {metric1_name}",
**metric1_kwargs,
**min_kwargs,
)
synthesis_loss(
mad_metric1_max,
ax=axes,
label=f"Maximize {metric1_name}",
**metric1_kwargs,
**max_kwargs,
)
# we pass the axes backwards here because the fixed and synthesis metrics are
# the opposite as they are in the instances above.
synthesis_loss(
mad_metric2_min,
ax=axes[::-1],
label=f"Minimize {metric2_name}",
**metric2_kwargs,
**min_kwargs,
)
synthesis_loss(
mad_metric2_max,
ax=axes[::-1],
label=f"Maximize {metric2_name}",
**metric2_kwargs,
**max_kwargs,
)
axes[0].set(ylabel="Loss", title=metric2_name)
axes[1].set(ylabel="Loss", title=metric1_name)
axes[1].legend(loc="center left", bbox_to_anchor=(1.1, 0.5))
return fig