"""Image-processing functions used by metrics.""" # numpydoc ignore=ES01
import warnings
from importlib import resources
from pathlib import Path
from typing import Literal
import numpy as np
import torch
import torch.nn.functional as F
from .convolutions import same_padding
from .filters import circular_gaussian2d
from .laplacian_pyramid import LaplacianPyramid
DIRNAME = resources.files("plenoptic.process")
__all__ = [
"normalized_laplacian_pyramid",
"ssim_map",
]
def __dir__() -> list[str]:
return __all__
def _ssim_parts(
img1: torch.Tensor,
img2: torch.Tensor,
pad: Literal[False, "constant", "reflect", "replicate", "circular"] = False,
func_name: str = "SSIM",
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Calculate the various components used to compute SSIM.
This should not be called by users directly, but is meant to assist for
calculating SSIM and MS-SSIM.
Parameters
----------
img1
The first image or batch of images, of shape (batch, channel, height, width).
img2
The second image or batch of images, of shape (batch, channel, height, width).
The heights and widths of ``img1`` and ``img2`` must be the same. The numbers of
batches and channels of ``img1`` and ``img2`` need to be broadcastable: either
they are the same or one of them is 1. The output will be computed separately
for each channel (so channels are treated in the same way as batches). Both
images should have values between 0 and 1. Otherwise, the result may be
inaccurate, and we will raise a warning (but will still compute it).
pad
If not ``False``, how to pad the image for the convolutions computing the
local average of each image. See :func:`torch.nn.functional.pad` for how
these work.
func_name
Name of the function that called this one, in order to raise more helpful error
/ warning messages.
Returns
-------
map_ssim
Map of SSIM values across the image.
contrast_structure_map
Map of contrast structure values.
weight
Weight used for stability of computation.
Raises
------
ValueError
If either ``img1`` or ``img2`` is not 4d.
ValueError
If ``img1`` and ``img2`` have different height or width.
ValueError
If ``img1`` and ``img2`` have different batch or channel, unless one of them has
a 1 there, so they can be broadcast.
ValueError
If ``img1`` and ``img2`` have different dtypes.
Warns
-----
UserWarning
If either ``img1`` or ``img2`` has multiple channels, as SSIM was designed for
grayscale images.
UserWarning
If either ``img1`` or ``img2`` has a value outside of range ``[0, 1]``.
""" # numpydoc ignore=EX01
img_ranges = torch.stack([img1.min(), img1.max(), img2.min(), img2.max()])
if (img_ranges > 1).any() or (img_ranges < 0).any():
warnings.warn(
f"Image range falls outside [0, 1]. {func_name} output may not make sense.",
)
if not img1.ndim == img2.ndim == 4:
raise ValueError(
"Input images should have four dimensions: (batch, channel, height, width)"
)
if img1.shape[-2:] != img2.shape[-2:]:
raise ValueError("img1 and img2 must have the same height and width!")
for i in range(2):
if img1.shape[i] != img2.shape[i] and img1.shape[i] != 1 and img2.shape[i] != 1:
raise ValueError(
"Either img1 and img2 should have the same number of "
"elements in the batch and channel dimensions, or one of "
"them should be 1! But got shapes "
f"{img1.shape}, {img2.shape} instead"
)
if img1.shape[1] > 1 or img2.shape[1] > 1:
warnings.warn(
"SSIM was designed for grayscale images and here it will be"
" computed separately for each channel (so channels are treated in"
" the same way as batches).",
)
if img1.dtype != img2.dtype:
raise ValueError("Input images must have same dtype!")
real_size = min(11, img1.shape[2], img1.shape[3])
std = torch.as_tensor(1.5).to(img1.device)
window = circular_gaussian2d(real_size, std=std).to(img1.dtype)
# these two checks are guaranteed with our above bits, but if we add
# ability for users to set own window, they'll be necessary
window_sum = window.sum((-1, -2), keepdim=True)
if not torch.allclose(window_sum, torch.ones_like(window_sum)):
warnings.warn("window should have sum of 1! normalizing...")
window = window / window_sum
if window.ndim != 4:
raise ValueError("window must have 4 dimensions!")
if pad is not False:
img1 = same_padding(img1, (real_size, real_size), pad_mode=pad)
img2 = same_padding(img2, (real_size, real_size), pad_mode=pad)
def windowed_average(img: torch.Tensor) -> torch.Tensor: # numpydoc ignore=GL08
padding = 0
(n_batches, n_channels, _, _) = img.shape
img = img.reshape(n_batches * n_channels, 1, img.shape[2], img.shape[3])
img_average = F.conv2d(img, window, padding=padding)
img_average = img_average.reshape(
n_batches, n_channels, img_average.shape[2], img_average.shape[3]
)
return img_average
mu1 = windowed_average(img1)
mu2 = windowed_average(img2)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = windowed_average(img1 * img1) - mu1_sq
sigma2_sq = windowed_average(img2 * img2) - mu2_sq
sigma12 = windowed_average(img1 * img2) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
# SSIM is the product of a luminance component, a contrast component, and a
# structure component. The contrast-structure component has to be separated
# when computing MS-SSIM.
luminance_map = (2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)
contrast_structure_map = (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
map_ssim = luminance_map * contrast_structure_map
# the weight used for stability
weight = torch.log((1 + sigma1_sq / C2) * (1 + sigma2_sq / C2))
return map_ssim, contrast_structure_map, weight
[docs]
def ssim_map(img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
"""
Structural similarity index map.
As described in Wang et al., 2004 [5]_, the structural similarity index (SSIM) is a
perceptual distance metric, giving the distance between two images. SSIM is based on
three comparison measurements between the two images: luminance, contrast, and
structure. All of these are computed convolutionally across the images. See the
references for more information.
This implementation follows the original implementation, as found online [6]_, as
well as providing the option to use the weighted version used in Wang and
Simoncelli, 2008 [8]_ (which was shown to consistently improve the image quality
prediction on the LIVE database). More info can be found online [7]_.
Note that this is a similarity metric (not a distance), and so 1 means the
two images are identical and 0 means they're very different. When the two
images are negatively correlated, SSIM can be negative. SSIM is bounded
between -1 and 1.
This function returns the SSIM map, showing the SSIM values across the
image. For the mean SSIM (a single value metric), call
:func:`~plenoptic.metric.ssim`.
Parameters
----------
img1
The first image or batch of images, of shape (batch, channel, height, width).
img2
The second image or batch of images, of shape (batch, channel, height, width).
The heights and widths of ``img1`` and ``img2`` must be the same. The numbers of
batches and channels of ``img1`` and ``img2`` need to be broadcastable: either
they are the same or one of them is 1. The output will be computed separately
for each channel (so channels are treated in the same way as batches). Both
images should have values between 0 and 1. Otherwise, the result may be
inaccurate, and we will raise a warning (but will still compute it).
Returns
-------
ssim_map
4d tensor containing the map of SSIM values.
Raises
------
ValueError
If either ``img1`` or ``img2`` is not 4d.
ValueError
If ``img1`` and ``img2`` have different height or width.
ValueError
If ``img1`` and ``img2`` have different batch or channel, unless one of them has
a 1 there, so they can be broadcast.
ValueError
If ``img1`` and ``img2`` have different dtypes.
Warns
-----
UserWarning
If either ``img1`` or ``img2`` has multiple channels, as SSIM was designed for
grayscale images.
UserWarning
If at least one scale from either ``img1`` or ``img2`` has height or width of
less than 11, since SSIM uses an 11x11 convolutional kernel.
References
----------
.. [5] Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image
quality assessment: From error measurement to structural similarity"
IEEE Transactions on Image Processing, vol. 13, no. 1, Jan. 2004.
.. [6] matlab code `<https://www.cns.nyu.edu/~lcv/ssim/ssim_index.m>`_
.. [7] project page `<https://www.cns.nyu.edu/~lcv/ssim/>`_
.. [8] Wang, Z., & Simoncelli, E. P. (2008). Maximum differentiation (MAD)
competition: A methodology for comparing computational models of
perceptual discriminability. Journal of Vision, 8(12), 1–13.
https://dx.doi.org/10.1167/8.12.8
Examples
--------
>>> import plenoptic as po
>>> import torch
>>> po.set_seed(0)
>>> img = po.data.einstein()
>>> ssim_map = po.process.ssim_map(img, img + torch.rand_like(img))
>>> ssim_map.shape
torch.Size([1, 1, 246, 246])
"""
if min(img1.shape[2], img1.shape[3]) < 11:
warnings.warn(
"SSIM uses 11x11 convolutional kernel, but the height and/or "
"the width of the input image is smaller than 11, so the "
"kernel size is set to be the minimum of these two numbers."
)
return _ssim_parts(img1, img2)[0]
[docs]
def normalized_laplacian_pyramid(img: torch.Tensor) -> list[torch.Tensor]:
"""
Compute the normalized Laplacian Pyramid using pre-optimized parameters.
Model parameters are those used in Laparra et al., 2016 [10]_, copied from the
matlab code used in the paper, found online [11]_.
Parameters
----------
img
Image, or batch of images of shape (batch, channel, height, width). This
representation is designed for grayscale images and will be computed separately
for each channel (so channels are treated in the same way as batches).
Returns
-------
normalized_laplacian_activations
The normalized Laplacian Pyramid with six scales.
References
----------
.. [10] Laparra, V., Ballé, J., Berardino, A. and Simoncelli, E.P., 2016. Perceptual
image quality assessment using a normalized Laplacian pyramid. Electronic
Imaging, 2016(16), pp.1-6.
.. [11] matlab code: `<https://www.cns.nyu.edu/~lcv/NLPyr/NLP_dist.m>`_
Examples
--------
.. plot::
>>> import plenoptic as po
>>> img = po.data.einstein()
>>> pyramid = po.process.normalized_laplacian_pyramid(img)
>>> [p.shape for p in pyramid]
[torch.Size([1, 1, 256, 256]),
torch.Size([1, 1, 128, 128]),
torch.Size([1, 1, 64, 64]),
torch.Size([1, 1, 32, 32]),
torch.Size([1, 1, 16, 16]),
torch.Size([1, 1, 8, 8])]
>>> po.plot.imshow(pyramid, col_wrap=3)
<PyrFigure size ...>
"""
(_, channel, _, _) = img.size()
N_scales = 6
spatialpooling_filters = np.load(Path(DIRNAME) / "DN_filts.npy")
sigmas = np.load(Path(DIRNAME) / "DN_sigmas.npy")
L = LaplacianPyramid(n_scales=N_scales, scale_filter=True)
laplacian_activations = L.forward(img)
padding = 2
normalized_laplacian_activations = []
for N_b in range(0, N_scales):
filt = torch.as_tensor(
spatialpooling_filters[N_b], dtype=img.dtype, device=img.device
).repeat(channel, 1, 1, 1)
filtered_activations = F.conv2d(
torch.abs(laplacian_activations[N_b]),
filt,
padding=padding,
groups=channel,
)
normalized_laplacian_activations.append(
laplacian_activations[N_b] / (sigmas[N_b] + filtered_activations)
)
return normalized_laplacian_activations