Source code for plenoptic._synthesize.metamer

"""
Model metamers.

Model metamers are images whose pixel values differ but whose model outputs are
identical. They allow researchers to better understand the information which have no
effect on a model's output, also known as their invariances.
"""  # numpydoc ignore=EX01

import warnings
from collections import OrderedDict
from collections.abc import Callable
from typing import Any, Literal

import numpy as np
import torch
from torch import Tensor
from tqdm.auto import tqdm

from .. import loss, regularize
from ..convergence import _coarse_to_fine_enough, _loss_convergence
from ..process import signal
from ..validate import validate_coarse_to_fine, validate_input, validate_model
from .synthesis import _OptimizedSynthesis

__all__ = [
    "Metamer",
    "MetamerCTF",
]


def __dir__() -> list[str]:
    return __all__


[docs] class Metamer(_OptimizedSynthesis): r""" Synthesize metamers for image-computable differentiable models. Following the basic idea in [1]_, this class creates a metamer for a given model on a given image. We iteratively adjust the pixel values so as to match the representation of the :attr:`metamer` and :attr:`image`. Parameters ---------- image A tensor, this is the image whose representation we wish to match. model A visual model. loss_function The loss function used to compare the representations of the models in order to determine their loss. penalty_function A function applied to the metamer during optimization, that returns a scalar penalty to be minimized. By penalizing certain properties of the image, like pixels values outside an allowed range, we can constrain those image properties. See :ref:`metamer-regularization` in the documentation for details and examples. penalty_lambda Weight of the penalty term. Must be non-negative. References ---------- .. [1] J Portilla and E P Simoncelli. A Parametric Texture Model based on Joint Statistics of Complex Wavelet Coefficients. Int'l Journal of Computer Vision. 40(1):49-71, October, 2000. https://www.cns.nyu.edu/~eero/ABSTRACTS/portilla99-abstract.html https://www.cns.nyu.edu/~lcv/texture/ Examples -------- Synthesize and visualize a metamer for a simple model: .. plot:: :context: reset >>> import plenoptic as po >>> import matplotlib.pyplot as plt >>> img = po.data.einstein() >>> model = po.models.Gaussian(30).eval() >>> po.remove_grad(model) >>> met = po.Metamer(img, model) >>> met.synthesize(110) >>> fig, axes = plt.subplots(1, 4, figsize=(16, 4)) >>> po.plot.imshow(img, ax=axes[0], title="Target image") <Figure size ... with 4 Axes> >>> axes[0].xaxis.set_visible(False) >>> axes[0].yaxis.set_visible(False) >>> po.plot.synthesis_status(met, fig=fig, axes_idx={"misc": 0}) <Figure size ...> """ def __init__( self, image: Tensor, model: torch.nn.Module, loss_function: Callable[[Tensor, Tensor], Tensor] = loss.mse, penalty_function: Callable[[Tensor], Tensor] = regularize.penalize_range, penalty_lambda: float = 0.1, ): super().__init__( penalty_function=penalty_function, penalty_lambda=penalty_lambda ) validate_input(image) validate_model( model, image_shape=image.shape, image_dtype=image.dtype, device=image.device, ) self._model = model self._image = image self._image_shape = image.shape self._target_representation = self.model(self.image) self._scheduler = None self._scheduler_step_arg = False self._loss_function = loss_function self._saved_metamer = [] self._store_progress = None self._metamer = None
[docs] def setup( self, initial_image: Tensor | None = None, optimizer: torch.optim.Optimizer | None = None, optimizer_kwargs: dict | None = None, scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, scheduler_kwargs: dict | None = None, ): """ Initialize the metamer, optimizer, and scheduler. Can only be called once. If ``load()`` has been called, ``initial_image`` must be ``None``. Parameters ---------- initial_image The tensor we use to initialize the metamer. If ``None``, we initialize with random noise uniformly-distributed in [0,1]. optimizer The un-initialized optimizer object to use. If ``None``, we use :class:`torch.optim.Adam`. optimizer_kwargs The keyword arguments to pass to the optimizer on initialization. If ``None``, we use ``{"lr": .01}`` and, if optimizer is ``None``, ``{"amsgrad": True}``. scheduler The un-initialized learning rate scheduler object to use. If ``None``, we don't use one. scheduler_kwargs The keyword arguments to pass to the scheduler on initialization. Raises ------ ValueError If you try to set ``initial_image`` after calling :func:`load`. ValueError If ``setup`` is called more than once or after :func:`synthesize`. ValueError If you try to set ``optimizer_kwargs`` after calling :func:`load`. TypeError If the loaded object had a non-Adam optimizer, but the ``optimizer`` arg is not specified. ValueError If the loaded object had an optimizer, and the ``optimizer`` arg is a different type. ValueError If you try to set ``scheduler_kwargs`` after calling :func:`load`. TypeError If the loaded object had a scheduler, but the ``scheduler`` arg is not specified. ValueError If the loaded object had a scheduler, but the ``scheduler`` arg is a different type. Warns ----- UserWarning If ``initial_image`` is a different shape than ``self.image``. Examples -------- Set initial image: >>> import plenoptic as po >>> img = po.data.einstein() >>> model = po.models.Gaussian(30).eval() >>> po.remove_grad(model) >>> met = po.Metamer(img, model) >>> met.setup(po.data.curie()) Set optimizer: >>> met = po.Metamer(img, model) >>> met.setup(optimizer=torch.optim.SGD, optimizer_kwargs={"lr": 0.01}) Set optimizer and scheduler: >>> met = po.Metamer(img, model) >>> met.setup( ... optimizer=torch.optim.SGD, ... optimizer_kwargs={"lr": 0.01}, ... scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau, ... ) Use with save/load. We only pass the optimizer/scheduler objects when calling setup after load, their kwargs and the initial image are handled during the load. >>> met = po.Metamer(img, model) >>> met.setup( ... po.data.curie(), ... optimizer=torch.optim.SGD, ... optimizer_kwargs={"lr": 0.01}, ... scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau, ... ) >>> met.synthesize(5) >>> met.save("metamer_setup.pt") >>> met = po.Metamer(img, model) >>> met.load("metamer_setup.pt") >>> met.setup( ... optimizer=torch.optim.SGD, ... scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau, ... ) """ if self._metamer is None: if initial_image is None: metamer = torch.rand_like(self.image) metamer = signal.rescale(metamer, 0, 1) else: validate_input(initial_image) if initial_image.size() != self.image.size(): warnings.warn( "initial_image and image are different sizes! This " "has not been tested as much, open an issue if you have " "any problems! https://github.com/plenoptic-org/plenoptic/" "issues/new?template=bug_report.md" ) metamer = initial_image.clone().detach() metamer = metamer.to(dtype=self.image.dtype, device=self.image.device) metamer.requires_grad_() self._metamer = metamer else: if self._loaded: if initial_image is not None: raise ValueError("Cannot set initial_image after calling load()!") else: raise ValueError( "setup() can only be called once and must be called" " before synthesize()!" ) # initialize the optimizer self._initialize_optimizer(optimizer, self.metamer, optimizer_kwargs) # and scheduler self._initialize_scheduler(scheduler, self.optimizer, scheduler_kwargs) # reset _loaded, if everything ran successfully self._loaded = False
[docs] def synthesize( self, max_iter: int = 100, store_progress: bool | int = False, stop_criterion: float = 1e-4, stop_iters_to_check: int = 50, ): r""" Synthesize a metamer. Update the pixels of :attr:`metamer` until its representation matches that of :attr:`image`. We run this until either we reach ``max_iter`` or the loss changes less than ``stop_criterion`` over the past ``stop_iters_to_check`` iterations, whichever comes first. Parameters ---------- max_iter The maximum number of iterations to run before we end synthesis (unless we hit the stop criterion). store_progress Whether we should store the metamer image in progress during synthesis. If ``False``, we don't save anything. If True, we save every iteration. If an int, we save every ``store_progress`` iterations (note then that 0 is the same as False and 1 the same as True). This is primarily useful for using :func:`~plenoptic.plot.synthesis_animate` to create a video of the course of synthesis. stop_criterion If the loss over the past ``stop_iters_to_check`` has changed less than ``stop_criterion``, we terminate synthesis. stop_iters_to_check How many iterations back to check in order to see if the loss has stopped decreasing (for ``stop_criterion``). Raises ------ ValueError If we find a NaN during optimization. See Also -------- :func:`~plenoptic.plot.synthesis_status` Create a plot summarizing synthesis status at a given iteration. :func:`~plenoptic.plot.synthesis_animate` Create a video of the metamer changing over the course of synthesis. Examples -------- >>> import plenoptic as po >>> po.set_seed(0) >>> img = po.data.einstein() >>> model = po.models.Gaussian(30).eval() >>> po.remove_grad(model) >>> met = po.Metamer(img, model) >>> # this isn't enough to run synthesis to completion, just an example >>> met.synthesize(5) >>> met.losses tensor([0.0194, 0.0198, 0.0179, 0.0160, 0.0145, 0.0132]) Synthesize a metamer, using ``store_progress`` so we can examine progress later. (This also enables us to create a video of the metamer changing over the course of synthesis, see :func:`~plenoptic.plot.synthesis_animate`.) >>> met = po.Metamer(img, model) >>> # this isn't enough to run synthesis to completion, just an example >>> met.synthesize(5, store_progress=2) >>> met.saved_metamer.shape torch.Size([4, 1, 1, 256, 256]) >>> # see loss, etc on the 4th iteration >>> progress = met.get_progress(4) >>> progress.keys() dict_keys(['losses', ..., 'saved_metamer', 'store_progress_iteration']) >>> progress["losses"] tensor(0.0139) Adjust ``stop_criterion`` and ``stop_iters_to_check`` to change how convergence is determined. In this case, we stop early by making ``stop_criterion`` fairly large. In practice, you're more likely to make ``stop_criterion`` smaller to let synthesis run for longer. >>> met = po.Metamer(img, model) >>> # this isn't enough to run synthesis to completion, just an example >>> met.synthesize(12, stop_criterion=0.001, stop_iters_to_check=2) >>> len(met.losses) 9 """ # if setup hasn't been called manually, call it now. if self._metamer is None or isinstance(self._scheduler, tuple): self.setup() self._current_loss = None self._current_penalty = None # get ready to store progress self.store_progress = store_progress pbar = tqdm(range(max_iter)) for i in pbar: # update saved_* attrs. len(_losses) gives the total number of # iterations and will be correct across calls to `synthesize` self._store(len(self._losses)) loss = self._optimizer_step(pbar) if not np.isfinite(loss): raise ValueError("Found a NaN in loss during optimization.") if self._check_convergence(stop_criterion, stop_iters_to_check): warnings.warn("Loss has converged, stopping synthesis") break # compute current loss, no need to compute gradient with torch.no_grad(): self._current_loss = self.objective_function().item() self._current_penalty = self.penalty_function(self.metamer).item() pbar.close()
def _objective_function( self, metamer: Tensor | None = None, target_representation: Tensor | None = None, **analyze_kwargs: Any, ) -> tuple[Tensor, Tensor]: """ Compute objective function components. This calls :attr:`loss_function` and :attr:`penalty_function` and returns their output, without combining them. It is not meant to be called directly, but is used by both :func:`_closure` and (public) :func:`objective_function`. Parameters ---------- metamer Current ``metamer``. If ``None``, we use ``self.metamer``. target_representation Model response to ``image``. If ``None``, we use ``self.target_representation``. **analyze_kwargs Additional kwargs to pass to ``self.model(metamer)``. Returns ------- loss 1-element tensor containing the metamer loss on this step (without penalty). penalty 1-element tensor containing the penalty on this step. """ if metamer is None: metamer = self.metamer # if this is empty, then self.metamer hasn't been initialized if metamer.numel() == 0: return torch.empty(0), torch.empty(0) if target_representation is None: target_representation = self.target_representation metamer_representation = self.model(metamer, **analyze_kwargs) loss = self.loss_function(metamer_representation, target_representation) penalty = self.penalty_function(metamer) return loss, penalty
[docs] def objective_function( self, metamer: Tensor | None = None, target_representation: Tensor | None = None, **analyze_kwargs: Any, ) -> Tensor: """ Compute the metamer synthesis loss. This calls :attr:`loss_function` on ``self.model(metamer, **analyze_kwargs)`` and ``target_representation`` and then adds :attr:`penalty_lambda` times :attr:`penalty_function` on ``metamer``. Its output over time is stored in :attr:`losses`. Parameters ---------- metamer Current ``metamer``. If ``None``, we use ``self.metamer``. target_representation Model response to ``image``. If ``None``, we use ``self.target_representation``. **analyze_kwargs Additional kwargs to pass to ``self.model(metamer)``. Returns ------- loss 1-element tensor containing the loss on this step. Examples -------- >>> import plenoptic as po >>> po.set_seed(0) >>> img = po.data.einstein() >>> model = po.models.Gaussian(30).eval() >>> po.remove_grad(model) >>> met = po.Metamer(img, model) Before :meth:`setup` or :meth:`synthesize` is called, this returns an empty tensor because the metamer attribute hasn't been initialized: >>> met.objective_function() tensor([]) >>> met.synthesize(5, store_progress=True) When called without any arguments, this returns the current loss: >>> met.objective_function() tensor(0.0132, grad_fn=<AddBackward0>) >>> met.losses[-1] tensor(0.0132) Can be called with a different image. (Note that, because we called :meth:`synthesize` with ``store_progress=True``, we cached the metamer over the course of synthesis): >>> met.objective_function(met.saved_metamer[0]) tensor(0.0194, grad_fn=<AddBackward0>) >>> met.losses[0] tensor(0.0194) This method differs from the :attr:`loss_function` attribute because of its inclusion of the penalty. In the following block, the pixels of ``rand_img`` all lie within [0, 1], and so the outputs of :attr:`objective_function` and :attr:`loss_function` are the same: >>> rand_img = torch.rand_like(img) >>> rand_img.min(), rand_img.max() (tensor(7.9870e-06), tensor(1.0000)) >>> met.objective_function(rand_img) tensor(0.0190) >>> met.loss_function(model(img), model(rand_img)) tensor(0.0190) In this block, the image's lie outside [0, 1], and so the outputs of :attr:`objective_function` and :attr:`loss_function` are different: >>> rand_img *= 2 >>> rand_img.min(), rand_img.max() (tensor(0.0001), tensor(2.0000)) >>> met.objective_function(rand_img) tensor(1100.9663) >>> loss = met.loss_function(model(img), model(rand_img)) >>> loss tensor(0.3133) To compute the output of the objective function, we take the output of :attr:`loss_function` and add the output of :attr:`penalty_function` times :attr:`penalty_lambda`: >>> penalty = met.penalty_function(rand_img) >>> penalty tensor(11006.5293) >>> loss + met.penalty_lambda * penalty tensor(1100.9663) """ loss, penalty = self._objective_function( metamer, target_representation, **analyze_kwargs ) return loss + self.penalty_lambda * penalty
[docs] def get_progress( self, iteration: int | None, iteration_selection: Literal["floor", "ceiling", "round"] = "round", ) -> dict: r""" Return dictionary summarizing synthesis progress at ``iteration``. This returns a dictionary containing info from :attr:`losses`, :attr:`pixel_change_norm`, :attr:`gradient_norm`, :attr:`penalties`, and :attr:`saved_metamer` corresponding to ``iteration``. If synthesis was run with ``store_progress=False`` (and so we did not cache anything in :attr:`saved_metamer`), then that key will be missing. If synthesis was run with ``store_progress>1``, we will grab the corresponding tensor from :attr:`saved_metamer`, with behavior determined by ``iteration_selection``. The returned dictionary will additionally contain the keys: - ``"iteration"``: the (0-indexed positive) synthesis iteration that the values for :attr:`losses`, :attr:`pixel_change_norm`, :attr:`penalties`, and :attr:`gradient_norm` come from. - If ``self.store_progress``, ``"store_progress_iteration"``: the (0-indexed positive) synthesis iteration that the value for :attr:`saved_metamer` comes from. Note that for the most recent iteration (``iteration=-1`` or ``iteration=None`` or ``iteration==len(self.losses)-1``), we do not have values for :attr:`pixel_change_norm` or :attr:`gradient_norm`, since in this case we are showing the loss and value for the current metamer. Parameters ---------- iteration Synthesis iteration to summarize. If ``None``, grab the most recent. Negative values are allowed. iteration_selection How to select the relevant iteration from :attr:`saved_metamer` when the request iteration wasn't stored. When synthesis was run with ``store_progress=n`` (where ``n>1``), metamers are only saved every ``n`` iterations. If you request an iteration where a metamer wasn't saved, this determines which available iteration is used instead: * ``"floor"``: use the closest saved iteration **before** the requested one. * ``"ceiling"``: use the closest saved iteration **after** the requested one. * ``"round"``: use the closest saved iteration. Returns ------- progress_info Dictionary summarizing synthesis progress. Raises ------ IndexError If ``iteration`` takes an illegal value. Warns ----- UserWarning If the iteration used for ``saved_metamer`` is not the same as the argument ``iteration`` (because e.g., you set ``iteration=3`` but ``self.store_progress=2``). See Also -------- :func:`~plenoptic.plot.synthesis_status` Create a plot summarizing synthesis status at a given iteration. :func:`~plenoptic.plot.synthesis_animate` Create a video of the metamer changing over the course of synthesis. Examples -------- >>> import plenoptic as po >>> po.set_seed(0) >>> img = po.data.einstein() >>> model = po.models.Gaussian(30).eval() >>> po.remove_grad(model) >>> met = po.Metamer(img, model) >>> met.synthesize(5) Get values from the first iteration: >>> met.get_progress(0) {'losses': tensor(0.0194), 'iteration': 0, 'penalties': tensor(0.), 'pixel_change_norm': tensor(2.5326), 'gradient_norm': tensor(0.0010)} Get values from last iteration of synthesis: >>> print(met.get_progress(-2)) {'losses': tensor(0.0145), 'iteration': 4, 'penalties': tensor(0.0180), 'pixel_change_norm': tensor(2.2698), 'gradient_norm': tensor(0.0268)} Get current values: >>> print(met.get_progress(-1)) {'losses': tensor(0.0132), 'iteration': 5, 'penalties': tensor(0.0174), 'pixel_change_norm': None, 'gradient_norm': None} When synthesis is run with ``store_progress=True``, this function also returns the metamer from the corresponding iteration: >>> met = po.Metamer(img, model) >>> met.synthesize(5, store_progress=True) >>> print(met.get_progress(-1)) {'losses': tensor(0.0124), 'iteration': 5, 'penalties': tensor(0.0168), 'pixel_change_norm': None, 'gradient_norm': None, 'saved_metamer': tensor([[[[0.4554, ...]]]], grad_fn=<SelectBackward0>), 'store_progress_iteration': 5} >>> torch.equal(met.saved_metamer[-1], met.get_progress(-1)["saved_metamer"]) True When synthesis is run with ``store_progress>1``, this function returns the metamer from the closest iteration: >>> met = po.Metamer(img, model) >>> met.synthesize(5, store_progress=2) >>> print(met.get_progress(-3)) {'losses': tensor(0.0152), 'iteration': 3, 'penalties': tensor(0.0182), 'pixel_change_norm': tensor(2.3592), 'gradient_norm': tensor(0.0269), 'saved_metamer': tensor([[[[0.8532, ...]]]], grad_fn=<SelectBackward0>), 'store_progress_iteration': 4} When we cannot grab the saved metamer corresponding to the requested iteration, ``iteration_selection`` controls how we determine "closest": >>> print(met.get_progress(-3, iteration_selection="floor")) {'losses': tensor(0.0152), 'iteration': 3, 'penalties': tensor(0.0182), 'pixel_change_norm': tensor(2.3592), 'gradient_norm': tensor(0.0269), 'saved_metamer': tensor([[[[ 0.8730, ...]]]], grad_fn=<SelectBackward0>), 'store_progress_iteration': 2} """ return super().get_progress( iteration, iteration_selection, store_progress_attributes=["saved_metamer"], )
def _closure(self) -> float: r""" Calculate the gradient, before the optimization step. This enables optimization algorithms that perform several evaluations of the gradient before taking a step (ie. second order methods like LBFGS). Additionally, this is where ``loss`` is calculated, ``loss.backward()`` is called, and ``self._penalties`` is updated (but not ``self._losses``! that happens in ``_optimizer_step``). Returns ------- loss Loss of the current objective function. """ self.optimizer.zero_grad() loss, penalty = self._objective_function() loss = loss + self.penalty_lambda * penalty loss.backward(retain_graph=False) self._penalty_tmp = penalty.item() return loss.item() def _optimizer_step(self, pbar: tqdm) -> Tensor: r""" Compute and propagate gradients, then step the optimizer to update metamer. Parameters ---------- pbar A tqdm progress-bar, which we update with a postfix describing the current loss, gradient norm, and learning rate (it already tells us which iteration and the time elapsed). Returns ------- loss 1-element tensor containing the loss on this step. """ # numpydoc ignore=ES01,EX01 last_iter_metamer = self.metamer.clone() loss = self.optimizer.step(self._closure) self._losses.append(loss) self._penalties.append(self._penalty_tmp) grad_norm = torch.linalg.vector_norm( self.metamer.grad.data, ord=2, dim=None ).item() self._gradient_norm.append(grad_norm) # optionally step the scheduler, passing loss if needed if self.scheduler is not None: if self._scheduler_step_arg: self.scheduler.step(loss) else: self.scheduler.step() pixel_change_norm = torch.linalg.vector_norm( self.metamer - last_iter_metamer, ord=2, dim=None ).item() self._pixel_change_norm.append(pixel_change_norm) # add extra info here if you want it to show up in progress bar pbar.set_postfix( OrderedDict( loss=f"{self._losses[-1]:.04e}", learning_rate=self.optimizer.param_groups[0]["lr"], penalty=f"{self._penalties[-1]:.04e}", gradient_norm=f"{grad_norm:.04e}", pixel_change_norm=f"{pixel_change_norm:.04e}", ) ) return loss def _check_convergence( self, stop_criterion: float, stop_iters_to_check: int ) -> bool: r""" Check whether the loss has stabilized and, if so, return True. Uses :func:`~plenoptic.convergence._loss_convergence`. Parameters ---------- stop_criterion If the loss over the past ``stop_iters_to_check`` has changed less than ``stop_criterion``, we terminate synthesis. stop_iters_to_check How many iterations back to check in order to see if the loss has stopped decreasing (for ``stop_criterion``). Returns ------- loss_stabilized Whether the loss has stabilized or not. """ # numpydoc ignore=EX01 return _loss_convergence(self, stop_criterion, stop_iters_to_check) def _store(self, i: int) -> bool: """ Store metamer, if appropriate. If it's the right iteration, we update :attr:`saved_metamer`. Parameters ---------- i The current iteration. Returns ------- stored True if we stored this iteration, False if not. """ # numpydoc ignore=EX01 if self.store_progress and (i % self.store_progress == 0): # want these to always be on cpu, to reduce memory use for GPUs self._saved_metamer.append(self.metamer.clone().to("cpu")) stored = True else: stored = False return stored
[docs] def save(self, file_path: str): r""" Save all relevant variables in .pt file. Note that if ``store_progress`` is True, this will probably be very large. Parameters ---------- file_path : The path to save the metamer object to. See Also -------- load Method to load in saved ``Metamer`` objects. Examples -------- >>> import plenoptic as po >>> img = po.data.einstein() >>> model = po.models.Gaussian(30).eval() >>> po.remove_grad(model) >>> met = po.Metamer(img, model) >>> met.synthesize(max_iter=5, store_progress=True) >>> met.save("metamers.pt") """ save_io_attrs = [ ( "_loss_function", ("_target_representation", "2 * _target_representation"), ), ("_penalty_function", ("_image",)), ("_model", ("_image",)), ] save_state_dict_attrs = ["_optimizer", "_scheduler"] super().save(file_path, save_io_attrs, save_state_dict_attrs)
[docs] def to(self, *args: Any, **kwargs: Any): r""" Move and/or cast the parameters and buffers. This can be called as .. code:: python to(device=None, dtype=None, non_blocking=False) .. code:: python to(dtype, non_blocking=False) .. code:: python to(tensor, non_blocking=False) Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point desired ``dtype``. In addition, this method will only cast the floating point parameters and buffers to ``dtype`` (if given). The integral parameters and buffers will be moved ``device``, if that is given, but with dtypes unchanged. When `on_blocking`` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See :meth:`torch.nn.Module.to` for examples. .. note:: This method modifies the module in-place. Parameters ---------- device : torch.device The desired device of the parameters and buffers in this module. dtype : torch.dtype The desired floating point type of the floating point parameters and buffers in this module. tensor : torch.Tensor Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module. Examples -------- >>> import plenoptic as po >>> img = po.data.einstein() >>> model = po.models.Gaussian(30).eval() >>> po.remove_grad(model) >>> met = po.Metamer(img, model) >>> met.image.dtype torch.float32 >>> met.model(met.image).dtype torch.float32 >>> met.to(torch.float64) >>> met.image.dtype torch.float64 >>> met.model(met.image).dtype torch.float64 """ # numpydoc ignore=PR01,PR02 attrs = ["_image", "_target_representation", "_metamer", "_saved_metamer"] super().to(*args, attrs=attrs, **kwargs) # try to call .to() on model. this should work, but it might fail if e.g., this # a custom model that doesn't inherit torch.nn.Module try: self._model = self._model.to(*args, **kwargs) except AttributeError: warnings.warn("Unable to call model.to(), so we leave it as is.")
[docs] def load( self, file_path: str, map_location: str | None = None, raise_on_checks: bool = True, tensor_equality_atol: float = 1e-8, tensor_equality_rtol: float = 1e-5, **pickle_load_args: Any, ): r""" Load all relevant stuff from a .pt file. This should be called by an initialized ``Metamer`` object -- we will ensure that ``image``, ``target_representation`` (and thus ``model``), and ``loss_function`` are all identical. Note this operates in place and so doesn't return anything. .. versionchanged:: 1.2 load behavior changed in a backwards-incompatible manner in order to compatible with breaking changes in torch 2.6. .. versionchanged:: 2.0.0 Adds ``raise_on_checks`` argument. Parameters ---------- file_path The path to load the synthesis object from. map_location Argument to pass to :func:`torch.load` as ``map_location``. If you save stuff that was being run on a GPU and are loading onto a CPU, you'll need this to make sure everything lines up properly. This should be structured like the str you would pass to :class:`torch.device`. raise_on_checks During load, we perform several checks to ensure that the saved object was initialized in the same way as the loading object. This is to ensure that the model, image, etc. are all the same and avoid unpleasant surprises. If ``True``, we raise a ``ValueError`` if any of these checks fail. If ``False``, we instead raise a ``LoadWarning``. The intended use here is if you're loading something that was saved with an older version of plenoptic and you're sure that you're doing everything correctly. Note that different devices or dtypes will always result in a ``ValueError``. See :ref:`raise-on-checks` on the "Reproducibility and Compatibility" page of the documentation for more info. Additionally, note that, if the ``Metamer`` object itself has changed, we cannot ensure that methods are the same -- proceed at your own risk. tensor_equality_atol Absolute tolerance to use when checking for tensor equality during load, passed to :func:`torch.allclose`. It may be necessary to increase if you are saving and loading on two machines with torch built by different cuda versions. Be careful when changing this! See :class:`torch.finfo<torch.torch.finfo>` for more details about floating point precision of different data types (especially, ``eps``); if you have to increase this by more than 1 or 2 decades, then you are probably not dealing with a numerical issue. tensor_equality_rtol Relative tolerance to use when checking for tensor equality during load, passed to :func:`torch.allclose`. It may be necessary to increase if you are saving and loading on two machines with torch built by different cuda versions. Be careful when changing this! See :class:`torch.finfo<torch.torch.finfo>` for more details about floating point precision of different data types (especially, ``eps``); if you have to increase this by more than 1 or 2 decades, then you are probably not dealing with a numerical issue. **pickle_load_args Any additional kwargs will be added to ``pickle_module.load`` via :func:`torch.load`, see that function's docstring for details. Raises ------ ValueError If :func:`setup` or :func:`synthesize` has been called before this call to ``load``. ValueError If the object saved at ``file_path`` is not a ``Metamer`` object. ValueError If the saved and loading ``Metamer`` objects have a different value for any of :attr:`image` or :attr:`penalty_lambda`, ValueError If the behavior of :attr:`loss_function` or :attr:`model` is different between the saved and loading objects. Warns ----- UserWarning If :func:`setup` will need to be called after load, to finish initializing :attr:`optimizer` or :attr:`scheduler`. See Also -------- :func:`~plenoptic.io.examine_saved_synthesis` Examine metadata from saved object: pytorch and plenoptic versions, name of the synthesis object, shapes of tensors, etc. Examples -------- In order to load a saved ``Metamer`` object, we must first initialize one using the same arguments. (We use float64 / "double" precision rather than torch's default float32 because it increases reproducibility, see the :ref:`Reproducibility <reproduce>` page of our documentations for more details.) Here, we load in a cached example: >>> import plenoptic as po >>> img = po.data.einstein().to(torch.float64) >>> model = po.models.Gaussian(30).eval().to(torch.float64) >>> po.remove_grad(model) >>> met = po.Metamer(img, model) >>> print(met.metamer) tensor([]) >>> met.load(po.data.fetch_data("example_metamer_gaussian.pt")) >>> print(met.metamer) tensor([[[[0.0692, ...]]]], dtype=torch.float64, requires_grad=True) If the saved ``Metamer`` object lived on a CUDA device and you do not have CUDA on the loading machine, use ``map_location`` to change device: >>> met = po.Metamer(img, model) >>> met.image.device device(type='cpu') >>> met.load(po.data.fetch_data("example_metamer_gaussian-cuda.pt")) Traceback (most recent call last): RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False... >>> met.load( ... po.data.fetch_data("example_metamer_gaussian-cuda.pt"), ... map_location="cpu", ... ) >>> print(met.metamer) tensor([[[[0.0692, ...]]]], dtype=torch.float64, requires_grad=True) If the loading ``Metamer`` object was not initialized with same values as the saved object, an error will be raised: >>> met = po.Metamer(torch.rand_like(img), model) >>> met.load(po.data.fetch_data("example_metamer_gaussian.pt")) Traceback (most recent call last): ValueError: Saved and initialized attribute image have different values... If the loading ``Metamer`` object has a different data type than the saved object, an error will be raised: >>> met = po.Metamer(img, model) >>> met.to(torch.float32) >>> met.load(po.data.fetch_data("example_metamer_gaussian.pt")) Traceback (most recent call last): ValueError: Saved and initialized attribute image have different dtype... """ self._load( file_path, map_location, raise_on_checks=raise_on_checks, tensor_equality_atol=tensor_equality_atol, tensor_equality_rtol=tensor_equality_rtol, **pickle_load_args, )
def _load( self, file_path: str, map_location: str | None = None, additional_check_attributes: list[str] = [], additional_check_io_attributes: list[str] = [], raise_on_checks: bool = True, tensor_equality_atol: float = 1e-8, tensor_equality_rtol: float = 1e-5, **pickle_load_args: Any, ): r""" Load from a file. This is a helper function for loading. Users interact with ``load`` (without the underscore), this is to allow subclasses to specify additional attributes or loss functions to check. Parameters ---------- file_path The path to load the synthesis object from. map_location Argument to pass to :func:`torch.load` as ``map_location``. If you save stuff that was being run on a GPU and are loading onto a CPU, you'll need this to make sure everything lines up properly. This should be structured like the str you would pass to :class:`torch.device`. additional_check_attributes Any additional attributes to check for equality. Intended for use by any subclasses, to add other attributes set at initialization. additional_check_io_attributes Any additional attributes whose input/output behavior we should check. Intended for use by any subclasses. raise_on_checks During load, we perform several checks to ensure that the saved object was initialized in the same way as the loading object. This is to ensure that the model, image, etc. are all the same and avoid unpleasant surprises. If ``True``, we raise a ``ValueError`` if any of these checks fail. If ``False``, we instead raise a ``LoadWarning``. The intended use here is if you're loading something that was saved with an older version of plenoptic and you're sure that you're doing everything correctly. Note that different devices or dtypes will always result in a ``ValueError``. See :ref:`raise-on-checks` on the "Reproducibility and Compatibility" page of the documentation for more info. Additionally, note that, if the synthesis object itself has changed, we cannot ensure that methods are the same -- proceed at your own risk. tensor_equality_atol Absolute tolerance to use when checking for tensor equality during load, passed to :func:`torch.allclose`. It may be necessary to increase if you are saving and loading on two machines with torch built by different cuda versions. Be careful when changing this! See :class:`torch.finfo<torch.torch.finfo>` for more details about floating point precision of different data types (especially, ``eps``); if you have to increase this by more than 1 or 2 decades, then you are probably not dealing with a numerical issue. tensor_equality_rtol Relative tolerance to use when checking for tensor equality during load, passed to :func:`torch.allclose`. It may be necessary to increase if you are saving and loading on two machines with torch built by different cuda versions. Be careful when changing this! See :class:`torch.finfo<torch.torch.finfo>` for more details about floating point precision of different data types (especially, ``eps``); if you have to increase this by more than 1 or 2 decades, then you are probably not dealing with a numerical issue. **pickle_load_args Any additional kwargs will be added to ``pickle_module.load`` via :func:`torch.load`, see that function's docstring for details. """ # numpydoc ignore=EX01 check_attributes = [ "_image", "_penalty_lambda", ] check_attributes += additional_check_attributes check_io_attrs = [ ( "_loss_function", ("_target_representation", "2 * _target_representation"), ), ("_penalty_function", ("_image",)), ("_model", ("_image",)), ] check_io_attrs += additional_check_io_attributes super().load( file_path, "_metamer", map_location=map_location, check_attributes=check_attributes, check_io_attributes=check_io_attrs, state_dict_attributes=["_optimizer", "_scheduler"], raise_on_checks=raise_on_checks, tensor_equality_atol=tensor_equality_atol, tensor_equality_rtol=tensor_equality_rtol, **pickle_load_args, ) # make this require a grad again self.metamer.requires_grad_() # these are always supposed to be on cpu, but may get copied over to # gpu on load (which can cause problems when resuming synthesis), so # fix that. if len(self._saved_metamer) and self._saved_metamer[0].device.type != "cpu": self._saved_metamer = [met.to("cpu") for met in self._saved_metamer] def __repr__(self) -> str: # numpydoc ignore=GL08 return super()._repr_format( ["image", "model", "loss_function", "penalty_function", "penalty_lambda"] ) @property def loss_function(self) -> Callable[[Tensor, Tensor], Tensor]: """Callable which specifies how close metamer representation is to target.""" # numpydoc ignore=RT01,ES01 return self._loss_function @property def model(self) -> torch.nn.Module: """The model for which the metamer is synthesized.""" # numpydoc ignore=RT01,ES01,EX01 return self._model @property def image(self) -> torch.Tensor: """Target image of metamer optimization.""" # numpydoc ignore=RT01,ES01,EX01 return self._image @property def target_representation(self) -> torch.Tensor: """ :attr:`model` representation of :attr:`image`. The goal of synthesis is for ``model(metamer)`` to match this value. Examples -------- >>> import plenoptic as po >>> img = po.data.einstein() >>> model = po.models.Gaussian(30).eval() >>> po.remove_grad(model) >>> met = po.Metamer(img, model) >>> torch.equal(model(img), met.target_representation) True """ # numpydoc ignore=RT01 return self._target_representation @property def metamer(self) -> torch.Tensor: """Model metamer, the parameter we are optimizing.""" # numpydoc ignore=RT01,ES01,EX01 if self._metamer is None: return torch.empty(0) return self._metamer @property def saved_metamer(self) -> torch.Tensor: """ :attr:`metamer`, cached over time for later examination. How often the metamer is cached is determined by the ``store_progress`` argument to the :func:`synthesize` function. The last entry will always be the current :attr:`metamer`. If ``store_progress==1``, then this corresponds directly to :attr:`losses`: ``losses[i]`` is the error for ``saved_metamer[i]`` This tensor always lives on the CPU, regardless of the device of the ``Metamer`` object. Examples -------- If synthesize is called without ``store_progress``, then this attribute just contains the metamer, though the number of dimensions is different: >>> import plenoptic as po >>> po.set_seed(0) >>> img = po.data.einstein() >>> model = po.models.Gaussian(30).eval() >>> po.remove_grad(model) >>> met = po.Metamer(img, model) >>> met.saved_metamer tensor([]) >>> met.synthesize(5) >>> met.saved_metamer tensor([[[[[ 0.0098, ...]]]]], grad_fn=<StackBackward0>) >>> met.metamer tensor([[[[ 0.0098, ...]]]], requires_grad=True) >>> met.saved_metamer.shape torch.Size([1, 1, 1, 256, 256]) >>> met.metamer.shape torch.Size([1, 1, 256, 256]) If synthesize is called with ``store_progress=1``, then this attribute contains the metamer at each iteration, and ``losses[i]`` contains the error for ``saved_metamer[i]``. >>> met = po.Metamer(img, model) >>> met.synthesize(5, store_progress=True) >>> met.saved_metamer.shape torch.Size([6, 1, 1, 256, 256]) >>> met.objective_function(met.saved_metamer[2]) tensor(0.0169, grad_fn=<AddBackward0>) >>> met.losses[2] tensor(0.0169) (In the above example, ``saved_metamer`` has 6 elements because it includes the metamer at the start of each of the 5 synthesis iterations, plus the current one.) """ # numpydoc ignore=RT01,EX01 if self._metamer is None: return torch.empty(0) else: # for memory purposes, always on CPU return torch.stack([*self._saved_metamer, self.metamer.to("cpu")]) @property def penalties(self) -> torch.Tensor: """ Penalty function output over iterations. Will have ``length=num_iter+1``, where ``num_iter`` is the number of iterations of synthesis run so far. This tensor always lives on the CPU. """ # numpydoc ignore=RT01 return super().penalties(self.metamer)
[docs] class MetamerCTF(Metamer): """ Synthesize model metamers with coarse-to-fine synthesis. This is a special case of ``Metamer``, which uses the coarse-to-fine synthesis procedure described in [1]_: we start by updating metamer with respect to only a subset of the model's representation (generally, that which corresponds to the lowest spatial frequencies), and changing which subset we consider over the course of synthesis. This is similar to optimizing with a blurred version of the objective function and gradually adding in finer details. It improves synthesis performance for some models. Parameters ---------- image A tensor, this is the image whose representation we wish to match. model A visual model. loss_function The loss function to use to compare the representations of the models in order to determine their loss. penalty_function A function applied to the metamer during optimization, that returns a scalar penalty to be minimized. By penalizing certain properties of the image, like pixels values outside an allowed range, we can constrain those image properties. See :ref:`metamer-regularization` in the documentation for details and examples. penalty_lambda Strength of the penalty term. Must be non-negative. coarse_to_fine - ``"together"``: start with the coarsest scale, then gradually add each finer scale. - ``"separate"``: compute the gradient with respect to each scale separately (ignoring the others), then with respect to all of them at the end. (see :ref:`Metamer tutorial <metamer-nb>` for more details). References ---------- .. [1] J Portilla and E P Simoncelli. A Parametric Texture Model based on Joint Statistics of Complex Wavelet Coefficients. Int'l Journal of Computer Vision. 40(1):49-71, October, 2000. https://www.cns.nyu.edu/~eero/ABSTRACTS/portilla99-abstract.html https://www.cns.nyu.edu/~lcv/texture/ Examples -------- Synthesize and visualize a metamer using coarse-to-fine synthesis: .. plot:: :context: reset >>> import plenoptic as po >>> import matplotlib.pyplot as plt >>> import torch >>> img = po.data.reptile_skin() >>> model = po.models.PortillaSimoncelli(img.shape[-2:]) >>> # to work with MetamerCTF, models must have a scales attribute >>> model.scales ['pixel_statistics', 'residual_lowpass', 3, 2, 1, 0, 'residual_highpass'] >>> met = po.MetamerCTF(img, model, loss_function=po.loss.l2_norm) >>> # initialize with an image that has a comparable mean and standard deviation >>> init_img = (torch.rand_like(img) - 0.5) * 0.1 + img.mean() >>> met.setup(init_img) >>> met.synthesize(150, change_scale_criterion=None, ctf_iters_to_check=7) >>> fig, axes = plt.subplots(1, 4, figsize=(25, 4), width_ratios=[1, 1, 1, 3]) >>> po.plot.imshow(img, ax=axes[0], title="Target image") <Figure size ... with 4 Axes> >>> axes[0].xaxis.set_visible(False) >>> axes[0].yaxis.set_visible(False) >>> po.plot.synthesis_status(met, fig=fig, axes_idx={"misc": 0}) <Figure size ...> Not all models work with ``MetamerCTF``: >>> import plenoptic as po >>> img = po.data.einstein() >>> model = po.models.Gaussian(30).eval() >>> po.remove_grad(model) >>> met = po.MetamerCTF(img, model) Traceback (most recent call last): AttributeError: model has no scales attribute ... """ def __init__( self, image: Tensor, model: torch.nn.Module, loss_function: Callable[[Tensor, Tensor], Tensor] = loss.mse, penalty_function: Callable[[Tensor], Tensor] = regularize.penalize_range, penalty_lambda: float = 0.1, coarse_to_fine: Literal["together", "separate"] = "together", ): super().__init__( image, model, loss_function, penalty_function, penalty_lambda, ) self._init_ctf(coarse_to_fine) def _init_ctf(self, coarse_to_fine: Literal["together", "separate"]): """ Initialize stuff related to coarse-to-fine. - Validates value of ``coarse_to_fine`` - Validates ``self.model`` for coarse-to-fine synthesis (calls :func:`validate_coarse_to_fine`). - Initializes attributes for coarse-to-fine synthesis. Parameters ---------- coarse_to_fine Which mode of coarse-to-fine to use, see initial docstring for details. Raises ------ ValueError If ``coarse_to_fine`` takes an illegal value. """ # numpydoc ignore=EX01 # this will hold the reduced representation of the target image. if coarse_to_fine not in ["separate", "together"]: raise ValueError( f"Don't know how to handle value {coarse_to_fine}!" " Must be one of: 'separate', 'together'" ) self._ctf_target_representation = None validate_coarse_to_fine( self.model, image_shape=self.image.shape, device=self.image.device ) # if self.scales is not None, we're continuing a previous version # and want to continue. this list comprehension creates a new # object, so we don't modify model.scales self._scales = [i for i in self.model.scales[:-1]] if coarse_to_fine == "separate": self._scales += [self.model.scales[-1]] self._scales += ["all"] self._scales_timing = dict((k, []) for k in self.scales) self._scales_timing[self.scales[0]].append(0) self._scales_loss = [] self._scales_finished = [] self._coarse_to_fine = coarse_to_fine self._initial_lr = None def _initialize_optimizer( self, optimizer: torch.optim.Optimizer | None, synth_attr: torch.Tensor, optimizer_kwargs: dict | None = None, ): """ Initialize optimizer. Calls ``super._initialize_optimizer()``, passing all arguments through, and also caches the initial learning rate (``self._initial_lr``), which we use when switching scales. Parameters ---------- optimizer The (un-initialized) optimizer object to use. If ``None``, we use :class:`torch.optim.Adam`. synth_attr The tensor we will optimize. optimizer_kwargs The keyword arguments to pass to the optimizer on initialization. If ``None``, we use ``{"lr": .01}`` and, if optimizer is ``None``, ``{"amsgrad": True}``. """ # numpydoc ignore=EX01 super()._initialize_optimizer(optimizer, synth_attr, optimizer_kwargs) # save the initial learning rate so we can reset it when we change scales self._initial_lr = [pg["lr"] for pg in self.optimizer.param_groups]
[docs] def synthesize( self, max_iter: int = 100, store_progress: bool | int = False, stop_criterion: float = 1e-4, stop_iters_to_check: int = 50, change_scale_criterion: float | None = 1e-2, ctf_iters_to_check: int = 50, ): r""" Synthesize a metamer. Update the pixels of ``metamer`` until its representation matches that of ``image``. We run this until either we reach ``max_iter`` or the change over the past ``stop_iters_to_check`` iterations is less than ``stop_criterion``, whichever comes first. Parameters ---------- max_iter The maximum number of iterations to run before we end synthesis (unless we hit the stop criterion). store_progress Whether we should store the metamer image in progress on every iteration. If ``False``, we don't save anything. If True, we save every iteration. If an int, we save every ``store_progress`` iterations (note then that 0 is the same as False and 1 the same as True). This is primarily useful for using :func:`~plenoptic.plot.synthesis_animate` to create a video of the course of synthesis. stop_criterion If the loss over the past ``stop_iters_to_check`` has changed less than ``stop_criterion``, we terminate synthesis. stop_iters_to_check How many iterations back to check in order to see if the loss has stopped decreasing (for ``stop_criterion``). change_scale_criterion Scale-specific analogue of ``change_scale_criterion``: we consider a given scale finished (and move onto the next) if the loss has changed less than this in the past ``ctf_iters_to_check`` iterations. If ``None``, we'll change scales as soon as we've spent ``ctf_iters_to_check`` on a given scale. ctf_iters_to_check Scale-specific analogue of ``stop_iters_to_check``: how many iterations back in order to check in order to see if we should switch scales. Raises ------ ValueError If ``stop_criterion >= change_scale_criterion`` -- behavior is strange otherwise. ValueError If we find a NaN during optimization. See Also -------- :func:`~plenoptic.plot.synthesis_status` Create a plot summarizing synthesis status at a given iteration. :func:`~plenoptic.plot.synthesis_animate` Create a video of the metamer changing over the course of synthesis. Examples -------- >>> import plenoptic as po >>> po.set_seed(0) >>> img = po.data.reptile_skin() >>> model = po.models.PortillaSimoncelli(img.shape[-2:]) >>> met = po.MetamerCTF(img, model) >>> # this isn't enough to run synthesis to completion, just an example >>> met.synthesize(5) >>> met.losses tensor([0.1062, ..., 0.1038]) You can examine scales_timing attribute to see when MetamerCTF started and stopped optimizing each scale: >>> met.scales_timing {'pixel_statistics': [0], 'residual_lowpass': [], 3: [], 2: [], 1: [], 0: [], 'all': []} Synthesize a metamer, using ``store_progress`` so we can examine progress later. (This also enables us to create a video of the metamer changing over the course of synthesis, see :func:`~plenoptic.plot.synthesis_animate`.) >>> met = po.MetamerCTF(img, model) >>> # this isn't enough to run synthesis to completion, just an example >>> met.synthesize(5, store_progress=2) >>> met.saved_metamer.shape torch.Size([4, 1, 1, 256, 256]) >>> # see loss, etc on the 4th iteration >>> progress = met.get_progress(4) >>> progress.keys() dict_keys(['losses', ..., 'saved_metamer', 'store_progress_iteration']) >>> progress["losses"] tensor(0.1109) Set ``change_scale_criterion`` and ``ctf_iters_to_check`` to change scale-switching behavior. >>> met = po.MetamerCTF(img, model) >>> # this isn't enough to run synthesis to completion, just an example >>> met.synthesize(5, change_scale_criterion=None, ctf_iters_to_check=2) >>> met.losses tensor([0.1119, ..., 0.0687]) >>> met.scales_timing {'pixel_statistics': [0, 1], 'residual_lowpass': [2, 3], 3: [4], 2: [], 1: [], 0: [], 'all': []} Adjust ``stop_criterion`` and ``stop_iters_to_check`` to change how convergence is determined. In this case, we stop early by making ``stop_criterion`` fairly large. In practice, you're more likely to make ``stop_criterion`` smaller to let synthesis run for longer. >>> met = po.MetamerCTF(img, model) >>> # this isn't enough to run synthesis to completion, just an example >>> met.synthesize(10, stop_criterion=0.001, stop_iters_to_check=2) """ if (change_scale_criterion is not None) and ( stop_criterion >= change_scale_criterion ): raise ValueError( "stop_criterion must be strictly less than " "change_scale_criterion, or things get weird!" ) # if setup hasn't been called manually, call it now. if self._metamer is None or isinstance(self._scheduler, tuple): self.setup() self._current_loss = None self._current_penalty = None # get ready to store progress self.store_progress = store_progress pbar = tqdm(range(max_iter)) for i in pbar: # update saved_* attrs. len(_losses) gives the total number of # iterations and will be correct across calls to `synthesize` self._store(len(self._losses)) loss = self._optimizer_step( pbar, change_scale_criterion, ctf_iters_to_check ) if not np.isfinite(loss): raise ValueError("Found a NaN in loss during optimization.") if self._check_convergence( i, stop_criterion, stop_iters_to_check, ctf_iters_to_check ): warnings.warn("Loss has converged, stopping synthesis") break # compute current loss, no need to compute gradient. with torch.no_grad(): self._current_loss = self.objective_function().item() self._current_penalty = self.penalty_function(self.metamer).item() pbar.close()
def _optimizer_step( self, pbar: tqdm, change_scale_criterion: float, ctf_iters_to_check: int, ) -> Tensor: r""" Compute and propagate gradients, then step the optimizer to update metamer. Parameters ---------- pbar A tqdm progress-bar, which we update with a postfix describing the current loss, gradient norm, and learning rate (it already tells us which iteration and the time elapsed). change_scale_criterion How many iterations back to check to see if the loss has stopped decreasing and we should thus move to the next scale in coarse-to-fine optimization. ctf_iters_to_check Minimum number of iterations coarse-to-fine must run at each scale. Returns ------- loss 1-element tensor containing the loss on this step. """ # numpydoc ignore=ES01,EX01 last_iter_metamer = self.metamer.clone() # Check if conditions hold for switching scales: # - Check if loss has decreased below the change_scale_criterion and # - if we've been optimizing this scale for the required number of iterations # - The first check here is because the last scale will be 'all', and # we never remove it if ( len(self.scales) > 1 and len(self.scales_loss) >= ctf_iters_to_check and ( change_scale_criterion is None or abs(self.scales_loss[-1] - self.scales_loss[-ctf_iters_to_check]) < change_scale_criterion ) and ( len(self._losses) - self.scales_timing[self.scales[0]][0] >= ctf_iters_to_check ) ): self._scales_timing[self.scales[0]].append(len(self._losses) - 1) self._scales_finished.append(self._scales.pop(0)) # Only append if scales list is still non-empty after the pop if self.scales: self._scales_timing[self.scales[0]].append(len(self._losses)) # Reset optimizer's learning rate for pg, lr in zip(self.optimizer.param_groups, self._initial_lr): pg["lr"] = lr # Reset ctf target representation for the next update self._ctf_target_representation = None # the loss returned by objective_function is from *before* updating the metamer, # so to compute the equivalent for display purposes, we need to call this before # calling step() if self.scales[0] != "all": with torch.no_grad(): overall_loss = self.objective_function(None, None).item() loss = self.optimizer.step(self._closure) if self.scales[0] == "all": # then the loss computed above includes all scales overall_loss = loss self._scales_loss.append(loss) self._losses.append(overall_loss) self._penalties.append(self._penalty_tmp) grad_norm = torch.linalg.vector_norm(self.metamer.grad.data, ord=2, dim=None) self._gradient_norm.append(grad_norm.item()) # optionally step the scheduler, passing loss if needed if self.scheduler is not None: if self._scheduler_step_arg: self.scheduler.step(loss) else: self.scheduler.step() pixel_change_norm = torch.linalg.vector_norm( self.metamer - last_iter_metamer, ord=2, dim=None ) self._pixel_change_norm.append(pixel_change_norm.item()) # add extra info here if you want it to show up in progress bar pbar.set_postfix( OrderedDict( loss=f"{overall_loss:.04e}", learning_rate=self.optimizer.param_groups[0]["lr"], penalty=f"{self._penalties[-1]:.04e}", gradient_norm=f"{grad_norm.item():.04e}", pixel_change_norm=f"{pixel_change_norm.item():.04e}", current_scale=self.scales[0], current_scale_loss=f"{loss:.04e}", ) ) return overall_loss def _closure(self) -> float: r""" Calculate the gradient, before the optimization step. This enables optimization algorithms that perform several evaluations of the gradient before taking a step (e.g., second order methods like LBFGS or methods with line searches). Additionally, this is where: - ``metamer_representation`` is calculated, and thus any modifications to the model's forward call (e.g., specifying ``scale`` kwarg for coarse-to-fine) should happen. - ``loss`` is calculated and ``loss.backward()`` is called. - ``self._penalties`` is updated (but not ``self._losses``! that happens in ``_optimizer_step``) Returns ------- loss Loss of the current objective function. """ # numpydoc ignore=EX01 self.optimizer.zero_grad() analyze_kwargs = {} # if we've reached 'all', we use the full model if self.scales[0] != "all": analyze_kwargs["scales"] = [self.scales[0]] # if 'together', then we also want all the coarser # scales if self.coarse_to_fine == "together": analyze_kwargs["scales"] += self.scales_finished # if analyze_kwargs is empty, we can just compare # metamer_representation against our cached target_representation if analyze_kwargs: if self._ctf_target_representation is None: target_rep = self.model(self.image, **analyze_kwargs) self._ctf_target_representation = target_rep else: target_rep = self._ctf_target_representation else: target_rep = None loss, penalty = self._objective_function( self.metamer, target_rep, **analyze_kwargs ) loss = loss + self.penalty_lambda * penalty loss.backward(retain_graph=False) self._penalty_tmp = penalty.item() return loss.item() def _check_convergence( self, i: int, stop_criterion: float, stop_iters_to_check: int, ctf_iters_to_check: int, ) -> bool: r""" Check whether the loss has stabilized and whether we've synthesized all scales. We check whether: - We have been synthesizing for ``stop_iters_to_check`` iterations, i.e. ``len(synth.losses) > stop_iters_to_check``. - Loss has decreased by less than ``stop_criterion`` over the past ``stop_iters_to_check`` iterations. - We have finished synthesizing each individual scale, i.e. ``synth.scales[0] == "all"``. - We have been synthesizing all scales for more than ``ctf_iters_to_check`` iterations, i.e. ``i - synth.scales_timing["all"][0]) > ctf_iters_to_check``. If all conditions are met, we return ``True``. Else, we return ``False``. Parameters ---------- i The current iteration (0-indexed). stop_criterion If the loss over the past ``stop_iters_to_check`` has changed less than ``stop_criterion``, we terminate synthesis. stop_iters_to_check How many iterations back to check in order to see if the loss has stopped decreasing (for ``stop_criterion``). ctf_iters_to_check Minimum number of iterations coarse-to-fine must run at each scale. Returns ------- loss_stabilized Whether the loss has stabilized and we've synthesized all scales. """ # noqa: E501 # numpydoc ignore=EX01 loss_conv = _loss_convergence(self, stop_criterion, stop_iters_to_check) return loss_conv and _coarse_to_fine_enough(self, i, ctf_iters_to_check)
[docs] def to(self, *args: Any, **kwargs: Any): r""" Move and/or cast the parameters and buffers. This can be called as .. code:: python to(device=None, dtype=None, non_blocking=False) .. code:: python to(dtype, non_blocking=False) .. code:: python to(tensor, non_blocking=False) Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point desired ``dtype``. In addition, this method will only cast the floating point parameters and buffers to ``dtype`` (if given). The integral parameters and buffers will be moved ``device``, if that is given, but with dtypes unchanged. When `on_blocking`` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See :meth:`torch.nn.Module.to` for examples. .. note:: This method modifies the module in-place. Parameters ---------- device : torch.device The desired device of the parameters and buffers in this module. dtype : torch.dtype The desired floating point type of the floating point parameters and buffers in this module. tensor : torch.Tensor Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module. Examples -------- >>> import plenoptic as po >>> img = po.data.reptile_skin() >>> model = po.models.PortillaSimoncelli(img.shape[-2:]) >>> met = po.MetamerCTF(img, model) >>> met.image.dtype torch.float32 >>> met.model(met.image).dtype torch.float32 >>> met.to(torch.float64) >>> met.image.dtype torch.float64 >>> met.model(met.image).dtype torch.float64 """ # numpydoc ignore=PR01,PR02 super().to(*args, **kwargs) # if synthesize has been called at least once and we have not finished moving # through all scales, _ctf_target_representation will be a Tensor which get # passed to objective_function at some point. thus, need to make sure it's also # updated. if self._ctf_target_representation is not None: self._ctf_target_representation = self._ctf_target_representation.to( *args, **kwargs )
[docs] def load( self, file_path: str, map_location: str | None = None, raise_on_checks: bool = True, tensor_equality_atol: float = 1e-8, tensor_equality_rtol: float = 1e-5, **pickle_load_args: Any, ): r""" Load all relevant stuff from a .pt file. This should be called by an initialized ``MetamerCTF`` object -- we will ensure that ``image``, ``target_representation`` (and thus ``model``), and ``loss_function`` are all identical. Note this operates in place and so doesn't return anything. .. versionchanged:: 1.2 load behavior changed in a backwards-incompatible manner in order to compatible with breaking changes in torch 2.6. .. versionchanged:: 2.0.0 Adds ``raise_on_checks`` argument. Parameters ---------- file_path The path to load the synthesis object from. map_location Argument to pass to :func:`torch.load` as ``map_location``. If you save stuff that was being run on a GPU and are loading onto a CPU, you'll need this to make sure everything lines up properly. This should be structured like the str you would pass to :class:`torch.device`. raise_on_checks During load, we perform several checks to ensure that the saved object was initialized in the same way as the loading object. This is to ensure that the model, image, etc. are all the same and avoid unpleasant surprises. If ``True``, we raise a ``ValueError`` if any of these checks fail. If ``False``, we instead raise a ``LoadWarning``. The intended use here is if you're loading something that was saved with an older version of plenoptic and you're sure that you're doing everything correctly. Note that different devices or dtypes will always result in a ``ValueError``. See :ref:`raise-on-checks` on the "Reproducibility and Compatibility" page of the documentation for more info. Additionally, note that, if the ``MetamerCTF`` object itself has changed, we cannot ensure that methods are the same -- proceed at your own risk. tensor_equality_atol Absolute tolerance to use when checking for tensor equality during load, passed to :func:`torch.allclose`. It may be necessary to increase if you are saving and loading on two machines with torch built by different cuda versions. Be careful when changing this! See :class:`torch.finfo<torch.torch.finfo>` for more details about floating point precision of different data types (especially, ``eps``); if you have to increase this by more than 1 or 2 decades, then you are probably not dealing with a numerical issue. tensor_equality_rtol Relative tolerance to use when checking for tensor equality during load, passed to :func:`torch.allclose`. It may be necessary to increase if you are saving and loading on two machines with torch built by different cuda versions. Be careful when changing this! See :class:`torch.finfo<torch.torch.finfo>` for more details about floating point precision of different data types (especially, ``eps``); if you have to increase this by more than 1 or 2 decades, then you are probably not dealing with a numerical issue. **pickle_load_args Any additional kwargs will be added to ``pickle_module.load`` via :func:`torch.load`, see that function's docstring for details. Raises ------ ValueError If :func:`setup` or :func:`synthesize` has been called before this call to ``load``. ValueError If the object saved at ``file_path`` is not a ``MetamerCTF`` object. ValueError If the saved and loading ``MetamerCTF`` objects have a different value for any of :attr:`image`, :attr:`penalty_lambda`, or :attr:`coarse_to_fine`. ValueError If the behavior of :attr:`loss_function` or :attr:`model` is different between the saved and loading objects. Warns ----- UserWarning If :func:`setup` will need to be called after load, to finish initializing :attr:`optimizer` or :attr:`scheduler`. Examples -------- In order to load a saved ``MetamerCTF`` object, we must first initialize one using the same arguments. (We use float64 / "double" precision rather than torch's default float32 because it increases reproducibility, see the :ref:`Reproducibility <reproduce>` page of our documentations for more details.) Here, we load in a cached example: >>> import plenoptic as po >>> img = po.data.reptile_skin().to(torch.float64) >>> model = po.models.PortillaSimoncelli(img.shape[-2:]) >>> met = po.MetamerCTF(img, model, po.loss.l2_norm) >>> print(met.metamer) tensor([]) >>> met.load(po.data.fetch_data("example_metamerCTF_ps.pt")) >>> print(met.metamer) tensor([[[[0.1421, ...]]]], dtype=torch.float64, requires_grad=True) If the saved ``MetamerCTF`` object lived on a CUDA device and you do not have CUDA on the loading machine, use ``map_location`` to change device: >>> met = po.MetamerCTF(img, model, po.loss.l2_norm) >>> met.image.device device(type='cpu') >>> met.load(po.data.fetch_data("example_metamerCTF_ps-cuda.pt")) Traceback (most recent call last): RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False... >>> met.load( ... po.data.fetch_data("example_metamerCTF_ps-cuda.pt"), map_location="cpu" ... ) >>> print(met.metamer) tensor([[[[0.1421, ...]]]], dtype=torch.float64, requires_grad=True) Loading and saving must both be done with ``MetamerCTF``: >>> met = po.Metamer(img, model) >>> met.load(po.data.fetch_data("example_metamerCTF_ps.pt")) Traceback (most recent call last): ValueError: Saved object was a plenoptic.MetamerCTF... If the loading ``MetamerCTF`` object was not initialized with same values as the saved object, an error will be raised: >>> met = po.MetamerCTF(torch.rand_like(img), model, po.loss.l2_norm) >>> met.load(po.data.fetch_data("example_metamerCTF_ps.pt")) Traceback (most recent call last): ValueError: Saved and initialized attribute image have different values... If the loading ``MetamerCTF`` object has a different data type than the saved object, an error will be raised: >>> met = po.MetamerCTF(img, model, po.loss.l2_norm) >>> met.to(torch.float32) >>> met.load(po.data.fetch_data("example_metamerCTF_ps.pt")) Traceback (most recent call last): ValueError: Saved and initialized attribute image have different dtype... """ super()._load( file_path, map_location, ["_coarse_to_fine"], raise_on_checks=raise_on_checks, tensor_equality_atol=tensor_equality_atol, tensor_equality_rtol=tensor_equality_rtol, **pickle_load_args, )
def __repr__(self) -> str: # numpydoc ignore=GL08 return super()._repr_format( [ "image", "model", "loss_function", "penalty_function", "penalty_lambda", "coarse_to_fine", ] ) @property def coarse_to_fine(self) -> str: """How we scales are handled, see :class:`MetamerCTF` for details.""" # numpydoc ignore=RT01,ES01,EX01 return self._coarse_to_fine @property def scales(self) -> tuple: """Model scales that we've yet to optimize, modified during optimization.""" # numpydoc ignore=RT01,ES01,EX01 return tuple(self._scales) @property def scales_loss(self) -> tuple: """Scale-specific loss at each iteration.""" # numpydoc ignore=RT01,ES01,EX01 return tuple(self._scales_loss) @property def scales_timing(self) -> dict: """ Information about when each scale was started and stopped. Keys are the values found in :attr:`scales`, and values are lists specifying the iteration where we started and stopped optimizing this scale, which are modified during optimization. """ # numpydoc ignore=RT01,EX01 return self._scales_timing @property def scales_finished(self) -> tuple: """Model scales that we've finished optimizing, modified during optimization.""" # numpydoc ignore=RT01,ES01,EX01 return tuple(self._scales_finished)