plenoptic.models.PortillaSimoncelli#

Note

This object is a torch.nn.Module. It therefore has all the methods and attributes from that class, even though they are not documented here (to avoid cluttering this page).

class plenoptic.models.PortillaSimoncelli(image_shape, n_scales=4, n_orientations=4, spatial_corr_width=7)[source]#

Bases: Module

Portila-Simoncelli texture statistics.

The Portilla-Simoncelli (PS) texture statistics are a set of image statistics, first described in Portilla and Simoncelli, 2000 [2], that are proposed as a sufficient set of measurements for describing visual textures. That is, if two texture images have the same values for all PS texture stats, humans should consider them as members of the same family of textures.

The PS stats are computed based on the SteerablePyramidFreq (Simoncelli and Freeman, 1995, [3]). They consist of the local auto-correlations, cross-scale (within-orientation) correlations, and cross-orientation (within-scale) correlations of both the pyramid coefficients and the local energy (as computed by those coefficients). Additionally, they include the first four global moments (mean, variance, skew, and kurtosis) of the image and down-sampled versions of that image. See the paper and notebook for more description.

Changed in version 2.0.0: Default spatial_corr_width value changed from 9 to 7, in order to match the value used to generate the figures in the Portilla and Simoncelli, 2000 [2], paper.

Parameters:
  • image_shape (tuple[int, int]) – Shape of input image.

  • n_scales (int (default: 4)) – The number of pyramid scales used to measure the statistics.

  • n_orientations (int (default: 4)) – The number of orientations used to measure the statistics.

  • spatial_corr_width (int (default: 7)) – The width of the spatial cross- and auto-correlation statistics.

scales#

The names of the unique scales of coefficients in the pyramid, used for coarse-to-fine metamer synthesis.

Type:

list

Raises:

ValueError – If the height or width of image cannot be divided by 2 n_scales times. This is necessary because of how the model handles multiscale representations.

References

Examples

Compute texture statistics of an image:

>>> import plenoptic as po
>>> img = po.data.reptile_skin()
>>> ps_model = po.models.PortillaSimoncelli(img.shape[2:])
>>> ps_model(img)
tensor([[[ 4.1716e-01, 5.4735e-02, ..., 4.7756e-03]]])

Visualize texture statistics:

>>> fig, axes = ps_model.plot_representation(ps_model(img))

(png, hires.png, pdf)

../../_images/plenoptic-models-PortillaSimoncelli-2.png

Convert texture statistics into an easier-to-read format:

>>> representation_dict = ps_model.convert_to_dict(ps_model(img))
>>> representation_dict.keys()
odict_keys(['pixel_statistics', ..., 'var_highpass_residual'])

Synthesize a texture metamer:

>>> import plenoptic as po
>>> import matplotlib.pyplot as plt
>>> import torch
>>> img = po.data.reptile_skin()
>>> ps_model = po.models.PortillaSimoncelli(img.shape[2:])
>>> loss = po.loss.portilla_simoncelli_loss_factory(ps_model, img)
>>> met = po.Metamer(img, ps_model, loss_function=loss)
>>> opt_kwargs = {
...     "max_iter": 10,
...     "max_eval": 10,
...     "history_size": 100,
...     "line_search_fn": "strong_wolfe",
...     "lr": 1,
... }
>>> met.setup(optimizer=torch.optim.LBFGS, optimizer_kwargs=opt_kwargs)
>>> # Note that this isn't enough to run synthesis to completion,
>>> # just an example to demonstrate what synthesis looks like
>>> met.synthesize(max_iter=20)
>>> fig, axes = plt.subplots(1, 4, figsize=(25, 4), width_ratios=[1, 1, 1, 3])
>>> po.plot.imshow(img, ax=axes[0], title="Target image")
<Figure size ... with 4 Axes>
>>> axes[0].xaxis.set_visible(False)
>>> axes[0].yaxis.set_visible(False)
>>> po.plot.synthesis_status(met, fig=fig, axes_idx={"misc": 0})
<Figure size ...>

(png, hires.png, pdf)

