Source code for plenoptic._synthesize.mad_competition

"""
Maximum Differentiation Competition.

Maximum Differentiation Competition synthesizes images which maximally distinguish
between a pair of metrics. Generally speaking, they are synthesized in pairs (two images
that one metric considers identical and the other considers as different as possible) or
groups of four (a pair of such pairs, one for each of the two metrics). They emphasize
the features that distinguish metrics, highlighting the features that one metric
considers important that the other is invariant to.
"""

import contextlib
import warnings
from collections import OrderedDict
from collections.abc import Callable
from typing import Any, Literal

import numpy as np
import torch
from torch import Tensor
from tqdm.auto import tqdm

from .. import regularize
from ..convergence import _loss_convergence
from ..validate import validate_input, validate_metric
from .synthesis import _OptimizedSynthesis

__all__ = [
    "MADCompetition",
]


def __dir__() -> list[str]:
    return __all__


[docs] class MADCompetition(_OptimizedSynthesis): r""" Synthesize a single maximally-differentiating image for two metrics. Following the basic idea in [1]_, this class synthesizes a maximally-differentiating image for two given metrics, based on a given image. We start by adding noise to this image and then iteratively adjusting its pixels so as to either minimize or maximize ``optimized_metric`` while holding the value of ``reference_metric`` constant. MADCompetiton accepts two metrics as its input. These should be callables that take two images and return a single number, and that number should be 0 if and only if the two images are identical (thus, the larger the number, the more different the two images). Note that a full set of MAD Competition images consists of two pairs: a maximal and a minimal image for each metric. A single instantiation of ``MADCompetition`` will generate one of these four images. Parameters ---------- image A tensor, this is the image we use as the reference point. optimized_metric The metric whose value you wish to minimize or maximize, which takes two tensors and returns a scalar. reference_metric The metric whose value you wish to keep fixed, which takes two tensors and returns a scalar. minmax Whether you wish to minimize or maximize ``optimized_metric``. metric_tradeoff_lambda Lambda to multiply by ``reference_metric`` loss and add to ``optimized_metric`` loss. If ``None``, we pick a value so the two initial losses are approximately equal in magnitude. penalty_function A function applied to the metamer during optimization, that returns a scalar penalty to be minimized. By penalizing certain properties of the image, like pixels values outside an allowed range, we can constrain those image properties. See :ref:`metamer-regularization` in the documentation for details and examples. penalty_lambda Weight of the penalty term. Must be non-negative. References ---------- .. [1] Wang, Z., & Simoncelli, E. P. (2008). Maximum differentiation (MAD) competition: A methodology for comparing computational models of perceptual discriminability. Journal of Vision, 8(12), 1–13. https://dx.doi.org/10.1167/8.12.8 """ def __init__( self, image: Tensor, optimized_metric: torch.nn.Module | Callable[[Tensor, Tensor], Tensor], reference_metric: torch.nn.Module | Callable[[Tensor, Tensor], Tensor], minmax: Literal["min", "max"], metric_tradeoff_lambda: float | None = None, penalty_function: Callable[[Tensor], Tensor] = regularize.penalize_range, penalty_lambda: float = 0.1, ): super().__init__( penalty_function=penalty_function, penalty_lambda=penalty_lambda ) validate_input(image) validate_metric( optimized_metric, image_shape=image.shape, image_dtype=image.dtype, device=image.device, ) validate_metric( reference_metric, image_shape=image.shape, image_dtype=image.dtype, device=image.device, ) self._optimized_metric = optimized_metric self._reference_metric = reference_metric self._image = image.detach() self._image_shape = image.shape self._scheduler = None self._scheduler_step_arg = False self._optimized_metric_loss = [] self._reference_metric_loss = [] if minmax not in ["min", "max"]: raise ValueError( "synthesis_target must be one of {'min', 'max'}, but got " f"value {minmax} instead!" ) self._mad_image = None self._initial_image = None self._reference_metric_target = None # If no metric_tradeoff_lambda is specified, pick one that gets them to # approximately the same magnitude if metric_tradeoff_lambda is None: other_image = torch.rand_like(image) optim_loss = optimized_metric(image, other_image) loss_ratio = optim_loss / reference_metric(image, other_image) metric_tradeoff_lambda = torch.pow( torch.as_tensor(10), torch.round(torch.log10(loss_ratio)) ).item() warnings.warn( "Since metric_tradeoff_lamda was None, automatically set" f" to {metric_tradeoff_lambda} to roughly balance metrics." ) self._metric_tradeoff_lambda = metric_tradeoff_lambda self._minmax = minmax self._store_progress = None self._saved_mad_image = []
[docs] def setup( self, initial_noise: float | None = None, optimizer: torch.optim.Optimizer | None = None, optimizer_kwargs: dict | None = None, scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, scheduler_kwargs: dict | None = None, ): """ Initialize the MAD image, optimizer, and scheduler. Can only be called once. If ``load()`` has been called, ``initial_noise`` must be None. Parameters ---------- initial_noise :attr:`mad_image` is initialized to ``self.image + initial_noise * torch.randn_like(self.image)``, so this gives the standard deviation of the Gaussian noise. If ``None``, we use a value of 0.1. optimizer The un-initialized optimizer object to use. If ``None``, we use Adam. optimizer_kwargs The keyword arguments to pass to the optimizer on initialization. If ``None``, we use ``{"lr": .01}`` and, if optimizer is ``None``, ``{"amsgrad": True}``. scheduler The un-initialized learning rate scheduler object to use. If ``None``, we don't use one. scheduler_kwargs The keyword arguments to pass to the scheduler on initialization. Raises ------ ValueError If you try to set ``initial_noise`` after calling :func:`load`. ValueError If ``setup`` is called more than once or after :func:`synthesize`. Examples -------- Set initial noise: >>> import plenoptic as po >>> img = po.data.einstein() >>> mad = po.MADCompetition( ... img, ... lambda x, y: 1 - po.metric.ssim(x, y), ... po.metric.mse, ... "min", ... metric_tradeoff_lambda=0.1, ... ) >>> mad.setup(1) >>> mad.synthesize(10) Set optimizer: >>> import plenoptic as po >>> img = po.data.einstein() >>> mad = po.MADCompetition( ... img, ... lambda x, y: 1 - po.metric.ssim(x, y), ... po.metric.mse, ... "min", ... metric_tradeoff_lambda=0.1, ... ) >>> mad.setup(optimizer=torch.optim.SGD, optimizer_kwargs={"lr": 0.01}) >>> mad.synthesize(10) Use with save/load. Only the optimizer object is necessary, its kwargs and the initial noise are handled by load. >>> import plenoptic as po >>> img = po.data.einstein() >>> mad = po.MADCompetition( ... img, ... lambda x, y: 1 - po.metric.ssim(x, y), ... po.metric.mse, ... "min", ... metric_tradeoff_lambda=0.1, ... ) >>> mad.setup(1, optimizer=torch.optim.SGD, optimizer_kwargs={"lr": 0.01}) >>> mad.synthesize(10) >>> mad.save("mad_setup.pt") >>> mad = po.MADCompetition( ... img, ... lambda x, y: 1 - po.metric.ssim(x, y), ... po.metric.mse, ... "min", ... metric_tradeoff_lambda=0.1, ... ) >>> mad.load("mad_setup.pt") >>> mad.setup(optimizer=torch.optim.SGD) >>> mad.synthesize(10) """ if self._mad_image is None: if initial_noise is None: initial_noise = 0.1 mad_image = self.image + initial_noise * torch.randn_like(self.image) self._initial_image = mad_image.clone() mad_image.requires_grad_() self._mad_image = mad_image self._reference_metric_target = self.reference_metric( self.image, self.mad_image ).item() self._reference_metric_loss.append(self._reference_metric_target) self._optimized_metric_loss.append( self.optimized_metric(self.image, self.mad_image).item() ) else: if self._loaded: if initial_noise is not None: raise ValueError("Cannot set initial_noise after calling load()!") else: raise ValueError( "setup() can only be called once and must be called" " before synthesize()!" ) # initialize the optimizer self._initialize_optimizer(optimizer, self.mad_image, optimizer_kwargs) # and scheduler self._initialize_scheduler(scheduler, self.optimizer, scheduler_kwargs) # reset _loaded, if everything ran successfully self._loaded = False
[docs] def synthesize( self, max_iter: int = 100, store_progress: bool | int = False, stop_criterion: float = 1e-4, stop_iters_to_check: int = 50, ): r""" Synthesize a MAD image. Update the pixels of :attr:`initial_image` to maximize or minimize (depending on the value of ``minmax``) the value of ``optimized_metric(image, mad_image)`` while keeping the value of ``reference_metric(image, mad_image)`` constant. We run this until either we reach ``max_iter`` or the loss changes less than ``stop_criterion`` over the past ``stop_iters_to_check`` iterations, whichever comes first. Parameters ---------- max_iter The maximum number of iterations to run before we end synthesis (unless we hit the stop criterion). store_progress Whether we should store the MAD image in progress during synthesis. If ``False``, we don't save anything. If True, we save every iteration. If an int, we save every ``store_progress`` iterations (note then that ``0`` is the same as ``False`` and ``1`` the same as ``True``). stop_criterion If the loss over the past ``stop_iters_to_check`` has changed less than ``stop_criterion``, we terminate synthesis. stop_iters_to_check How many iterations back to check in order to see if the loss has stopped decreasing (for ``stop_criterion``). Raises ------ ValueError If we find a NaN during optimization. """ # if setup hasn't been called manually, call it now. if self._mad_image is None or isinstance(self._scheduler, tuple): self.setup() self._current_loss = None self._current_penalty = None # get ready to store progress self.store_progress = store_progress pbar = tqdm(range(max_iter)) for _ in pbar: # update saved_* attrs. len(_losses) gives the total number of # iterations and will be correct across calls to `synthesize` self._store(len(self._losses)) loss = self._optimizer_step(pbar) if not np.isfinite(loss): raise ValueError("Found a NaN in loss during optimization.") if self._check_convergence(stop_criterion, stop_iters_to_check): warnings.warn("Loss has converged, stopping synthesis") break # compute current loss, no need to compute gradient with torch.no_grad(): self._current_loss = self.objective_function().item() self._current_penalty = self.penalty_function(self.mad_image).item() pbar.close()
def _objective_function( self, mad_image: Tensor | None = None, image: Tensor | None = None, ) -> tuple[Tensor, Tensor, Tensor]: """ Compute objective function components. This calls :attr:`optimized_metric`, :attr:`reference_metric`, and :attr:`penalty_function` and returns their output, without combining them. It is not meant to be called directly, but is used by both :func:`_closure` and (public) :func:`objective_function`. Parameters ---------- mad_image Proposed ``mad_image``. If ``None``, use ``self.mad_image``. image Proposed ``image``. If ``None``, use ``self.image``. Returns ------- sm 1-element tensor containing optimized_metric(image, mad_image). fm 1-element tensor containing reference_metric(image, mad_image). penalty 1-element tensor containing the penalty on this step. """ if image is None: image = self.image if mad_image is None: mad_image = self.mad_image # if this is empty, then self.mad_image hasn't been initialized if mad_image.numel() == 0: return torch.empty(0), torch.empty(0), torch.empty(0) sm = self.optimized_metric(image, mad_image) fm = self.reference_metric(image, mad_image) penalty = self.penalty_function(mad_image) return sm, fm, penalty
[docs] def objective_function( self, mad_image: Tensor | None = None, image: Tensor | None = None, ) -> Tensor: r""" Compute the MADCompetition synthesis loss. This computes: .. math:: t L_1(x, \hat{x}) &+ \lambda_1 [L_2(x, x+\epsilon) - L_2(x, \hat{x})]^2 \\ &+ \lambda_2 \mathcal{B}(\hat{x}) where :math:`t` is 1 if :attr:`minmax` is ``'min'`` and -1 if it's ``'max'``, :math:`L_1` is :attr:`optimized_metric`, :math:`L_2` is :attr:`reference_metric`, :math:`x` is :attr:`image`, :math:`\hat{x}` is :attr:`mad_image`, :math:`\epsilon` is the initial noise, :math:`\mathcal{B}` is the penalty function, :math:`\lambda_1` is :attr:`metric_tradeoff_lambda` and :math:`\lambda_2` is :attr:`penalty_lambda`. If :meth:`setup` or :meth:`synthesize` has not been called to initialize the MAD image, then this will return an empty tensor. Parameters ---------- mad_image Proposed ``mad_image``, :math:`\hat{x}` in the above equation. If ``None``, use ``self.mad_image``. image Proposed ``image``, :math:`x` in the above equation. If ``None``, use ``self.image``. Returns ------- loss 1-element tensor containing the loss on this step. """ if self._reference_metric_target is None: return torch.empty(0) sm, fm, penalty = self._objective_function(mad_image, image) synth_target = {"min": 1, "max": -1}[self.minmax] fixed_loss = (self._reference_metric_target - fm).pow(2) return ( synth_target * sm + self.metric_tradeoff_lambda * fixed_loss + self.penalty_lambda * penalty )
[docs] def get_progress( self, iteration: int, iteration_selection: Literal["floor", "ceiling", "round"] = "round", ) -> dict: """ Return dictionary summarizing synthesis progress at ``iteration``. This returns a dictionary containing info from :attr:`losses`, :attr:`pixel_change_norm`, :attr:`gradient_norm`, :attr:`penalties`, and :attr:`saved_mad_image` corresponding to ``iteration``. If synthesis was run with ``store_progress=False`` (and so we did not cache anything in :attr:`saved_mad_image`), then that key will be missing. If synthesis was run with ``store_progress>1``, we will grab the corresponding tensor from :attr:`saved_mad_image`, with behavior determined by ``iteration_selection``. The returned dictionary will additionally contain the keys: - ``"iteration"``: the (0-indexed positive) synthesis iteration that the values for :attr:`losses`, :attr:`pixel_change_norm`, :attr:`penalties` and :attr:`gradient_norm` come from. - If ``self.store_progress``, ``"store_progress_iteration"``: the (0-indexed positive) synthesis iteration that the value for :attr:`saved_mad_image` comes from. Note that for the most recent iteration (``iteration=-1`` or ``iteration=None`` or ``iteration==len(self.losses)-1``), we do not have values for :attr:`pixel_change_norm` or :attr:`gradient_norm`, since in this case we are showing the loss and value for the current MAD image. Parameters ---------- iteration Synthesis iteration to summarize. If ``None``, grab the most recent. Negative values are allowed. iteration_selection How to select the relevant iteration from :attr:`saved_mad_image` when the request iteration wasn't stored. When synthesis was run with ``store_progress=n`` (where ``n>1``), MAD images are only saved every ``n`` iterations. If you request an iteration where a MAD image wasn't saved, this determines which available iteration is used instead: * ``"floor"``: use the closest saved iteration **before** the requested one. * ``"ceiling"``: use the closest saved iteration **after** the requested one. * ``"round"``: use the closest saved iteration. Returns ------- progress_info Dictionary summarizing synthesis progress. Raises ------ IndexError If ``iteration`` takes an illegal value. Warns ----- UserWarning If the iteration used for ``saved_mad_image`` is not the same as the argument ``iteration`` (because e.g., you set ``iteration=3`` but ``self.store_progress=2``). """ return super().get_progress( iteration, iteration_selection, ["reference_metric_loss", "optimized_metric_loss"], store_progress_attributes=["saved_mad_image"], )
def _closure(self) -> float: r""" Calculate the gradient, before the optimization step. This enables optimization algorithms that perform several evaluations of the gradient before taking a step (e.g., second order methods like LBFGS or methods with line searches). Additionally, this is where ``loss`` is calculated, ``loss.backward()`` is called, and ``self._penalties``, ``self._reference_metric_loss``, and ``self._optimized_metric_loss`` are updated (but not ``self._losses``! that happens in ``_optimizer_step``). Returns ------- loss Loss of the current objective function. """ self.optimizer.zero_grad() sm, fm, penalty = self._objective_function() synth_target = {"min": 1, "max": -1}[self.minmax] fixed_loss = (self._reference_metric_target - fm).pow(2) loss = ( synth_target * sm + self.metric_tradeoff_lambda * fixed_loss + self.penalty_lambda * penalty ) loss.backward(retain_graph=False) self._reference_metric_tmp = fm.item() self._optimized_metric_tmp = sm.item() self._penalty_tmp = penalty.item() return loss.item() def _optimizer_step(self, pbar: tqdm) -> Tensor: r""" Compute and propagate gradients, then step optimizer to update mad_image. Parameters ---------- pbar A tqdm progress-bar, which we update with a postfix describing the current loss, gradient norm, and learning rate (it already tells us which iteration and the time elapsed). Returns ------- loss 1-element tensor containing the loss on this step. """ # numpydoc ignore=ES01 last_iter_mad_image = self.mad_image.clone() loss = self.optimizer.step(self._closure) self._losses.append(loss) self._penalties.append(self._penalty_tmp) self._reference_metric_loss.append(self._reference_metric_tmp) self._optimized_metric_loss.append(self._optimized_metric_tmp) grad_norm = torch.linalg.vector_norm( self.mad_image.grad.data, ord=2, dim=None ).item() self._gradient_norm.append(grad_norm) # optionally step the scheduler, passing loss if needed if self.scheduler is not None: if self._scheduler_step_arg: self.scheduler.step(loss) else: self.scheduler.step() pixel_change_norm = torch.linalg.vector_norm( self.mad_image - last_iter_mad_image, ord=2, dim=None ).item() self._pixel_change_norm.append(pixel_change_norm) # add extra info here if you want it to show up in progress bar pbar.set_postfix( OrderedDict( loss=f"{loss:.04e}", learning_rate=self.optimizer.param_groups[0]["lr"], penalty=f"{self._penalties[-1]:.04e}", gradient_norm=f"{grad_norm:.04e}", pixel_change_norm=f"{pixel_change_norm:.04e}", reference_metric=f"{self._reference_metric_loss[-1]:.04e}", optimized_metric=f"{self._optimized_metric_loss[-1]:.04e}", ) ) return loss def _check_convergence( self, stop_criterion: float, stop_iters_to_check: int ) -> bool: r""" Check whether the loss has stabilized and, if so, return True. Uses :func:`~plenoptic.convergence._loss_convergence`. Parameters ---------- stop_criterion If the loss over the past ``stop_iters_to_check`` has changed less than ``stop_criterion``, we terminate synthesis. stop_iters_to_check How many iterations back to check in order to see if the loss has stopped decreasing (for ``stop_criterion``). Returns ------- loss_stabilized Whether the loss has stabilized or not. """ return _loss_convergence(self, stop_criterion, stop_iters_to_check) def _store(self, i: int) -> bool: """ Store mad_image and model response, if appropriate. If it's the right iteration, we update :attr:`saved_mad_image`. Parameters ---------- i The current iteration. Returns ------- stored True if we stored this iteration, False if not. """ if self.store_progress and (i % self.store_progress == 0): # want these to always be on cpu, to reduce memory use for GPUs self._saved_mad_image.append(self.mad_image.clone().to("cpu")) stored = True else: stored = False return stored
[docs] def save(self, file_path: str): r""" Save all relevant variables in .pt file. Note that if ``store_progress`` is True, this will probably be very large. See :func:`load` docstring for an example of use. Parameters ---------- file_path The path to save the MADCompetition object to. """ save_io_attrs = [ ("_optimized_metric", ("_image", "_mad_image")), ("_reference_metric", ("_image", "_mad_image")), ("_penalty_function", ("_image",)), ] save_state_dict_attrs = ["_optimizer", "_scheduler"] super().save(file_path, save_io_attrs, save_state_dict_attrs)
[docs] def to(self, *args: Any, **kwargs: Any): r""" Move and/or casts the parameters and buffers. This can be called as .. code:: python to(device=None, dtype=None, non_blocking=False) .. code:: python to(dtype, non_blocking=False) .. code:: python to(tensor, non_blocking=False) Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point desired ``dtype``. In addition, this method will only cast the floating point parameters and buffers to ``dtype`` (if given). The integral parameters and buffers will be moved ``device``, if that is given, but with dtypes unchanged. When `on_blocking`` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See :meth:`torch.nn.Module.to` for examples. .. note:: This method modifies the module in-place. Parameters ---------- device : torch.device The desired device of the parameters and buffers in this module. dtype : torch.dtype The desired floating point type of the floating point parameters and buffers in this module. tensor : torch.Tensor Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module. """ # numpydoc ignore=PR01,PR02 attrs = ["_initial_image", "_image", "_mad_image", "_saved_mad_image"] super().to(*args, attrs=attrs, **kwargs) # if the metrics are Modules, then we should pass them as well. If # they're functions then nothing needs to be done. with contextlib.suppress(AttributeError): self.reference_metric.to(*args, **kwargs) with contextlib.suppress(AttributeError): self.optimized_metric.to(*args, **kwargs)
[docs] def load( self, file_path: str, map_location: str | None = None, raise_on_checks: bool = True, tensor_equality_atol: float = 1e-8, tensor_equality_rtol: float = 1e-5, **pickle_load_args: Any, ): r""" Load all relevant stuff from a .pt file. This must be called by a ``MADCompetition`` object initialized just like the saved object. Note this operates in place and so doesn't return anything. .. versionchanged:: 1.2 load behavior changed in a backwards-incompatible manner in order to compatible with breaking changes in torch 2.6. .. versionchanged:: 2.0.0 Adds ``raise_on_checks`` argument. Parameters ---------- file_path The path to load the synthesis object from. map_location Argument to pass to ``torch.load`` as ``map_location``. If you save stuff that was being run on a GPU and are loading onto a CPU, you'll need this to make sure everything lines up properly. This should be structured like the str you would pass to :class:`torch.device`. raise_on_checks During load, we perform several checks to ensure that the saved object was initialized in the same way as the loading object. This is to ensure that the model, image, etc. are all the same and avoid unpleasant surprises. If ``True``, we raise a ``ValueError`` if any of these checks fail. If ``False``, we instead raise a ``LoadWarning``. The intended use here is if you're loading something that was saved with an older version of plenoptic and you're sure that you're doing everything correctly. Note that different devices or dtypes will always result in a ``ValueError``. See :ref:`raise-on-checks` on the "Reproducibility and Compatibility" page of the documentation for more info. Additionally, note that, if the ``MADCompetition`` object itself has changed, we cannot ensure that methods are the same -- proceed at your own risk. tensor_equality_atol Absolute tolerance to use when checking for tensor equality during load, passed to :func:`torch.allclose`. It may be necessary to increase if you are saving and loading on two machines with torch built by different cuda versions. Be careful when changing this! See :class:`torch.finfo<torch.torch.finfo>` for more details about floating point precision of different data types (especially, ``eps``); if you have to increase this by more than 1 or 2 decades, then you are probably not dealing with a numerical issue. tensor_equality_rtol Relative tolerance to use when checking for tensor equality during load, passed to :func:`torch.allclose`. It may be necessary to increase if you are saving and loading on two machines with torch built by different cuda versions. Be careful when changing this! See :class:`torch.finfo<torch.torch.finfo>` for more details about floating point precision of different data types (especially, ``eps``); if you have to increase this by more than 1 or 2 decades, then you are probably not dealing with a numerical issue. **pickle_load_args Any additional kwargs will be added to ``pickle_module.load`` via :func:`torch.load`, see that function's docstring for details. Raises ------ ValueError If :func:`setup` or :func:`synthesize` has been called before this call to ``load``. ValueError If the object saved at ``file_path`` is not a ``MADCompetition`` object. ValueError If the saved and loading ``MADCompetition`` objects have a different value for any of :attr:`image`, :attr:`penalty_lambda`, :attr:`metric_tradeoff_lambda`, or :attr:`minmax`. ValueError If the behavior of :attr:`optimized_metric` or :attr:`reference_metric` is different between the saved and loading objects. Warns ----- UserWarning If :func:`setup` will need to be called after ``load``, to finish initializing :attr:`optimizer` or :attr:`scheduler`. See Also -------- :func:`~plenoptic.io.examine_saved_synthesis` Examine metadata from saved object: pytorch and plenoptic versions, name of the synthesis object, shapes of tensors, etc. Examples -------- >>> import plenoptic as po >>> img = po.data.einstein() >>> def ds_ssim(x, y): ... return 1 - po.metric.ssim(x, y) >>> mad = po.MADCompetition( ... img, po.metric.mse, ds_ssim, "min", metric_tradeoff_lambda=10 ... ) >>> mad.synthesize(max_iter=5, store_progress=True) >>> mad.save("mad.pt") >>> mad_copy = po.MADCompetition( ... img, po.metric.mse, ds_ssim, "min", metric_tradeoff_lambda=10 ... ) >>> mad_copy.load("mad.pt") """ check_attributes = [ "_image", "_metric_tradeoff_lambda", "_penalty_lambda", "_minmax", ] check_io_attrs = [ ("_optimized_metric", ("_image", "_mad_image")), ("_reference_metric", ("_image", "_mad_image")), ("_penalty_function", ("_image",)), ] super().load( file_path, "_mad_image", map_location=map_location, check_attributes=check_attributes, check_io_attributes=check_io_attrs, state_dict_attributes=["_optimizer", "_scheduler"], raise_on_checks=raise_on_checks, tensor_equality_atol=tensor_equality_atol, tensor_equality_rtol=tensor_equality_rtol, **pickle_load_args, ) # make this require a grad again self.mad_image.requires_grad_() # these are always supposed to be on cpu, but may get copied over to # gpu on load (which can cause problems when resuming synthesis), so # fix that. if len(self._saved_mad_image) and self._saved_mad_image[0].device.type != "cpu": self._saved_mad_image = [mad.to("cpu") for mad in self._saved_mad_image]
def __repr__(self) -> str: # numpydoc ignore=GL08 return super()._repr_format( [ "image", "optimized_metric", "reference_metric", "minmax", "metric_tradeoff_lambda", "penalty_function", "penalty_lambda", ] ) @property def mad_image(self) -> Tensor: """Maximally-differentiating image, the parameter we are optimizing.""" # numpydoc ignore=RT01,ES01 if self._mad_image is None: return torch.empty(0) return self._mad_image @property def optimized_metric(self) -> torch.nn.Module | Callable[[Tensor, Tensor], Tensor]: """The metric whose value we are minimizing or maximizing.""" # numpydoc ignore=RT01,ES01 return self._optimized_metric @property def reference_metric(self) -> torch.nn.Module | Callable[[Tensor, Tensor], Tensor]: """The metric whose value we are keeping constant.""" # numpydoc ignore=RT01,ES01 return self._reference_metric @property def image(self) -> Tensor: """The reference image for this MAD Competition.""" # numpydoc ignore=RT01,ES01 return self._image @property def initial_image(self) -> Tensor: """ Initial image for MAD Competition. This is the image whose distance to ``image``, the reference, we are maximizing/minimizing for ``optimized_metric``, while keeping constant for ``reference_metric``. """ # numpydoc ignore=RT01 return self._initial_image @property def reference_metric_loss(self) -> Tensor: """ :attr:`reference_metric` loss over iterations. That is, the value of ``reference_metric(image, mad_image)``. Ideally, this is equal to ``reference_metric(image, initial_image)``. This tensor always lives on the CPU, regardless of the device of the ``MADCompetition`` object. """ # numpydoc ignore=RT01 return torch.as_tensor(self._reference_metric_loss) @property def optimized_metric_loss(self) -> Tensor: """ :attr:`optimized_metric` loss over iterations. That is, the value of ``optimized_metric(image, mad_image)``. Ideally, this is either very different from ``optimized_metric(image, initial_image)``. This tensor always lives on the CPU, regardless of the device of the ``MADCompetition`` object. """ # numpydoc ignore=RT01 return torch.as_tensor(self._optimized_metric_loss) @property def metric_tradeoff_lambda(self) -> float: """Tradeoff between the two metrics in synthesis loss.""" # numpydoc ignore=RT01,ES01 return self._metric_tradeoff_lambda @property def minmax(self) -> Literal["min", "max"]: """Whether we are minimizing or maximizing :attr:`optimized_metric`.""" # numpydoc ignore=RT01,ES01 return self._minmax @property def saved_mad_image(self) -> Tensor: """ :attr:`mad_image`, cached over time for later examination. How often the MAD image is cached is determined by the ``store_progress`` argument to the :func:`synthesize` function. The last entry will always be the current :attr:`mad_image`. If ``store_progress==1``, then this corresponds directly to :attr:`losses`: ``losses[i]`` is the error for ``saved_mad_image[i]`` This tensor always lives on the CPU, regardless of the device of the ``MADCompetition`` object. """ # numpydoc ignore=RT01 if self._mad_image is None: return torch.empty(0) else: # for memory purposes, always on CPU return torch.stack([*self._saved_mad_image, self.mad_image.to("cpu")]) @property def penalties(self) -> torch.Tensor: """ Penalty function output over iterations. Will have ``length=num_iter+1``, where ``num_iter`` is the number of iterations of synthesis run so far. This tensor always lives on the CPU. """ # numpydoc ignore=RT01 return super().penalties(self.mad_image)