Source code for plenoptic.synthesize.mad_competition

"""Run MAD Competition."""

import contextlib
import warnings
from collections import OrderedDict
from collections.abc import Callable
from typing import Literal

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch
from pyrtools.tools.display import make_figure as pt_make_figure
from torch import Tensor
from tqdm.auto import tqdm

from ..tools import data, display, optim
from ..tools.convergence import loss_convergence
from ..tools.validate import validate_input, validate_metric
from .synthesis import OptimizedSynthesis


[docs] class MADCompetition(OptimizedSynthesis): r"""Synthesize a single maximally-differentiating image for two metrics. Following the basic idea in [1]_, this class synthesizes a maximally-differentiating image for two given metrics, based on a given image. We start by adding noise to this image and then iteratively adjusting its pixels so as to either minimize or maximize ``optimized_metric`` while holding the value of ``reference_metric`` constant. MADCompetiton accepts two metrics as its input. These should be callables that take two images and return a single number, and that number should be 0 if and only if the two images are identical (thus, the larger the number, the more different the two images). Note that a full set of images MAD Competition images consists of two pairs: a maximal and a minimal image for each metric. A single instantiation of ``MADCompetition`` will generate one of these four images. Parameters ---------- image : A 4d tensor, this is the image whose representation we wish to match. If this is not a tensor, we try to cast it as one. optimized_metric : The metric whose value you wish to minimize or maximize, which takes two tensors and returns a scalar. Because of the limitations of pickle, you cannot use a lambda function for this if you wish to save the MADCompetition object (i.e., it must be one of our built-in functions or defined using a `def` statement) reference_metric : The metric whose value you wish to keep fixed, which takes two tensors and returns a scalar. Because of the limitations of pickle, you cannot use a lambda function for this if you wish to save the MADCompetition object (i.e., it must be one of our built-in functions or defined using a `def` statement) minmax : Whether you wish to minimize or maximize ``optimized_metric``. initial_noise : Standard deviation of the Gaussian noise used to initialize ``mad_image`` from ``image``. metric_tradeoff_lambda : Lambda to multiply by ``reference_metric`` loss and add to ``optimized_metric`` loss. If ``None``, we pick a value so the two initial losses are approximately equal in magnitude. range_penalty_lambda : Lambda to multiply by range penalty and add to loss. allowable_range : Range (inclusive) of allowed pixel values. Any values outside this range will be penalized. Attributes ---------- mad_image : torch.Tensor The Maximally-Differentiating Image. This may be unfinished depending on how many iterations we've run for. initial_image : torch.Tensor The initial ``mad_image``, which we obtain by adding Gaussian noise to ``image``. losses : list A list of the objective function's loss over iterations. gradient_norm : list A list of the gradient's L2 norm over iterations. pixel_change_norm : list A list containing the L2 norm of the pixel change over iterations (``pixel_change_norm[i]`` is the pixel change norm in ``mad_image`` between iterations ``i`` and ``i-1``). optimized_metric_loss : list A list of the ``optimized_metric`` loss over iterations. reference_metric_loss : list A list of the ``reference_metric`` loss over iterations. saved_mad_image : torch.Tensor Saved ``self.mad_image`` for later examination. References ---------- .. [1] Wang, Z., & Simoncelli, E. P. (2008). Maximum differentiation (MAD) competition: A methodology for comparing computational models of perceptual discriminability. Journal of Vision, 8(12), 1–13. https://dx.doi.org/10.1167/8.12.8 """ def __init__( self, image: Tensor, optimized_metric: torch.nn.Module | Callable[[Tensor, Tensor], Tensor], reference_metric: torch.nn.Module | Callable[[Tensor, Tensor], Tensor], minmax: Literal["min", "max"], initial_noise: float = 0.1, metric_tradeoff_lambda: float | None = None, range_penalty_lambda: float = 0.1, allowed_range: tuple[float, float] = (0, 1), ): super().__init__(range_penalty_lambda, allowed_range) validate_input(image, allowed_range=allowed_range) validate_metric( optimized_metric, image_shape=image.shape, image_dtype=image.dtype, device=image.device, ) validate_metric( reference_metric, image_shape=image.shape, image_dtype=image.dtype, device=image.device, ) self._optimized_metric = optimized_metric self._reference_metric = reference_metric self._image = image.detach() self._image_shape = image.shape self.scheduler = None self._optimized_metric_loss = [] self._reference_metric_loss = [] if minmax not in ["min", "max"]: raise ValueError( "synthesis_target must be one of {'min', 'max'}, but got " f"value {minmax} instead!" ) self._minmax = minmax self._initialize(initial_noise) # If no metric_tradeoff_lambda is specified, pick one that gets them to # approximately the same magnitude if metric_tradeoff_lambda is None: loss_ratio = torch.as_tensor( self.optimized_metric_loss[-1] / self.reference_metric_loss[-1], dtype=image.dtype, ) metric_tradeoff_lambda = torch.pow( torch.as_tensor(10), torch.round(torch.log10(loss_ratio)) ).item() warnings.warn( "Since metric_tradeoff_lamda was None, automatically set" f" to {metric_tradeoff_lambda} to roughly balance metrics." ) self._metric_tradeoff_lambda = metric_tradeoff_lambda self._store_progress = None self._saved_mad_image = [] def _initialize(self, initial_noise: float = 0.1): """Initialize the synthesized image. Initialize ``self.mad_image`` attribute to be ``image`` plus Gaussian noise with user-specified standard deviation. Parameters ---------- initial_noise : Standard deviation of the Gaussian noise used to initialize ``mad_image`` from ``image``. """ mad_image = self.image + initial_noise * torch.randn_like(self.image) mad_image = mad_image.clamp(*self.allowed_range) self._initial_image = mad_image.clone() mad_image.requires_grad_() self._mad_image = mad_image self._reference_metric_target = self.reference_metric( self.image, self.mad_image ).item() self._reference_metric_loss.append(self._reference_metric_target) self._optimized_metric_loss.append( self.optimized_metric(self.image, self.mad_image).item() )
[docs] def synthesize( self, max_iter: int = 100, optimizer: torch.optim.Optimizer | None = None, scheduler: torch.optim.lr_scheduler._LRScheduler | None = None, store_progress: bool | int = False, stop_criterion: float = 1e-4, stop_iters_to_check: int = 50, ): r"""Synthesize a MAD image. Update the pixels of ``initial_image`` to maximize or minimize (depending on the value of ``minmax``) the value of ``optimized_metric(image, mad_image)`` while keeping the value of ``reference_metric(image, mad_image)`` constant. We run this until either we reach ``max_iter`` or the change over the past ``stop_iters_to_check`` iterations is less than ``stop_criterion``, whichever comes first Parameters ---------- max_iter : The maximum number of iterations to run before we end synthesis (unless we hit the stop criterion). optimizer : The optimizer to use. If None and this is the first time calling synthesize, we use Adam(lr=.01, amsgrad=True); if synthesize has been called before, this must be None and we reuse the previous optimizer. scheduler : The learning rate scheduler to use. If None, we don't use one. store_progress : Whether we should store the representation of the MAD image in progress on every iteration. If False, we don't save anything. If True, we save every iteration. If an int, we save every ``store_progress`` iterations (note then that 0 is the same as False and 1 the same as True). stop_criterion : If the loss over the past ``stop_iters_to_check`` has changed less than ``stop_criterion``, we terminate synthesis. stop_iters_to_check : How many iterations back to check in order to see if the loss has stopped decreasing (for ``stop_criterion``). """ # initialize the optimizer and scheduler self._initialize_optimizer(optimizer, scheduler) # get ready to store progress self.store_progress = store_progress pbar = tqdm(range(max_iter)) for _ in pbar: # update saved_* attrs. len(losses) gives the total number of # iterations and will be correct across calls to `synthesize` self._store(len(self.losses)) loss = self._optimizer_step(pbar) if not torch.isfinite(loss): raise ValueError("Found a NaN in loss during optimization.") if self._check_convergence(stop_criterion, stop_iters_to_check): warnings.warn("Loss has converged, stopping synthesis") break pbar.close()
[docs] def objective_function( self, mad_image: Tensor | None = None, image: Tensor | None = None, ) -> Tensor: r"""Compute the MADCompetition synthesis loss. This computes: .. math:: t L_1(x, \hat{x}) &+ \lambda_1 [L_2(x, x+\epsilon) - L_2(x, \hat{x})]^2 \\ &+ \lambda_2 \mathcal{B}(\hat{x}) where :math:`t` is 1 if ``self.minmax`` is ``'min'`` and -1 if it's ``'max'``, :math:`L_1` is ``self.optimized_metric``, :math:`L_2` is ``self.reference_metric``, :math:`x` is ``self.image``, :math:`\hat{x}` is ``self.mad_image``, :math:`\epsilon` is the initial noise, :math:`\mathcal{B}` is the quadratic bound penalty, :math:`\lambda_1` is ``self.metric_tradeoff_lambda`` and :math:`\lambda_2` is ``self.range_penalty_lambda``. Parameters ---------- mad_image : Proposed ``mad_image``, :math:`\hat{x}` in the above equation. If None, use ``self.mad_image``. image : Proposed ``image``, :math:`x` in the above equation. If None, use ``self.image``. Returns ------- loss """ if image is None: image = self.image if mad_image is None: mad_image = self.mad_image synth_target = {"min": 1, "max": -1}[self.minmax] synthesis_loss = self.optimized_metric(image, mad_image) fixed_loss = ( self._reference_metric_target - self.reference_metric(image, mad_image) ).pow(2) range_penalty = optim.penalize_range(mad_image, self.allowed_range) return ( synth_target * synthesis_loss + self.metric_tradeoff_lambda * fixed_loss + self.range_penalty_lambda * range_penalty )
def _optimizer_step(self, pbar: tqdm) -> Tensor: r"""Compute and propagate gradients, then step the optimizer to update mad_image. Parameters ---------- pbar A tqdm progress-bar, which we update with a postfix describing the current loss, gradient norm, and learning rate (it already tells us which iteration and the time elapsed). Returns ------- loss 1-element tensor containing the loss on this step """ last_iter_mad_image = self.mad_image.clone() loss = self.optimizer.step(self._closure) self._losses.append(loss.item()) grad_norm = torch.linalg.vector_norm(self.mad_image.grad.data, ord=2, dim=None) self._gradient_norm.append(grad_norm.item()) fm = self.reference_metric(self.image, self.mad_image) self._reference_metric_loss.append(fm.item()) sm = self.optimized_metric(self.image, self.mad_image) self._optimized_metric_loss.append(sm.item()) # optionally step the scheduler if self.scheduler is not None: self.scheduler.step(loss.item()) pixel_change_norm = torch.linalg.vector_norm( self.mad_image - last_iter_mad_image, ord=2, dim=None ) self._pixel_change_norm.append(pixel_change_norm.item()) # add extra info here if you want it to show up in progress bar pbar.set_postfix( OrderedDict( loss=f"{loss.item():.04e}", learning_rate=self.optimizer.param_groups[0]["lr"], gradient_norm=f"{grad_norm.item():.04e}", pixel_change_norm=f"{pixel_change_norm.item():.04e}", reference_metric=f"{fm.item():.04e}", optimized_metric=f"{sm.item():.04e}", ) ) return loss def _check_convergence(self, stop_criterion, stop_iters_to_check): r"""Check whether the loss has stabilized and, if so, return True. Have we been synthesizing for ``stop_iters_to_check`` iterations? | | no yes | '---->Is ``abs(synth.loss[-1] - synth.losses[-stop_iters_to_check]) < stop_criterion``? | no | | | yes <-------' | | '------> return ``True`` | '---------> return ``False`` Parameters ---------- stop_criterion : If the loss over the past ``stop_iters_to_check`` has changed less than ``stop_criterion``, we terminate synthesis. stop_iters_to_check : How many iterations back to check in order to see if the loss has stopped decreasing (for ``stop_criterion``). Returns ------- loss_stabilized : Whether the loss has stabilized or not. """ # noqa: E501 return loss_convergence(self, stop_criterion, stop_iters_to_check) def _initialize_optimizer(self, optimizer, scheduler): """Initialize optimizer and scheduler.""" super()._initialize_optimizer(optimizer, "mad_image") self.scheduler = scheduler def _store(self, i: int) -> bool: """Store mad_image anbd model response, if appropriate. if it's the right iteration, we update ``saved_mad_image`` Parameters ---------- i the current iteration Returns ------- stored : True if we stored this iteration, False if not. """ if self.store_progress and (i % self.store_progress == 0): # want these to always be on cpu, to reduce memory use for GPUs self._saved_mad_image.append(self.mad_image.clone().to("cpu")) stored = True else: stored = False return stored
[docs] def save(self, file_path: str): r"""Save all relevant variables in .pt file. Note that if store_progress is True, this will probably be very large. See ``load`` docstring for an example of use. Parameters ---------- file_path : str The path to save the MADCompetition object to """ # this copies the attributes dict so we don't actually remove the # model attribute in the next line attrs = {k: v for k, v in vars(self).items()} # if the metrics are Modules, then we don't want to save them. If # they're functions then saving them is fine. if isinstance(self.optimized_metric, torch.nn.Module): attrs.pop("_optimized_metric") if isinstance(self.reference_metric, torch.nn.Module): attrs.pop("_reference_metric") super().save(file_path, attrs=attrs)
[docs] def to(self, *args, **kwargs): r"""Moves and/or casts the parameters and buffers. This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) .. function:: to(dtype, non_blocking=False) .. function:: to(tensor, non_blocking=False) Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point desired :attr:`dtype` s. In addition, this method will only cast the floating point parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples. .. note:: This method modifies the module in-place. Args: device (:class:`torch.device`): the desired device of the parameters and buffers in this module dtype (:class:`torch.dtype`): the desired floating point type of the floating point parameters and buffers in this module tensor (torch.Tensor): Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module """ attrs = ["_initial_image", "_image", "_mad_image", "_saved_mad_image"] super().to(*args, attrs=attrs, **kwargs) # if the metrics are Modules, then we should pass them as well. If # they're functions then nothing needs to be done. with contextlib.suppress(AttributeError): self.reference_metric.to(*args, **kwargs) with contextlib.suppress(AttributeError): self.optimized_metric.to(*args, **kwargs)
[docs] def load( self, file_path: str, map_location: str | None = None, **pickle_load_args, ): r"""Load all relevant stuff from a .pt file. This should be called by an initialized ``MADCompetition`` object -- we will ensure that ``image``, ``metric_tradeoff_lambda``, ``range_penalty_lambda``, ``allowed_range``, ``minmax`` are all identical, and that ``reference_metric`` and ``optimize_metric`` return identical values. Note this operates in place and so doesn't return anything. Parameters ---------- file_path : str The path to load the synthesis object from map_location : str, optional map_location argument to pass to ``torch.load``. If you save stuff that was being run on a GPU and are loading onto a CPU, you'll need this to make sure everything lines up properly. This should be structured like the str you would pass to ``torch.device`` pickle_load_args : any additional kwargs will be added to ``pickle_module.load`` via ``torch.load``, see that function's docstring for details. Examples -------- >>> mad = po.synth.MADCompetition(img, model) >>> mad.synthesize(max_iter=10, store_progress=True) >>> mad.save('mad.pt') >>> mad_copy = po.synth.MADCompetition(img, model) >>> mad_copy.load('mad.pt') Note that you must create a new instance of the Synthesis object and *then* load. """ check_attributes = [ "_image", "_metric_tradeoff_lambda", "_range_penalty_lambda", "_allowed_range", "_minmax", ] check_loss_functions = ["_reference_metric", "_optimized_metric"] super().load( file_path, map_location=map_location, check_attributes=check_attributes, check_loss_functions=check_loss_functions, **pickle_load_args, ) # make this require a grad again self.mad_image.requires_grad_() # these are always supposed to be on cpu, but may get copied over to # gpu on load (which can cause problems when resuming synthesis), so # fix that. if len(self._saved_mad_image) and self._saved_mad_image[0].device.type != "cpu": self._saved_mad_image = [mad.to("cpu") for mad in self._saved_mad_image]
@property def mad_image(self): return self._mad_image @property def optimized_metric(self): return self._optimized_metric @property def reference_metric(self): return self._reference_metric @property def image(self): return self._image @property def initial_image(self): return self._initial_image @property def reference_metric_loss(self): return torch.as_tensor(self._reference_metric_loss) @property def optimized_metric_loss(self): return torch.as_tensor(self._optimized_metric_loss) @property def metric_tradeoff_lambda(self): return self._metric_tradeoff_lambda @property def minmax(self): return self._minmax @property def saved_mad_image(self): return torch.stack(self._saved_mad_image)
[docs] def plot_loss( mad: MADCompetition, iteration: int | None = None, axes: list[mpl.axes.Axes] | mpl.axes.Axes | None = None, **kwargs, ) -> mpl.axes.Axes: """Plot metric losses. Plots ``mad.optimized_metric_loss`` and ``mad.reference_metric_loss`` on two separate axes, over all iterations. Also plots a red dot at ``iteration``, to highlight the loss there. If ``iteration=None``, then the dot will be at the final iteration. Parameters ---------- mad : MADCompetition object whose loss we want to plot. iteration : Which iteration to display. If None, the default, we show the most recent one. Negative values are also allowed. axes : Pre-existing axes for plot. If a list of axes, must be the two axes to use for this plot. If a single axis, we'll split it in half horizontally. If None, we call ``plt.gca()``. kwargs : passed to plt.plot Returns ------- axes : The matplotlib axes containing the plot. Notes ----- We plot ``abs(mad.losses)`` because if we're maximizing the synthesis metric, we minimized its negative. By plotting the absolute value, we get them all on the same scale. """ if iteration is None: loss_idx = len(mad.losses) - 1 elif iteration < 0: loss_idx = len(mad.losses) + iteration # Work-around for x-value alignment else: loss_idx = iteration if axes is None: axes = plt.gca() if not hasattr(axes, "__iter__"): axes = display.clean_up_axes( axes, False, ["top", "right", "bottom", "left"], ["x", "y"] ) gs = axes.get_subplotspec().subgridspec(1, 2) fig = axes.figure axes = [fig.add_subplot(gs[0, 0]), fig.add_subplot(gs[0, 1])] losses = [mad.reference_metric_loss, mad.optimized_metric_loss] names = ["Reference metric loss", "Optimized metric loss"] for ax, loss, name in zip(axes, losses, names): ax.plot(loss, **kwargs) ax.scatter(loss_idx, loss[loss_idx], c="r") ax.set(xlabel="Synthesis iteration", ylabel=name) return ax
[docs] def display_mad_image( mad: MADCompetition, batch_idx: int = 0, channel_idx: int | None = None, zoom: float | None = None, iteration: int | None = None, ax: mpl.axes.Axes | None = None, title: str = "MADCompetition", **kwargs, ) -> mpl.axes.Axes: """Display MAD image. You can specify what iteration to view by using the ``iteration`` arg. The default, ``None``, shows the final one. We use ``plenoptic.imshow`` to display the synthesized image and attempt to automatically find the most reasonable zoom value. You can override this value using the zoom arg, but remember that ``plenoptic.imshow`` is opinionated about the size of the resulting image and will throw an Exception if the axis created is not big enough for the selected zoom. Parameters ---------- mad : MADCompetition object whose MAD image we want to display. batch_idx : Which index to take from the batch dimension channel_idx : Which index to take from the channel dimension. If None, we assume image is RGB(A) and show all channels. zoom : How much to zoom in / enlarge the synthesized image, the ratio of display pixels to image pixels. If None (the default), we attempt to find the best value ourselves. iteration : Which iteration to display. If None, the default, we show the most recent one. Negative values are also allowed. ax : Pre-existing axes for plot. If None, we call ``plt.gca()``. title : Title of the axis. kwargs : Passed to ``plenoptic.imshow`` Returns ------- ax : The matplotlib axes containing the plot. """ image = mad.mad_image if iteration is None else mad.saved_mad_image[iteration] if batch_idx is None: raise ValueError("batch_idx must be an integer!") # we're only plotting one image here, so if the user wants multiple # channels, they must be RGB as_rgb = bool(channel_idx is None and image.shape[1] > 1) if ax is None: ax = plt.gca() display.imshow( image, ax=ax, title=title, zoom=zoom, batch_idx=batch_idx, channel_idx=channel_idx, as_rgb=as_rgb, **kwargs, ) ax.xaxis.set_visible(False) ax.yaxis.set_visible(False) return ax
[docs] def plot_pixel_values( mad: MADCompetition, batch_idx: int = 0, channel_idx: int | None = None, iteration: int | None = None, ylim: tuple[float] | Literal[False] = False, ax: mpl.axes.Axes | None = None, **kwargs, ) -> mpl.axes.Axes: r"""Plot histogram of pixel values of reference and MAD images. As a way to check the distributions of pixel intensities and see if there's any values outside the allowed range Parameters ---------- mad : MADCompetition object with the images whose pixel values we want to compare. batch_idx : Which index to take from the batch dimension channel_idx : Which index to take from the channel dimension. If None, we use all channels (assumed use-case is RGB(A) images). iteration : Which iteration to display. If None, the default, we show the most recent one. Negative values are also allowed. ylim : if tuple, the ylimit to set for this axis. If False, we leave it untouched ax : Pre-existing axes for plot. If None, we call ``plt.gca()``. kwargs : passed to plt.hist Returns ------- ax : Creates axes. """ def _freedman_diaconis_bins(a): """Calculate number of hist bins using Freedman-Diaconis rule. copied from seaborn.""" # From https://stats.stackexchange.com/questions/798/ a = np.asarray(a) iqr = np.diff(np.percentile(a, [0.25, 0.75]))[0] if len(a) < 2: return 1 h = 2 * iqr / (len(a) ** (1 / 3)) # fall back to sqrt(a) bins if iqr is 0 if h == 0: return int(np.sqrt(a.size)) else: return int(np.ceil((a.max() - a.min()) / h)) kwargs.setdefault("alpha", 0.4) if iteration is None: mad_image = mad.mad_image[batch_idx] else: mad_image = mad.saved_mad_image[iteration, batch_idx] image = mad.image[batch_idx] if channel_idx is not None: image = image[channel_idx] mad_image = mad_image[channel_idx] if ax is None: ax = plt.gca() image = data.to_numpy(image).flatten() mad_image = data.to_numpy(mad_image).flatten() ax.hist( image, bins=min(_freedman_diaconis_bins(image), 50), label="Reference image", **kwargs, ) ax.hist( mad_image, bins=min(_freedman_diaconis_bins(image), 50), label="MAD image", **kwargs, ) ax.legend() if ylim: ax.set_ylim(ylim) ax.set_title("Histogram of pixel values") return ax
def _check_included_plots(to_check: list[str] | dict[str, int], to_check_name: str): """Check whether the user wanted us to create plots that we can't. Helper function for plot_synthesis_status and animate. Raises a ValueError to_check contains any values that are not allowed. Parameters ---------- to_check : The variable to check. We ensure that it doesn't contain any extra (not allowed) values. If a list, we check its contents. If a dict, we check its keys. to_check_name : Name of the `to_check` variable, used in the error message. """ allowed_vals = [ "display_mad_image", "plot_loss", "plot_pixel_values", "misc", ] try: vals = to_check.keys() except AttributeError: vals = to_check not_allowed = [v for v in vals if v not in allowed_vals] if not_allowed: raise ValueError( f"{to_check_name} contained value(s) {not_allowed}! " f"Only {allowed_vals} are permissible!" ) def _setup_synthesis_fig( fig: mpl.figure.Figure | None = None, axes_idx: dict[str, int] = {}, figsize: tuple[float] | None = None, included_plots: list[str] = [ "display_mad_image", "plot_loss", "plot_pixel_values", ], display_mad_image_width: float = 1, plot_loss_width: float = 2, plot_pixel_values_width: float = 1, ) -> tuple[mpl.figure.Figure, list[mpl.axes.Axes], dict[str, int]]: """Set up figure for plot_synthesis_status. Creates figure with enough axes for the all the plots you want. Will also create index in axes_idx for them if you haven't done so already. By default, all axes will be on the same row and have the same width. If you want them to be on different rows, will need to initialize fig yourself and pass that in. For changing width, change the corresponding *_width arg, which gives width relative to other axes. So if you want the axis for the loss plot to be three times as wide as the others, set loss_width=3. Parameters ---------- fig : The figure to plot on or None. If None, we create a new figure axes_idx : Dictionary specifying which axes contains which type of plot, allows for more fine-grained control of the resulting figure. Probably only helpful if fig is also defined. Possible keys: loss, pixel_values, misc. Values should all be ints. If you tell this function to create a plot that doesn't have a corresponding key, we find the lowest int that is not already in the dict, so if you have axes that you want unchanged, place their idx in misc. figsize : The size of the figure to create. It may take a little bit of playing around to find a reasonable value. If None, we attempt to make our best guess, aiming to have relative width=1 correspond to 5 included_plots : Which plots to include. Must be some subset of ``'display_mad_image', 'plot_loss', 'plot_pixel_values'``. display_mad_image_width : Relative width of the axis for the synthesized image. plot_loss_width : Relative width of the axis for loss plot. plot_pixel_values_width : Relative width of the axis for image pixel intensities histograms. Returns ------- fig : The figure to plot on axes : List or array of axes contained in fig axes_idx : Dictionary identifying the idx for each plot type """ n_subplots = 0 axes_idx = axes_idx.copy() width_ratios = [] if "display_mad_image" in included_plots: n_subplots += 1 width_ratios.append(display_mad_image_width) if "display_mad_image" not in axes_idx: axes_idx["display_mad_image"] = data._find_min_int(axes_idx.values()) if "plot_loss" in included_plots: n_subplots += 1 width_ratios.append(plot_loss_width) if "plot_loss" not in axes_idx: axes_idx["plot_loss"] = data._find_min_int(axes_idx.values()) if "plot_pixel_values" in included_plots: n_subplots += 1 width_ratios.append(plot_pixel_values_width) if "plot_pixel_values" not in axes_idx: axes_idx["plot_pixel_values"] = data._find_min_int(axes_idx.values()) if fig is None: width_ratios = np.array(width_ratios) if figsize is None: # we want (5, 5) for each subplot, with a bit of room between # each subplot figsize = ((width_ratios * 5).sum() + width_ratios.sum() - 1, 5) width_ratios = width_ratios / width_ratios.sum() fig, axes = plt.subplots( 1, n_subplots, figsize=figsize, gridspec_kw={"width_ratios": width_ratios}, ) if n_subplots == 1: axes = [axes] else: axes = fig.axes # make sure misc contains all the empty axes misc_axes = axes_idx.get("misc", []) if not hasattr(misc_axes, "__iter__"): misc_axes = [misc_axes] all_axes = [] for i in axes_idx.values(): # so if it's a list of ints if hasattr(i, "__iter__"): all_axes.extend(i) else: all_axes.append(i) misc_axes += [i for i, _ in enumerate(fig.axes) if i not in all_axes] axes_idx["misc"] = misc_axes return fig, axes, axes_idx
[docs] def plot_synthesis_status( mad: MADCompetition, batch_idx: int = 0, channel_idx: int | None = None, iteration: int | None = None, vrange: tuple[float] | str = "indep1", zoom: float | None = None, fig: mpl.figure.Figure | None = None, axes_idx: dict[str, int] = {}, figsize: tuple[float] | None = None, included_plots: list[str] = [ "display_mad_image", "plot_loss", "plot_pixel_values", ], width_ratios: dict[str, float] = {}, ) -> tuple[mpl.figure.Figure, dict[str, int]]: r"""Make a plot showing synthesis status. We create several subplots to analyze this. By default, we create two subplots on a new figure: the first one contains the MAD image and the second contains the loss. There is an optional additional plot: pixel_values, a histogram of pixel values of the synthesized and target images. All of these (including the default plots) can be toggled using their corresponding boolean flags, and can be created separately using the method with the name `plot_{flag}`. Parameters ---------- mad : MADCompetition object whose status we want to plot. batch_idx : Which index to take from the batch dimension channel_idx : Which index to take from the channel dimension. If None, we use all channels (assumed use-case is RGB(A) image). iteration : Which iteration to display. If None, the default, we show the most recent one. Negative values are also allowed. vrange : The vrange option to pass to ``display_mad_image()``. See docstring of ``imshow`` for possible values. zoom : How much to zoom in / enlarge the synthesized image, the ratio of display pixels to image pixels. If None (the default), we attempt to find the best value ourselves. fig : if None, we create a new figure. otherwise we assume this is an empty figure that has the appropriate size and number of subplots axes_idx : Dictionary specifying which axes contains which type of plot, allows for more fine-grained control of the resulting figure. Probably only helpful if fig is also defined. Possible keys: ``'mad_image', 'loss', 'pixel_values', 'misc'``. Values should all be ints. If you tell this function to create a plot that doesn't have a corresponding key, we find the lowest int that is not already in the dict, so if you have axes that you want unchanged, place their idx in ``'misc'``. figsize : The size of the figure to create. It may take a little bit of playing around to find a reasonable value. If None, we attempt to make our best guess, aiming to have each axis be of size (5, 5) included_plots : Which plots to include. Must be some subset of ``'display_mad_image', 'plot_loss', 'plot_pixel_values'``. width_ratios : By default, all plots axes will have the same width. To change that, specify their relative widths using the keys: ['display_mad_image', 'plot_loss', 'plot_pixel_values'] and floats specifying their relative width. Any not included will be assumed to be 1. Returns ------- fig : The figure containing this plot axes_idx : Dictionary giving index of each plot. """ if iteration is not None and not mad.store_progress: raise ValueError( "synthesis() was run with store_progress=False, " "cannot specify which iteration to plot (only" " last one, with iteration=None)" ) if mad.mad_image.ndim not in [3, 4]: raise ValueError( "plot_synthesis_status() expects 3 or 4d data;" "unexpected behavior will result otherwise!" ) _check_included_plots(included_plots, "included_plots") _check_included_plots(width_ratios, "width_ratios") _check_included_plots(axes_idx, "axes_idx") width_ratios = {f"{k}_width": v for k, v in width_ratios.items()} fig, axes, axes_idx = _setup_synthesis_fig( fig, axes_idx, figsize, included_plots, **width_ratios ) if "display_mad_image" in included_plots: display_mad_image( mad, batch_idx=batch_idx, channel_idx=channel_idx, iteration=iteration, ax=axes[axes_idx["display_mad_image"]], zoom=zoom, vrange=vrange, ) if "plot_loss" in included_plots: plot_loss(mad, iteration=iteration, axes=axes[axes_idx["plot_loss"]]) # this function creates a single axis for loss, which plot_loss then # split into two. this makes sure the right two axes are present in the # dict all_axes = [] for i in axes_idx.values(): # so if it's a list of ints if hasattr(i, "__iter__"): all_axes.extend(i) else: all_axes.append(i) new_axes = [i for i, _ in enumerate(fig.axes) if i not in all_axes] axes_idx["plot_loss"] = new_axes if "plot_pixel_values" in included_plots: plot_pixel_values( mad, batch_idx=batch_idx, channel_idx=channel_idx, iteration=iteration, ax=axes[axes_idx["plot_pixel_values"]], ) return fig, axes_idx
[docs] def animate( mad: MADCompetition, framerate: int = 10, batch_idx: int = 0, channel_idx: int | None = None, zoom: float | None = None, fig: mpl.figure.Figure | None = None, axes_idx: dict[str, int] = {}, figsize: tuple[float] | None = None, included_plots: list[str] = [ "display_mad_image", "plot_loss", "plot_pixel_values", ], width_ratios: dict[str, float] = {}, ) -> mpl.animation.FuncAnimation: r"""Animate synthesis progress. This is essentially the figure produced by ``mad.plot_synthesis_status`` animated over time, for each stored iteration. This functions returns a matplotlib FuncAnimation object. See our documentation (e.g., [Quickstart](https://docs.plenoptic.org/docs/branch/main/tutorials/00_quickstart.html)) for examples on how to view it in a Jupyter notebook. In order to save, use ``anim.save(filename)``. In either case, this can take a while and you'll need the appropriate writer installed and on your path, e.g., ffmpeg, imagemagick, etc). See [matplotlib documentation](https://matplotlib.org/stable/api/animation_api.html) for more details. Parameters ---------- mad : MADCompetition object whose synthesis we want to animate. framerate : How many frames a second to display. batch_idx : Which index to take from the batch dimension channel_idx : Which index to take from the channel dimension. If None, we use all channels (assumed use-case is RGB(A) image). zoom : How much to zoom in / enlarge the synthesized image, the ratio of display pixels to image pixels. If None (the default), we attempt to find the best value ourselves. fig : If None, create the figure from scratch. Else, should be an empty figure with enough axes (the expected use here is have same-size movies with different plots). axes_idx : Dictionary specifying which axes contains which type of plot, allows for more fine-grained control of the resulting figure. Probably only helpful if fig is also defined. Possible keys: ``'mad_image', 'loss', 'pixel_values', 'misc'``. Values should all be ints. If you tell this function to create a plot that doesn't have a corresponding key, we find the lowest int that is not already in the dict, so if you have axes that you want unchanged, place their idx in ``'misc'``. figsize : The size of the figure to create. It may take a little bit of playing around to find a reasonable value. If None, we attempt to make our best guess, aiming to have each axis be of size (5, 5) width_ratios : By default, all plots axes will have the same width. To change that, specify their relative widths using the keys: ['display_mad_image', 'plot_loss', 'plot_pixel_values'] and floats specifying their relative width. Any not included will be assumed to be 1. Returns ------- anim : The animation object. In order to view, must convert to HTML or save. Notes ----- By default, we use the ffmpeg backend, which requires that you have ffmpeg installed and on your path (https://ffmpeg.org/download.html). To use a different, use the matplotlib rcParams: `matplotlib.rcParams['animation.writer'] = writer`, see https://matplotlib.org/stable/api/animation_api.html#writer-classes for more details. For displaying in a jupyter notebook, ffmpeg appears to be required. """ if not mad.store_progress: raise ValueError( "synthesize() was run with store_progress=False, cannot animate!" ) if mad.mad_image.ndim not in [3, 4]: raise ValueError( "animate() expects 3 or 4d data; unexpected" " behavior will result otherwise!" ) _check_included_plots(included_plots, "included_plots") _check_included_plots(width_ratios, "width_ratios") _check_included_plots(axes_idx, "axes_idx") # we run plot_synthesis_status to initialize the figure if either fig is # None or if there are no titles on any axes, which we assume means that # it's an empty figure if fig is None or not any([ax.get_title() for ax in fig.axes]): fig, axes_idx = plot_synthesis_status( mad=mad, batch_idx=batch_idx, channel_idx=channel_idx, iteration=0, figsize=figsize, zoom=zoom, fig=fig, included_plots=included_plots, axes_idx=axes_idx, width_ratios=width_ratios, ) # grab the artist for the second plot (we don't need to do this for the # MAD image plot, because we use the update_plot function for that) if "plot_loss" in included_plots: scat = [fig.axes[i].collections[0] for i in axes_idx["plot_loss"]] # can also have multiple plots def movie_plot(i): artists = [] if "display_mad_image" in included_plots: artists.extend( display.update_plot( fig.axes[axes_idx["display_mad_image"]], data=mad.saved_mad_image[i], batch_idx=batch_idx, ) ) if "plot_pixel_values" in included_plots: # this is the dumbest way to do this, but it's simple -- # clearing the axes can cause problems if the user has, for # example, changed the tick locator or formatter. not sure how # to handle this best right now fig.axes[axes_idx["plot_pixel_values"]].clear() plot_pixel_values( mad, batch_idx=batch_idx, channel_idx=channel_idx, iteration=i, ax=fig.axes[axes_idx["plot_pixel_values"]], ) if "plot_loss" in included_plots: # loss always contains values from every iteration, but everything # else will be subsampled. x_val = i * mad.store_progress scat[0].set_offsets((x_val, mad.reference_metric_loss[x_val])) scat[1].set_offsets((x_val, mad.optimized_metric_loss[x_val])) artists.extend(scat) # as long as blitting is True, need to return a sequence of artists return artists # don't need an init_func, since we handle initialization ourselves anim = mpl.animation.FuncAnimation( fig, movie_plot, frames=len(mad.saved_mad_image), blit=True, interval=1000.0 / framerate, repeat=False, ) plt.close(fig) return anim
[docs] def display_mad_image_all( mad_metric1_min: MADCompetition, mad_metric2_min: MADCompetition, mad_metric1_max: MADCompetition, mad_metric2_max: MADCompetition, metric1_name: str | None = None, metric2_name: str | None = None, zoom: int | float = 1, **kwargs, ) -> mpl.figure.Figure: """Display all MAD Competition images. To generate a full set of MAD Competition images, you need four instances: one for minimizing and maximizing each metric. This helper function creates a figure to display the full set of images. In addition to the four MAD Competition images, this also plots the initial image from `mad_metric1_min`, for comparison. Note that all four MADCompetition instances must have the same `image`. Parameters ---------- mad_metric1_min : MADCompetition object that minimized the first metric. mad_metric2_min : MADCompetition object that minimized the second metric. mad_metric1_max : MADCompetition object that maximized the first metric. mad_metric2_max : MADCompetition object that maximized the second metric. metric1_name : Name of the first metric. If None, we use the name of the `optimized_metric` function from `mad_metric1_min`. metric2_name : Name of the second metric. If None, we use the name of the `optimized_metric` function from `mad_metric2_min`. zoom : Ratio of display pixels to image pixels. See `plenoptic.imshow` for details. kwargs : Passed to `plenoptic.imshow`. Returns ------- fig : Figure containing the images. """ # this is a bit of a hack right now, because they don't all have same # initial image if not torch.allclose(mad_metric1_min.image, mad_metric2_min.image): raise ValueError("All four instances of MADCompetition must have same image!") if not torch.allclose(mad_metric1_min.image, mad_metric1_max.image): raise ValueError("All four instances of MADCompetition must have same image!") if not torch.allclose(mad_metric1_min.image, mad_metric2_max.image): raise ValueError("All four instances of MADCompetition must have same image!") if metric1_name is None: metric1_name = mad_metric1_min.optimized_metric.__name__ if metric2_name is None: metric2_name = mad_metric2_min.optimized_metric.__name__ fig = pt_make_figure(3, 2, [zoom * i for i in mad_metric1_min.image.shape[-2:]]) mads = [mad_metric1_min, mad_metric1_max, mad_metric2_min, mad_metric2_max] titles = [ f"Minimize {metric1_name}", f"Maximize {metric1_name}", f"Minimize {metric2_name}", f"Maximize {metric2_name}", ] # we're only plotting one image here, so if the user wants multiple # channels, they must be RGB if kwargs.get("channel_idx") is None and mad_metric1_min.initial_image.shape[1] > 1: as_rgb = True else: as_rgb = False display.imshow( mad_metric1_min.image, ax=fig.axes[0], title="Reference image", zoom=zoom, as_rgb=as_rgb, **kwargs, ) display.imshow( mad_metric1_min.initial_image, ax=fig.axes[1], title="Initial (noisy) image", zoom=zoom, as_rgb=as_rgb, **kwargs, ) for ax, mad, title in zip(fig.axes[2:], mads, titles): display_mad_image(mad, zoom=zoom, ax=ax, title=title, **kwargs) return fig
[docs] def plot_loss_all( mad_metric1_min: MADCompetition, mad_metric2_min: MADCompetition, mad_metric1_max: MADCompetition, mad_metric2_max: MADCompetition, metric1_name: str | None = None, metric2_name: str | None = None, metric1_kwargs: dict = {"c": "C0"}, metric2_kwargs: dict = {"c": "C1"}, min_kwargs: dict = {"linestyle": "--"}, max_kwargs: dict = {"linestyle": "-"}, figsize=(10, 5), ) -> mpl.figure.Figure: """Plot loss for full set of MAD Competiton instances. To generate a full set of MAD Competition images, you need four instances: one for minimizing and maximizing each metric. This helper function creates a two-axis figure to display the loss for this full set. Note that all four MADCompetition instances must have the same `image`. Parameters ---------- mad_metric1_min : MADCompetition object that minimized the first metric. mad_metric2_min : MADCompetition object that minimized the second metric. mad_metric1_max : MADCompetition object that maximized the first metric. mad_metric2_max : MADCompetition object that maximized the second metric. metric1_name : Name of the first metric. If None, we use the name of the `optimized_metric` function from `mad_metric1_min`. metric2_name : Name of the second metric. If None, we use the name of the `optimized_metric` function from `mad_metric2_min`. metric1_kwargs : Dictionary of arguments to pass to `matplotlib.pyplot.plot` to identify synthesis instance where the first metric was being optimized. metric2_kwargs : Dictionary of arguments to pass to `matplotlib.pyplot.plot` to identify synthesis instance where the second metric was being optimized. min_kwargs : Dictionary of arguments to pass to `matplotlib.pyplot.plot` to identify synthesis instance where `optimized_metric` was being minimized. max_kwargs : Dictionary of arguments to pass to `matplotlib.pyplot.plot` to identify synthesis instance where `optimized_metric` was being maximized. figsize : Size of the figure we create. Returns ------- fig : Figure containing the plot. """ if not torch.allclose(mad_metric1_min.image, mad_metric2_min.image): raise ValueError("All four instances of MADCompetition must have same image!") if not torch.allclose(mad_metric1_min.image, mad_metric1_max.image): raise ValueError("All four instances of MADCompetition must have same image!") if not torch.allclose(mad_metric1_min.image, mad_metric2_max.image): raise ValueError("All four instances of MADCompetition must have same image!") if metric1_name is None: metric1_name = mad_metric1_min.optimized_metric.__name__ if metric2_name is None: metric2_name = mad_metric2_min.optimized_metric.__name__ fig, axes = plt.subplots(1, 2, figsize=figsize) plot_loss( mad_metric1_min, axes=axes, label=f"Minimize {metric1_name}", **metric1_kwargs, **min_kwargs, ) plot_loss( mad_metric1_max, axes=axes, label=f"Maximize {metric1_name}", **metric1_kwargs, **max_kwargs, ) # we pass the axes backwards here because the fixed and synthesis metrics are # the opposite as they are in the instances above. plot_loss( mad_metric2_min, axes=axes[::-1], label=f"Minimize {metric2_name}", **metric2_kwargs, **min_kwargs, ) plot_loss( mad_metric2_max, axes=axes[::-1], label=f"Maximize {metric2_name}", **metric2_kwargs, **max_kwargs, ) axes[0].set(ylabel="Loss", title=metric2_name) axes[1].set(ylabel="Loss", title=metric1_name) axes[1].legend(loc="center left", bbox_to_anchor=(1.1, 0.5)) return fig