../../_images/plenoptic-models-PortillaSimoncelli-3.png

Methods

convert_to_dict(representation_tensor)

Convert tensor of statistics to a dictionary.

convert_to_tensor(representation_dict)

Convert dictionary of statistics to a tensor.

forward(image[, scales])

Generate Texture Statistics representation of an image.

plot_representation(data[, ax, figsize, ...])

Plot the representation in a human viewable format.

remove_scales(representation_tensor, ...)

Remove statistics not associated with scales.

update_plot(axes, data[, batch_idx])

Update the information in our representation plot.

convert_to_dict(representation_tensor)[source]#

Convert tensor of statistics to a dictionary.

While the tensor representation is required by plenoptic’s synthesis objects, the dictionary representation is easier to manually inspect.

This dictionary will contain NaNs in its values: these are placeholders for the redundant statistics.

Parameters:

representation_tensor (Tensor) – 3d tensor of statistics.

Return type:

OrderedDict

Returns:

rep – Dictionary of representation, with informative keys.

Raises:

ValueError – If representation_tensor has an unexpected number of elements. This can happen if some elements were manually removed from representation_tensor, if a non-None value was passed to forward when computing it, or if it was computed using a different instantiation of the model.

See also

convert_to_tensor

Convert dictionary representation to tensor.

Examples

>>> import plenoptic as po
>>> img = po.data.reptile_skin()
>>> portilla_simoncelli_model = po.models.PortillaSimoncelli(
...     img.shape[2:], n_scales=3
... )
>>> representation_tensor = portilla_simoncelli_model(img)
>>> representation_dict = portilla_simoncelli_model.convert_to_dict(
...     representation_tensor
... )
>>> # We will go through and examine each of these keys individually
>>> # Shape is (batch, channel, 6): first four moments plus min and max
>>> # of input image
>>> representation_dict["pixel_statistics"].shape
torch.Size([1, 1, 6])
>>> # Shape is (batch, channel, spatial_corr_width, spatial_corr_width,
>>> # n_orientations, n_scales)
>>> representation_dict["auto_correlation_magnitude"].shape
torch.Size([1, 1, 7, 7, 4, 3])
>>> # Shape is (batch, channel, n_scales+1)
>>> representation_dict["skew_reconstructed"].shape
torch.Size([1, 1, 4])
>>> # Shape is (batch, channel, n_scales+1)
>>> representation_dict["kurtosis_reconstructed"].shape
torch.Size([1, 1, 4])
>>> # Shape is (batch, channel, spatial_corr_width, spatial_corr_width,
>>> # n_scales+1)
>>> representation_dict["auto_correlation_reconstructed"].shape
torch.Size([1, 1, 7, 7, 4])
>>> # Shape is (batch, channel, n_scales+1)
>>> representation_dict["std_reconstructed"].shape
torch.Size([1, 1, 4])
>>> # Shape is (batch, channel, n_orientations, n_orientations, n_scales)
>>> representation_dict["cross_orientation_correlation_magnitude"].shape
torch.Size([1, 1, 4, 4, 3])
>>> # Shape is (batch, channel, n_orientations, n_scales)
>>> representation_dict["magnitude_std"].shape
torch.Size([1, 1, 4, 3])
>>> # Shape is (batch, channel, n_orientations, n_orientations, n_scales-1)
>>> representation_dict["cross_scale_correlation_magnitude"].shape
torch.Size([1, 1, 4, 4, 2])
>>> # Shape is (batch, channel, n_orientations, 2*n_orientations, n_scales-1)
>>> representation_dict["cross_scale_correlation_real"].shape
torch.Size([1, 1, 4, 8, 2])
>>> # Shape is (batch, channel, 1)
>>> representation_dict["var_highpass_residual"].shape
torch.Size([1, 1, 1])
convert_to_tensor(representation_dict)[source]#

Convert dictionary of statistics to a tensor.

The output has shape (batch, channel, n_statistics), flattening and concatenating across all statistic classes. The dictionary representation may be easier to make sense of.

Parameters:

representation_dict (OrderedDict) – Dictionary of representation.

