Source code for plenoptic.simulate.models.portilla_simoncelli

"""Portilla-Simoncelli texture statistics.

The Portilla-Simoncelli (PS) texture statistics are a set of image
statistics, first described in [1]_, that are proposed as a sufficient set
of measurements for describing visual textures. That is, if two texture
images have the same values for all PS texture stats, humans should
consider them as members of the same family of textures.
"""

from collections import OrderedDict
from typing import Literal

import einops
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.fft
import torch.nn as nn
from torch import Tensor

from ...tools import signal, stats
from ...tools.data import to_numpy
from ...tools.display import clean_stem_plot, clean_up_axes, update_stem
from ...tools.validate import validate_input
from ..canonical_computations.steerable_pyramid_freq import (
    SCALES_TYPE as PYR_SCALES_TYPE,
)
from ..canonical_computations.steerable_pyramid_freq import (
    SteerablePyramidFreq,
)

SCALES_TYPE = Literal["pixel_statistics"] | PYR_SCALES_TYPE


[docs] class PortillaSimoncelli(nn.Module): r"""Portila-Simoncelli texture statistics. The Portilla-Simoncelli (PS) texture statistics are a set of image statistics, first described in [1]_, that are proposed as a sufficient set of measurements for describing visual textures. That is, if two texture images have the same values for all PS texture stats, humans should consider them as members of the same family of textures. The PS stats are computed based on the steerable pyramid [2]_. They consist of the local auto-correlations, cross-scale (within-orientation) correlations, and cross-orientation (within-scale) correlations of both the pyramid coefficients and the local energy (as computed by those coefficients). Additionally, they include the first four global moments (mean, variance, skew, and kurtosis) of the image and down-sampled versions of that image. See the paper and notebook for more description. Parameters ---------- image_shape: Shape of input image. n_scales: The number of pyramid scales used to measure the statistics (default=4) n_orientations: The number of orientations used to measure the statistics (default=4) spatial_corr_width: The width of the spatial cross- and auto-correlation statistics Attributes ---------- scales: list The names of the unique scales of coefficients in the pyramid, used for coarse-to-fine metamer synthesis. References ---------- .. [1] J Portilla and E P Simoncelli. A Parametric Texture Model based on Joint Statistics of Complex Wavelet Coefficients. Int'l Journal of Computer Vision. 40(1):49-71, October, 2000. https://www.cns.nyu.edu/~eero/ABSTRACTS/portilla99-abstract.html https://www.cns.nyu.edu/~lcv/texture/ .. [2] E P Simoncelli and W T Freeman, "The Steerable Pyramid: A Flexible Architecture for Multi-Scale Derivative Computation," Second Int'l Conf on Image Processing, Washington, DC, Oct 1995. """ def __init__( self, image_shape: tuple[int, int], n_scales: int = 4, n_orientations: int = 4, spatial_corr_width: int = 9, ): super().__init__() self.image_shape = image_shape if any([(image_shape[-1] / 2**i) % 2 for i in range(n_scales)]) or any( [(image_shape[-2] / 2**i) % 2 for i in range(n_scales)] ): raise ValueError( "Because of how the Portilla-Simoncelli model handles " "multiscale representations, it only works with images" " whose shape can be divided by 2 `n_scales` times." ) self.spatial_corr_width = spatial_corr_width self.n_scales = n_scales self.n_orientations = n_orientations self._pyr = SteerablePyramidFreq( self.image_shape, height=self.n_scales, order=self.n_orientations - 1, is_complex=True, tight_frame=False, ) self.scales = ( ["pixel_statistics", "residual_lowpass"] + [ii for ii in range(n_scales - 1, -1, -1)] + ["residual_highpass"] ) # Dictionary defining shape of the statistics and which scale they're # associated with scales_shape_dict = self._create_scales_shape_dict() # Dictionary defining necessary statistics, that is, those that are not # redundant self._necessary_stats_dict = self._create_necessary_stats_dict( scales_shape_dict ) # turn this into tensor we can use in forward pass. first into a # boolean mask... _necessary_stats_mask = einops.pack( list(self._necessary_stats_dict.values()), "*" )[0] # then into a tensor of indices _necessary_stats_mask = torch.where(_necessary_stats_mask)[0] self.register_buffer("_necessary_stats_mask", _necessary_stats_mask) # This array is composed of the following values: 'pixel_statistics', # 'residual_lowpass', 'residual_highpass' and integer values from 0 to # self.n_scales-1. It is the same size as the representation tensor # returned by this object's forward method. It must be a numpy array so # we can have a mixture of ints and strs (and so we can use np.in1d # later) self._representation_scales = einops.pack( list(scales_shape_dict.values()), "*" )[0] # just select the scales of the necessary stats. self._representation_scales = self._representation_scales[ self._necessary_stats_mask ] def _create_scales_shape_dict(self) -> OrderedDict: """Create dictionary defining scales and shape of each stat. This dictionary functions as metadata which is used for two main purposes: - Scale assignment. In order for optimization to work well, we proceed in a "coarse-to-fine" manner. That is, we start optimization by only considering the statistics related to the lowest frequencies, and gradually add in those related to higher and higher frequencies. This is similar to blurring the objective function and then gradually adding in finer and finer details. The numbers in this dictionary map the computed statistics to their corresponding scales, which we use in remove_scales to throw away some stats as needed. - Redundant stat identification. As described at the bottom of the notebook, the model incidentally computes a whole bunch of redundant stats, because auto- and cross-correlation matrices have certain symmetries. the _create_necessary_stats_dict method accepts the dictionary created here as input and uses the values to get the shapes of these and insert True/False as necessary. Returns ------- scales_shape_dict Dictionary defining shape and associated scales of each computed statistic. The keys name each statistic, with dummy arrays as values. These arrays have the same shape as the stat (excluding batch and channel), with values defining which scale they correspond to. """ shape_dict = OrderedDict() # There are 6 pixel statistics shape_dict["pixel_statistics"] = np.array(6 * ["pixel_statistics"]) # These are the basic building blocks of the scale assignments for many # of the statistics calculated by the PortillaSimoncelli model. scales = np.arange(self.n_scales) # the cross-scale correlations exclude the coarsest scale scales_without_coarsest = np.arange(self.n_scales - 1) # the statistics computed on the reconstructed bandpass images have an # extra scale corresponding to the lowpass residual scales_with_lowpass = np.array( scales.tolist() + ["residual_lowpass"], dtype=object ) # now we go through each statistic in order and create a dummy array # full of 1s with the same shape as the actual statistic (excluding the # batch and channel dimensions, as each stat is computed independently # across those dimensions). We then multiply it by one of the scales # arrays above to turn those 1s into values describing the # corresponding scale. auto_corr_mag = np.ones( ( self.spatial_corr_width, self.spatial_corr_width, self.n_orientations, self.n_scales, ), dtype=int, ) # this rearrange call is turning scales from 1d with shape (n_scales, ) # to 4d with shape (1, 1, n_scales, 1), so that it matches # auto_corr_mag. the following rearrange calls do similar. auto_corr_mag *= einops.rearrange(scales, "s -> 1 1 1 s") shape_dict["auto_correlation_magnitude"] = auto_corr_mag shape_dict["skew_reconstructed"] = scales_with_lowpass shape_dict["kurtosis_reconstructed"] = scales_with_lowpass auto_corr = np.ones( ( self.spatial_corr_width, self.spatial_corr_width, self.n_scales + 1, ), dtype=object, ) auto_corr *= einops.rearrange(scales_with_lowpass, "s -> 1 1 s") shape_dict["auto_correlation_reconstructed"] = auto_corr shape_dict["std_reconstructed"] = scales_with_lowpass cross_orientation_corr_mag = np.ones( (self.n_orientations, self.n_orientations, self.n_scales), dtype=int, ) cross_orientation_corr_mag *= einops.rearrange(scales, "s -> 1 1 s") shape_dict["cross_orientation_correlation_magnitude"] = ( cross_orientation_corr_mag ) mags_std = np.ones((self.n_orientations, self.n_scales), dtype=int) mags_std *= einops.rearrange(scales, "s -> 1 s") shape_dict["magnitude_std"] = mags_std cross_scale_corr_mag = np.ones( (self.n_orientations, self.n_orientations, self.n_scales - 1), dtype=int, ) cross_scale_corr_mag *= einops.rearrange(scales_without_coarsest, "s -> 1 1 s") shape_dict["cross_scale_correlation_magnitude"] = cross_scale_corr_mag cross_scale_corr_real = np.ones( (self.n_orientations, 2 * self.n_orientations, self.n_scales - 1), dtype=int, ) cross_scale_corr_real *= einops.rearrange(scales_without_coarsest, "s -> 1 1 s") shape_dict["cross_scale_correlation_real"] = cross_scale_corr_real shape_dict["var_highpass_residual"] = np.array(["residual_highpass"]) return shape_dict def _create_necessary_stats_dict( self, scales_shape_dict: OrderedDict ) -> OrderedDict: """Create mask specifying the necessary statistics. Some of the statistics computed by the model are redundant, due to symmetries. For example, about half of the values in the autocorrelation matrices are duplicates. See the Portilla-Simoncelli notebook for more details. Parameters ---------- scales_shape_dict Dictionary defining shape and associated scales of each computed statistic. Returns ------- necessary_stats_dict Dictionary defining which statistics are necessary (i.e., not redundant). Will have the same keys as scales_shape_dict, with the values being boolean tensors of the same shape as scales_shape_dict's corresponding values. True denotes the statistics that will be included in the model's output, while False denotes the redundant ones we will toss. """ mask_dict = scales_shape_dict.copy() # Pre-compute some necessary indices. # Lower triangular indices (including diagonal), for auto correlations tril_inds = torch.tril_indices(self.spatial_corr_width, self.spatial_corr_width) # Get the second half of the diagonal, i.e., everything from the center # element on. These are all repeated for the auto correlations. (As # these are autocorrelations (rather than auto-covariance) matrices, # they've been normalized by the variance and so the center element is # always 1, and thus uninformative) diag_repeated = torch.arange( start=self.spatial_corr_width // 2, end=self.spatial_corr_width ) # Upper triangle indices, including diagonal. These are redundant stats # for cross_orientation_correlation_magnitude (because we've normalized # this matrix to be true cross-correlations, the diagonals are all 1, # like for the auto-correlations) triu_inds = torch.triu_indices(self.n_orientations, self.n_orientations) for k, v in mask_dict.items(): if k in [ "auto_correlation_magnitude", "auto_correlation_reconstructed", ]: # Symmetry M_{i,j} = M_{n-i+1, n-j+1} # Start with all False, then place True in necessary stats. mask = torch.zeros(v.shape, dtype=torch.bool) mask[tril_inds[0], tril_inds[1]] = True # if spatial_corr_width is even, then the first row is not # redundant with anything either if np.mod(self.spatial_corr_width, 2) == 0: mask[0] = True mask[diag_repeated, diag_repeated] = False elif k == "cross_orientation_correlation_magnitude": # Symmetry M_{i,j} = M_{j,i}. # Start with all True, then place False in redundant stats. mask = torch.ones(v.shape, dtype=torch.bool) mask[triu_inds[0], triu_inds[1]] = False else: # all of the other stats have no redundancies mask = torch.ones(v.shape, dtype=torch.bool) mask_dict[k] = mask return mask_dict
[docs] def forward(self, image: Tensor, scales: list[SCALES_TYPE] | None = None) -> Tensor: r"""Generate Texture Statistics representation of an image. Note that separate batches and channels are analyzed in parallel. Parameters ---------- image : A 4d tensor (batch, channel, height, width) containing the image(s) to analyze. scales : Which scales to include in the returned representation. If None, we include all scales. Otherwise, can contain subset of values present in this model's ``scales`` attribute, and the returned tensor will then contain the subset corresponding to those scales. Returns ------- representation_tensor: 3d tensor of shape (batch, channel, stats) containing the measured texture statistics. Raises ------ ValueError : If `image` is not 4d or has a dtype other than float or complex. """ validate_input(image) # pyr_dict is the dictionary of complex-valued tensors returned by the # steerable pyramid. pyr_coeffs is a list (length n_scales) of 5d # tensors, each of shape (batch, channel, scales, n_orientations, # height, width) containing the complex-valued oriented bands, while # highpass is a real-valued 4d tensor of shape (batch, channel, height, # width). Note that the residual lowpass in pyr_dict has been demeaned. # We keep both the dict and list of pyramid coefficients because we # need the dictionary for reconstructing the image done later on. pyr_dict, pyr_coeffs, highpass, _ = self._compute_pyr_coeffs(image) # Now, we create several intermediate representations that we'll use to # compute the texture statistics later. # First, two intermediate dictionaries: magnitude_pyr_coeffs and # real_pyr_coeffs, which contain the demeaned magnitude of the pyramid # coefficients and the real part of the pyramid coefficients # respectively. ( mag_pyr_coeffs, real_pyr_coeffs, ) = self._compute_intermediate_representations(pyr_coeffs) # Then, the reconstructed lowpass image at each scale. (this is a list # of length n_scales+1 containing tensors of shape (batch, channel, # height, width)) reconstructed_images = self._reconstruct_lowpass_at_each_scale(pyr_dict) # the reconstructed_images list goes from coarse-to-fine, but we want # each of the stats computed from it to go from fine-to-coarse, so we # reverse its direction. reconstructed_images = reconstructed_images[::-1] # Now, start calculating the PS texture stats. # Calculate pixel statistics (mean, variance, skew, kurtosis, min, # max). pixel_stats = self._compute_pixel_stats(image) # Compute the central autocorrelation of the coefficient magnitudes. This is a # tensor of shape: (batch, channel, spatial_corr_width, spatial_corr_width, # n_orientations, n_scales). var_mags is a tensor of shape (batch, channel, # n_orientations, n_scales) autocorr_mags, mags_var = self._compute_autocorr(mag_pyr_coeffs) # mags_var is the variance of the magnitude coefficients at each scale (it's an # intermediary of the computation of the auto-correlations). We take the square # root to get the standard deviation. mags_std = mags_var.sqrt() # Compute the central autocorrelation of the reconstructed lowpass # images at each scale (and their variances). autocorr_recon is a # tensor of shape (batch, channel, spatial_corr_width, # spatial_corr_width, n_scales+1), and var_recon is a tensor of shape # (batch, channel, n_scales+1) autocorr_recon, var_recon = self._compute_autocorr(reconstructed_images) # Compute the standard deviation, skew, and kurtosis of each # reconstructed lowpass image. std_recon, skew_recon, and # kurtosis_recon will all end up as tensors of shape (batch, channel, # n_scales+1) std_recon = var_recon.sqrt() skew_recon, kurtosis_recon = self._compute_skew_kurtosis_recon( reconstructed_images, var_recon, pixel_stats[..., 1] ) # Compute the cross-orientation correlations between the magnitude # coefficients at each scale. this will be a tensor of shape (batch, # channel, n_orientations, n_orientations, n_scales) cross_ori_corr_mags = self._compute_cross_correlation( mag_pyr_coeffs, mag_pyr_coeffs, mags_var, mags_var ) # If we have more than one scale, compute the cross-scale correlations if self.n_scales != 1: # First, double the phase the coefficients, so we can correctly # compute correlations across scales. ( phase_doubled_mags, phase_doubled_sep, ) = self._double_phase_pyr_coeffs(pyr_coeffs) # Compute the cross-scale correlations between the magnitude # coefficients. For each coefficient, we're correlating it with the # coefficients at the next-coarsest scale. this will be a tensor of # shape (batch, channel, n_orientations, n_orientations, # n_scales-1) cross_scale_corr_mags = self._compute_cross_correlation( mag_pyr_coeffs[:-1], phase_doubled_mags, mags_var[..., :-1] ) # Compute the cross-scale correlations between the real # coefficients and the real and imaginary coefficients at the next # coarsest scale. this will be a tensor of shape (batch, channel, # n_orientations, 2*n_orientations, n_scales-1) cross_scale_corr_real = self._compute_cross_correlation( real_pyr_coeffs[:-1], phase_doubled_sep ) # Compute the variance of the highpass residual var_highpass_residual = highpass.pow(2).mean(dim=(-2, -1)) # Now, combine all these stats together, first into a list all_stats = [ pixel_stats, autocorr_mags, skew_recon, kurtosis_recon, autocorr_recon, std_recon, cross_ori_corr_mags, mags_std, ] if self.n_scales != 1: all_stats += [cross_scale_corr_mags, cross_scale_corr_real] all_stats += [var_highpass_residual] # And then pack them into a 3d tensor representation_tensor, pack_info = einops.pack(all_stats, "b c *") # the only time when this is None is during testing, when we make sure # that our assumptions are all valid. if self._necessary_stats_mask is None: # store this so we can unpack this info (only possible when we've # discarded no stats) self._pack_info = pack_info else: # Throw away all redundant statistics representation_tensor = representation_tensor.index_select( -1, self._necessary_stats_mask ) # Return the subset of stats corresponding to the specified scale. if scales is not None: representation_tensor = self.remove_scales(representation_tensor, scales) return representation_tensor
[docs] def remove_scales( self, representation_tensor: Tensor, scales_to_keep: list[SCALES_TYPE] ) -> Tensor: """Remove statistics not associated with scales. For a given representation_tensor and a list of scales_to_keep, this attribute removes all statistics *not* associated with those scales. Note that calling this method will always remove statistics. Parameters ---------- representation_tensor: 3d tensor containing the measured representation statistics. scales_to_keep: Which scales to include in the returned representation. Can contain subset of values present in this model's ``scales`` attribute, and the returned tensor will then contain the subset of the full representation corresponding to those scales. Returns ------- limited_representation_tensor : Representation tensor with some statistics removed. """ # this is necessary because object is the dtype of # self._representation_scales scales_to_keep = np.array(scales_to_keep, dtype=object) # np.in1d returns a 1d boolean array of the same shape as # self._representation_scales with True at each location where that # value appears in scales_to_keep. where then converts this boolean # array into indices ind = np.where(np.in1d(self._representation_scales, scales_to_keep))[0] ind = torch.from_numpy(ind).to(representation_tensor.device) return representation_tensor.index_select(-1, ind)
[docs] def convert_to_tensor(self, representation_dict: OrderedDict) -> Tensor: r"""Convert dictionary of statistics to a tensor. Parameters ---------- representation_dict : Dictionary of representation. Returns ------- 3d tensor of statistics. See Also -------- convert_to_dict: Convert tensor representation to dictionary. """ rep = einops.pack(list(representation_dict.values()), "b c *")[0] # then get rid of all the nans / unnecessary stats return rep.index_select(-1, self._necessary_stats_mask)
[docs] def convert_to_dict(self, representation_tensor: Tensor) -> OrderedDict: """Convert tensor of statistics to a dictionary. While the tensor representation is required by plenoptic's synthesis objects, the dictionary representation is easier to manually inspect. This dictionary will contain NaNs in its values: these are placeholders for the redundant statistics. Parameters ---------- representation_tensor 3d tensor of statistics. Returns ------- rep Dictionary of representation, with informative keys. See Also -------- convert_to_tensor: Convert dictionary representation to tensor. """ if representation_tensor.shape[-1] != len(self._representation_scales): raise ValueError( "representation tensor is the wrong length (expected" f" {len(self._representation_scales)} but got" f" {representation_tensor.shape[-1]})! Did you remove some of" " the scales? (i.e., by setting scales in the forward pass)?" " convert_to_dict does not support such tensors." ) rep = self._necessary_stats_dict.copy() n_filled = 0 for k, v in rep.items(): # each statistic is a tensor with batch and channel dimensions as # found in representation_tensor and all the other dimensions # determined by the values in necessary_stats_dict. shape = (*representation_tensor.shape[:2], *v.shape) new_v = torch.nan * torch.ones( shape, dtype=representation_tensor.dtype, device=representation_tensor.device, ) # v.sum() gives the number of necessary elements from this stat this_stat_vec = representation_tensor[..., n_filled : n_filled + v.sum()] # use boolean indexing to put the values from new_stat_vec in the # appropriate place new_v[..., v] = this_stat_vec rep[k] = new_v n_filled += v.sum() return rep
def _compute_pyr_coeffs( self, image: Tensor ) -> tuple[OrderedDict, list[Tensor], Tensor, Tensor]: """Compute pyramid coefficients of image. Note that the residual lowpass has been demeaned independently for each batch and channel (and this is true of the lowpass returned separately as well as the one included in pyr_coeffs_dict) Parameters ---------- image : 4d tensor of shape (batch, channel, height, width) containing the image Returns ------- pyr_coeffs_dict : OrderedDict of containing all pyramid coefficients. pyr_coeffs : List of length n_scales, containing 5d tensors of shape (batch, channel, n_orientations, height, width) containing the complex-valued oriented bands (note that height and width shrink by half on each scale). This excludes the residual highpass and lowpass bands. highpass : The residual highpass as a real-valued 4d tensor (batch, channel, height, width) lowpass : The residual lowpass as a real-valued 4d tensor (batch, channel, height, width). This tensor has been demeaned (independently for each batch and channel). """ pyr_coeffs = self._pyr.forward(image) # separate out the residuals and demean the residual lowpass lowpass = pyr_coeffs["residual_lowpass"] lowpass = lowpass - lowpass.mean(dim=(-2, -1), keepdim=True) pyr_coeffs["residual_lowpass"] = lowpass highpass = pyr_coeffs["residual_highpass"] # This is a list of tensors, one for each scale, where each tensor is # of shape (batch, channel, n_orientations, height, width) (note that # height and width halves on each scale) coeffs_list = [ torch.stack([pyr_coeffs[(i, j)] for j in range(self.n_orientations)], 2) for i in range(self.n_scales) ] return pyr_coeffs, coeffs_list, highpass, lowpass @staticmethod def _compute_pixel_stats(image: Tensor) -> Tensor: """Compute the pixel stats: first four moments, min, and max. Parameters ---------- image : 4d tensor of shape (batch, channel, height, width) containing input image. Stats are computed indepently for each batch and channel. Returns ------- pixel_stats : 3d tensor of shape (batch, channel, 6) containing the mean, variance, skew, kurtosis, minimum pixel value, and maximum pixel value (in that order) """ mean = torch.mean(image, dim=(-2, -1), keepdim=True) # we use torch.var instead of plenoptic.tools.variance, because our # variance is the uncorrected (or sample) variance and we want the # corrected one here. var = torch.var(image, dim=(-2, -1)) skew = stats.skew(image, mean=mean, var=var, dim=[-2, -1]) kurtosis = stats.kurtosis(image, mean=mean, var=var, dim=[-2, -1]) # can't compute min/max over two dims simultaneously with # torch.min/max, so use einops img_min = einops.reduce(image, "b c h w -> b c", "min") img_max = einops.reduce(image, "b c h w -> b c", "max") # mean needed to be unflattened to be used by skew and kurtosis # correctly, but we'll want it to be flattened like this in the final # representation tensor return einops.pack([mean, var, skew, kurtosis, img_min, img_max], "b c *")[0] @staticmethod def _compute_intermediate_representations( pyr_coeffs: Tensor, ) -> tuple[list[Tensor], list[Tensor]]: """Compute useful intermediate representations. These representations are: 1) demeaned magnitude of the pyramid coefficients, 2) real part of the pyramid coefficients These two are used in computing some of the texture representation. Parameters ---------- pyr_coeffs : Complex steerable pyramid coefficients (without residuals), as list of length n_scales, containing 5d tensors of shape (batch, channel, n_orientations, height, width) Returns ------- magnitude_pyr_coeffs : List of length n_scales, containing 5d tensors of shape (batch, channel, n_orientations, height, width) (same as ``pyr_coeffs``), containing the demeaned magnitude of the steerable pyramid coefficients (i.e., coeffs.abs() - coeffs.abs().mean((-2, -1))) real_pyr_coeffs : List of length n_scales, containing 5d tensors of shape (batch, channel, n_orientations, height, width) (same as ``pyr_coeffs``), containing the real components of the coefficients (i.e. coeffs.real) """ magnitude_pyr_coeffs = [coeff.abs() for coeff in pyr_coeffs] magnitude_means = [ mag.mean((-2, -1), keepdim=True) for mag in magnitude_pyr_coeffs ] magnitude_pyr_coeffs = [ mag - mn for mag, mn in zip(magnitude_pyr_coeffs, magnitude_means) ] real_pyr_coeffs = [coeff.real for coeff in pyr_coeffs] return magnitude_pyr_coeffs, real_pyr_coeffs def _reconstruct_lowpass_at_each_scale( self, pyr_coeffs_dict: OrderedDict ) -> list[Tensor]: """Reconstruct the lowpass unoriented image at each scale. The autocorrelation, standard deviation, skew, and kurtosis of each of these images is part of the texture representation. Parameters ---------- pyr_coeffs_dict : Dictionary containing the steerable pyramid coefficients, with the lowpass residual demeaned. Returns ------- reconstructed_images : List of length n_scales+1 containing the reconstructed unoriented image at each scale, from fine to coarse. The final image is reconstructed just from the residual lowpass image. Each is a 4d tensor, this is a list because they are all different heights and widths. """ reconstructed_images = [ self._pyr.recon_pyr(pyr_coeffs_dict, levels=["residual_lowpass"]) ] # go through scales backwards for lev in range(self.n_scales - 1, -1, -1): recon = self._pyr.recon_pyr(pyr_coeffs_dict, levels=[lev]) reconstructed_images.append(recon + reconstructed_images[-1]) # now downsample as necessary, so that these end up the same size as # their corresponding coefficients. We multiply by the factor of 4 here # in order to approximately equalize the steerable pyramid coefficient # values across scales. This could also be handled by making the # pyramid tight frame reconstructed_images[:-1] = [ signal.shrink(r, 2 ** (self.n_scales - i)) * 4 ** (self.n_scales - i) for i, r in enumerate(reconstructed_images[:-1]) ] return reconstructed_images def _compute_autocorr(self, coeffs_list: list[Tensor]) -> tuple[Tensor, Tensor]: """Compute the autocorrelation of some statistics. Parameters ---------- coeffs_list : List (of length s) of tensors of shape (batch, channel, *, height, width), where * is zero or one additional dimensions. Intended use case: magnitude_pyr_coeffs (which is list of length n_scales of 5d tensors, with * containing n_orientations) or reconstructed_images (which is a list of length n_scales+1 of 4d tensors) Returns ------- autocorrs : Tensor of shape (batch, channel, spatial_corr_width, spatial_corr_width, *, s) containing the autocorrelation (up to distance ``spatial_corr_width//2``) of each element in ``coeffs_list``, computed independently over all but the final two dimensions. vars : 3d Tensor of shape (batch, channel, *, s) containing the variance of each element in ``coeffs_list``, computed independently over all but the final two dimensions. """ if coeffs_list[0].ndim == 5: dims = "o" elif coeffs_list[0].ndim == 4: dims = "" else: raise ValueError( "coeffs_list must contain tensors of either 4 or 5 dimensions!" ) acs = [signal.autocorrelation(coeff) for coeff in coeffs_list] var = [signal.center_crop(ac, 1) for ac in acs] acs = [ac / v for ac, v in zip(acs, var)] var = einops.rearrange(var, f"s b c {dims} 1 1 -> b c {dims} s") acs = [signal.center_crop(ac, self.spatial_corr_width) for ac in acs] acs = torch.stack(acs, 2) return einops.rearrange(acs, f"b c s {dims} a1 a2 -> b c a1 a2 {dims} s"), var @staticmethod def _compute_skew_kurtosis_recon( reconstructed_images: list[Tensor], var_recon: Tensor, img_var: Tensor ) -> tuple[Tensor, Tensor]: """Compute the skew and kurtosis of each lowpass reconstructed image. For each scale, if the ratio of its variance to the original image's pixel variance is below a threshold of torch.finfo(img_var.dtype).resolution (1e-6 for float32, 1e-15 for float64), skew and kurtosis are assigned default values of 0 or 3, respectively. Parameters ---------- reconstructed_images : List of length n_scales+1 containing the reconstructed unoriented image at each scale, from fine to coarse. The final image is reconstructed just from the residual lowpass image. var_recon : Tensor of shape (batch, channel, n_scales+1) containing the variance of each tensor in reconstruced_images img_var : Tensor of shape (batch, channel) containing the pixel variance (from pixel_stats tensor) Returns ------- skew_recon, kurtosis_recon : Tensors of shape (batch, channel, n_scales+1) containing the skew and kurtosis, respectively, of each tensor in ``reconstructed_images``. """ skew_recon = [ stats.skew(im, mean=0, var=var_recon[..., i], dim=[-2, -1]) for i, im in enumerate(reconstructed_images) ] skew_recon = torch.stack(skew_recon, -1) kurtosis_recon = [ stats.kurtosis(im, mean=0, var=var_recon[..., i], dim=[-2, -1]) for i, im in enumerate(reconstructed_images) ] kurtosis_recon = torch.stack(kurtosis_recon, -1) skew_default = torch.zeros_like(skew_recon) kurtosis_default = 3 * torch.ones_like(kurtosis_recon) # if this variance ratio is too small, then use the default values # instead. unsqueeze is used here because var_recon is shape (batch, # channel, scales+1), whereas img_var is just (batch, channel) res = torch.finfo(img_var.dtype).resolution unstable_locs = var_recon / img_var.unsqueeze(-1) < res skew_recon = torch.where(unstable_locs, skew_default, skew_recon) kurtosis_recon = torch.where(unstable_locs, kurtosis_default, kurtosis_recon) return skew_recon, kurtosis_recon def _compute_cross_correlation( self, coeffs_tensor: list[Tensor], coeffs_tensor_other: list[Tensor], coeffs_var: None | Tensor = None, coeffs_other_var: None | Tensor = None, ) -> Tensor: """Compute cross-correlations. Parameters ---------- coeffs_tensor, coeffs_tensor_other : The two lists of length scales, each containing 5d tensors of shape (batch, channel, n_orientations, height, width) to be correlated. coeffs_var, coeffs_other_var : Two optional tensors containing the variances of coeffs_tensor and coeffs_tensor_other, respectively, in case they've already been computed. Should be of shape (batch, channel, n_orientations, n_scales). Used to normalize the covariances into cross-correlations. Returns ------- cross_corrs : Tensor of shape (batch, channel, n_orientations, n_orientations, scales) containing the cross-correlations at each scale. """ covars = [] for i, (coeff, coeff_other) in enumerate( zip(coeffs_tensor, coeffs_tensor_other) ): # precompute this, which we'll use for normalization numel = torch.mul(*coeff.shape[-2:]) # compute the covariance covar = einops.einsum( coeff, coeff_other, "b c o1 h w, b c o2 h w -> b c o1 o2" ) covar = covar / numel # Then normalize it to get the Pearson product-moment correlation # coefficient, see # https://numpy.org/doc/stable/reference/generated/numpy.corrcoef.html. if coeffs_var is None: # First, compute the variances of each coeff coeff_var = einops.einsum( coeff, coeff, "b c o1 h w, b c o1 h w -> b c o1" ) coeff_var = coeff_var / numel else: coeff_var = coeffs_var[..., i] if coeffs_other_var is None: # First, compute the variances of each coeff coeff_other_var = einops.einsum( coeff_other, coeff_other, "b c o1 h w, b c o1 h w -> b c o1" ) coeff_other_var = coeff_other_var / numel else: coeff_other_var = coeffs_other_var[..., i] # Then compute the outer product of those variances. var_outer_prod = einops.einsum( coeff_var, coeff_other_var, "b c o1, b c o2 -> b c o1 o2" ) # And the sqrt of this is what we use to normalize the covariance # into the cross-correlation covars.append(covar / var_outer_prod.sqrt()) return torch.stack(covars, -1) @staticmethod def _double_phase_pyr_coeffs( pyr_coeffs: list[Tensor], ) -> tuple[list[Tensor], list[Tensor]]: """Upsample and double the phase of pyramid coefficients. Parameters ---------- pyr_coeffs : Complex steerable pyramid coefficients (without residuals), as list of length n_scales, containing 5d tensors of shape (batch, channel, n_orientations, height, width) Returns ------- doubled_phase_mags : The demeaned magnitude (i.e., pyr_coeffs.abs()) of each upsampled double-phased coefficient. List of length n_scales-1 containing tensors of same shape the input (the finest scale has been removed). doubled_phase_separate : The real and imaginary parts of each double-phased coefficient. List of length n_scales-1, containing tensors of shape (batch, channel, 2*n_orientations, height, width), with the real component found at the same orientation index as the input, and the imaginary at orientation+self.n_orientations. (The finest scale has been removed.) """ doubled_phase_mags = [] doubled_phase_sep = [] # don't do this for the finest scale for coeff in pyr_coeffs[1:]: # We divide by the factor of 4 here in order to approximately # equalize the steerable pyramid coefficient values across scales. # This could also be handled by making the pyramid tight frame doubled_phase = signal.expand(coeff, 2) / 4.0 doubled_phase = signal.modulate_phase(doubled_phase, 2) doubled_phase_mag = doubled_phase.abs() doubled_phase_mag = doubled_phase_mag - doubled_phase_mag.mean( (-2, -1), keepdim=True ) doubled_phase_mags.append(doubled_phase_mag) doubled_phase_sep.append( einops.pack([doubled_phase.real, doubled_phase.imag], "b c * h w")[0] ) return doubled_phase_mags, doubled_phase_sep
[docs] def plot_representation( self, data: Tensor, ax: plt.Axes | None = None, figsize: tuple[float, float] = (15, 15), ylim: tuple[float, float] | Literal[False] | None = None, batch_idx: int = 0, title: str | None = None, ) -> tuple[plt.Figure, list[plt.Axes]]: r"""Plot the representation in a human viewable format -- stem plots with data separated out by statistic type. This plots the representation of a single batch and averages over all channels in the representation. We create the following axes: - pixels+var_highpass: marginal pixel statistics (first four moments, min, max) and variance of the residual highpass. - std+skew+kurtosis recon: the standard deviation, skew, and kurtosis of the reconstructed lowpass image at each scale - magnitude_std: the standard deviation of the steerable pyramid coefficient magnitudes at each orientation and scale. - auto_correlation_reconstructed: the auto-correlation of the reconstructed lowpass image at each scale (summarized using Euclidean norm). - auto_correlation_magnitude: the auto-correlation of the pyramid coefficient magnitudes at each scale and orientation (summarized using Euclidean norm). - cross_orientation_correlation_magnitude: the cross-correlations between each orientation at each scale (summarized using Euclidean norm) If self.n_scales > 1, we also have combination of the following, where all cross-correlations are summarized using Euclidean norm over the channel dimension: - cross_scale_correlation_magnitude: the cross-correlations between the pyramid coefficient magnitude at one scale and the same orientation at the next-coarsest scale. - cross_scale_correlation_real: the cross-correlations between the real component of the pyramid coefficients and the real and imaginary components (at the same orientation) at the next-coarsest scale. Parameters ---------- data : The data to show on the plot. Else, should look like the output of ``self.forward(img)``, with the exact same structure (e.g., as returned by ``metamer.representation_error()`` or another instance of this class). ax : Axes where we will plot the data. If a ``plt.Axes`` instance, will subdivide into 6 or 8 new axes (depending on self.n_scales). If None, we create a new figure. figsize : The size of the figure. Ignored if ax is not None. ylim : If not None, the y-limits to use for this plot. If None, we use the default, slightly adjusted so that the minimum is 0. If False, do not change y-limits. batch_idx : Which index to take from the batch dimension (the first one) title : string Title for the plot Returns ------- fig: Figure containing the plot axes: List of 6 or 8 axes containing the plot (depending on self.n_scales) """ # pick the batch_idx we want (but keep the data 3d), and average over # channels (but keep the data 3d). We keep data 3d because # convert_to_dict relies on it. data = data[batch_idx].unsqueeze(0).mean(1, keepdim=True) # each of these values should now be a 3d tensor with 1 element in each # of the first two dims rep = {k: v[0, 0] for k, v in self.convert_to_dict(data).items()} data = self._representation_for_plotting(rep) # Determine plot grid layout if self.n_scales != 1: n_rows = 3 n_cols = int(np.ceil(len(data) / n_rows)) else: # then we don't have any cross-scale correlations, so fewer axes. n_rows = 2 n_cols = int(np.ceil(len(data) / n_rows)) # Set up grid spec if ax is None: # we add 2 to order because we're adding one to get the # number of orientations and then another one to add an # extra column for the mean luminance plot fig = plt.figure(figsize=figsize) gs = mpl.gridspec.GridSpec(n_rows, n_cols, fig) else: # want to make sure the axis we're taking over is basically invisible. ax = clean_up_axes( ax, False, ["top", "right", "bottom", "left"], ["x", "y"] ) gs = ax.get_subplotspec().subgridspec(n_rows, n_cols) fig = ax.figure # plot data axes = [] for i, (k, v) in enumerate(data.items()): ax = fig.add_subplot(gs[i // n_cols, i % n_cols]) ax = clean_stem_plot(to_numpy(v).flatten(), ax, k, ylim=ylim) axes.append(ax) if title is not None: fig.suptitle(title) return fig, axes
def _representation_for_plotting(self, rep: OrderedDict) -> OrderedDict: r"""Convert the data into a dictionary representation that is more convenient for plotting. Intended as a helper function for plot_representation. """ if rep["skew_reconstructed"].ndim > 1: raise ValueError( "Currently, only know how to plot single batch and channel at" " a time! Select and/or average over those dimensions" ) data = OrderedDict() data["pixels+var_highpass"] = torch.cat( [rep.pop("pixel_statistics"), rep.pop("var_highpass_residual")] ) data["std+skew+kurtosis recon"] = torch.cat( ( rep.pop("std_reconstructed"), rep.pop("skew_reconstructed"), rep.pop("kurtosis_reconstructed"), ) ) data["magnitude_std"] = rep.pop("magnitude_std") # want to plot these in a specific order all_keys = [ "auto_correlation_reconstructed", "auto_correlation_magnitude", "cross_orientation_correlation_magnitude", "cross_scale_correlation_magnitude", "cross_scale_correlation_real", ] if set(rep.keys()) != set(all_keys): raise ValueError("representation has unexpected keys!") for k in all_keys: # if we only have one scale, no cross-scale stats if k.startswith("cross_scale") and self.n_scales == 1: continue # we compute L2 norm manually, since there are NaNs (marking # redundant stats) data[k] = rep[k].pow(2).nansum((0, 1)).sqrt().flatten() return data
[docs] def update_plot( self, axes: list[plt.Axes], data: Tensor, batch_idx: int = 0, ) -> list[plt.Artist]: r"""Update the information in our representation plot. This is used for creating an animation of the representation over time. In order to create the animation, we need to know how to update the matplotlib Artists, and this provides a simple way of doing that. It relies on the fact that we've used ``plot_representation`` to create the plots we want to update and so know that they're stem plots. We take the axes containing the representation information (note that this is probably a subset of the total number of axes in the figure, if we're showing other information, as done by ``Metamer.animate``), grab the representation from plotting and, since these are both lists, iterate through them, updating them to the values in ``data`` as we go. In order for this to be used by ``FuncAnimation``, we need to return Artists, so we return a list of the relevant artists, the ``markerline`` and ``stemlines`` from the ``StemContainer``. Currently, this averages over all channels in the representation. Parameters ---------- axes : A list of axes to update. We assume that these are the axes created by ``plot_representation`` and so contain stem plots in the correct order. batch_idx : Which index to take from the batch dimension (the first one) data : The data to show on the plot. Else, should look like the output of ``self.forward(img)``, with the exact same structure (e.g., as returned by ``metamer.representation_error()`` or another instance of this class). Returns ------- stem_artists : A list of the artists used to update the information on the stem plots """ stem_artists = [] axes = [ax for ax in axes if len(ax.containers) == 1] # pick the batch_idx we want (but keep the data 3d), and average over # channels (but keep the data 3d). We keep data 3d because # convert_to_dict relies on it. data = data[batch_idx].unsqueeze(0).mean(1, keepdim=True) # each of these values should now be a 3d tensor with 1 element in each # of the first two dims rep = {k: v[0, 0] for k, v in self.convert_to_dict(data).items()} rep = self._representation_for_plotting(rep) for ax, d in zip(axes, rep.values()): if isinstance(d, dict): vals = np.array([dd.detach() for dd in d.values()]) else: vals = d.flatten().detach().numpy() sc = update_stem(ax.containers[0], vals) stem_artists.extend([sc.markerline, sc.stemlines]) return stem_artists