"""
Steerable frequency pyramid.
Construct a steerable pyramid on matrix two dimensional signals, in the
Fourier domain.
""" # numpydoc ignore=EX01
import warnings
from collections import OrderedDict
from typing import Literal
import einops
import numpy as np
import torch
import torch.fft as fft
import torch.nn as nn
from numpy.typing import NDArray
from scipy.special import factorial
from torch import Tensor
from .signal import _interpolate1d, _raised_cosine, _steer
complex_types = [torch.cdouble, torch.cfloat]
SCALES_TYPE = int | Literal["residual_lowpass", "residual_highpass"]
__all__ = [
"SteerablePyramidFreq",
]
def __dir__() -> list[str]:
return __all__
[docs]
class SteerablePyramidFreq(nn.Module):
r"""
Steerable frequency pyramid in Torch.
Construct a steerable pyramid on matrix two dimensional signals, in the Fourier
domain. Boundary-handling is circular. Reconstruction is exact (within floating
point errors). However, if the image has an odd-shape, the reconstruction will not
be exact due to boundary-handling issues that have not been resolved. Similarly,
a complex pyramid of order=0 has non-exact reconstruction and cannot be tight-frame.
The squared radial functions tile the Fourier plane with a raised-cosine
falloff. Angular functions are
.. math::
\cos\left(\frac{\theta-k*\pi}{o+1}\right)^o
where :math:`o` is the order parameter set at initialization and :math:`k` runs from
0 to :math:`o` for a total of :math:`o+1` orientations.
Parameters
----------
image_shape
Shape of input image.
height
The height of the pyramid. If ``'auto'``, will automatically determine based on
the size of ``image``. If an ``int``, must be non-negative and less than
``log2(min(image_shape[1], image_shape[1]))-2``. If ``height=0``, this only
returns the residuals.
order
The Gaussian derivative order used for the steerable filters, in ``[0, 15]``.
Note that to achieve steerability the minimum number of orientation is
``order + 1``, which is used here. To get more orientations at the same order,
use the method :meth:`steer_coeffs`.
twidth
The width of the transition region of the radial lowpass function, in
octaves.
is_complex
Whether the pyramid coefficients should be complex or not. If ``True``, the real
and imaginary parts correspond to a pair of odd and even symmetric filters. If
``False``, the coefficients only include the real part. Regardless of the value
of ``is_complex``, the symmetry of the real part is determined by the ``order``
parameter: if ``order`` is even, then the real coefficients are even symmetric;
if ``order`` is odd, then the real coefficients are odd symmetric. (If
``is_complex=True``, then the imaginary coefficients will have the opposite
symmetry of the real ones).
downsample
Whether to downsample each scale in the pyramid or keep the output
pyramid coefficients in fixed bands of size ``image_shape``. When
downsample is ``False``, the forward method returns a tensor.
tight_frame
Whether the pyramid obeys the generalized parseval theorem or not (i.e.
is a tight frame). If ``True``, the energy of the pyr_coeffs equals the energy
of the image. In order to match the `matlabPyrTools
<http://github.com/labForComputationalVision/matlabpyrtools>`_ or `pyrtools
<https://github.com/labForComputationalVision/pyrtools>`_ implementations, this
must be set to ``False``.
Attributes
----------
image_shape : tuple
Shape of input image.
pyr_size : OrderedDict
Dictionary containing the height and width of the pyramid coefficients. Keys are
the same as those in ``pyr_coeffs`` returned by :meth:`forward`, in order:
``"residual_highpass"``, the integers from ``0`` to (the initialization
argument) ``order``, and ``"residual_lowpass"``. The values are 2-tuples of
ints. While the dictionary is initialized with the object, the values are not
set until the first time :meth:`forward` is called.
fft_norm : str
The way the ffts are normalized, see :func:`torch.fft.fft2` for more details.
is_complex : bool
Whether the coefficients are complex- or real-valued.
scales : list
All the scales of the representation (including residuals) in coarse-to-fine
order. A subset of this list can be passed to the :meth:`forward` method to
restrict the output.
Raises
------
ValueError
If ``image_shape`` contains non-integers.
ValueError
If ``len(image_shape) != 2`` .
ValueError
If ``height`` is not a non-negative integer or is larger than the biggest
possible value (determined by ``image_shape``).
ValueError
If ``order`` not an integer in ``[0, 15]``.
ValueError
If ``order == 0`` and ``is_complex is False``. See
https://github.com/plenoptic-org/plenoptic/issues/326 for an explanation
ValueError
If ``twidth`` not positive.
Warns
-----
UserWarning
If ``image_shape`` has an odd value, because then reconstruction will be
imperfect.
Notes
-----
Transform described in Simoncelli and Freeman, 1995 [1]_, filter kernel design
described in Karasaridis and Smoncelli, 1996 [2]_. For further information see
online [3]_.
References
----------
.. [1] E P Simoncelli and W T Freeman, "The Steerable Pyramid: A Flexible
Architecture for Multi-Scale Derivative Computation," Second Int'l Conf
on Image Processing, Washington, DC, Oct 1995.
.. [2] A Karasaridis and E P Simoncelli, "A Filter Design Technique for
Steerable Pyramid Image Transforms", ICASSP, Atlanta, GA, May 1996. ..
.. [3] `<https://www.cns.nyu.edu/~eero/steerpyr/>`_
Examples
--------
>>> import plenoptic as po
>>> spyr = po.process.SteerablePyramidFreq((256, 256))
"""
def __init__(
self,
image_shape: tuple[int, int],
height: Literal["auto"] | int = "auto",
order: int = 3,
twidth: int = 1,
is_complex: bool = False,
downsample: bool = True,
tight_frame: bool = False,
):
super().__init__()
self.order = order
# complex_const comes from the Fourier transform of a gaussian derivative.
self._complex_const_forward = np.power(complex(0, -1), self.order)
self._complex_const_recon = np.power(complex(0, 1), self.order)
try:
self.image_shape = tuple([int(i) for i in image_shape])
except ValueError:
raise ValueError(
f"image_shape must be castable to ints, but got {image_shape}!"
)
if self.image_shape != tuple(image_shape):
raise ValueError(
f"image_shape must be castable to ints, but got {image_shape}!"
)
if len(self.image_shape) != 2:
raise ValueError(
f"image_shape must be a tuple of length 2, but got {self.image_shape}!"
)
if (self.image_shape[0] % 2 != 0) or (self.image_shape[1] % 2 != 0):
warnings.warn("Reconstruction will not be perfect with odd-sized images")
self.is_complex = is_complex
self.downsample = downsample
self.tight_frame = tight_frame
if self.tight_frame:
self.fft_norm = "ortho"
else:
self.fft_norm = "backward"
# cache constants
self.lutsize = 1024
self.Xcosn = (
np.pi
* np.array(range(-(2 * self.lutsize + 1), (self.lutsize + 2)))
/ self.lutsize
)
self.alpha = (self.Xcosn + np.pi) % (2 * np.pi) - np.pi
max_ht = np.floor(np.log2(min(self.image_shape[0], self.image_shape[1]))) - 2
if height == "auto":
self.num_scales = int(max_ht)
elif height > max_ht:
raise ValueError(f"Cannot build pyramid higher than {max_ht:.0f} levels.")
elif height < 0:
raise ValueError("Height must be a non-negative integer.")
else:
self.num_scales = int(height)
if self.order > 15 or self.order < 0:
raise ValueError("order must be an integer in the range [0, 15].")
if self.order == 0 and self.is_complex:
raise ValueError(
"Complex pyramid cannot have order=0! See "
"https://github.com/plenoptic-org/plenoptic/issues/326 "
"for an explanation."
)
self.num_orientations = int(self.order + 1)
if twidth <= 0:
raise ValueError("twidth must be positive.")
twidth = int(twidth)
dims = np.array(self.image_shape)
# make a grid for the raised cosine interpolation
ctr = np.ceil((np.array(dims) + 0.5) / 2).astype(int)
(xramp, yramp) = np.meshgrid(
np.linspace(-1, 1, dims[1] + 1)[:-1],
np.linspace(-1, 1, dims[0] + 1)[:-1],
indexing="xy",
)
self.angle = np.arctan2(yramp, xramp)
log_rad = np.sqrt(xramp**2 + yramp**2)
log_rad[ctr[0] - 1, ctr[1] - 1] = log_rad[ctr[0] - 1, ctr[1] - 2]
self.log_rad = np.log2(log_rad)
# radial transition function (a raised cosine in log-frequency):
self.Xrcos, Yrcos = _raised_cosine(twidth, (-twidth / 2.0), np.array([0, 1]))
self.Yrcos = np.sqrt(Yrcos)
self.YIrcos = np.sqrt(1.0 - self.Yrcos**2)
# create low and high masks
lo0mask = _interpolate1d(self.log_rad, self.YIrcos, self.Xrcos)
hi0mask = _interpolate1d(self.log_rad, self.Yrcos, self.Xrcos)
self.register_buffer(
"lo0mask", einops.rearrange(torch.as_tensor(lo0mask), "h w -> 1 1 1 h w")
)
self.register_buffer("hi0mask", torch.as_tensor(hi0mask).unsqueeze(0))
# need a mock image to down-sample so that we correctly
# construct the differently-sized masks
mock_image = np.random.rand(*self.image_shape)
imdft = np.fft.fftshift(np.fft.fft2(mock_image))
lodft = imdft * lo0mask
# this list, used by coarse-to-fine optimization, gives all the
# scales (including residuals) from coarse to fine
self.scales = (
["residual_lowpass"]
+ list(range(self.num_scales))[::-1]
+ ["residual_highpass"]
)
self.pyr_size = OrderedDict({k: () for k in self.scales[::-1]})
# we create these copies because they will be modified in the
# following loops
Xrcos = self.Xrcos.copy()
angle = self.angle.copy()
log_rad = self.log_rad.copy()
# pre-generate the angle, hi and lo masks, as well as the
# indices used for down-sampling.
self._loindices = []
for i in range(self.num_scales):
Xrcos -= np.log2(2)
const = (
(2 ** (2 * self.order))
* (factorial(self.order, exact=True) ** 2)
/ float(self.num_orientations * factorial(2 * self.order, exact=True))
)
if self.is_complex:
Ycosn_forward = (
2.0
* np.sqrt(const)
* (np.cos(self.Xcosn) ** self.order)
* (np.abs(self.alpha) < np.pi / 2.0).astype(int)
)
Ycosn_recon = np.sqrt(const) * (np.cos(self.Xcosn)) ** self.order
else:
Ycosn_forward = np.sqrt(const) * (np.cos(self.Xcosn)) ** self.order
Ycosn_recon = Ycosn_forward
himask = _interpolate1d(log_rad, self.Yrcos, Xrcos)
self.register_buffer(
f"_himasks_scale_{i}", torch.as_tensor(himask).unsqueeze(0)
)
anglemasks = []
anglemasks_recon = []
for b in range(self.num_orientations):
anglemask = _interpolate1d(
angle,
Ycosn_forward,
self.Xcosn + np.pi * b / self.num_orientations,
)
anglemask_recon = _interpolate1d(
angle,
Ycosn_recon,
self.Xcosn + np.pi * b / self.num_orientations,
)
anglemasks.append(torch.as_tensor(anglemask).unsqueeze(0))
anglemasks_recon.append(torch.as_tensor(anglemask_recon).unsqueeze(0))
self.register_buffer(
f"_anglemasks_scale_{i}",
einops.rearrange(anglemasks, "o 1 h w -> 1 1 o h w"),
)
self.register_buffer(
f"_anglemasks_recon_scale_{i}", torch.cat(anglemasks_recon)
)
if not self.downsample:
lomask = _interpolate1d(log_rad, self.YIrcos, Xrcos)
self.register_buffer(
f"_lomasks_scale_{i}", torch.as_tensor(lomask).unsqueeze(0)
)
self._loindices.append([np.array([0, 0]), dims])
lodft = lodft * lomask
else:
# subsample lowpass
dims = np.array([lodft.shape[0], lodft.shape[1]])
ctr = np.ceil((dims + 0.5) / 2).astype(int)
lodims = np.ceil((dims - 0.5) / 2).astype(int)
loctr = np.ceil((lodims + 0.5) / 2).astype(int)
lostart = ctr - loctr
loend = lostart + lodims
self._loindices.append([lostart, loend])
# subsample indices
log_rad = log_rad[lostart[0] : loend[0], lostart[1] : loend[1]]
angle = angle[lostart[0] : loend[0], lostart[1] : loend[1]]
lomask = _interpolate1d(log_rad, self.YIrcos, Xrcos)
self.register_buffer(
f"_lomasks_scale_{i}", torch.as_tensor(lomask).unsqueeze(0)
)
# subsampling
lodft = lodft[lostart[0] : loend[0], lostart[1] : loend[1]]
# convolution in spatial domain
lodft = lodft * lomask
# reasonable default dtype
self.to(torch.float32)
# This model has no trainable parameters, so it's always in eval mode
self.eval()
[docs]
def forward(
self,
image: Tensor,
scales: list[SCALES_TYPE] | None = None,
) -> OrderedDict:
r"""
Generate the steerable pyramid coefficients for an image.
The steerable pyramid coefficients run from fine to coarse and split the image
into subbands corresponding to different orientations and scales (i.e., spatial
frequencies).
.. versionchanged:: 1.4
The returned ``pyr_coeffs`` dictionary's keys are now either strings
specifying the residual or integers specifying the scale. The non-residual
coefficients are now 5d tensors of shape (batch, channel, num_orientations,
height, width).
Parameters
----------
image
A tensor containing the image to analyze. We want to operate
on this in the pytorch-y way, so we want it to be 4d (batch,
channel, height, width).
scales
Which scales to include in the returned representation. If None, we
include all scales. Otherwise, can contain subset of values present
in this model's ``scales`` attribute (ints from 0 up to
``self.num_scales-1`` and the strs 'residual_highpass' and
'residual_lowpass'. Can contain a single value or multiple values.
If it's an int, we include all orientations from that scale. Order
within the list does not matter.
Returns
-------
pyr_coeffs
Pyramid coefficients. These will be stored in an ordered dictionary with
keys that are, in order: ``"residual_highpass"``, the integers from ``0`` to
(the initialization argument) ``order``, and ``"residual_lowpass"``.
Coefficients have shape ``(*image.shape[:2], self.num_orientations,
image.shape[2] / 2**scale, image.shape[3] / 2**scale)``, with the
``"residual_highpass"`` height and width matching that of ``image``, and
``"residual_lowpass"`` having height and width ``(image.shape[2] /
2**self.num_scales, image.shape[3] / 2**self.num_scales)``. They are
ordered from fine to coarse: ``"residual_highpass", 0, 1, ...,
num_scales-1, "residual_lowpass"``.
Raises
------
ValueError
If ``image`` is the wrong shape, i.e. ``image.shape[-2:] !=
self.image_shape``.
Examples
--------
.. plot::
:context: reset
>>> import plenoptic as po
>>> img = po.data.einstein()
>>> spyr = po.process.SteerablePyramidFreq(img.shape[-2:])
>>> po.plot.pyrshow(spyr(img))
<PyrFigure ...>
"""
if self.image_shape != image.shape[-2:]:
raise ValueError(
f"Input tensor height/width {tuple(image.shape[-2:])} does not match "
f"image_shape set at initialization {tuple(self.image_shape)}. "
"Either resize the input or re-initialize this model."
)
pyr_coeffs = OrderedDict()
if scales is None:
scales = self.scales
scale_ints = [s for s in scales if isinstance(s, int)]
if len(scale_ints) != 0:
assert (max(scale_ints) < self.num_scales) and (min(scale_ints) >= 0), (
"Scales must be within 0 and num_scales-1"
)
# image is a torch tensor batch of images of size (batch, channel, height,
# width)
if len(image.shape) != 4:
raise ValueError("Input image must be 4d (batch, channel, height, width)!")
imdft = fft.fft2(image, dim=(-2, -1), norm=self.fft_norm)
imdft = fft.fftshift(imdft, dim=(-2, -1))
if "residual_highpass" in scales:
# high-pass
hi0dft = imdft * self.hi0mask
hi0 = fft.ifftshift(hi0dft, dim=(-2, -1))
hi0 = fft.ifft2(hi0, dim=(-2, -1), norm=self.fft_norm)
pyr_coeffs["residual_highpass"] = hi0.real
self.pyr_size["residual_highpass"] = tuple(hi0.real.shape[-2:])
# input to the next scale is the low-pass filtered component. after this
# multiplication, lodft will be shape (batch, channel, orientations, height,
# width)
lodft = einops.einsum(imdft, self.lo0mask, "b c h w, b c o h w -> b c o h w")
for i in range(self.num_scales):
if i in scales:
# high-pass mask is selected based on the current scale
himask = getattr(self, f"_himasks_scale_{i}")
mask = getattr(self, f"_anglemasks_scale_{i}") * himask
# compute filter output at each orientation
# band pass filtering is done in the fourier space as multiplying
# by the fft of a gaussian derivative.
# The oriented dft is computed as a product of the fft of the
# low-passed component, the precomputed anglemask (specifies
# orientation), and the precomputed hipass mask (creating a bandpass
# filter) the complex_const variable comes from the Fourier
# transform of a gaussian derivative.
# Based on the order of the gaussian, this constant changes.
banddft = self._complex_const_forward * lodft * mask
# fft output is then shifted to center frequencies
band = fft.ifftshift(banddft, dim=(-2, -1))
# ifft is applied to recover the filtered representation in spatial
# domain
band = fft.ifft2(band, dim=(-2, -1), norm=self.fft_norm)
# for real pyramid, take the real component of the complex band
if not self.is_complex:
pyr_coeffs[i] = band.real
else:
# Because the input signal is real, to maintain a tight frame
# if the complex pyramid is used, magnitudes need to be divided
# by sqrt(2) because energy is doubled.
if self.tight_frame:
band = band / np.sqrt(2)
pyr_coeffs[i] = band
self.pyr_size[i] = tuple(band.shape[-2:])
if not self.downsample:
# no subsampling of angle and rad
# just use lo0mask
lomask = getattr(self, f"_lomasks_scale_{i}")
lodft = lodft * lomask
# Since we don't subsample here, if we are not using
# orthonormalization that we need to manually account for the
# subsampling, so that energy in each band remains the same
# the energy is cut by factor of 4 so we need to scale magnitudes
# by factor of 2.
if self.fft_norm != "ortho":
lodft = 2 * lodft
else:
# subsample indices
lostart, loend = self._loindices[i]
# subsampling of the dft for next scale
lodft = lodft[..., lostart[0] : loend[0], lostart[1] : loend[1]]
# low-pass filter mask is selected
lomask = getattr(self, f"_lomasks_scale_{i}")
# again multiply dft by subsampled mask (convolution in spatial domain)
lodft = lodft * lomask
if "residual_lowpass" in scales:
# compute residual lowpass when height <=1
lo0 = fft.ifftshift(lodft, dim=(-2, -1))
lo0 = fft.ifft2(lo0, dim=(-2, -1), norm=self.fft_norm)
pyr_coeffs["residual_lowpass"] = lo0.real.squeeze(2)
self.pyr_size["residual_lowpass"] = tuple(lo0.real.shape[-2:])
return pyr_coeffs
[docs]
@staticmethod
def convert_pyr_to_tensor(
pyr_coeffs: OrderedDict, split_complex: bool = False
) -> tuple[
Tensor, tuple[int, list[SCALES_TYPE], list[torch.Size], list[torch.Size] | bool]
]:
r"""
Convert coefficient dictionary to a tensor.
The output tensor has shape (batch, channel, height, width) and is
intended to be used in an :class:`torch.nn.Module` downstream. In the
multichannel case, all bands for each channel will be stacked together
(i.e. if there are 2 channels and 18 bands per channel,
``pyr_tensor[:,0:18,...]`` will contain the pyr responses for channel 1 and
``pyr_tensor[:, 18:36, ...]`` will contain the responses for channel 2). In
the case of a complex, multichannel pyramid with ``split_complex=True``,
the real/imaginary bands will be intereleaved so that they appear as
pairs with neighboring indices in the channel dimension of the tensor.
(Note: the residual bands are always real so they will only ever have a
single band even when ``split_complex=True``.)
This only works if ``pyr_coeffs`` was created with a pyramid with
``downsample=False``
Parameters
----------
pyr_coeffs
The pyramid coefficients.
split_complex
Indicates whether the output should split complex bands into
real/imag channels or keep them as a single channel. This should be
``True`` if you intend to use a convolutional layer on top of the
output.
Returns
-------
pyr_tensor
Tensor with shape (batch, channel, height, width). pyramid coefficients
reshaped into tensor. The first channel will be the residual
highpass and the last will be the residual lowpass. Each band is
then a separate channel, going from fine to coarse (i.e., starting with all
of scale 0's orientations, then scale 1's, etc.).
pyr_info
Information required to recreate the dictionary, containing the
number of channels, the list of pyramid keys for the dictionary,
info on how to unpack the coefficients, and info on how ``split_complex``
was used.
Raises
------
RuntimeError
If ``self.downsample is True``. In this case, we can't concatenate across
scales, because each scale is a different size.
See Also
--------
convert_tensor_to_pyr
Convert tensor representation to pyramid dictionary.
Examples
--------
.. plot::
:context: reset
>>> import plenoptic as po
>>> img = po.data.einstein()
>>> spyr = po.process.SteerablePyramidFreq(img.shape[-2:], downsample=False)
>>> coeffs = spyr(img)
>>> coeffs_tensor, _ = spyr.convert_pyr_to_tensor(coeffs)
>>> coeffs_tensor.shape
torch.Size([1, 26, 256, 256])
>>> # rearrange so that the residuals are at the end
>>> coeffs_tensor = [
... coeffs_tensor[:, 1:-1],
... coeffs_tensor[:, :1],
... coeffs_tensor[:, -1:],
... ]
>>> po.plot.imshow(coeffs_tensor, col_wrap=spyr.num_orientations)
<PyrFigure ...>
"""
pyr_keys = list(pyr_coeffs.keys())
num_channels = pyr_coeffs[pyr_keys[0]].size(1)
try:
packed, pack_info = einops.pack(list(pyr_coeffs.values()), "b c * h w")
except RuntimeError:
raise RuntimeError(
"feature maps could not be concatenated into tensor. Check that you"
"are using coefficients that are not downsampled across scales."
"This is done with the 'downsample=False' argument for the pyramid"
)
# if the second half of this is False, then pyr_coeffs only contains residuals
if split_complex and not all([isinstance(k, str) for k in pyr_keys]):
start_idx = 0
end_idx = None
if "residual_highpass" in pyr_keys:
start_idx = 1
if "residual_lowpass" in pyr_keys:
end_idx = -1
complex_coeffs = packed[:, :, start_idx:end_idx]
try:
separated = einops.rearrange(
[complex_coeffs.real, complex_coeffs.imag],
"complex b c o h w -> b c (o complex) h w",
)
except RuntimeError:
raise RuntimeError(
"split_complex=True but coefficient tensors are real-valued! "
"Either set split_complex=False or regenerate the coefficients "
"with a complex pyramid."
)
to_pack = []
if "residual_highpass" in pyr_keys:
to_pack.append(packed[:, :, 0].real)
to_pack.append(separated)
if "residual_lowpass" in pyr_keys:
to_pack.append(packed[:, :, -1].real)
packed, split_complex = einops.pack(to_pack, "b c * h w")
pyr_info = (num_channels, pyr_keys, pack_info, split_complex)
return einops.rearrange(packed, "b c o h w -> b (c o) h w"), pyr_info
[docs]
@staticmethod
def convert_tensor_to_pyr(
pyr_tensor: Tensor,
num_channels: int,
pyr_keys: list[SCALES_TYPE],
pack_info: list[torch.Size],
split_complex_pack_info: list[torch.Size] | bool,
) -> OrderedDict:
r"""
Convert pyramid coefficient tensor to dictionary format.
The arguments other than ``pyr_tensor`` are elements of the
``pyr_info`` tuple returned by :meth:`convert_pyr_to_tensor`. You
should always unpack the arguments for this function from that
``pyr_info`` tuple. See Examples section below.
Parameters
----------
pyr_tensor
Shape (batch, channel, height, width). The pyramid coefficients.
num_channels
Number of channels in the original input tensor the pyramid was
created for (i.e. if the input was an RGB image, this would be 3).
pyr_keys
Keys from the original pyramid dictionary.
pack_info
List of sizes of the fifth dimension for each coefficient (i.e., the number
of orientations) used to pack/unpack the tensors.
split_complex_pack_info
If :meth:`convert_pyr_to_tensor` was called with ``split_complex=True``,
another list of sizes used to pack/unpack the tensors. Else, ``False``.
Returns
-------
pyr_coeffs
Pyramid coefficients in dictionary format as returned by :meth:`forward`.
See Also
--------
convert_pyr_to_tensor
Convert pyramid dictionary representation to tensor.
Examples
--------
>>> import plenoptic as po
>>> img = po.data.einstein()
>>> spyr = po.process.SteerablePyramidFreq(
... img.shape[-2:], downsample=False, is_complex=True
... )
>>> coeffs = spyr(img)
>>> coeffs_tensor, pyr_info = spyr.convert_pyr_to_tensor(coeffs)
>>> coeffs_tensor.shape
torch.Size([1, 26, 256, 256])
>>> coeffs_tensor.dtype
torch.complex64
>>> new_coeffs = spyr.convert_tensor_to_pyr(coeffs_tensor, *pyr_info)
>>> all([torch.equal(v, new_coeffs[k]) for k, v in coeffs.items()])
True
>>> coeffs_tensor, pyr_info = spyr.convert_pyr_to_tensor(
... coeffs, split_complex=True
... )
>>> coeffs_tensor.shape
torch.Size([1, 50, 256, 256])
>>> coeffs_tensor.dtype
torch.float32
>>> new_coeffs = spyr.convert_tensor_to_pyr(coeffs_tensor, *pyr_info)
>>> all([torch.equal(v, new_coeffs[k]) for k, v in coeffs.items()])
True
"""
# this function just undoes the einops calls in convert_pyr_to_tensor
unpacked = einops.rearrange(
pyr_tensor, "b (c o) h w -> b c o h w", c=num_channels
)
if not isinstance(split_complex_pack_info, bool):
unpacked = einops.unpack(unpacked, split_complex_pack_info, "b c * h w")
if "residual_highpass" in pyr_keys:
complex_coeffs = unpacked[1]
complex_pack_info = pack_info[1:]
else:
complex_coeffs = unpacked[0]
complex_pack_info = pack_info
if "residual_lowpass" in pyr_keys:
complex_pack_info = complex_pack_info[:-1]
bands = einops.rearrange(
complex_coeffs, "b c (o complex) h w -> b c o h w complex", complex=2
).contiguous()
bands = torch.view_as_complex(bands)
bands = einops.unpack(bands, complex_pack_info, "b c * h w")
coeffs = []
if "residual_highpass" in pyr_keys:
coeffs.append(unpacked[0])
coeffs.extend(bands)
if "residual_lowpass" in pyr_keys:
coeffs.append(unpacked[-1])
else:
coeffs = einops.unpack(unpacked, pack_info, "b c * h w")
pyr_coeffs = OrderedDict({k: v for k, v in zip(pyr_keys, coeffs)})
# make sure these are real-valued
for k in ["residual_lowpass", "residual_highpass"]:
if k in pyr_coeffs:
pyr_coeffs[k] = pyr_coeffs[k].real
return pyr_coeffs
def _recon_levels_check(
self, levels: Literal["all"] | list[SCALES_TYPE]
) -> list[SCALES_TYPE]:
r"""
Check whether levels arg is valid for reconstruction and return valid version.
When reconstructing the input image (i.e., when calling :meth:`recon_pyr()`),
the user specifies which levels to include. This makes sure those
levels are valid and gets them in the form we expect for the rest of
the reconstruction. If the user passes ``"all"``, this constructs the
appropriate list (based on the values of ``pyr_coeffs``).
Parameters
----------
levels
If ``list`` should contain some subset of integers from ``0`` to
``self.num_scales-1`` (inclusive) and ``"residual_highpass"`` and
``"residual_lowpass"`` (if appropriate for the pyramid). If ``"all"``,
returned value will contain all valid levels.
Returns
-------
levels
List containing the valid levels for reconstruction.
Raises
------
TypeError
If ``levels`` is not one of the allowed values.
""" # numpydoc ignore=EX01
if isinstance(levels, str):
if levels != "all":
raise TypeError(
"levels must be a list of levels or the string 'all' but"
f" got {levels}"
)
levels = (
["residual_highpass"]
+ list(range(self.num_scales))
+ ["residual_lowpass"]
)
else:
if not hasattr(levels, "__iter__"):
raise TypeError(
"levels must be a list of levels or the string 'all' but"
f" got {levels}"
)
levs_nums = np.array([int(i) for i in levels if isinstance(i, int)])
assert (levs_nums >= 0).all(), "Level numbers must be non-negative."
assert (levs_nums < self.num_scales).all(), (
f"Level numbers must be in the range [0, {self.num_scales - 1:d}]"
)
levs_tmp = list(np.sort(levs_nums)) # we want smallest first
if "residual_highpass" in levels:
levs_tmp = ["residual_highpass"] + levs_tmp
if "residual_lowpass" in levels:
levs_tmp = levs_tmp + ["residual_lowpass"]
levels = levs_tmp
# not all pyramids have residual highpass / lowpass, but it's easier
# to construct the list including them, then remove them if necessary.
if "residual_lowpass" not in self.pyr_size and "residual_lowpass" in levels:
levels.pop(-1)
if "residual_highpass" not in self.pyr_size and "residual_highpass" in levels:
levels.pop(0)
return levels
def _recon_bands_check(self, bands: Literal["all"] | list[int]) -> list[int]:
"""
Check whether bands arg is valid for reconstruction and return valid version.
When reconstructing the input image (i.e., when calling :meth:`recon_pyr()`),
the user specifies which orientations to include. This makes sure those
orientations are valid and gets them in the form we expect for the rest
of the reconstruction. If the user passes ``'all'``, this
constructs the appropriate list (based on the values of ``pyr_coeffs``).
Parameters
----------
bands
If list, should contain some subset of integers from ``0`` to
``self.num_orientations-1``. If ``'all'``, returned value will
contain all valid orientations.
Returns
-------
bands
List containing the valid orientations for reconstruction.
Raises
------
TypeError
If ``bands`` is not an int or ``"all"``.
ValueError
If ``bands`` is an integer outside of the range ``[0,
self.num_orientations-1]``.
""" # numpydoc ignore=EX01
if isinstance(bands, str):
if bands != "all":
raise TypeError(
f"bands must be a list of ints or the string 'all' but got {bands}"
)
else:
if not hasattr(bands, "__iter__"):
raise TypeError(
f"bands must be a list of ints or the string 'all' but got {bands}"
)
bands: NDArray = np.array(bands, ndmin=1)
assert (bands >= 0).all(), "Error: band numbers must be larger than 0."
if any(bands > self.num_orientations):
raise ValueError(
"Error: band numbers must be in the range "
f"[0, {self.num_orientations - 1:d}]"
)
return bands
[docs]
def recon_pyr(
self,
pyr_coeffs: OrderedDict,
levels: Literal["all"] | list[SCALES_TYPE] = "all",
bands: Literal["all"] | list[int] = "all",
) -> Tensor:
"""
Reconstruct image from coefficients, optionally using a subset.
Parameters
----------
pyr_coeffs
Pyramid coefficients to reconstruct from.
levels
If ``list`` should contain some subset of integers from ``0`` to
``self.num_scales-1`` (inclusive), ``"residual_lowpass"``, and
``"residual_highpass"``. If ``"all"``, returned value will contain all
valid levels. Otherwise, must be one of the valid levels.
bands
If list, should contain some subset of integers from ``0`` to
``self.num_orientations-1``. If ``"all"``, returned value will contain
all valid orientations. Otherwise, must be one of the valid
orientations.
Returns
-------
recon
The reconstructed image, of shape (batch, channel, height, width).
Raises
------
ValueError
If ``self.forward()`` was called with ``scales`` argument not ``None``.
TypeError
If ``levels`` is not one of the allowed values.
TypeError
If ``bands`` is not an integer or ``"all"`` .
ValueError
If ``bands`` is an integer outside of the range ``[0,
self.num_orientations-1]``.
Examples
--------
.. plot::
:context: reset
>>> import plenoptic as po
>>> import torch
>>> img = po.data.einstein()
>>> spyr = po.process.SteerablePyramidFreq(img.shape[-2:])
>>> coeffs = spyr(img)
>>> recon = spyr.recon_pyr(coeffs)
>>> torch.allclose(recon, img, rtol=1e-8, atol=1e-5)
True
>>> titles = ["Original", "Reconstructed", "Difference"]
>>> po.plot.imshow([img, recon, img - recon], title=titles)
<PyrFigure ...>
""" # numpydoc ignore=ES01
# For reconstruction to work, last time we called forward needed
# to include all levels
for s in self.scales:
if s not in pyr_coeffs:
raise ValueError(
f"scale {s} not in pyr_coeffs! pyr_coeffs must include"
" all scales, so make sure forward() was called with"
" arg scales=None"
)
levels = self._recon_levels_check(levels)
bands = self._recon_bands_check(bands)
scale = 0
# Recursively generate the reconstruction - function starts with
# fine scales going down to coarse and then the reconstruction
# is built recursively from the coarse scale up
recondft = self._recon_levels(pyr_coeffs, levels, bands, scale)
outdft = recondft * self.lo0mask.squeeze()
# generate highpass residual Reconstruction
if "residual_highpass" in levels:
hidft = fft.fft2(
pyr_coeffs["residual_highpass"],
dim=(-2, -1),
norm=self.fft_norm,
)
hidft = fft.fftshift(hidft, dim=(-2, -1))
# output dft is the sum of the recondft from the recursive
# function times the lomask (low pass component) with the
# highpass dft * the highpass mask
outdft = outdft + hidft * self.hi0mask
# get output reconstruction by inverting the fft
reconstruction = fft.ifftshift(outdft, dim=(-2, -1))
reconstruction = fft.ifft2(reconstruction, dim=(-2, -1), norm=self.fft_norm)
# get real part of reconstruction (if complex)
return reconstruction.real
def _recon_levels(
self,
pyr_coeffs: OrderedDict,
recon_levels: list[SCALES_TYPE],
recon_bands: list[int] | Literal["all"],
scale: int,
) -> Tensor:
"""
Recursive function used to build the reconstruction. Called by recon_pyr.
Each time this function is called, it reconstructs a single scale.
Parameters
----------
pyr_coeffs
Dictionary containing the coefficients of the pyramid. Keys are
``(level, band)`` tuples and the strings ``"residual_lowpass"`` and
``"residual_highpass"`` and values are Tensors of shape (batch,
channel, height, width).
recon_levels
List of scales to include in the reconstruction.
recon_bands
Either ``"all"`` (in which case we include all bands)
or list of bands to include in the reconstruction.
scale
Current scale that is being used to build the reconstruction
scale is incremented by 1 on each call of the function.
Returns
-------
recondft
Current reconstruction based on the orientation band dft from the
current scale summed with the output of recursive call with the
next scale incremented.
""" # numpydoc ignore=EX01
# base case, return the low-pass residual
if scale == self.num_scales:
if "residual_lowpass" in recon_levels:
lodft = fft.fft2(
pyr_coeffs["residual_lowpass"],
dim=(-2, -1),
norm=self.fft_norm,
)
lodft = fft.fftshift(lodft)
else:
lodft = fft.fft2(
torch.zeros_like(pyr_coeffs["residual_lowpass"]),
dim=(-2, -1),
norm=self.fft_norm,
)
return lodft
# Reconstruct from orientation bands
# update himask
if scale in recon_levels:
himask = getattr(self, f"_himasks_scale_{scale}")
mask = getattr(self, f"_anglemasks_recon_scale_{scale}") * himask
coeffs = pyr_coeffs[scale]
# then recon_bands is not "all" and we're subselecting them
if not isinstance(recon_bands, str):
coeffs = coeffs[:, :, recon_bands]
mask = mask[recon_bands]
if self.tight_frame and self.is_complex:
coeffs = coeffs * np.sqrt(2)
orientdft = fft.fft2(coeffs, dim=(-2, -1), norm=self.fft_norm)
orientdft = fft.fftshift(orientdft, dim=(-2, -1))
orientdft = self._complex_const_recon * orientdft * mask
orientdft = orientdft.sum(2)
else:
orientdft = torch.zeros_like(pyr_coeffs[scale][:, :, 0])
# get the bounding box indices for the low-pass component
lostart, loend = self._loindices[scale]
# create lowpass mask
lomask = getattr(self, f"_lomasks_scale_{scale}")
# Recursively reconstruct by going to the next scale
reslevdft = self._recon_levels(pyr_coeffs, recon_levels, recon_bands, scale + 1)
# in not downsampled case, rescale the magnitudes of the reconstructed
# dft at each level by factor of 2 to account for the scaling in the forward
if (not self.tight_frame) and (not self.downsample):
reslevdft = reslevdft / 2
# create output for reconstruction result
resdft = torch.zeros_like(pyr_coeffs[scale][:, :, 0], dtype=torch.complex64)
# place upsample and convolve lowpass component
resdft[..., lostart[0] : loend[0], lostart[1] : loend[1]] = reslevdft * lomask
recondft = resdft + orientdft
# add orientation interpolated and added images to the lowpass image
return recondft
[docs]
def steer_coeffs(
self,
pyr_coeffs: OrderedDict,
angles: list[float],
even_phase: bool = True,
) -> tuple[dict, dict]:
"""
Steer pyramid coefficients to the specified angles.
This allows you to have filters that have the Gaussian derivative order
specified in construction, but arbitrary angles or number of orientations.
.. versionchanged:: 1.4
The returned ``resteered_coeffs`` dictionary now only contains the new
angles, as opposed to concatenating the new angles onto those found in
the input ``pyr_coeffs``. Like the input ``pyr_coeffs``, the dictionary
keys are now integers specifying the scale and the coefficients are
5d tensors of shape (batch, channel, angles, height, width).
Parameters
----------
pyr_coeffs
The pyramid coefficients to steer, as returned by :meth:`forward`.
angles
List of angles (in radians) to steer the pyramid coefficients to.
even_phase
Specifies whether the harmonics are cosine or sine phase aligned
about those positions.
Returns
-------
resteered_coeffs
Dictionary of re-steered pyramid coefficients. will have the same
number of scales as the original pyramid (though it will not
contain the residual highpass or lowpass). Like the input ``pyr_coeffs``,
keys are ints indexing the scale and values are tensors of shape (batch,
channel, orientations, height, width), but now orientations index ``angles``
instead of ``self.num_orientations``.
resteering_weights :
Dictionary of weights used to re-steer the pyramid coefficients.
will have the same keys as ``resteered_coeffs``.
Examples
--------
.. plot::
>>> import plenoptic as po
>>> import torch
>>> img = po.data.einstein()
>>> spyr = po.process.SteerablePyramidFreq(img.shape[-2:], height=3)
>>> coeffs = spyr(img)
>>> resteered_coeffs, resteering_weights = spyr.steer_coeffs(
... coeffs, torch.linspace(0, 2 * torch.pi, 64)
... )
>>> ani = po.plot.animshow(
... resteered_coeffs[2], repeat=True, framerate=6, zoom=4
... )
>>> # Save the video (here we're saving it as a .gif)
>>> ani.save("resteered_coeffs.gif")
.. image:: resteered_coeffs.gif
"""
assert pyr_coeffs[0].dtype not in complex_types, (
"steering only implemented for real coefficients"
)
resteered_coeffs = {}
resteering_weights = {}
num_scales = self.num_scales
for i in range(num_scales):
# put orientation on the last dimension
basis = einops.rearrange(pyr_coeffs[i], "b c o h w -> b c h w o")
res, steervect = [], []
for j, a in enumerate(angles):
r, s = _steer(basis, a, even_phase=even_phase)
res.append(r)
steervect.append(s)
# when called like above, the output of steer always has a singleton
# dimension at the end corresponding to the single angle it was steered to
resteered_coeffs[i] = einops.rearrange(
res, "o b c h w dummy -> b c (o dummy) h w"
)
resteering_weights[i] = torch.stack(steervect, dim=-1)
return resteered_coeffs, resteering_weights