Return type:

Tensor

Returns:

rep – 3d tensor of statistics.

See also

convert_to_dict

Convert tensor representation to dictionary.

Examples

>>> import plenoptic as po
>>> import torch
>>> img = po.data.reptile_skin()
>>> portilla_simoncelli_model = po.models.PortillaSimoncelli(img.shape[2:])
>>> representation_tensor = portilla_simoncelli_model(img)
>>> representation_dict = portilla_simoncelli_model.convert_to_dict(
...     representation_tensor
... )
>>> representation_tensor_new = portilla_simoncelli_model.convert_to_tensor(
...     representation_dict
... )
>>> torch.equal(representation_tensor, representation_tensor_new)
True
forward(image, scales=None)[source]#

Generate Texture Statistics representation of an image.

Note that separate batches and channels are analyzed in parallel.

For any representation that contains info across scales, the scales always run from fine to coarse, representing all orientations at a given scale before moving on.

Parameters:
  • image (Tensor) – A 4d tensor (batch, channel, height, width) containing the image(s) to analyze.

  • scales (list[Literal['pixel_statistics'] | int | Literal['residual_lowpass', 'residual_highpass']] | None (default: None)) – Which scales to include in the returned representation. If None, we include all scales. Otherwise, can contain subset of values present in this model’s scales attribute, and the returned tensor will then contain the subset corresponding to those scales.

Return type:

Tensor

Returns:

representation_tensor – 3d tensor of shape (batch, channel, stats) containing the measured texture statistics.

Raises:

ValueError – If image is not 4d or has a dtype other than float or complex.

Examples

>>> import plenoptic as po
>>> img = po.data.reptile_skin()
>>> portilla_simoncelli_model = po.models.PortillaSimoncelli(img.shape[2:])
>>> representation_tensor = portilla_simoncelli_model(img)
>>> representation_tensor.shape
torch.Size([1, 1, 710])
plot_representation(data, ax=None, figsize=None, ylim=None, batch_idx=0, title=None)[source]#

Plot the representation in a human viewable format.

We plot the representation as stem plots with data separated out by statistic type.

This plots the representation of a single batch and averages over all channels in the representation.

We create the following axes:

  • pixels+var_highpass: marginal pixel statistics (first four moments, min, max) and variance of the residual highpass.

  • std+skew+kurtosis recon: the standard deviation, skew, and kurtosis of the reconstructed lowpass image at each scale

  • magnitude_std: the standard deviation of the steerable pyramid coefficient magnitudes at each orientation and scale.

  • auto_correlation_reconstructed: the auto-correlation of the reconstructed lowpass image at each scale (summarized using Euclidean norm).

  • auto_correlation_magnitude: the auto-correlation of the pyramid coefficient magnitudes at each scale and orientation (summarized using Euclidean norm).

  • cross_orientation_correlation_magnitude: the cross-correlations between each orientation at each scale (summarized using Euclidean norm)

If self.n_scales > 1, we also have combination of the following, where all cross-correlations are summarized using Euclidean norm over the channel dimension:

  • cross_scale_correlation_magnitude: the cross-correlations between the pyramid coefficient magnitude at one scale and the same orientation at the next-coarsest scale.

  • cross_scale_correlation_real: the cross-correlations between the real component of the pyramid coefficients and the real and imaginary components (at the same orientation) at the next-coarsest scale.

Parameters:
  • data (Tensor) – The data to show on the plot. Else, should look like the output of self.forward(img), with the exact same structure (e.g., as returned by metamer.representation_error() or another instance of this class).

  • ax (Axes | None (default: None)) – Axes where we will plot the data. If a plt.Axes instance, will subdivide into 6 or 8 new axes (depending on self.n_scales). If None, we create a new figure.

  • figsize (tuple[float, float] | None (default: None)) – The size of the figure to create. Must be None if ax is not None. If both figsize and ax are None, then we set figsize=(12, 15).

  • ylim (tuple[float, float] | Literal[False] | None (default: None)) – If not None, the y-limits to use for this plot. If None, we use the default, slightly adjusted so that the minimum is 0. If False, do not change y-limits.

  • batch_idx (int (default: 0)) – Which index to take from the batch dimension (the first one).

  • title (str | None (default: None)) – Title for the plot.

