Source code for plenoptic._synthesize.eigendistortion

"""
Eigendistortions.

Eigendistortions are images which make changes that a model thinks are most or least
noticeable to an image, given a constrained pixel budget. They allow researchers to see
which features drive a model's response and which have no effect.
"""

import warnings
from typing import Any, Literal

import torch
from torch import Tensor
from tqdm.auto import tqdm

from ..validate import validate_input, validate_model
from .autodiff import (
    _fisher_info_matrix_eigenvalue,
    _fisher_info_matrix_vector_product,
    _jacobian,
)
from .synthesis import _Synthesis

__all__ = [
    "Eigendistortion",
]


def __dir__() -> list[str]:
    return __all__


[docs] class Eigendistortion(_Synthesis): r""" Synthesize eigendistortions induced by a model on a given input image. Following the basic idea in [1]_, this class synthesizes image perturbations that are considered the most and least noticeable for a model on a given image. Because these are perturbations on the input image, they are local in pixel space, i.e., they do not change the pixels much. Parameters ---------- image Image to perturb. We currently do not support batches of images, as each image requires its own optimization, so either ``image.ndimension()==1`` or ``image.shape[0]==1``. model Torch model with defined forward and backward operations. Notes ----- This is a method for comparing image representations in terms of their ability to explain perceptual sensitivity in humans. It estimates eigenvectors of the Fisher Information Matrix. A model, :math:`y = f(x)`, is a deterministic (and differentiable) mapping from the input pixels :math:`x \in \mathbb{R}^n` to a mean output response vector :math:`y\in \mathbb{R}^m`, where we assume additive white Gaussian noise in the response space. The Jacobian matrix at :math:`x` is: :math:`J(x) = J = dydx`, where :math:`J\in\mathbb{R}^{m \times n}` (i.e. output_dim x input_dim). The matrix consists of all partial derivatives of the vector-valued function :math:`f`. The Fisher Information Matrix (FIM) at :math:`x`, under white Gaussian noise in the response space, is: :math:`F = J^T J` It is a quadratic approximation of the discriminability of distortions relative to :math:`x`. References ---------- .. [1] Berardino, A., Laparra, V., Ballé, J. and Simoncelli, E., 2017. Eigen-distortions of hierarchical representations. In Advances in neural information processing systems (pp. 3530-3539). https://www.cns.nyu.edu/pub/lcv/berardino17c-final.pdf https://www.cns.nyu.edu/~lcv/eigendistortions/ """ def __init__(self, image: Tensor, model: torch.nn.Module): super().__init__() validate_input(image, no_batch=True) validate_model( model, image_shape=image.shape, image_dtype=image.dtype, device=image.device, ) self._image_shape = image.shape if image.ndimension() != 1: self._image_shape = self._image_shape[1:] self._model = model # flatten and attach gradient and reshape to image self._image_flat = image.flatten().unsqueeze(1).requires_grad_(True) self._init_representation(image) self._jacobian = None self._eigendistortions = None self._eigenvalues = None self._eigenindex = None def _init_representation(self, image: torch.Tensor): """Set self._representation_flat, based on ``self.model`` and ``image``.""" # numpydoc ignore=ES01,PR01 self._image = self._image_flat.view(*image.shape) image_representation = self.model(self.image) if len(image_representation) > 1: self._representation_flat = torch.cat( [s.squeeze().view(-1) for s in image_representation] ).unsqueeze(1) else: self._representation_flat = ( image_representation.squeeze().view(-1).unsqueeze(1) )
[docs] def synthesize( self, method: Literal["exact", "power", "randomized_svd"] = "power", k: int = 1, max_iter: int = 1000, p: int = 5, q: int = 2, stop_criterion: float = 1e-7, ): r""" Compute eigendistortions of Fisher Information Matrix with given input image. There are three potential ways of computing the eigendistortion for a model; all have the same interpretation. See ``method`` argument for details. Parameters ---------- method Eigensolver method. ``'exact'`` tries to do eigendecomposition directly (not recommended for very large inputs). ``'power'`` uses the power method to compute first and last eigendistortions, with maximum number of iterations dictated by n_steps. ``'randomized_svd'`` uses randomized SVD to approximate the top k eigendistortions and their corresponding eigenvalues. k How many vectors to return using block power method or svd. max_iter Maximum number of steps to run for ``method='power'`` in eigenvalue computation. Ignored for other methods. p Oversampling parameter for randomized SVD. k+p vectors will be sampled, and k will be returned. See docstring of ``_synthesize_randomized_svd`` for more details including algorithm reference. q Matrix power parameter for randomized SVD. This is an effective trick for the algorithm to converge to the correct eigenvectors when the eigenspectrum does not decay quickly. See ``_synthesize_randomized_svd`` for more details including algorithm reference. stop_criterion Used if ``method='power'`` to check for convergence. If the L2-norm of the eigenvalues has changed by less than this value from one iteration to the next, we terminate synthesis. Raises ------ ValueError If ``method`` takes an illegal value. Warns ----- UserWarning If ``method == "power"`` but the Jacobian size is greater than 1e6 (which depends on the number of elements in the model's representation and input image), in which case we're worried about running out of memory. """ allowed_methods = ["power", "exact", "randomized_svd"] if method not in allowed_methods: raise ValueError(f"method must be in {allowed_methods}") if ( method == "exact" and self._representation_flat.size(0) * self._image_flat.size(0) > 1e6 ): warnings.warn( "Jacobian > 1e6 elements and may cause out-of-memory. Use" " method = {'power', 'randomized_svd'}." ) if method == "exact": # compute exact Jacobian eig_vals, eig_vecs = self._synthesize_exact() eig_vecs = self._vector_to_image(eig_vecs.detach()) eig_vecs_ind = torch.arange(len(eig_vecs)) elif method == "randomized_svd": lmbda_new, v_new, error_approx = self._synthesize_randomized_svd( k=k, p=p, q=q ) eig_vecs = self._vector_to_image(v_new.detach()) eig_vals = lmbda_new.squeeze() eig_vecs_ind = torch.arange(k) # display the approximate estimation error of the range space warnings.warn( "Randomized SVD complete! Estimated spectral approximation" f" error = {error_approx:.2f}" ) else: # method == 'power' assert max_iter > 0, "max_iter must be greater than zero" lmbda_max, v_max = self._synthesize_power( k=k, shift=0.0, tol=stop_criterion, max_iter=max_iter ) lmbda_min, v_min = self._synthesize_power( k=k, shift=lmbda_max[0], tol=stop_criterion, max_iter=max_iter ) n = v_max.shape[0] eig_vecs = self._vector_to_image(torch.cat((v_max, v_min), dim=1).detach()) eig_vals = torch.cat([lmbda_max, lmbda_min]).squeeze() eig_vecs_ind = torch.cat((torch.arange(k), torch.arange(n - k, n))) # reshape to (n x num_chans x h x w) self._eigendistortions = torch.stack(eig_vecs, 0) if len(eig_vecs) != 0 else [] self._eigenvalues = torch.abs(eig_vals.detach()) self._eigenindex = eig_vecs_ind
def _synthesize_exact(self) -> tuple[Tensor, Tensor]: r""" Eigendecomposition of explicitly computed Fisher Information Matrix. To be used when the input is small (e.g. less than 70x70 image on cluster or 30x30 on your own machine). This method obviates the power iteration and its related algorithms (e.g. Lanczos). This method computes the Fisher Information Matrix by explicitly computing the Jacobian of the representation with respect to the input. Returns ------- eig_vals Eigenvalues in decreasing order. eig_vecs Eigenvectors in 2D tensor, whose cols are eigenvectors (i.e. eigendistortions) corresponding to eigenvalues. """ J = self.compute_jacobian() F = J.T @ J eig_vals, eig_vecs = torch.linalg.eigh(F, UPLO="U") eig_vecs = eig_vecs.flip(dims=(1,)) eig_vals = eig_vals.flip(dims=(0,)) return eig_vals, eig_vecs
[docs] def compute_jacobian(self) -> Tensor: r""" Compute, cache, and return jacobian. If the jacobian has not been cached: compute, cache, and return. If the jacobian has already been cached, we simply return it. Returns ------- J Jacobian of representation with respect to input. Warns ----- UserWarning If input dimensionality is greater than 1e4, in which case we believe that this calculation will take too long. """ if self.jacobian is None: J = _jacobian(self._representation_flat, self._image_flat) self._jacobian = J else: J = self.jacobian return J
def _synthesize_power( self, k: int, shift: Tensor | float, tol: float, max_iter: int ) -> tuple[Tensor, Tensor]: r""" Use power method to obtain largest (smallest) eigenvalue/vector pairs. When ``k>1``, uses orthogonal iteration, see [1]_. Apply the algorithm to approximate the extremal eigenvalues and eigenvectors of the Fisher Information Matrix, without explicitly representing that matrix. This method repeatedly calls :func:`_fisher_info_matrix_vector_product` with a single (``k=1``), or multiple (``k>1``) vectors. Parameters ---------- k Number of top and bottom eigendistortions to synthesize; i.e. if k=2, then the top 2 and bottom 2 will be returned. When `k>1`, multiple eigendistortions are synthesized, and each power iteration step is followed by a QR orthogonalization step to ensure the vectors are orthonormal. shift When `shift=0`, this function estimates the top `k` eigenvalue/vector pairs. When `shift` is set to the estimated top eigenvalue this function will estimate the smallest eigenval/eigenvector pairs. tol Tolerance value. max_iter Maximum number of steps. Returns ------- lmbda Eigenvalue corresponding to final vector of power iteration. v Final eigenvector(s) (i.e. eigendistortions) of power (orthogonal) iteration procedure. References ---------- .. [1] Orthogonal iteration; Algorithm 8.2.8 Golub and Van Loan, Matrix Computations, 3rd Ed. """ x, y = self._image_flat, self._representation_flat # note: v is an n x k matrix where k is number of eigendists to be synthesized! v = torch.randn(len(x), k, device=x.device, dtype=x.dtype) v = v / torch.linalg.vector_norm(v, dim=0, keepdim=True, ord=2) _dummy_vec = torch.ones_like(y, requires_grad=True) # cache a dummy vec for jvp Fv = _fisher_info_matrix_vector_product(y, x, v, _dummy_vec) v = Fv / torch.linalg.vector_norm(Fv, dim=0, keepdim=True, ord=2) lmbda = _fisher_info_matrix_eigenvalue(y, x, v, _dummy_vec) d_lambda = torch.as_tensor(float("inf")) lmbda_new, v_new = None, None desc = ("Top" if shift == 0 else "Bottom") + f" k={k} eigendists" pbar = tqdm(range(max_iter), desc=desc) postfix_dict = {"delta_eigenval": None} for _ in pbar: postfix_dict.update(dict(delta_eigenval=f"{d_lambda.item():.2E}")) pbar.set_postfix(**postfix_dict) if d_lambda <= tol: print(f"{desc} computed | Stop criterion {tol:.2E} reached.") break Fv = _fisher_info_matrix_vector_product(y, x, v, _dummy_vec) Fv = Fv - shift * v # optionally shift: (F - shift*I)v v_new, _ = torch.linalg.qr(Fv, "reduced") # (ortho)normalize vector(s) lmbda_new = _fisher_info_matrix_eigenvalue(y, x, v_new, _dummy_vec) d_lambda = torch.linalg.vector_norm( lmbda - lmbda_new, ord=2 ) # stability of eigenspace v = v_new lmbda = lmbda_new pbar.close() return lmbda_new, v_new def _synthesize_randomized_svd( self, k: int, p: int, q: int ) -> tuple[Tensor, Tensor, Tensor]: r""" Synthesize eigendistortions using randomized truncated SVD. This method approximates the column space of the Fisher Info Matrix, projects the FIM into that column space, then computes its SVD. See [1]_ for details. Parameters ---------- k Number of eigenvecs (rank of factorization) to be returned. p Oversampling parameter, recommended to be 5. q Matrix power iteration. Used to squeeze the eigen spectrum for more accurate approximation. Recommended to be 2. Returns ------- S Eigenvalues, Size((n, )). V Eigendistortions, Size((n, k)). error_approx Estimate of the approximation error. Defined as the expected error between the true subspace and approximated subspace. References ---------- .. [1] Halko, Martinsson, Tropp, Finding structure with randomness: Probabilistic algorithms for constructing approximate matrix decompositions, SIAM Rev. 53:2, pp. 217-288 https://arxiv.org/abs/0909.4061 (2011) """ x, y = self._image_flat, self._representation_flat n = len(x) P = torch.randn(n, k + p).to(x.device) # orthogonalize first for numerical stability P, _ = torch.linalg.qr(P, "reduced") _dummy_vec = torch.ones_like(y, requires_grad=True) Z = _fisher_info_matrix_vector_product(y, x, P, _dummy_vec) # optional power iteration to squeeze the spectrum for more accurate # estimate for _ in range(q): Z = _fisher_info_matrix_vector_product(y, x, Z, _dummy_vec) Q, _ = torch.linalg.qr(Z, "reduced") # B = Q.T @ A @ Q B = Q.T @ _fisher_info_matrix_vector_product(y, x, Q, _dummy_vec) _, S, Vh = torch.linalg.svd(B, False) # eigendecomp of small matrix V = Vh.T V = Q @ V # lift up to original dimensionality # estimate error in Q estimate of range space omega = _fisher_info_matrix_vector_product( y, x, torch.randn(n, 20).to(x.device), _dummy_vec ) error_approx = omega - (Q @ Q.T @ omega) error_approx = torch.linalg.vector_norm(error_approx, dim=0, ord=2).mean() return S[:k].clone(), V[:, :k].clone(), error_approx # truncate def _vector_to_image(self, vecs: Tensor) -> list[Tensor]: r""" Reshape eigenvectors back into correct image dimensions. Parameters ---------- vecs Eigendistortion tensor with ``torch.Size([N, num_distortions])``. Each distortion will be reshaped into the original image shape and placed in a list. Returns ------- imgs List of Tensor images. """ # numpydoc ignore=ES01 imgs = [vecs[:, i].reshape(self._image_shape) for i in range(vecs.shape[1])] return imgs def _indexer(self, idx: int) -> int: """ Map eigenindex to arg index (0-indexed). Parameters ---------- idx Eigenvalue number. Returns ------- idx Index into eigenvalue tensor. Raises ------ ValueError If ``eigenindex`` doesn't correspond to one of the synthesized eigendistortions. ValueError If no eigendistortions have been synthesized. """ # numpydoc ignore=ES01 n = len(self._image_flat) idx_range = range(n) i = idx_range[idx] all_idx = self.eigenindex if i not in all_idx: raise ValueError("eigenindex must be the index of one of the vectors") if all_idx is None or len(all_idx) == 0: raise ValueError("No eigendistortions synthesized") return torch.where(all_idx == i)[0].item()
[docs] def save(self, file_path: str): r""" Save all relevant variables in .pt file. See :meth:`load` docstring for an example of use. Parameters ---------- file_path : str The path to save the Eigendistortion object to. """ save_io_attrs = [("_model", ("_image",))] super().save(file_path, save_io_attrs)
[docs] def to(self, *args: Any, **kwargs: Any): r""" Move and/or cast the parameters and buffers. This can be called as .. code:: python to(device=None, dtype=None, non_blocking=False) .. code:: python to(dtype, non_blocking=False) .. code:: python to(tensor, non_blocking=False) Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point desired ``dtype``. In addition, this method will only cast the floating point parameters and buffers to ``dtype`` (if given). The integral parameters and buffers will be moved ``device``, if that is given, but with dtypes unchanged. When `on_blocking`` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See :meth:`torch.nn.Module.to` for examples. .. note:: This method modifies the module in-place. Parameters ---------- device : torch.device The desired device of the parameters and buffers in this module. dtype : torch.dtype The desired floating point type of the floating point parameters and buffers in this module. tensor : torch.Tensor Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module. """ # numpydoc ignore=PR01,PR02 attrs = [ "_jacobian", "_eigendistortions", "_eigenvalues", "_eigenindex", "_image", "_image_flat", "_representation_flat", ] super().to(*args, attrs=attrs, **kwargs) # try to call .to() on model. this should work, but it might fail if e.g., this # a custom model that doesn't inherit torch.nn.Module try: self._model = self._model.to(*args, **kwargs) except AttributeError: warnings.warn("Unable to call model.to(), so we leave it as is.") # we need _representation_flat and _image_flat to be connected in the # computation graph for the autograd calls to work, so we reinitialize # it here self._init_representation(self.image)
[docs] def load( self, file_path: str, map_location: str | None = None, raise_on_checks: bool = True, tensor_equality_atol: float = 1e-8, tensor_equality_rtol: float = 1e-5, **pickle_load_args: Any, ): r""" Load all relevant stuff from a .pt file. This must be called by a ``Eigendistortion`` object initialized just like the saved object. Note this operates in place and so doesn't return anything. .. versionchanged:: 1.2 load behavior changed in a backwards-incompatible manner in order to compatible with breaking changes in torch 2.6. .. versionchanged:: 2.0.0 Adds ``raise_on_checks`` argument. Parameters ---------- file_path The path to load the synthesis object from. map_location Argument to pass to :func:`torch.load` as ``map_location``. If you save stuff that was being run on a GPU and are loading onto a CPU, you'll need this to make sure everything lines up properly. This should be structured like the str you would pass to :class:`torch.device`. raise_on_checks During load, we perform several checks to ensure that the saved object was initialized in the same way as the loading object. This is to ensure that the model, image, etc. are all the same and avoid unpleasant surprises. If ``True``, we raise a ``ValueError`` if any of these checks fail. If ``False``, we instead raise a ``LoadWarning``. The intended use here is if you're loading something that was saved with an older version of plenoptic and you're sure that you're doing everything correctly. Note that different devices or dtypes will always result in a ``ValueError``. See :ref:`raise-on-checks` on the "Reproducibility and Compatibility" page of the documentation for more info. Additionally, note that, if the ``Eigendistortion`` object itself has changed, we cannot ensure that methods are the same -- proceed at your own risk. tensor_equality_atol Absolute tolerance to use when checking for tensor equality during load, passed to :func:`torch.allclose`. It may be necessary to increase if you are saving and loading on two machines with torch built by different cuda versions. Be careful when changing this! See :class:`torch.finfo<torch.torch.finfo>` for more details about floating point precision of different data types (especially, ``eps``); if you have to increase this by more than 1 or 2 decades, then you are probably not dealing with a numerical issue. tensor_equality_rtol Relative tolerance to use when checking for tensor equality during load, passed to :func:`torch.allclose`. It may be necessary to increase if you are saving and loading on two machines with torch built by different cuda versions. Be careful when changing this! See :class:`torch.finfo<torch.torch.finfo>` for more details about floating point precision of different data types (especially, ``eps``); if you have to increase this by more than 1 or 2 decades, then you are probably not dealing with a numerical issue. **pickle_load_args Any additional kwargs will be added to ``pickle_module.load`` via :func:`torch.load`, see that function's docstring for details. Raises ------ ValueError If :func:`synthesize` has been called before this call to ``load``. ValueError If the object saved at ``file_path`` is not a ``Eigendistortion`` object. ValueError If the saved and loading ``Eigendistortion`` objects have a different value for :attr:`image`. ValueError If the behavior of :attr:`model` is different between the saved and loading objects. See Also -------- :func:`~plenoptic.io.examine_saved_synthesis` Examine metadata from saved object: pytorch and plenoptic versions, name of the synthesis object, shapes of tensors, etc. Examples -------- >>> import plenoptic as po >>> img = po.data.einstein() >>> model = po.models.Gaussian(30).eval() >>> po.remove_grad(model) >>> eig = po.Eigendistortion(img, model) >>> eig.synthesize(max_iter=5) >>> eig.save("eig.pt") >>> eig_copy = po.Eigendistortion(img, model) >>> eig_copy.load("eig.pt") """ check_attributes = ["_image"] check_io_attrs = [("_model", ("_image",))] super().load( file_path, "eigenindex", map_location=map_location, check_attributes=check_attributes, check_io_attributes=check_io_attrs, raise_on_checks=raise_on_checks, tensor_equality_atol=tensor_equality_atol, tensor_equality_rtol=tensor_equality_rtol, **pickle_load_args, ) # make these require a grad again self._image_flat.requires_grad_() # we need _representation_flat and _image_flat to be connected in the # computation graph for the autograd calls to work, so we reinitialize # it here self._init_representation(self.image)
def __repr__(self) -> str: # numpydoc ignore=GL08 return super()._repr_format(["image", "model"]) @property def model(self) -> torch.nn.Module: """The model for which the eigendistortions are synthesized.""" # numpydoc ignore=RT01,ES01 return self._model @property def image(self) -> torch.Tensor: """Target image of eigendistortion synthesis.""" # numpydoc ignore=RT01,ES01 return self._image @property def jacobian(self) -> torch.Tensor: """ Jacobian matrix of :attr:`model` with respect to :attr:`image`. Only set when :func:`synthesize` is run with ``method='exact'``. Else, ``None``. """ # numpydoc ignore=RT01 return self._jacobian @property def eigendistortions(self) -> torch.Tensor: """ Eigendistortions, ordered by eigenvalue. Eigendistortions are the eigenvectors of Fisher matrix, will have size ``Size((n_distortions, *image.shape[1:]))``. """ # numpydoc ignore=RT01 return self._eigendistortions @property def eigenvalues(self) -> torch.Tensor: """Eigenvalues corresponding to each eigendistortion, in decreasing order.""" # numpydoc ignore=RT01,ES01,SS05 return self._eigenvalues @property def eigenindex(self) -> torch.Tensor: """Index of each eigenvector/eigenvalue.""" # numpydoc ignore=RT01,ES01 return self._eigenindex