Source code for plenoptic.simulate.canonical_computations.non_linearities

import torch

from ...tools.conv import blur_downsample, upsample_blur
from ...tools.signal import polar_to_rectangular, rectangular_to_polar


[docs] def rectangular_to_polar_dict(coeff_dict, residuals=False): """Return the complex modulus and the phase of each complex tensor in a dictionary. Parameters ---------- coeff_dict : dict A dictionary containing complex tensors. dim : int, optional The dimension that contains the real and imaginary components. residuals: bool, optional An option to carry around residuals in the energy branch. Returns ------- energy : dict The dictionary of torch.Tensors containing the local complex modulus of ``coeff_dict``. state: dict The dictionary of torch.Tensors containing the local phase of ``coeff_dict``. """ energy = {} state = {} for key in coeff_dict: # ignore residuals if isinstance(key, tuple) or not key.startswith("residual"): energy[key], state[key] = rectangular_to_polar(coeff_dict[key]) if residuals: energy["residual_lowpass"] = coeff_dict["residual_lowpass"] energy["residual_highpass"] = coeff_dict["residual_highpass"] return energy, state
[docs] def polar_to_rectangular_dict(energy, state, residuals=True): """Return the real and imaginary parts of tensor in a dictionary. Parameters ---------- energy : dict The dictionary of torch.Tensors containing the local complex modulus. state : dict The dictionary of torch.Tensors containing the local phase. dim : int, optional The dimension that contains the real and imaginary components. residuals: bool, optional An option to carry around residuals in the energy branch. Returns ------- coeff_dict : dict A dictionary containing complex tensors of coefficients. """ coeff_dict = {} for key in energy: # ignore residuals if isinstance(key, tuple) or not key.startswith("residual"): coeff_dict[key] = polar_to_rectangular(energy[key], state[key]) if residuals: coeff_dict["residual_lowpass"] = energy["residual_lowpass"] coeff_dict["residual_highpass"] = energy["residual_highpass"] return coeff_dict
[docs] def local_gain_control(x, epsilon=1e-8): """Spatially local gain control. Parameters ---------- x : torch.Tensor Tensor of shape (batch, channel, height, width) epsilon: float, optional Small constant to avoid division by zero. Returns ------- norm : torch.Tensor The local energy of ``x``. Note that it is down sampled by a factor 2 in (unlike rect2pol). direction: torch.Tensor The local phase of ``x`` (aka. local unit vector, or local state) Notes ----- This function is an analogue to rectangular_to_polar for real valued signals. Norm and direction (analogous to complex modulus and phase) are defined using blurring operator and division. Indeed blurring the responses removes high frequencies introduced by the squaring operation. In the complex case adding the quadrature pair response has the same effect (note that this is most clearly seen in the frequency domain). Here computing the direction (phase) reduces to dividing out the norm (modulus), indeed the signal only has one real component. This is a normalization operation (local unit vector), hence the connection to local gain control. """ # these could be parameters, but no use case so far p = 2.0 norm = blur_downsample(torch.abs(x**p)).pow(1 / p) odd = torch.as_tensor(x.shape)[2:4] % 2 direction = x / (upsample_blur(norm, odd) + epsilon) return norm, direction
[docs] def local_gain_release(norm, direction, epsilon=1e-8): """Spatially local gain release. Parameters ---------- norm : torch.Tensor The local energy of ``x``. Note that it is down sampled by a factor 2 in (unlike rect2pol). direction: torch.Tensor The local phase of ``x`` (aka. local unit vector, or local state) epsilon: float, optional Small constant to avoid division by zero. Returns ------- x : torch.Tensor Tensor of shape (batch, channel, height, width) Notes ----- This function is an analogue to polar_to_rectangular for real valued signals. Norm and direction (analogous to complex modulus and phase) are defined using blurring operator and division. Indeed blurring the responses removes high frequencies introduced by the squaring operation. In the complex case adding the quadrature pair response has the same effect (note that this is most clearly seen in the frequency domain). Here computing the direction (phase) reduces to dividing out the norm (modulus), indeed the signal only has one real component. This is a normalization operation (local unit vector), hence the connection to local gain control. """ odd = torch.as_tensor(direction.shape)[2:4] % 2 x = direction * (upsample_blur(norm, odd) + epsilon) return x
[docs] def local_gain_control_dict(coeff_dict, residuals=True): """Spatially local gain control, for each element in a dictionary. Parameters ---------- coeff_dict : dict A dictionary containing tensors of shape (batch, channel, height, width) residuals: bool, optional An option to carry around residuals in the energy dict. Note that the transformation is not applied to the residuals, that is dictionary elements with a key starting in "residual". Returns ------- energy : dict The dictionary of torch.Tensors containing the local energy of ``x``. state: dict The dictionary of torch.Tensors containing the local phase of ``x``. Notes ----- Note that energy and state is not computed on the residuals. The inverse operation is achieved by `local_gain_release_dict`. This function is an analogue to rectangular_to_polar_dict for real valued signals. For more details, see :meth:`local_gain_control` """ energy = {} state = {} for key in coeff_dict: if isinstance(key, tuple) or not key.startswith("residual"): energy[key], state[key] = local_gain_control(coeff_dict[key]) if residuals: energy["residual_lowpass"] = coeff_dict["residual_lowpass"] energy["residual_highpass"] = coeff_dict["residual_highpass"] return energy, state
[docs] def local_gain_release_dict(energy, state, residuals=True): """Spatially local gain release, for each element in a dictionary. Parameters ---------- energy : dict The dictionary of torch.Tensors containing the local energy of ``x``. state: dict The dictionary of torch.Tensors containing the local phase of ``x``. residuals: bool, optional An option to carry around residuals in the energy dict. Note that the transformation is not applied to the residuals, that is dictionary elements with a key starting in "residual". Returns ------- coeff_dict : dict A dictionary containing tensors of shape (batch, channel, height, width) Notes ----- The inverse operation to `local_gain_control_dict`. This function is an analogue to polar_to_rectangular_dict for real valued signals. For more details, see :meth:`local_gain_release` """ coeff_dict = {} for key in energy: if isinstance(key, tuple) or not key.startswith("residual"): coeff_dict[key] = local_gain_release(energy[key], state[key]) if residuals: coeff_dict["residual_lowpass"] = energy["residual_lowpass"] coeff_dict["residual_highpass"] = energy["residual_highpass"] return coeff_dict