Source code for plenoptic.process.non_linearities

"""
Some useful non-linearities for visual models.

The functions operate on dictionaries or tensors.
"""  # numpydoc ignore=EX01

import torch

from . import signal
from .convolutions import blur_downsample, upsample_blur

__all__ = [
    "local_gain_control",
    "local_gain_control_dict",
    "local_gain_release",
    "local_gain_release_dict",
    "polar_to_rectangular_dict",
    "rectangular_to_polar_dict",
]


def __dir__() -> list[str]:
    return __all__


[docs] def rectangular_to_polar_dict( coeff_dict: dict, residuals: bool = False, ) -> tuple[dict, dict]: """ Return the complex amplitude and the phase of each complex tensor in a dictionary. Keys are preserved, with the option of dropping ``"residual_lowpass"`` and ``"residual_highpass"`` by setting ``residuals=False``. Parameters ---------- coeff_dict A dictionary containing complex tensors. residuals An option to include residuals in the returned ``energy`` dict. Returns ------- energy The dictionary of :class:`torch.Tensor` containing the local complex amplitude of ``coeff_dict``. state The dictionary of :class:`torch.Tensor` containing the local phase of ``coeff_dict``. See Also -------- :func:`~plenoptic.process.rectangular_to_polar` Same operation on tensors. polar_to_rectangular_dict The inverse operation. local_gain_control_dict The analogous function for real-valued signals. Examples -------- .. plot:: >>> import plenoptic as po >>> img = po.data.einstein() >>> spyr = po.process.SteerablePyramidFreq( ... img.shape[-2:], is_complex=True, height=3 ... ) >>> coeffs = spyr(img) >>> energy, state = po.process.rectangular_to_polar_dict(coeffs) >>> po.plot.pyrshow(energy) <PyrFigure size ...> >>> po.plot.pyrshow(state) <PyrFigure size ...> """ energy = {} state = {} for key in coeff_dict: # ignore residuals if not isinstance(key, str) or not key.startswith("residual"): energy[key], state[key] = signal.rectangular_to_polar(coeff_dict[key]) if residuals: energy["residual_lowpass"] = coeff_dict["residual_lowpass"] energy["residual_highpass"] = coeff_dict["residual_highpass"] return energy, state
[docs] def polar_to_rectangular_dict( energy: dict, state: dict, ) -> dict: """ Return the real and imaginary parts of tensor in a dictionary. Keys in the output are identical to those in the input. Will grab residuals from ``energy``, if present, with keys ``"residual_highpass"`` and ``"residual_lowpass"``. Parameters ---------- energy The dictionary of :class:`torch.Tensor` containing the local complex amplitude. state The dictionary of :class:`torch.Tensor` containing the local phase. Returns ------- coeff_dict A dictionary containing complex tensors of coefficients. See Also -------- :func:`~plenoptic.process.polar_to_rectangular` Same operation on tensors. rectangular_to_polar_dict The inverse operation. local_gain_release_dict The analogous function for real-valued signals. Examples -------- .. plot:: >>> import plenoptic as po >>> import numpy as np >>> import torch >>> img = po.data.einstein() >>> spyr = po.process.SteerablePyramidFreq( ... img.shape[-2:], is_complex=True, height=3 ... ) >>> coeffs = spyr(img) >>> energy, state = po.process.rectangular_to_polar_dict(coeffs, residuals=True) >>> coeffs_back = po.process.polar_to_rectangular_dict(energy, state) >>> all(torch.allclose(coeffs[key], coeffs_back[key]) for key in coeffs) True >>> po.plot.pyrshow(coeffs_back) <PyrFigure size ...> """ coeff_dict = {} for key in energy: # ignore residuals here if not isinstance(key, str) or not key.startswith("residual"): coeff_dict[key] = signal.polar_to_rectangular(energy[key], state[key]) if "residual_lowpass" in energy: coeff_dict["residual_lowpass"] = energy["residual_lowpass"] coeff_dict["residual_highpass"] = energy["residual_highpass"] return coeff_dict
[docs] def local_gain_control( x: torch.Tensor, epsilon: float = 1e-8 ) -> tuple[torch.Tensor, torch.Tensor]: """ Spatially local gain control. Compute the local energy and phase of a real-valued tensor. Parameters ---------- x Tensor of shape (batch, channel, height, width) or (batch, channel, angle, height, width). epsilon Small constant to avoid division by zero. Returns ------- norm The local energy of ``x``, shape (batch, channel, height/2, width/2) or (batch, channel, angle, height/2, width/2), depending on dimensionality of ``x``. direction The local phase of ``x`` (a.k.a. local unit vector, or local state), shape (batch, channel, height, width) or (batch, channel, angle, height, width), depending on dimensionality of ``x``. Raises ------ ValueError If ``x`` does not have 4 or 5 dimensions. See Also -------- local_gain_control_dict Same operation on dictionaries. local_gain_release The inverse operation. :func:`~plenoptic.process.rectangular_to_polar` The analogous function for complex-valued signals. Notes ----- Norm and direction (analogous to complex amplitude and phase) are defined using blurring operator and division. Indeed blurring the responses removes high frequencies introduced by the squaring operation. In the complex case adding the quadrature pair response has the same effect (note that this is most clearly seen in the frequency domain). Here computing the direction (phase) reduces to dividing out the norm (amplitude), indeed the signal only has one real component. This is a normalization operation (local unit vector), hence the connection to local gain control. Examples -------- .. plot:: >>> import plenoptic as po >>> img = po.data.einstein() >>> norm, direction = po.process.local_gain_control(img) >>> po.plot.imshow([img, norm, direction], title=["image", "norm", "direction"]) <PyrFigure size ...> """ # these could be parameters, but no use case so far p = 2.0 def _local_gain_control(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Compute gain control in helper function we can vmap.""" # noqa: DOC201 # numpydoc ignore=ES01,PR01,RT01,EX01 norm = blur_downsample(torch.abs(x**p)).pow(1 / p) odd = torch.as_tensor(x.shape)[-2:] % 2 direction = x / (upsample_blur(norm, odd) + epsilon) return norm, direction if x.ndim == 5: func = torch.vmap(_local_gain_control, in_dims=2, out_dims=2) elif x.ndim == 4: func = _local_gain_control else: raise ValueError("Tensor must have 4 or 5 dimensions!") return func(x)
[docs] def local_gain_release( norm: torch.Tensor, direction: torch.Tensor, epsilon: float = 1e-8, ) -> torch.Tensor: """ Spatially local gain release. Convert the local energy and phase to a single real-valued tensor. Parameters ---------- norm The local energy of a tensor, with shape (batch, channel, height/2, width/2) or (batch, channel, angle, height/2, width/2). direction The local phase of a tensor (a.k.a. local unit vector, or local state), with shape (batch, channel, height, width) or (batch, channel, angle, height, width). epsilon Small constant to avoid division by zero. Returns ------- x Tensor of shape (batch, channel, height, width) or (batch, channel, angle, height, width), depending on input tensor dimensionality. Raises ------ ValueError If input tensors do not have 4 or 5 dimensions. See Also -------- local_gain_release_dict Same operation on dictionaries. local_gain_control The inverse operation. :func:`~plenoptic.process.polar_to_rectangular` The analogous function for complex-valued signals. Notes ----- Norm and direction (analogous to complex amplitude and phase) are defined using blurring operator and division. Indeed blurring the responses removes high frequencies introduced by the squaring operation. In the complex case adding the quadrature pair response has the same effect (note that this is most clearly seen in the frequency domain). Here computing the direction (phase) reduces to dividing out the norm (amplitude), indeed the signal only has one real component. This is a normalization operation (local unit vector), hence the connection to local gain control. Examples -------- .. plot:: >>> import plenoptic as po >>> img = po.data.einstein() >>> norm, direction = po.process.local_gain_control(img) >>> x = po.process.local_gain_release(norm, direction) >>> po.plot.imshow( ... [img, x, img - x], ... title=["Original image", "Gain release output", "Difference"], ... ) <PyrFigure size ...> """ def _local_gain_release( direction: torch.Tensor, norm: torch.Tensor ) -> torch.Tensor: """Compute gain release in helper function we can vmap.""" # noqa: DOC201 # numpydoc ignore=ES01,PR01,RT01,EX01 odd = torch.as_tensor(direction.shape)[-2:] % 2 return direction * (upsample_blur(norm, odd) + epsilon) if direction.ndim == 5: func = torch.vmap(_local_gain_release, in_dims=2, out_dims=2) elif direction.ndim == 4: func = _local_gain_release else: raise ValueError("Tensor must have 4 or 5 dimensions!") return func(direction, norm)
[docs] def local_gain_control_dict( coeff_dict: dict, residuals: bool = True, ) -> tuple[dict, dict]: """ Spatially local gain control, for each element in a dictionary. For more details, see :func:`local_gain_control`. Parameters ---------- coeff_dict A dictionary containing tensors of shape (batch, channel, height, width) or (batch, channel, angle, height, width). residuals An option to carry around residuals in the energy dict. Note that the transformation is not applied to the residuals, that is dictionary elements with a key starting in "residual". Returns ------- energy The dictionary of :class:`torch.Tensor` containing the local energy of ``x``. Tensor shapes match those found in ``coeff_dict``. state The dictionary of :class:`torch.Tensor` containing the local phase of ``x``. Tensor shapes match those found in ``coeff_dict``. Raises ------ ValueError If the tensors contained within ``coeff_dict`` do not have 4 or 5 dimensions. See Also -------- local_gain_control Same operation on tensors. local_gain_release_dict The inverse operation. rectangular_to_polar_dict The analogous function for complex-valued signals. Notes ----- Note that energy and state are not computed on the residuals. Examples -------- .. plot:: >>> import plenoptic as po >>> img = po.data.einstein() >>> spyr = po.process.SteerablePyramidFreq(img.shape[-2:], height=3) >>> coeffs = spyr(img) >>> energy, state = po.process.local_gain_control_dict(coeffs) >>> po.plot.pyrshow(energy) <PyrFigure size ...> >>> po.plot.pyrshow(state) <PyrFigure size ...> """ energy = {} state = {} for key in coeff_dict: if not isinstance(key, str) or not key.startswith("residual"): energy[key], state[key] = local_gain_control(coeff_dict[key]) if residuals: energy["residual_lowpass"] = coeff_dict["residual_lowpass"] energy["residual_highpass"] = coeff_dict["residual_highpass"] return energy, state
[docs] def local_gain_release_dict( energy: dict, state: dict, residuals: bool = True, ) -> dict: """ Spatially local gain release, for each element in a dictionary. For more details, see :func:`local_gain_release`. Parameters ---------- energy The dictionary of :class:`torch.Tensor` containing the local energy of ``x``, with shape (batch, channel, height, width) or (batch, channel, angle, height, width). state The dictionary of :class:`torch.Tensor` containing the local phase of ``x``. residuals An option to carry around residuals in the energy dict. Note that the transformation is not applied to the residuals, that is dictionary elements with a key starting in "residual". Returns ------- coeff_dict A dictionary containing the "gain released" tensors, with shapes matching those found in ``energy``. Raises ------ ValueError If the tensors contained within ``energy`` and ``state`` do not have 4 or 5 dimensions. See Also -------- local_gain_release Same operation on tensors. local_gain_control_dict The inverse operation. polar_to_rectangular_dict The analogous function for complex-valued signals. Examples -------- .. plot:: >>> import plenoptic as po >>> import torch >>> img = po.data.einstein() >>> spyr = po.process.SteerablePyramidFreq(img.shape[-2:], height=3) >>> coeffs = spyr(img) >>> energy, state = po.process.local_gain_control_dict(coeffs) >>> coeffs_dict = po.process.local_gain_release_dict(energy, state) >>> all([torch.allclose(coeffs[k], coeffs_dict[k]) for k in coeffs.keys()]) True >>> po.plot.pyrshow(coeffs_dict) <PyrFigure size ...> """ coeff_dict = {} for key in energy: if not isinstance(key, str) or not key.startswith("residual"): coeff_dict[key] = local_gain_release(energy[key], state[key]) if residuals: coeff_dict["residual_lowpass"] = energy["residual_lowpass"] coeff_dict["residual_highpass"] = energy["residual_highpass"] return coeff_dict