import warnings
from collections.abc import Callable
from typing import Literal
import matplotlib.pyplot
import numpy as np
import torch
from matplotlib.figure import Figure
from torch import Tensor
from tqdm.auto import tqdm
from ..tools.display import imshow
from ..tools.validate import validate_input, validate_model
from .autodiff import (
jacobian,
jacobian_vector_product,
vector_jacobian_product,
)
from .synthesis import Synthesis
[docs]
def fisher_info_matrix_vector_product(
y: Tensor, x: Tensor, v: Tensor, dummy_vec: Tensor
) -> Tensor:
r"""Compute Fisher Information Matrix Vector Product: :math:`Fv`
Parameters
----------
y
Output tensor with gradient attached
x
Input tensor with gradient attached
v
The vectors with which to compute Fisher vector products
dummy_vec
Dummy vector for Jacobian vector product trick
Returns
-------
Fv
Vector, Fisher vector product
Notes
-----
Under white Gaussian noise assumption, :math:`F` is matrix multiplication
of Jacobian transpose and Jacobian:
:math:`F = J^T J`. Hence:
:math:`Fv = J^T (Jv)`
"""
Jv = jacobian_vector_product(y, x, v, dummy_vec)
Fv = vector_jacobian_product(y, x, Jv, detach=True)
return Fv
[docs]
def fisher_info_matrix_eigenvalue(
y: Tensor, x: Tensor, v: Tensor, dummy_vec: Tensor | None = None
) -> Tensor:
r"""Compute the eigenvalues of the Fisher Information Matrix corresponding to
eigenvectors in v:math:`\lambda= v^T F v`
"""
if dummy_vec is None:
dummy_vec = torch.ones_like(y, requires_grad=True)
Fv = fisher_info_matrix_vector_product(y, x, v, dummy_vec)
# compute eigenvalues for all vectors in v
lmbda = torch.stack([a.dot(b) for a, b in zip(v.T, Fv.T)])
return lmbda
[docs]
class Eigendistortion(Synthesis):
r"""Synthesis object to compute eigendistortions induced by a model on a given
input image.
Parameters
----------
image
Image, torch.Size(batch=1, channel, height, width). We currently do not
support batches of images, as each image requires its own optimization.
model
Torch model with defined forward and backward operations.
Attributes
----------
batch_size: int
n_channels: int
im_height: int
im_width: int
jacobian: Tensor
Is only set when :func:`synthesize` is run with ``method='exact'``. Default to
``None``.
eigendistortions: Tensor
Tensor of eigendistortions (eigenvectors of Fisher matrix), ordered by
eigenvalue, with Size((n_distortions, n_channels, im_height,
im_width)).
eigenvalues: Tensor
Tensor of eigenvalues corresponding to each eigendistortion, listed in
decreasing order.
eigenindex: listlike
Index of each eigenvector/eigenvalue.
Notes
-----
This is a method for comparing image representations in terms of their ability to
explain perceptual sensitivity in humans. It estimates eigenvectors of the FIM.
A model, :math:`y = f(x)`, is a deterministic (and differentiable)
mapping from the input pixels :math:`x \in \mathbb{R}^n` to a mean output
response vector :math:`y\in \mathbb{R}^m`, where we assume additive white
Gaussian noise in the response space.
The Jacobian matrix at x is:
:math:`J(x) = J = dydx`,
:math:`J\in\mathbb{R}^{m \times n}` (ie. output_dim x input_dim)
The matrix consists of all partial derivatives of the vector-valued function f.
The Fisher Information Matrix (FIM) at x, under white Gaussian noise in the
response space, is:
:math:`F = J^T J`
It is a quadratic approximation of the discriminability of distortions
relative to :math:`x`.
References
----------
.. [1] Berardino, A., Laparra, V., Ballé, J. and Simoncelli, E., 2017.
Eigen-distortions of hierarchical representations. In Advances in
neural information processing systems (pp. 3530-3539).
https://www.cns.nyu.edu/pub/lcv/berardino17c-final.pdf
https://www.cns.nyu.edu/~lcv/eigendistortions/
"""
def __init__(self, image: Tensor, model: torch.nn.Module):
validate_input(image, no_batch=True)
validate_model(
model,
image_shape=image.shape,
image_dtype=image.dtype,
device=image.device,
)
(
self.batch_size,
self.n_channels,
self.im_height,
self.im_width,
) = image.shape
self._model = model
# flatten and attach gradient and reshape to image
self._image_flat = image.flatten().unsqueeze(1).requires_grad_(True)
self._init_representation(image)
print(
"\nInitializing Eigendistortion -- Input dim:"
f" {len(self._image_flat.squeeze())} | Output dim:"
f" {len(self._representation_flat.squeeze())}"
)
self._jacobian = None
self._eigendistortions = None
self._eigenvalues = None
self._eigenindex = None
def _init_representation(self, image):
"""Set self._representation_flat, based on model and image"""
self._image = self._image_flat.view(*image.shape)
image_representation = self.model(self.image)
if len(image_representation) > 1:
self._representation_flat = torch.cat(
[s.squeeze().view(-1) for s in image_representation]
).unsqueeze(1)
else:
self._representation_flat = (
image_representation.squeeze().view(-1).unsqueeze(1)
)
[docs]
def synthesize(
self,
method: Literal["exact", "power", "randomized_svd"] = "power",
k: int = 1,
max_iter: int = 1000,
p: int = 5,
q: int = 2,
stop_criterion: float = 1e-7,
):
r"""
Compute eigendistortions of Fisher Information Matrix with given input image.
Parameters
----------
method
Eigensolver method. 'exact' tries to do eigendecomposition directly (
not recommended for very large inputs). 'power' (default) uses the power
method to compute first and last eigendistortions, with maximum number of
iterations dictated by n_steps. 'randomized_svd' uses randomized SVD to
approximate the top k eigendistortions and their corresponding eigenvalues.
k
How many vectors to return using block power method or svd.
max_iter
Maximum number of steps to run for ``method='power'`` in eigenvalue
computation. Ignored for other methods.
p
Oversampling parameter for randomized SVD. k+p vectors will be sampled,
and k will be returned. See docstring of ``_synthesize_randomized_svd``
for more details including algorithm reference.
q
Matrix power parameter for randomized SVD. This is an effective trick for
the algorithm to converge to the correct eigenvectors when the
eigenspectrum does not decay quickly. See ``_synthesize_randomized_svd``
for more details including algorithm reference.
stop_criterion
Used if ``method='power'`` to check for convergence. If the L2-norm
of the eigenvalues has changed by less than this value from one
iteration to the next, we terminate synthesis.
"""
allowed_methods = ["power", "exact", "randomized_svd"]
assert method in allowed_methods, f"method must be in {allowed_methods}"
if (
method == "exact"
and self._representation_flat.size(0) * self._image_flat.size(0) > 1e6
):
warnings.warn(
"Jacobian > 1e6 elements and may cause out-of-memory. Use"
" method = {'power', 'randomized_svd'}."
)
if method == "exact": # compute exact Jacobian
print("Computing all eigendistortions")
eig_vals, eig_vecs = self._synthesize_exact()
eig_vecs = self._vector_to_image(eig_vecs.detach())
eig_vecs_ind = torch.arange(len(eig_vecs))
elif method == "randomized_svd":
print(f"Estimating top k={k} eigendistortions using randomized SVD")
lmbda_new, v_new, error_approx = self._synthesize_randomized_svd(
k=k, p=p, q=q
)
eig_vecs = self._vector_to_image(v_new.detach())
eig_vals = lmbda_new.squeeze()
eig_vecs_ind = torch.arange(k)
# display the approximate estimation error of the range space
print(
"Randomized SVD complete! Estimated spectral approximation"
f" error = {error_approx:.2f}"
)
else: # method == 'power'
assert max_iter > 0, "max_iter must be greater than zero"
lmbda_max, v_max = self._synthesize_power(
k=k, shift=0.0, tol=stop_criterion, max_iter=max_iter
)
lmbda_min, v_min = self._synthesize_power(
k=k, shift=lmbda_max[0], tol=stop_criterion, max_iter=max_iter
)
n = v_max.shape[0]
eig_vecs = self._vector_to_image(torch.cat((v_max, v_min), dim=1).detach())
eig_vals = torch.cat([lmbda_max, lmbda_min]).squeeze()
eig_vecs_ind = torch.cat((torch.arange(k), torch.arange(n - k, n)))
# reshape to (n x num_chans x h x w)
self._eigendistortions = torch.stack(eig_vecs, 0) if len(eig_vecs) != 0 else []
self._eigenvalues = torch.abs(eig_vals.detach())
self._eigenindex = eig_vecs_ind
def _synthesize_exact(self) -> tuple[Tensor, Tensor]:
r"""Eigendecomposition of explicitly computed Fisher Information Matrix.
To be used when the input is small (e.g. less than 70x70 image on cluster or
30x30 on your own machine). This method obviates the power iteration and its
related algorithms (e.g. Lanczos). This method computes the Fisher Information
Matrix by explicitly computing the Jacobian of the representation wrt the input.
Returns
-------
eig_vals
Eigenvalues in decreasing order.
eig_vecs
Eigenvectors in 2D tensor, whose cols are eigenvectors
(i.e. eigendistortions) corresponding to eigenvalues.
"""
J = self.compute_jacobian()
F = J.T @ J
eig_vals, eig_vecs = torch.linalg.eigh(F, UPLO="U")
eig_vecs = eig_vecs.flip(dims=(1,))
eig_vals = eig_vals.flip(dims=(0,))
return eig_vals, eig_vecs
[docs]
def compute_jacobian(self) -> Tensor:
r"""
Calls autodiff.jacobian and returns jacobian. Will throw error if input too big.
Returns
-------
J
Jacobian of representation wrt input.
"""
if self.jacobian is None:
J = jacobian(self._representation_flat, self._image_flat)
self._jacobian = J
else:
print("Jacobian already computed, returning self.jacobian")
J = self.jacobian
return J
def _synthesize_power(
self, k: int, shift: Tensor | float, tol: float, max_iter: int
) -> tuple[Tensor, Tensor]:
r"""Use power method (or orthogonal iteration when k>1) to obtain largest
(smallest) eigenvalue/vector pairs.
Apply the algorithm to approximate the extremal eigenvalues and eigenvectors
of the Fisher Information Matrix, without explicitly representing that matrix.
This method repeatedly calls ``fisher_info_matrix_vector_product()`` with a
single (`k=1`), or multiple (`k>1`) vectors.
Parameters
----------
k
Number of top and bottom eigendistortions to synthesize; i.e. if k=2,
then the top 2 and bottom 2 will be returned. When `k>1`, multiple
eigendistortions are synthesized, and each power iteration step is followed
by a QR orthogonalization step to ensure the vectors are orthonormal.
shift
When `shift=0`, this function estimates the top `k` eigenvalue/vector
pairs. When `shift` is set to the estimated top eigenvalue this function
will estimate the smallest eigenval/eigenvector pairs.
tol
Tolerance value.
max_iter
Maximum number of steps.
Returns
-------
lmbda
Eigenvalue corresponding to final vector of power iteration.
v
Final eigenvector(s) (i.e. eigendistortions) of power (orthogonal)
iteration procedure.
References
----------
[1] Orthogonal iteration; Algorithm 8.2.8 Golub and Van Loan, Matrix
Computations, 3rd Ed.
"""
x, y = self._image_flat, self._representation_flat
# note: v is an n x k matrix where k is number of eigendists to be synthesized!
v = torch.randn(len(x), k, device=x.device, dtype=x.dtype)
v = v / torch.linalg.vector_norm(v, dim=0, keepdim=True, ord=2)
_dummy_vec = torch.ones_like(y, requires_grad=True) # cache a dummy vec for jvp
Fv = fisher_info_matrix_vector_product(y, x, v, _dummy_vec)
v = Fv / torch.linalg.vector_norm(Fv, dim=0, keepdim=True, ord=2)
lmbda = fisher_info_matrix_eigenvalue(y, x, v, _dummy_vec)
d_lambda = torch.as_tensor(float("inf"))
lmbda_new, v_new = None, None
desc = ("Top" if shift == 0 else "Bottom") + f" k={k} eigendists"
pbar = tqdm(range(max_iter), desc=desc)
postfix_dict = {"delta_eigenval": None}
for _ in pbar:
postfix_dict.update(dict(delta_eigenval=f"{d_lambda.item():.2E}"))
pbar.set_postfix(**postfix_dict)
if d_lambda <= tol:
print(f"{desc} computed | Stop criterion {tol:.2E} reached.")
break
Fv = fisher_info_matrix_vector_product(y, x, v, _dummy_vec)
Fv = Fv - shift * v # optionally shift: (F - shift*I)v
v_new, _ = torch.linalg.qr(Fv, "reduced") # (ortho)normalize vector(s)
lmbda_new = fisher_info_matrix_eigenvalue(y, x, v_new, _dummy_vec)
d_lambda = torch.linalg.vector_norm(
lmbda - lmbda_new, ord=2
) # stability of eigenspace
v = v_new
lmbda = lmbda_new
pbar.close()
return lmbda_new, v_new
def _synthesize_randomized_svd(
self, k: int, p: int, q: int
) -> tuple[Tensor, Tensor, Tensor]:
r"""Synthesize eigendistortions using randomized truncated SVD.
This method approximates the column space of the Fisher Info Matrix, projects
the FIM into that column space, then computes its SVD.
Parameters
----------
k
Number of eigenvecs (rank of factorization) to be returned.
p
Oversampling parameter, recommended to be 5.
q
Matrix power iteration. Used to squeeze the eigen spectrum for more
accurate approximation. Recommended to be 2.
Returns
-------
S
Eigenvalues, Size((n, )).
V
Eigendistortions, Size((n, k)).
error_approx
Estimate of the approximation error. Defined as the expected error
between the true subspace and approximated subspace.
References
-----
[1] Halko, Martinsson, Tropp, Finding structure with randomness:
Probabilistic algorithms for constructing approximate matrix decompositions,
SIAM Rev. 53:2, pp. 217-288 https://arxiv.org/abs/0909.4061 (2011)
"""
x, y = self._image_flat, self._representation_flat
n = len(x)
P = torch.randn(n, k + p).to(x.device)
# orthogonalize first for numerical stability
P, _ = torch.linalg.qr(P, "reduced")
_dummy_vec = torch.ones_like(y, requires_grad=True)
Z = fisher_info_matrix_vector_product(y, x, P, _dummy_vec)
# optional power iteration to squeeze the spectrum for more accurate
# estimate
for _ in range(q):
Z = fisher_info_matrix_vector_product(y, x, Z, _dummy_vec)
Q, _ = torch.linalg.qr(Z, "reduced")
# B = Q.T @ A @ Q
B = Q.T @ fisher_info_matrix_vector_product(y, x, Q, _dummy_vec)
_, S, Vh = torch.linalg.svd(B, False) # eigendecomp of small matrix
V = Vh.T
V = Q @ V # lift up to original dimensionality
# estimate error in Q estimate of range space
omega = fisher_info_matrix_vector_product(
y, x, torch.randn(n, 20).to(x.device), _dummy_vec
)
error_approx = omega - (Q @ Q.T @ omega)
error_approx = torch.linalg.vector_norm(error_approx, dim=0, ord=2).mean()
return S[:k].clone(), V[:, :k].clone(), error_approx # truncate
def _vector_to_image(self, vecs: Tensor) -> list[Tensor]:
r"""Reshapes eigenvectors back into correct image dimensions.
Parameters
----------
vecs
Eigendistortion tensor with ``torch.Size([N, num_distortions])``.
Each distortion will be reshaped into the original image shape and
placed in a list.
Returns
-------
imgs
List of Tensor images, each with ``torch.Size(img_height, im_width)``.
"""
imgs = [
vecs[:, i].reshape((self.n_channels, self.im_height, self.im_width))
for i in range(vecs.shape[1])
]
return imgs
def _indexer(self, idx: int) -> int:
"""Maps eigenindex to arg index (0-indexed)"""
n = len(self._image_flat)
idx_range = range(n)
i = idx_range[idx]
all_idx = self.eigenindex
assert i in all_idx, "eigenindex must be the index of one of the vectors"
assert (
all_idx is not None and len(all_idx) != 0
), "No eigendistortions synthesized"
return int(np.where(all_idx == i)[0])
[docs]
def save(self, file_path: str):
r"""Save all relevant variables in .pt file.
See ``load`` docstring for an example of use.
Parameters
----------
file_path : str
The path to save the Eigendistortion object to
"""
super().save(file_path, attrs=None)
[docs]
def to(self, *args, **kwargs):
r"""Moves and/or casts the parameters and buffers.
This can be called as
.. function:: to(device=None, dtype=None, non_blocking=False)
.. function:: to(dtype, non_blocking=False)
.. function:: to(tensor, non_blocking=False)
Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
floating point desired :attr:`dtype` s. In addition, this method will
only cast the floating point parameters and buffers to :attr:`dtype`
(if given). The integral parameters and buffers will be moved
:attr:`device`, if that is given, but with dtypes unchanged. When
:attr:`non_blocking` is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
See below for examples.
.. note::
This method modifies the module in-place.
Args:
device (:class:`torch.device`): the desired device of the parameters
and buffers in this module
dtype (:class:`torch.dtype`): the desired floating point type of
the floating point parameters and buffers in this module
tensor (torch.Tensor): Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
"""
attrs = [
"_jacobian",
"_eigendistortions",
"_eigenvalues",
"_eigenindex",
"_image",
"_image_flat",
"_representation_flat",
]
super().to(*args, attrs=attrs, **kwargs)
# we need _representation_flat and _image_flat to be connected in the
# computation graph for the autograd calls to work, so we reinitialize
# it here
self._init_representation(self.image)
# try to call .to() on model. this should work, but it might fail if e.g., this
# a custom model that doesn't inherit torch.nn.Module
try:
self._model = self._model.to(*args, **kwargs)
except AttributeError:
warnings.warn("Unable to call model.to(), so we leave it as is.")
[docs]
def load(
self,
file_path: str,
map_location: str | None = None,
**pickle_load_args,
):
r"""Load all relevant stuff from a .pt file.
This should be called by an initialized ``Eigendistortion`` object --
we will ensure that ``image`` and ``model`` are identical.
Note this operates in place and so doesn't return anything.
Parameters
----------
file_path : str
The path to load the synthesis object from
map_location : str, optional
map_location argument to pass to ``torch.load``. If you save
stuff that was being run on a GPU and are loading onto a
CPU, you'll need this to make sure everything lines up
properly. This should be structured like the str you would
pass to ``torch.device``
pickle_load_args :
any additional kwargs will be added to ``pickle_module.load`` via
``torch.load``, see that function's docstring for details.
Examples
--------
>>> eig = po.synth.Eigendistortion(img, model)
>>> eig.synthesize(max_iter=10)
>>> eig.save('eig.pt')
>>> eig_copy = po.synth.Eigendistortion(img, model)
>>> eig_copy.load('eig.pt')
Note that you must create a new instance of the Synthesis object and
*then* load.
"""
check_attributes = ["_image", "_representation_flat"]
check_loss_functions = []
super().load(
file_path,
map_location=map_location,
check_attributes=check_attributes,
check_loss_functions=check_loss_functions,
**pickle_load_args,
)
# make these require a grad again
self._image_flat.requires_grad_()
# we need _representation_flat and _image_flat to be connected in the
# computation graph for the autograd calls to work, so we reinitialize
# it here
self._init_representation(self.image)
@property
def model(self):
return self._model
@property
def image(self):
return self._image
@property
def jacobian(self):
"""Is only set when :func:`synthesize` is run with ``method='exact'``.
Default to ``None``."""
return self._jacobian
@property
def eigendistortions(self):
"""Tensor of eigendistortions (eigenvectors of Fisher matrix), ordered by
eigenvalue."""
return self._eigendistortions
@property
def eigenvalues(self):
"""Tensor of eigenvalues corresponding to each eigendistortion, listed in
decreasing order."""
return self._eigenvalues
@property
def eigenindex(self):
"""Index of each eigenvector/eigenvalue."""
return self._eigenindex
[docs]
def display_eigendistortion(
eigendistortion: Eigendistortion,
eigenindex: int = 0,
alpha: float = 5.0,
process_image: Callable[[Tensor], Tensor] = lambda x: x,
# ax: matplotlib.pyplot.axis | None = None,
ax: matplotlib.axes.Axes | None = None,
plot_complex: str = "rectangular",
**kwargs,
) -> Figure:
r"""Displays specified eigendistortion added to the image.
If image or eigendistortions have 3 channels, then it is assumed to be a color
image and it is converted to grayscale. This is merely for display convenience
and may change in the future.
Parameters
----------
eigendistortion
Eigendistortion object whose synthesized eigendistortion we want to display
eigenindex
Index of eigendistortion to plot. E.g. If there are 10 eigenvectors, 0 will
index the first one, and -1 or 9 will index the last one.
alpha
Amount by which to scale eigendistortion for `image + (alpha * eigendistortion)`
for display.
process_image
A function to process the image+alpha*distortion before clamping between 0,1.
E.g. multiplying by the stdev ImageNet then adding the mean of ImageNet to undo
image preprocessing.
ax
Axis handle on which to plot.
plot_complex
Parameter for :meth:`plenoptic.imshow` determining how to handle complex values.
Defaults to 'rectangular', which plots real and complex components as separate
images. Can also be 'polar' or 'logpolar'; see that method's docstring
for details.
kwargs
Additional arguments for :meth:`po.imshow()`.
Returns
-------
fig
matplotlib Figure handle returned by plenoptic.imshow()
"""
# reshape so channel dim is last
im_shape = (
eigendistortion.n_channels,
eigendistortion.im_height,
eigendistortion.im_width,
)
image = eigendistortion.image.detach().view(1, *im_shape).cpu()
dist = (
eigendistortion.eigendistortions[eigendistortion._indexer(eigenindex)]
.unsqueeze(0)
.cpu()
)
img_processed = process_image(image + alpha * dist)
to_plot = torch.clamp(img_processed, 0, 1)
fig = imshow(to_plot, ax=ax, plot_complex=plot_complex, **kwargs)
return fig