Source code for plenoptic.simulate.canonical_computations.steerable_pyramid_freq

"""Steerable frequency pyramid

Construct a steerable pyramid on matrix two dimensional signals, in the
Fourier domain.
"""

import warnings
from collections import OrderedDict
from typing import Literal

import numpy as np
import torch
import torch.fft as fft
import torch.nn as nn
from einops import rearrange
from numpy.typing import NDArray
from scipy.special import factorial
from torch import Tensor

from ...tools.signal import interpolate1d, raised_cosine, steer

complex_types = [torch.cdouble, torch.cfloat]

SCALES_TYPE = int | Literal["residual_lowpass", "residual_highpass"]
KEYS_TYPE = tuple[int, int] | Literal["residual_lowpass", "residual_highpass"]


[docs] class SteerablePyramidFreq(nn.Module): r"""Steerable frequency pyramid in Torch Construct a steerable pyramid on matrix two dimensional signals, in the Fourier domain. Boundary-handling is circular. Reconstruction is exact (within floating point errors). However, if the image has an odd-shape, the reconstruction will not be exact due to boundary-handling issues that have not been resolved. The squared radial functions tile the Fourier plane with a raised-cosine falloff. Angular functions are cos(theta-k*pi/order+1)^(order). Notes ----- Transform described in [1]_, filter kernel design described in [2]_. For further information see the project webpage_ Parameters ---------- image_shape : `list or tuple` shape of input image height : 'auto' or `int` The height of the pyramid. If 'auto', will automatically determine based on the size of `image`. If an int, must be non-negative and less than log2(min(image_shape[1], image_shape[1]))-2. If height=0, this only returns the residuals. order : `int`. The Gaussian derivative order used for the steerable filters, in [1, 15]. Note that to achieve steerability the minimum number of orientation is `order` + 1, and is used here. To get more orientations at the same order, use the method `steer_coeffs` twidth : `int` The width of the transition region of the radial lowpass function, in octaves is_complex : `bool` Whether the pyramid coefficients should be complex or not. If True, the real and imaginary parts correspond to a pair of even and odd symmetric filters. If False, the coefficients only include the real part / even downsample: `bool` Whether to downsample each scale in the pyramid or keep the output pyramid coefficients in fixed bands of size imshapeximshape. When downsample is False, the forward method returns a tensor. tight_frame: `bool` default: False Whether the pyramid obeys the generalized parseval theorem or not (i.e. is a tight frame). If True, the energy of the pyr_coeffs = energy of the image. If not this is not true. In order to match the matlabPyrTools or pyrtools pyramids, this must be set to False Attributes ---------- image_shape : `list or tuple` shape of input image pyr_size : `dict` Dictionary containing the sizes of the pyramid coefficients. Keys are `(level, band)` tuples and values are tuples. fft_norm : `str` The way the ffts are normalized, see pytorch documentation for more details. is_complex : `bool` Whether the coefficients are complex- or real-valued. References ---------- .. [1] E P Simoncelli and W T Freeman, "The Steerable Pyramid: A Flexible Architecture for Multi-Scale Derivative Computation," Second Int'l Conf on Image Processing, Washington, DC, Oct 1995. .. [2] A Karasaridis and E P Simoncelli, "A Filter Design Technique for Steerable Pyramid Image Transforms", ICASSP, Atlanta, GA, May 1996. .. _webpage: https://www.cns.nyu.edu/~eero/steerpyr/ """ def __init__( self, image_shape: tuple[int, int], height: Literal["auto"] | int = "auto", order: int = 3, twidth: int = 1, is_complex: bool = False, downsample: bool = True, tight_frame: bool = False, ): super().__init__() self.pyr_size = OrderedDict() self.order = order self.image_shape = image_shape if (self.image_shape[0] % 2 != 0) or (self.image_shape[1] % 2 != 0): warnings.warn("Reconstruction will not be perfect with odd-sized images") self.is_complex = is_complex self.downsample = downsample self.tight_frame = tight_frame if self.tight_frame: self.fft_norm = "ortho" else: self.fft_norm = "backward" # cache constants self.lutsize = 1024 self.Xcosn = ( np.pi * np.array(range(-(2 * self.lutsize + 1), (self.lutsize + 2))) / self.lutsize ) self.alpha = (self.Xcosn + np.pi) % (2 * np.pi) - np.pi max_ht = np.floor(np.log2(min(self.image_shape[0], self.image_shape[1]))) - 2 if height == "auto": self.num_scales = int(max_ht) elif height > max_ht: raise ValueError(f"Cannot build pyramid higher than {max_ht:.0f} levels.") elif height < 0: raise ValueError("Height must be a non-negative integer.") else: self.num_scales = int(height) if self.order > 15 or self.order <= 0: raise ValueError("order must be an integer in the range [1,15].") self.num_orientations = int(self.order + 1) if twidth <= 0: raise ValueError("twidth must be positive.") twidth = int(twidth) dims = np.array(self.image_shape) # make a grid for the raised cosine interpolation ctr = np.ceil((np.array(dims) + 0.5) / 2).astype(int) (xramp, yramp) = np.meshgrid( np.linspace(-1, 1, dims[1] + 1)[:-1], np.linspace(-1, 1, dims[0] + 1)[:-1], ) self.angle = np.arctan2(yramp, xramp) log_rad = np.sqrt(xramp**2 + yramp**2) log_rad[ctr[0] - 1, ctr[1] - 1] = log_rad[ctr[0] - 1, ctr[1] - 2] self.log_rad = np.log2(log_rad) # radial transition function (a raised cosine in log-frequency): self.Xrcos, Yrcos = raised_cosine(twidth, (-twidth / 2.0), np.array([0, 1])) self.Yrcos = np.sqrt(Yrcos) self.YIrcos = np.sqrt(1.0 - self.Yrcos**2) # create low and high masks lo0mask = interpolate1d(self.log_rad, self.YIrcos, self.Xrcos) hi0mask = interpolate1d(self.log_rad, self.Yrcos, self.Xrcos) self.register_buffer("lo0mask", torch.as_tensor(lo0mask).unsqueeze(0)) self.register_buffer("hi0mask", torch.as_tensor(hi0mask).unsqueeze(0)) # need a mock image to down-sample so that we correctly # construct the differently-sized masks mock_image = np.random.rand(*self.image_shape) imdft = np.fft.fftshift(np.fft.fft2(mock_image)) lodft = imdft * lo0mask # this list, used by coarse-to-fine optimization, gives all the # scales (including residuals) from coarse to fine self.scales = ( ["residual_lowpass"] + list(range(self.num_scales))[::-1] + ["residual_highpass"] ) # we create these copies because they will be modified in the # following loops Xrcos = self.Xrcos.copy() angle = self.angle.copy() log_rad = self.log_rad.copy() # pre-generate the angle, hi and lo masks, as well as the # indices used for down-sampling. self._loindices = [] for i in range(self.num_scales): Xrcos -= np.log2(2) const = ( (2 ** (2 * self.order)) * (factorial(self.order, exact=True) ** 2) / float(self.num_orientations * factorial(2 * self.order, exact=True)) ) if self.is_complex: Ycosn_forward = ( 2.0 * np.sqrt(const) * (np.cos(self.Xcosn) ** self.order) * (np.abs(self.alpha) < np.pi / 2.0).astype(int) ) Ycosn_recon = np.sqrt(const) * (np.cos(self.Xcosn)) ** self.order else: Ycosn_forward = np.sqrt(const) * (np.cos(self.Xcosn)) ** self.order Ycosn_recon = Ycosn_forward himask = interpolate1d(log_rad, self.Yrcos, Xrcos) self.register_buffer( f"_himasks_scale_{i}", torch.as_tensor(himask).unsqueeze(0) ) anglemasks = [] anglemasks_recon = [] for b in range(self.num_orientations): anglemask = interpolate1d( angle, Ycosn_forward, self.Xcosn + np.pi * b / self.num_orientations, ) anglemask_recon = interpolate1d( angle, Ycosn_recon, self.Xcosn + np.pi * b / self.num_orientations, ) anglemasks.append(torch.as_tensor(anglemask).unsqueeze(0)) anglemasks_recon.append(torch.as_tensor(anglemask_recon).unsqueeze(0)) self.register_buffer(f"_anglemasks_scale_{i}", torch.cat(anglemasks)) self.register_buffer( f"_anglemasks_recon_scale_{i}", torch.cat(anglemasks_recon) ) if not self.downsample: lomask = interpolate1d(log_rad, self.YIrcos, Xrcos) self.register_buffer( f"_lomasks_scale_{i}", torch.as_tensor(lomask).unsqueeze(0) ) self._loindices.append([np.array([0, 0]), dims]) lodft = lodft * lomask else: # subsample lowpass dims = np.array([lodft.shape[0], lodft.shape[1]]) ctr = np.ceil((dims + 0.5) / 2).astype(int) lodims = np.ceil((dims - 0.5) / 2).astype(int) loctr = np.ceil((lodims + 0.5) / 2).astype(int) lostart = ctr - loctr loend = lostart + lodims self._loindices.append([lostart, loend]) # subsample indices log_rad = log_rad[lostart[0] : loend[0], lostart[1] : loend[1]] angle = angle[lostart[0] : loend[0], lostart[1] : loend[1]] lomask = interpolate1d(log_rad, self.YIrcos, Xrcos) self.register_buffer( f"_lomasks_scale_{i}", torch.as_tensor(lomask).unsqueeze(0) ) # subsampling lodft = lodft[lostart[0] : loend[0], lostart[1] : loend[1]] # convolution in spatial domain lodft = lodft * lomask # reasonable default dtype self.to(torch.float32)
[docs] def forward( self, x: Tensor, scales: list[SCALES_TYPE] | None = None, ) -> OrderedDict: r"""Generate the steerable pyramid coefficients for an image Parameters ---------- x : A tensor containing the image to analyze. We want to operate on this in the pytorch-y way, so we want it to be 4d (batch, channel, height, width). scales : Which scales to include in the returned representation. If None, we include all scales. Otherwise, can contain subset of values present in this model's ``scales`` attribute (ints from 0 up to ``self.num_scales-1`` and the strs 'residual_highpass' and 'residual_lowpass'. Can contain a single value or multiple values. If it's an int, we include all orientations from that scale. Order within the list does not matter. Returns ------- representation: Pyramid coefficients """ pyr_coeffs = OrderedDict() if scales is None: scales = self.scales scale_ints = [s for s in scales if isinstance(s, int)] if len(scale_ints) != 0: assert (max(scale_ints) < self.num_scales) and ( min(scale_ints) >= 0 ), "Scales must be within 0 and num_scales-1" angle = self.angle.copy() log_rad = self.log_rad.copy() lo0mask = self.lo0mask.clone() hi0mask = self.hi0mask.clone() # x is a torch tensor batch of images of size (batch, channel, height, # width) assert len(x.shape) == 4, "Input must be batch of images of shape BxCxHxW" imdft = fft.fft2(x, dim=(-2, -1), norm=self.fft_norm) imdft = fft.fftshift(imdft) if "residual_highpass" in scales: # high-pass hi0dft = imdft * hi0mask hi0 = fft.ifftshift(hi0dft) hi0 = fft.ifft2(hi0, dim=(-2, -1), norm=self.fft_norm) pyr_coeffs["residual_highpass"] = hi0.real self.pyr_size["residual_highpass"] = tuple(hi0.real.shape[-2:]) # input to the next scale is the low-pass filtered component lodft = imdft * lo0mask for i in range(self.num_scales): if i in scales: # high-pass mask is selected based on the current scale himask = getattr(self, f"_himasks_scale_{i}") # compute filter output at each orientation for b in range(self.num_orientations): # band pass filtering is done in the fourier space as multiplying # by the fft of a gaussian derivative. # The oriented dft is computed as a product of the fft of the # low-passed component, the precomputed anglemask (specifies # orientation), and the precomputed hipass mask (creating a bandpass # filter) the complex_const variable comes from the Fourier # transform of a gaussian derivative. # Based on the order of the gaussian, this constant changes. anglemask = getattr(self, f"_anglemasks_scale_{i}")[b] complex_const = np.power(complex(0, -1), self.order) banddft = complex_const * lodft * anglemask * himask # fft output is then shifted to center frequencies band = fft.ifftshift(banddft) # ifft is applied to recover the filtered representation in spatial # domain band = fft.ifft2(band, dim=(-2, -1), norm=self.fft_norm) # for real pyramid, take the real component of the complex band if not self.is_complex: pyr_coeffs[(i, b)] = band.real else: # Because the input signal is real, to maintain a tight frame # if the complex pyramid is used, magnitudes need to be divided # by sqrt(2) because energy is doubled. if self.tight_frame: band = band / np.sqrt(2) pyr_coeffs[(i, b)] = band self.pyr_size[(i, b)] = tuple(band.shape[-2:]) if not self.downsample: # no subsampling of angle and rad # just use lo0mask lomask = getattr(self, f"_lomasks_scale_{i}") lodft = lodft * lomask # Since we don't subsample here, if we are not using # orthonormalization that we need to manually account for the # subsampling, so that energy in each band remains the same # the energy is cut by factor of 4 so we need to scale magnitudes # by factor of 2. if self.fft_norm != "ortho": lodft = 2 * lodft else: # subsample indices lostart, loend = self._loindices[i] log_rad = log_rad[lostart[0] : loend[0], lostart[1] : loend[1]] angle = angle[lostart[0] : loend[0], lostart[1] : loend[1]] # subsampling of the dft for next scale lodft = lodft[:, :, lostart[0] : loend[0], lostart[1] : loend[1]] # low-pass filter mask is selected lomask = getattr(self, f"_lomasks_scale_{i}") # again multiply dft by subsampled mask (convolution in spatial domain) lodft = lodft * lomask if "residual_lowpass" in scales: # compute residual lowpass when height <=1 lo0 = fft.ifftshift(lodft) lo0 = fft.ifft2(lo0, dim=(-2, -1), norm=self.fft_norm) pyr_coeffs["residual_lowpass"] = lo0.real self.pyr_size["residual_lowpass"] = tuple(lo0.real.shape[-2:]) return pyr_coeffs
[docs] @staticmethod def convert_pyr_to_tensor( pyr_coeffs: OrderedDict, split_complex: bool = False ) -> tuple[Tensor, tuple[int, bool, list[KEYS_TYPE]]]: r"""Convert coefficient dictionary to a tensor. The output tensor has shape (batch, channel, height, width) and is intended to be used in an ``torch.nn.Module`` downstream. In the multichannel case, all bands for each channel will be stacked together (i.e. if there are 2 channels and 18 bands per channel, pyr_tensor[:,0:18,...] will contain the pyr responses for channel 1 and pyr_tensor[:, 18:36, ...] will contain the responses for channel 2). In the case of a complex, multichannel pyramid with split_complex=True, the real/imaginary bands will be intereleaved so that they appear as pairs with neighboring indices in the channel dimension of the tensor (Note: the residual bands are always real so they will only ever have a single band even when split_complex=True.) This only works if ``pyr_coeffs`` was created with a pyramid with ``downsample=False`` Parameters ---------- pyr_coeffs: the pyramid coefficients split_complex: indicates whether the output should split complex bands into real/imag channels or keep them as a single channel. This should be True if you intend to use a convolutional layer on top of the output. Returns ------- pyr_tensor: shape (batch, channel, height, width). pyramid coefficients reshaped into tensor. The first channel will be the residual highpass and the last will be the residual lowpass. Each band is then a separate channel. pyr_info: Information required to recreate the dictionary, containing the number of channels, if split_complex was used in this function call, and the list of pyramid keys for the dictionary See also -------- convert_tensor_to_pyr: Convert tensor representation to pyramid dictionary. """ pyr_keys = list(pyr_coeffs.keys()) test_band = pyr_coeffs[pyr_keys[0]] num_channels = test_band.size(1) coeff_list = [] for ch in range(num_channels): coeff_list_resid = [] coeff_list_bands = [] for k in pyr_keys: coeffs = pyr_coeffs[k][:, ch : (ch + 1), ...] if "residual" in k: coeff_list_resid.append(coeffs) else: if (coeffs.dtype in complex_types) and split_complex: coeff_list_bands.extend([coeffs.real, coeffs.imag]) else: coeff_list_bands.append(coeffs) if "residual_highpass" in pyr_coeffs: coeff_list_bands.insert(0, coeff_list_resid[0]) if "residual_lowpass" in pyr_coeffs: coeff_list_bands.append(coeff_list_resid[1]) elif "residual_lowpass" in pyr_coeffs: coeff_list_bands.append(coeff_list_resid[0]) coeff_list.extend(coeff_list_bands) try: pyr_tensor = torch.cat(coeff_list, dim=1) pyr_info = tuple([num_channels, split_complex, pyr_keys]) except RuntimeError: raise Exception( """feature maps could not be concatenated into tensor. Check that you are using coefficients that are not downsampled across scales. This is done with the 'downsample=False' argument for the pyramid""" ) return pyr_tensor, pyr_info
[docs] @staticmethod def convert_tensor_to_pyr( pyr_tensor: Tensor, num_channels: int, split_complex: bool, pyr_keys: list[KEYS_TYPE], ) -> OrderedDict: r"""Convert pyramid coefficient tensor to dictionary format. ``num_channels``, ``split_complex``, and ``pyr_keys`` are elements of the ``pyr_info`` tuple returned by ``convert_pyr_to_tensor``. You should always unpack the arguments for this function from that ``pyr_info`` tuple. Example Usage: .. code-block:: python pyr_tensor, pyr_info = convert_pyr_to_tensor(pyr_coeffs, split_complex=True) pyr_dict = convert_tensor_to_pyr(pyr_tensor, *pyr_info) Parameters ---------- pyr_tensor: Shape (batch, channel, height, width). The pyramid coefficients num_channels: number of channels in the original input tensor the pyramid was created for (i.e. if the input was an RGB image, this would be 3) split_complex: true or false, specifying whether the pyr_tensor was created with complex channels split or not (if the pyramid was a complex pyramid). pyr_keys: tuple containing the list of keys for the original pyramid dictionary Returns ------- pyr_coeffs: pyramid coefficients in dictionary format See also -------- convert_pyr_to_tensor: Convert pyramid dictionary representation to tensor. """ pyr_coeffs = OrderedDict() i = 0 for ch in range(num_channels): for k in pyr_keys: if "residual" in k: band = pyr_tensor[:, i, ...].unsqueeze(1).type(torch.float) i += 1 else: if split_complex: band = torch.view_as_complex( rearrange( pyr_tensor[:, i : i + 2, ...], "b c h w -> b h w c", ) .unsqueeze(1) .contiguous() ) i += 2 else: band = pyr_tensor[:, i, ...].unsqueeze(1) i += 1 if k not in pyr_coeffs: pyr_coeffs[k] = band else: pyr_coeffs[k] = torch.cat([pyr_coeffs[k], band], dim=1) return pyr_coeffs
def _recon_levels_check( self, levels: Literal["all"] | list[SCALES_TYPE] ) -> list[SCALES_TYPE]: r""" Check whether levels arg is valid for reconstruction and return valid version When reconstructing the input image (i.e., when calling `recon_pyr()`), the user specifies which levels to include. This makes sure those levels are valid and gets them in the form we expect for the rest of the reconstruction. If the user passes `'all'`, this constructs the appropriate list (based on the values of `pyr_coeffs`). Parameters ---------- levels : If `list` should contain some subset of integers from `0` to `self.num_scales-1` (inclusive) and `'residual_highpass'` and `'residual_lowpass'` (if appropriate for the pyramid). If `'all'`, returned value will contain all valid levels. Returns ------- levels : List containing the valid levels for reconstruction. """ if isinstance(levels, str): if levels != "all": raise TypeError( "levels must be a list of levels or the string 'all' but" f" got {levels}" ) levels = ( ["residual_highpass"] + list(range(self.num_scales)) + ["residual_lowpass"] ) else: if not hasattr(levels, "__iter__"): raise TypeError( "levels must be a list of levels or the string 'all' but" f" got {levels}" ) levs_nums = np.array([int(i) for i in levels if isinstance(i, int)]) assert (levs_nums >= 0).all(), "Level numbers must be non-negative." assert ( levs_nums < self.num_scales ).all(), f"Level numbers must be in the range [0, {self.num_scales - 1:d}]" levs_tmp = list(np.sort(levs_nums)) # we want smallest first if "residual_highpass" in levels: levs_tmp = ["residual_highpass"] + levs_tmp if "residual_lowpass" in levels: levs_tmp = levs_tmp + ["residual_lowpass"] levels = levs_tmp # not all pyramids have residual highpass / lowpass, but it's easier # to construct the list including them, then remove them if necessary. if "residual_lowpass" not in self.pyr_size and "residual_lowpass" in levels: levels.pop(-1) if "residual_highpass" not in self.pyr_size and "residual_highpass" in levels: levels.pop(0) return levels def _recon_bands_check(self, bands: Literal["all"] | list[int]) -> list[int]: """Check whether bands arg is valid for reconstruction and return valid version When reconstructing the input image (i.e., when calling `recon_pyr()`), the user specifies which orientations to include. This makes sure those orientations are valid and gets them in the form we expect for the rest of the reconstruction. If the user passes `'all'`, this constructs the appropriate list (based on the values of `pyr_coeffs`). Parameters ---------- bands : If list, should contain some subset of integers from `0` to `self.num_orientations-1`. If `'all'`, returned value will contain all valid orientations. Returns ------- bands: List containing the valid orientations for reconstruction. """ if isinstance(bands, str): if bands != "all": raise TypeError( "bands must be a list of ints or the string 'all' but got" f" {bands}" ) bands = np.arange(self.num_orientations) else: if not hasattr(bands, "__iter__"): raise TypeError( "bands must be a list of ints or the string 'all' but got" f" {bands}" ) bands: NDArray = np.array(bands, ndmin=1) assert (bands >= 0).all(), "Error: band numbers must be larger than 0." assert (bands < self.num_orientations).all(), ( "Error: band numbers must be in the range [0, " f"{self.num_orientations - 1:d}]" ) return list(bands) def _recon_keys( self, levels: Literal["all"] | list[SCALES_TYPE], bands: Literal["all"] | list[int], max_orientations: int | None = None, ) -> list[KEYS_TYPE]: """Make a list of all the relevant keys from `pyr_coeffs` to use in pyramid reconstruction When reconstructing the input image (i.e., when calling `recon_pyr()`), the user specifies some subset of the pyramid coefficients to include in the reconstruction. This function takes in those specifications, checks that they're valid, and returns a list of tuples that are keys into the `pyr_coeffs` dictionary. Parameters ---------- levels: If `list` should contain some subset of integers from `0` to `self.num_scales-1` (inclusive) and `'residual_highpass'` and `'residual_lowpass'` (if appropriate for the pyramid). If `'all'`, returned value will contain all valid levels. bands: If list, should contain some subset of integers from `0` to `self.num_orientations-1`. If `'all'`, returned value will contain all valid orientations. max_orientations: The maximum number of orientations we allow in the reconstruction. when we determine which ints are allowed for bands, we ignore all those greater than max_orientations. Returns ------- recon_keys : List of `tuples`, all of which are keys in `pyr_coeffs`. These are the coefficients to include in the reconstruction of the image. """ levels = self._recon_levels_check(levels) bands = self._recon_bands_check(bands) if max_orientations is not None: for i in bands: if i >= max_orientations: warnings.warn( f"You wanted band {i:d} in the reconstruction but" f" max_orientation is {max_orientations:d}, so we" "'re ignoring that band" ) bands = [i for i in bands if i < max_orientations] recon_keys = [] for level in levels: # residual highpass and lowpass if isinstance(level, str): recon_keys.append(level) # else we have to get each of the (specified) bands at # that level else: recon_keys.extend([(level, band) for band in bands]) return recon_keys
[docs] def recon_pyr( self, pyr_coeffs: OrderedDict, levels: Literal["all"] | list[SCALES_TYPE] = "all", bands: Literal["all"] | list[int] = "all", ) -> Tensor: """Reconstruct the image or batch of images, optionally using subset of pyramid coefficients. NOTE: in order to call this function, you need to have previously called `self.forward(x)`, where `x` is the tensor you wish to reconstruct. This will fail if you called `forward()` with a subset of scales. Parameters ---------- pyr_coeffs: pyramid coefficients to reconstruct from levels: If `list` should contain some subset of integers from `0` to `self.num_scales-1` (inclusive), `'residual_lowpass'`, and `'residual_highpass'`. If `'all'`, returned value will contain all valid levels. Otherwise, must be one of the valid levels. bands : If list, should contain some subset of integers from `0` to `self.num_orientations-1`. If `'all'`, returned value will contain all valid orientations. Otherwise, must be one of the valid orientations. Returns ------- recon: The reconstructed image, of shape (batch, channel, height, width) """ # For reconstruction to work, last time we called forward needed # to include all levels for s in self.scales: if isinstance(s, str): if s not in pyr_coeffs: raise Exception( f"scale {s} not in pyr_coeffs! pyr_coeffs must include" " all scales, so make sure forward() was called with" " arg scales=None" ) else: for b in range(self.num_orientations): if (s, b) not in pyr_coeffs: raise Exception( f"scale {s} not in pyr_coeffs! pyr_coeffs must" " include all scales, so make sure forward() was" " called with arg scales=None" ) recon_keys = self._recon_keys(levels, bands) scale = 0 # load masks from model lo0mask = self.lo0mask hi0mask = self.hi0mask # Recursively generate the reconstruction - function starts with # fine scales going down to coarse and then the reconstruction # is built recursively from the coarse scale up recondft = self._recon_levels(pyr_coeffs, recon_keys, scale) # generate highpass residual Reconstruction if "residual_highpass" in recon_keys: hidft = fft.fft2( pyr_coeffs["residual_highpass"], dim=(-2, -1), norm=self.fft_norm, ) hidft = fft.fftshift(hidft) # output dft is the sum of the recondft from the recursive # function times the lomask (low pass component) with the # highpass dft * the highpass mask outdft = recondft * lo0mask + hidft * hi0mask else: outdft = recondft * lo0mask # get output reconstruction by inverting the fft reconstruction = fft.ifftshift(outdft) reconstruction = fft.ifft2(reconstruction, dim=(-2, -1), norm=self.fft_norm) # get real part of reconstruction (if complex) reconstruction = reconstruction.real return reconstruction
def _recon_levels( self, pyr_coeffs: OrderedDict, recon_keys: list[KEYS_TYPE], scale: int ) -> Tensor: """Recursive function used to build the reconstruction. Called by recon_pyr Parameters ---------- pyr_coeffs : Dictionary containing the coefficients of the pyramid. Keys are `(level, band)` tuples and the strings `'residual_lowpass'` and `'residual_highpass'` and values are Tensors of shape (batch, channel, height, width). recon_keys : list of the keys that index into the pyr_coeffs Dictionary scale : current scale that is being used to build the reconstruction scale is incremented by 1 on each call of the function Returns ------- recondft : Current reconstruction based on the orientation band dft from the current scale summed with the output of recursive call with the next scale incremented """ # base case, return the low-pass residual if scale == self.num_scales: if "residual_lowpass" in recon_keys: lodft = fft.fft2( pyr_coeffs["residual_lowpass"], dim=(-2, -1), norm=self.fft_norm, ) lodft = fft.fftshift(lodft) else: lodft = fft.fft2( torch.zeros_like(pyr_coeffs["residual_lowpass"]), dim=(-2, -1), norm=self.fft_norm, ) return lodft # Reconstruct from orientation bands # update himask himask = getattr(self, f"_himasks_scale_{scale}") orientdft = torch.zeros_like(pyr_coeffs[(scale, 0)]) for b in range(self.num_orientations): if (scale, b) in recon_keys: anglemask = getattr(self, f"_anglemasks_recon_scale_{scale}")[b] coeffs = pyr_coeffs[(scale, b)] if self.tight_frame and self.is_complex: coeffs = coeffs * np.sqrt(2) banddft = fft.fft2(coeffs, dim=(-2, -1), norm=self.fft_norm) banddft = fft.fftshift(banddft) complex_const = np.power(complex(0, 1), self.order) banddft = complex_const * banddft * anglemask * himask orientdft = orientdft + banddft # get the bounding box indices for the low-pass component lostart, loend = self._loindices[scale] # create lowpass mask lomask = getattr(self, f"_lomasks_scale_{scale}") # Recursively reconstruct by going to the next scale reslevdft = self._recon_levels(pyr_coeffs, recon_keys, scale + 1) # in not downsampled case, rescale the magnitudes of the reconstructed # dft at each level by factor of 2 to account for the scaling in the forward if (not self.tight_frame) and (not self.downsample): reslevdft = reslevdft / 2 # create output for reconstruction result resdft = torch.zeros_like(pyr_coeffs[(scale, 0)], dtype=torch.complex64) # place upsample and convolve lowpass component resdft[:, :, lostart[0] : loend[0], lostart[1] : loend[1]] = reslevdft * lomask recondft = resdft + orientdft # add orientation interpolated and added images to the lowpass image return recondft
[docs] def steer_coeffs( self, pyr_coeffs: OrderedDict, angles: list[float], even_phase: bool = True, ) -> tuple[dict, dict]: """Steer pyramid coefficients to the specified angles This allows you to have filters that have the Gaussian derivative order specified in construction, but arbitrary angles or number of orientations. Parameters ---------- pyr_coeffs : the pyramid coefficients to steer angles : list of angles (in radians) to steer the pyramid coefficients to even_phase : specifies whether the harmonics are cosine or sine phase aligned about those positions. Returns ------- resteered_coeffs : dictionary of re-steered pyramid coefficients. will have the same number of scales as the original pyramid (though it will not contain the residual highpass or lowpass). like `pyr_coeffs`, keys are 2-tuples of ints indexing the scale and orientation, but now we're indexing `angles` instead of `self.num_orientations`. resteering_weights : dictionary of weights used to re-steer the pyramid coefficients. will have the same keys as `resteered_coeffs`. """ assert ( pyr_coeffs[(0, 0)].dtype not in complex_types ), "steering only implemented for real coefficients" resteered_coeffs = {} resteering_weights = {} num_scales = self.num_scales num_orientations = self.num_orientations for i in range(num_scales): basis = torch.cat( [ pyr_coeffs[(i, j)].squeeze().unsqueeze(-1) for j in range(num_orientations) ], dim=-1, ) for j, a in enumerate(angles): res, steervect = steer( basis, a, return_weights=True, even_phase=even_phase ) resteering_weights[(i, j)] = steervect resteered_coeffs[(i, num_orientations + j)] = res.reshape( pyr_coeffs[(i, 0)].shape ) return resteered_coeffs, resteering_weights