Source code for plenoptic.metric.perceptual_distance

import warnings
from importlib import resources
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F

from ..simulate.canonical_computations import LaplacianPyramid
from ..simulate.canonical_computations.filters import circular_gaussian2d
from ..tools.conv import same_padding

DIRNAME = resources.files("plenoptic.metric")


def _ssim_parts(img1, img2, pad=False):
    """Calcluates the various components used to compute SSIM

    This should not be called by users directly, but is meant to assist for
    calculating SSIM and MS-SSIM.

    Parameters
    ----------
    img1: torch.Tensor of shape (batch, channel, height, width)
        The first image or batch of images.
    img2: torch.Tensor of shape (batch, channel, height, width)
        The second image or batch of images. The heights and widths of `img1`
        and `img2` must be the same. The numbers of batches and channels of
        `img1` and `img2` need to be broadcastable: either they are the same
        or one of them is 1. The output will be computed separately for each
        channel (so channels are treated in the same way as batches). Both
        images should have values between 0 and 1. Otherwise, the result may
        be inaccurate, and we will raise a warning (but will still compute it).
    pad : {False, 'constant', 'reflect', 'replicate', 'circular'}, optional
        If not False, how to pad the image for the convolutions computing the
        local average of each image. See `torch.nn.functional.pad` for how
        these work.

    """
    img_ranges = torch.as_tensor([[img1.min(), img1.max()], [img2.min(), img2.max()]])
    if (img_ranges > 1).any() or (img_ranges < 0).any():
        warnings.warn(
            "Image range falls outside [0, 1]."
            f" img1: {img_ranges[0]}, img2: {img_ranges[1]}. "
            "Continuing anyway..."
        )

    if not img1.ndim == img2.ndim == 4:
        raise Exception(
            "Input images should have four dimensions: (batch, channel,"
            " height, width)"
        )
    if img1.shape[-2:] != img2.shape[-2:]:
        raise Exception("img1 and img2 must have the same height and width!")
    for i in range(2):
        if img1.shape[i] != img2.shape[i] and img1.shape[i] != 1 and img2.shape[i] != 1:
            raise Exception(
                "Either img1 and img2 should have the same number of "
                "elements in each dimension, or one of "
                "them should be 1! But got shapes "
                f"{img1.shape}, {img2.shape} instead"
            )
    if img1.shape[1] > 1 or img2.shape[1] > 1:
        warnings.warn(
            "SSIM was designed for grayscale images and here it will be"
            " computed separately for each channel (so channels are treated in"
            " the same way as batches)."
        )
    if img1.dtype != img2.dtype:
        raise ValueError("Input images must have same dtype!")

    real_size = min(11, img1.shape[2], img1.shape[3])
    std = torch.as_tensor(1.5).to(img1.device)
    window = circular_gaussian2d(real_size, std=std).to(img1.dtype)

    # these two checks are guaranteed with our above bits, but if we add
    # ability for users to set own window, they'll be necessary
    window_sum = window.sum((-1, -2), keepdim=True)
    if not torch.allclose(window_sum, torch.ones_like(window_sum)):
        warnings.warn("window should have sum of 1! normalizing...")
        window = window / window_sum
    if window.ndim != 4:
        raise Exception("window must have 4 dimensions!")

    if pad is not False:
        img1 = same_padding(img1, (real_size, real_size), pad_mode=pad)
        img2 = same_padding(img2, (real_size, real_size), pad_mode=pad)

    def windowed_average(img):
        padd = 0
        (n_batches, n_channels, _, _) = img.shape
        img = img.reshape(n_batches * n_channels, 1, img.shape[2], img.shape[3])
        img_average = F.conv2d(img, window, padding=padd)
        img_average = img_average.reshape(
            n_batches, n_channels, img_average.shape[2], img_average.shape[3]
        )
        return img_average

    mu1 = windowed_average(img1)
    mu2 = windowed_average(img2)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = windowed_average(img1 * img1) - mu1_sq
    sigma2_sq = windowed_average(img2 * img2) - mu2_sq
    sigma12 = windowed_average(img1 * img2) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    # SSIM is the product of a luminance component, a contrast component, and a
    # structure component. The contrast-structure component has to be separated
    # when computing MS-SSIM.
    luminance_map = (2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)
    contrast_structure_map = (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
    map_ssim = luminance_map * contrast_structure_map

    # the weight used for stability
    weight = torch.log((1 + sigma1_sq / C2) * (1 + sigma2_sq / C2))
    return map_ssim, contrast_structure_map, weight


[docs] def ssim(img1, img2, weighted=False, pad=False): r"""Structural similarity index As described in [1]_, the structural similarity index (SSIM) is a perceptual distance metric, giving the distance between two images. SSIM is based on three comparison measurements between the two images: luminance, contrast, and structure. All of these are computed convolutionally across the images. See the references for more information. This implementation follows the original implementation, as found at [2]_, as well as providing the option to use the weighted version used in [4]_ (which was shown to consistently improve the image quality prediction on the LIVE database). Note that this is a similarity metric (not a distance), and so 1 means the two images are identical and 0 means they're very different. When the two images are negatively correlated, SSIM can be negative. SSIM is bounded between -1 and 1. This function returns the mean SSIM, a scalar-valued metric giving the average over the whole image. For the SSIM map (showing the computed value across the image), call `ssim_map`. Parameters ---------- img1: torch.Tensor of shape (batch, channel, height, width) The first image or batch of images. img2: torch.Tensor of shape (batch, channel, height, width) The second image or batch of images. The heights and widths of `img1` and `img2` must be the same. The numbers of batches and channels of `img1` and `img2` need to be broadcastable: either they are the same or one of them is 1. The output will be computed separately for each channel (so channels are treated in the same way as batches). Both images should have values between 0 and 1. Otherwise, the result may be inaccurate, and we will raise a warning (but will still compute it). weighted : bool, optional whether to use the original, unweighted SSIM version (`False`) as used in [1]_ or the weighted version (`True`) as used in [4]_. See Notes section for the weight pad : {False, 'constant', 'reflect', 'replicate', 'circular'}, optional If not False, how to pad the image for the convolutions computing the local average of each image. See `torch.nn.functional.pad` for how these work. Returns ------- mssim : torch.Tensor 2d tensor of shape (batch, channel) containing the mean SSIM for each image, averaged over the whole image Notes ----- The weight used when `weighted=True` is: .. math:: \log((1+\frac{\sigma_1^2}{C_2})(1+\frac{\sigma_2^2}{C_2})) where :math:`sigma_1^2` and :math:`sigma_2^2` are the variances of `img1` and `img2`, respectively, and :math:`C_2` is a constant. See [4]_ for more details. References ---------- .. [1] Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image quality assessment: From error measurement to structural similarity" IEEE Transactions on Image Processing, vol. 13, no. 1, Jan. 2004. .. [2] [matlab code](https://www.cns.nyu.edu/~lcv/ssim/ssim_index.m) .. [3] [project page](https://www.cns.nyu.edu/~lcv/ssim/) .. [4] Wang, Z., & Simoncelli, E. P. (2008). Maximum differentiation (MAD) competition: A methodology for comparing computational models of perceptual discriminability. Journal of Vision, 8(12), 1–13. https://dx.doi.org/10.1167/8.12.8 """ # these are named map_ssim instead of the perhaps more natural ssim_map # because that's the name of a function map_ssim, _, weight = _ssim_parts(img1, img2, pad) if not weighted: mssim = map_ssim.mean((-1, -2)) else: mssim = (map_ssim * weight).sum((-1, -2)) / weight.sum((-1, -2)) if min(img1.shape[2], img1.shape[3]) < 11: warnings.warn( "SSIM uses 11x11 convolutional kernel, but the height and/or " "the width of the input image is smaller than 11, so the " "kernel size is set to be the minimum of these two numbers." ) return mssim
[docs] def ssim_map(img1, img2): """Structural similarity index map As described in [1]_, the structural similarity index (SSIM) is a perceptual distance metric, giving the distance between two images. SSIM is based on three comparison measurements between the two images: luminance, contrast, and structure. All of these are computed convolutionally across the images. See the references for more information. This implementation follows the original implementation, as found at [2]_, as well as providing the option to use the weighted version used in [4]_ (which was shown to consistently improve the image quality prediction on the LIVE database). Note that this is a similarity metric (not a distance), and so 1 means the two images are identical and 0 means they're very different. When the two images are negatively correlated, SSIM can be negative. SSIM is bounded between -1 and 1. This function returns the SSIM map, showing the SSIM values across the image. For the mean SSIM (a single value metric), call `ssim`. Parameters ---------- img1: torch.Tensor of shape (batch, channel, height, width) The first image or batch of images. img2: torch.Tensor of shape (batch, channel, height, width) The second image or batch of images. The heights and widths of `img1` and `img2` must be the same. The numbers of batches and channels of `img1` and `img2` need to be broadcastable: either they are the same or one of them is 1. The output will be computed separately for each channel (so channels are treated in the same way as batches). Both images should have values between 0 and 1. Otherwise, the result may be inaccurate, and we will raise a warning (but will still compute it). weighted : bool, optional whether to use the original, unweighted SSIM version (`False`) as used in [1]_ or the weighted version (`True`) as used in [4]_. See Notes section for the weight Returns ------- ssim_map : torch.Tensor 4d tensor containing the map of SSIM values. References ---------- .. [1] Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image quality assessment: From error measurement to structural similarity" IEEE Transactions on Image Processing, vol. 13, no. 1, Jan. 2004. .. [2] [matlab code](https://www.cns.nyu.edu/~lcv/ssim/ssim_index.m) .. [3] [project page](https://www.cns.nyu.edu/~lcv/ssim/) .. [4] Wang, Z., & Simoncelli, E. P. (2008). Maximum differentiation (MAD) competition: A methodology for comparing computational models of perceptual discriminability. Journal of Vision, 8(12), 1–13. https://dx.doi.org/10.1167/8.12.8 """ if min(img1.shape[2], img1.shape[3]) < 11: warnings.warn( "SSIM uses 11x11 convolutional kernel, but the height and/or " "the width of the input image is smaller than 11, so the " "kernel size is set to be the minimum of these two numbers." ) return _ssim_parts(img1, img2)[0]
[docs] def ms_ssim(img1, img2, power_factors=None): r"""Multiscale structural similarity index (MS-SSIM) As described in [1]_, multiscale structural similarity index (MS-SSIM) is an improvement upon structural similarity index (SSIM) that takes into account the perceptual distance between two images on different scales. SSIM is based on three comparison measurements between the two images: luminance, contrast, and structure. All of these are computed convolutionally across the images, producing three maps instead of scalars. The SSIM map is the elementwise product of these three maps. See `metric.ssim` and `metric.ssim_map` for a full description of SSIM. To get images of different scales, average pooling operations with kernel size 2 are performed recursively on the input images. The product of contrast map and structure map (the "contrast-structure map") is computed for all but the coarsest scales, and the overall SSIM map is only computed for the coarsest scale. Their mean values are raised to exponents and multiplied to produce MS-SSIM: .. math:: MSSSIM = {SSIM}_M^{a_M} \prod_{i=1}^{M-1} ({CS}_i)^{a_i} Here :math: `M` is the number of scales, :math: `{CS}_i` is the mean value of the contrast-structure map for the i'th finest scale, and :math: `{SSIM}_M` is the mean value of the SSIM map for the coarsest scale. If at least one of these terms are negative, the value of MS-SSIM is zero. The values of :math: `a_i, i=1,...,M` are taken from the argument `power_factors`. Parameters ---------- img1: torch.Tensor of shape (batch, channel, height, width) The first image or batch of images. img2: torch.Tensor of shape (batch, channel, height, width) The second image or batch of images. The heights and widths of `img1` and `img2` must be the same. The numbers of batches and channels of `img1` and `img2` need to be broadcastable: either they are the same or one of them is 1. The output will be computed separately for each channel (so channels are treated in the same way as batches). Both images should have values between 0 and 1. Otherwise, the result may be inaccurate, and we will raise a warning (but will still compute it). power_factors : 1D array, optional. power exponents for the mean values of maps, for different scales (from fine to coarse). The length of this array determines the number of scales. By default, this is set to [0.0448, 0.2856, 0.3001, 0.2363, 0.1333], which is what psychophysical experiments in [1]_ found. Returns ------- msssim : torch.Tensor 2d tensor of shape (batch, channel) containing the MS-SSIM for each image References ---------- .. [1] Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. "Multiscale structural similarity for image quality assessment." The Thrity-Seventh Asilomar Conference on Signals, Systems & Computers, 2003. Vol. 2. IEEE, 2003. """ if power_factors is None: power_factors = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] def downsample(img): img = F.pad(img, (0, img.shape[3] % 2, 0, img.shape[2] % 2), mode="replicate") img = F.avg_pool2d(img, kernel_size=2) return img msssim = 1 for i in range(len(power_factors) - 1): _, contrast_structure_map, _ = _ssim_parts(img1, img2) msssim *= F.relu(contrast_structure_map.mean((-1, -2))).pow(power_factors[i]) img1 = downsample(img1) img2 = downsample(img2) map_ssim, _, _ = _ssim_parts(img1, img2) msssim *= F.relu(map_ssim.mean((-1, -2))).pow(power_factors[-1]) if min(img1.shape[2], img1.shape[3]) < 11: warnings.warn( "SSIM uses 11x11 convolutional kernel, but for some scales " "of the input image, the height and/or the width is smaller " "than 11, so the kernel size in SSIM is set to be the " "minimum of these two numbers for these scales." ) return msssim
[docs] def normalized_laplacian_pyramid(img): """Compute the normalized Laplacian Pyramid using pre-optimized parameters Parameters ---------- img: torch.Tensor of shape (batch, channel, height, width) Image, or batch of images. This representation is designed for grayscale images and will be computed separately for each channel (so channels are treated in the same way as batches). Returns ------- normalized_laplacian_activations: list of torch.Tensor The normalized Laplacian Pyramid with six scales """ (_, channel, height, width) = img.size() N_scales = 6 spatialpooling_filters = np.load(Path(DIRNAME) / "DN_filts.npy") sigmas = np.load(Path(DIRNAME) / "DN_sigmas.npy") L = LaplacianPyramid(n_scales=N_scales, scale_filter=True) laplacian_activations = L.forward(img) padd = 2 normalized_laplacian_activations = [] for N_b in range(0, N_scales): filt = torch.as_tensor( spatialpooling_filters[N_b], dtype=img.dtype, device=img.device ).repeat(channel, 1, 1, 1) filtered_activations = F.conv2d( torch.abs(laplacian_activations[N_b]), filt, padding=padd, groups=channel, ) normalized_laplacian_activations.append( laplacian_activations[N_b] / (sigmas[N_b] + filtered_activations) ) return normalized_laplacian_activations
[docs] def nlpd(img1, img2): """Normalized Laplacian Pyramid Distance As described in [1]_, this is an image quality metric based on the transformations associated with the early visual system: local luminance subtraction and local contrast gain control. A laplacian pyramid subtracts a local estimate of the mean luminance at six scales. Then a local gain control divides these centered coefficients by a weighted sum of absolute values in spatial neighborhood. These weights parameters were optimized for redundancy reduction over an training database of (undistorted) natural images. Note that we compute root mean squared error for each scale, and then average over these, effectively giving larger weight to the lower frequency coefficients (which are fewer in number, due to subsampling). Parameters ---------- img1: torch.Tensor of shape (batch, channel, height, width) The first image or batch of images. img2: torch.Tensor of shape (batch, channel, height, width) The second image or batch of images. The heights and widths of `img1` and `img2` must be the same. The numbers of batches and channels of `img1` and `img2` need to be broadcastable: either they are the same or one of them is 1. The output will be computed separately for each channel (so channels are treated in the same way as batches). Both images should have values between 0 and 1. Otherwise, the result may be inaccurate, and we will raise a warning (but will still compute it). Returns ------- distance: torch.Tensor of shape (batch, channel) The normalized Laplacian Pyramid distance. References ---------- .. [1] Laparra, V., Ballé, J., Berardino, A. and Simoncelli, E.P., 2016. Perceptual image quality assessment using a normalized Laplacian pyramid. Electronic Imaging, 2016(16), pp.1-6. """ if not img1.ndim == img2.ndim == 4: raise Exception( "Input images should have four dimensions: (batch, channel," " height, width)" ) if img1.shape[-2:] != img2.shape[-2:]: raise Exception("img1 and img2 must have the same height and width!") for i in range(2): if img1.shape[i] != img2.shape[i] and img1.shape[i] != 1 and img2.shape[i] != 1: raise Exception( "Either img1 and img2 should have the same number of " "elements in each dimension, or one of " "them should be 1! But got shapes " f"{img1.shape}, {img2.shape} instead" ) if img1.shape[1] > 1 or img2.shape[1] > 1: warnings.warn( "NLPD was designed for grayscale images and here it will be" " computed separately for each channel (so channels are treated in" " the same way as batches)." ) img_ranges = torch.as_tensor([[img1.min(), img1.max()], [img2.min(), img2.max()]]) if (img_ranges > 1).any() or (img_ranges < 0).any(): warnings.warn( "Image range falls outside [0, 1]." f" img1: {img_ranges[0]}, img2: {img_ranges[1]}. " "Continuing anyway..." ) y1 = normalized_laplacian_pyramid(img1) y2 = normalized_laplacian_pyramid(img2) epsilon = 1e-10 # for optimization purpose (stabilizing the gradient around zero) dist = [] for i in range(6): dist.append(torch.sqrt(torch.mean((y1[i] - y2[i]) ** 2, dim=(2, 3)) + epsilon)) return torch.stack(dist).mean(dim=0)