Return type:

tuple[Figure, list[Axes]]

Returns:

  • fig – Figure containing the plot.

  • axes – List of 6 or 8 axes containing the plot (depending on self.n_scales).

Raises:

ValueError – If both figsize and ax are not None.

Examples

>>> import plenoptic as po
>>> img = po.data.reptile_skin()
>>> portilla_simoncelli_model = po.models.PortillaSimoncelli(img.shape[2:])
>>> representation_tensor = portilla_simoncelli_model(img)
>>> fig, axes = portilla_simoncelli_model.plot_representation(
...     representation_tensor
... )

(png, hires.png, pdf)

../../_images/plenoptic-models-PortillaSimoncelli-4.png
remove_scales(representation_tensor, scales_to_keep)[source]#

Remove statistics not associated with scales.

For a given representation_tensor and a list of scales_to_keep, this attribute removes all statistics not associated with those scales.

Note that calling this method will always remove statistics.

Parameters:
  • representation_tensor (Tensor) – 3d tensor containing the measured representation statistics.

  • scales_to_keep (list[Literal['pixel_statistics'] | int | Literal['residual_lowpass', 'residual_highpass']]) – Which scales to include in the returned representation. Can contain subset of values present in this model’s scales attribute, and the returned tensor will then contain the subset of the full representation corresponding to those scales.

Return type:

Tensor

Returns:

limited_representation_tensor – Representation tensor with some statistics removed.

Examples

>>> import plenoptic as po
>>> img = po.data.reptile_skin()
>>> portilla_simoncelli_model = po.models.PortillaSimoncelli(img.shape[2:])
>>> representation_tensor = portilla_simoncelli_model(img)
>>> representation_tensor.shape
torch.Size([1, 1, 710])
>>> limited_representation_tensor = portilla_simoncelli_model.remove_scales(
...     representation_tensor, scales_to_keep=[0]
... )
>>> limited_representation_tensor.shape
torch.Size([1, 1, 181])
update_plot(axes, data, batch_idx=0)[source]#

Update the information in our representation plot.

This is used for creating an animation of the representation over time. In order to create the animation, we need to know how to update the matplotlib Artists, and this provides a simple way of doing that. It relies on the fact that we’ve used plot_representation to create the plots we want to update and so know that they’re stem plots.

We take the axes containing the representation information (note that this is probably a subset of the total number of axes in the figure, if we’re showing other information, as done by Metamer.animate), grab the representation from plotting and, since these are both lists, iterate through them, updating them to the values in data as we go.

In order for this to be used by FuncAnimation, we need to return Artists, so we return a list of the relevant artists, the markerline and stemlines from the StemContainer.

Currently, this averages over all channels in the representation.

Parameters:
  • axes (list[Axes]) – A list of axes to update. We assume that these are the axes created by plot_representation and so contain stem plots in the correct order.

  • data (Tensor) – The data to show on the plot. Else, should look like the output of self.forward(img), with the exact same structure (e.g., as returned by metamer.representation_error() or another instance of this class).

  • batch_idx (int (default: 0)) – Which index to take from the batch dimension (the first one).

Return type:

list[Artist]

Returns:

stem_artists – A list of the artists used to update the information on the stem plots.

Examples

This method is meant to be used by animation functions, so users won’t typically use this directly.

>>> import plenoptic as po
>>> img = po.data.reptile_skin()
>>> portilla_simoncelli_model = po.models.PortillaSimoncelli(img.shape[2:])
>>> representation_tensor = portilla_simoncelli_model.forward(img)
>>> fig, axes = portilla_simoncelli_model.plot_representation(
...     representation_tensor
... )
>>> new_img = po.data.einstein()
>>> new_representation_tensor = portilla_simoncelli_model.forward(new_img)
>>> stem_artists = portilla_simoncelli_model.update_plot(
...     axes, new_representation_tensor
... )