Source code for plenoptic.tools.signal

import numpy as np
import torch
from deprecated.sphinx import deprecated
from pyrtools.pyramids.steer import steer_to_harmonics_mtx
from torch import Tensor


[docs] @deprecated( "Use :py:func:`einops.reduce` instead: https://einops.rocks/1-einops-basics/#meet-einopsreduce", version="1.1.0", ) def minimum(x: Tensor, dim: list[int] | None = None, keepdim: bool = False) -> Tensor: r"""Compute minimum in torch over any axis or combination of axes in tensor. Parameters ---------- x Input tensor. dim Dimensions over which you would like to compute the minimum. keepdim Keep original dimensions of tensor when returning result. Returns ------- min_x Minimum value of x. Notes ----- """ if dim is None: dim = tuple(range(x.ndim)) dim = reversed(sorted(dim)) min_x = x for i in dim: min_x, _ = min_x.min(i, keepdim) return min_x
[docs] @deprecated( "Use :py:func:`einops.reduce` instead: https://einops.rocks/1-einops-basics/#meet-einopsreduce", version="1.1.0", ) # noqa: E501 def maximum(x: Tensor, dim: list[int] | None = None, keepdim: bool = False) -> Tensor: r"""Compute maximum in torch over any dim or combination of axes in tensor. Parameters ---------- x Input tensor dim Dimensions over which you would like to compute the minimum keepdim Keep original dimensions of tensor when returning result Returns ------- max_x Maximum value of x. Notes ----- """ if dim is None: dim = tuple(range(x.ndim)) dim = reversed(sorted(dim)) max_x = x for i in dim: max_x, _ = max_x.max(i, keepdim) return max_x
[docs] def rescale(x: Tensor, a: float = 0.0, b: float = 1.0) -> Tensor: r"""Linearly rescale the dynamic range of the input x to [a,b].""" v = x.max() - x.min() g = x - x.min() if v > 0: g = g / v return a + g * (b - a)
[docs] def raised_cosine( width: float = 1, position: float = 0, values: tuple[float, float] = (0, 1) ) -> tuple[np.ndarray, np.ndarray]: """Return a lookup table containing a "raised cosine" soft threshold function. Y = VALUES(1) + (VALUES(2)-VALUES(1)) * cos^2( PI/2 * (X - POSITION + WIDTH)/WIDTH ) This lookup table is suitable for use by `interpolate1d` Parameters ---------- width The width of the region over which the transition occurs. position The location of the center of the threshold. values 2-tuple specifying the values to the left and right of the transition. Returns ------- X The x values of this raised cosine. Y The y values of this raised cosine. """ sz = 256 # arbitrary! X = np.pi * np.arange(-sz - 1, 2) / (2 * sz) Y = values[0] + (values[1] - values[0]) * np.cos(X) ** 2 # make sure end values are repeated, for extrapolation... Y[0] = Y[1] Y[sz + 2] = Y[sz + 1] X = position + (2 * width / np.pi) * (X + np.pi / 4) return X, Y
[docs] def interpolate1d( x_new: Tensor, Y: Tensor | np.ndarray, X: Tensor | np.ndarray ) -> Tensor: r"""One-dimensional linear interpolation. Returns the one-dimensional piecewise linear interpolant to a function with given discrete data points (X, Y), evaluated at x_new. Note: this function is just a wrapper around ``np.interp()``. Parameters ---------- x_new The x-coordinates at which to evaluate the interpolated values. Y The y-coordinates of the data points. X The x-coordinates of the data points, same length as X. Returns ------- Interpolated values of shape identical to `x_new`. """ out = np.interp(x=x_new.flatten(), xp=X, fp=Y) return np.reshape(out, x_new.shape)
[docs] def rectangular_to_polar(x: Tensor) -> tuple[Tensor, Tensor]: r"""Rectangular to polar coordinate transform Parameters ---------- x Complex tensor. Returns ------- amplitude Tensor containing the amplitude (aka. complex modulus). phase Tensor containing the phase. """ amplitude = torch.abs(x) phase = torch.angle(x) return amplitude, phase
[docs] def polar_to_rectangular(amplitude: Tensor, phase: Tensor) -> Tensor: r"""Polar to rectangular coordinate transform Parameters ---------- amplitude Tensor containing the amplitude (aka. complex modulus). Must be > 0. phase Tensor containing the phase Returns ------- Complex tensor. """ if (amplitude < 0).any(): raise ValueError("Amplitudes must be strictly positive.") real = amplitude * torch.cos(phase) imaginary = amplitude * torch.sin(phase) return torch.complex(real, imaginary)
[docs] def steer( basis: Tensor, angle: np.ndarray | Tensor | float, harmonics: list[int] | None = None, steermtx: Tensor | np.ndarray | None = None, return_weights: bool = False, even_phase: bool = True, ): """Steer BASIS to the specfied ANGLE. Parameters ---------- basis Array whose columns are vectorized rotated copies of a steerable function, or the responses of a set of steerable filters. angle Scalar or column vector the size of the basis. specifies the angle(s) (in radians) to steer to harmonics A list of harmonic numbers indicating the angular harmonic content of the basis. if None (default), N even or odd low frequencies, as for derivative filters steermtx Matrix which maps the filters onto Fourier series components (ordered [cos0 cos1 sin1 cos2 sin2 ... sinN]). See steer_to_harmonics_mtx function for more details. If None (default), assumes cosine phase harmonic components, and filter positions at 2pi*n/N. return_weights Whether to return the weights or not. even_phase Specifies whether the harmonics are cosine or sine phase aligned about those positions. Returns ------- res The resteered basis. steervect The weights used to resteer the basis. only returned if ``return_weights`` is True. """ num = basis.shape[-1] device = basis.device if isinstance(angle, int | float): angle = np.array([angle]) else: if angle.shape[0] != basis.shape[0] or angle.shape[1] != 1: raise Exception( "ANGLE must be a scalar, or a column vector the" "size of the basis elements" ) # If HARMONICS is not specified, assume derivatives. if harmonics is None: harmonics = np.arange(1 - (num % 2), num, 2) if len(harmonics.shape) == 1 or harmonics.shape[0] == 1: # reshape to column matrix harmonics = harmonics.reshape(harmonics.shape[0], 1) elif harmonics.shape[0] != 1 and harmonics.shape[1] != 1: raise Exception("input parameter HARMONICS must be 1D!") if 2 * harmonics.shape[0] - (harmonics == 0).sum() != num: raise Exception("harmonics list is incompatible with basis size!") # If STEERMTX not passed, assume evenly distributed cosine-phase filters: if steermtx is None: steermtx = steer_to_harmonics_mtx( harmonics, np.pi * np.arange(num) / num, even_phase=even_phase ) steervect = np.zeros((angle.shape[0], num)) arg = angle * harmonics[np.nonzero(harmonics)[0]].T if all(harmonics): steervect[:, range(0, num, 2)] = np.cos(arg) steervect[:, range(1, num, 2)] = np.sin(arg) else: steervect[:, 0] = np.ones((arg.shape[0], 1)) steervect[:, range(1, num, 2)] = np.cos(arg) steervect[:, range(2, num, 2)] = np.sin(arg) steervect = np.dot(steervect, steermtx) steervect = torch.as_tensor(steervect, dtype=basis.dtype).to(device) if steervect.shape[0] > 1: tmp = basis @ steervect res = tmp.sum().t() else: res = basis @ steervect.t() if return_weights: return res, steervect.reshape(num) else: return res
[docs] def make_disk( img_size: int | tuple[int, int] | torch.Size, outer_radius: float | None = None, inner_radius: float | None = None, ) -> Tensor: r"""Create a circular mask with softened edges to an image. All values within ``inner_radius`` will be 1, and all values from ``inner_radius`` to ``outer_radius`` will decay smoothly to 0. Parameters ---------- img_size Size of image in pixels. outer_radius Total radius of disk. Values from ``inner_radius`` to ``outer_radius`` will decay smoothly to zero. inner_radius Radius of inner disk. All elements from the origin to ``inner_radius`` will be set to 1. Returns ------- mask Tensor mask with torch.Size(img_size). """ if isinstance(img_size, int): img_size = (img_size, img_size) assert len(img_size) == 2 if outer_radius is None: outer_radius = (min(img_size) - 1) / 2 if inner_radius is None: inner_radius = outer_radius / 2 mask = torch.empty(*img_size) i0, j0 = (img_size[0] - 1) / 2, (img_size[1] - 1) / 2 # image center for i in range(img_size[0]): # height for j in range(img_size[1]): # width r = np.sqrt((i - i0) ** 2 + (j - j0) ** 2) if r > outer_radius: mask[i][j] = 0 elif r < inner_radius: mask[i][j] = 1 else: radial_decay = (r - inner_radius) / (outer_radius - inner_radius) mask[i][j] = (1 + np.cos(np.pi * radial_decay)) / 2 return mask
[docs] def add_noise(img: Tensor, noise_mse: float | list[float]) -> Tensor: """Add normally distributed noise to an image This adds normally-distributed noise to an image so that the resulting noisy version has the specified mean-squared error. Parameters ---------- img The image to make noisy. noise_mse The target MSE value / variance of the noise. More than one value is allowed. Returns ------- noisy_img The noisy image. If `noise_mse` contains only one element, this will be the same size as `img`. Else, each separate value from `noise_mse` will be along the batch dimension. """ noise_mse = torch.as_tensor( noise_mse, dtype=img.dtype, device=img.device ).unsqueeze(0) noise_mse = noise_mse.view(noise_mse.nelement(), 1, 1, 1) noise = 200 * torch.randn( max(noise_mse.shape[0], img.shape[0]), *img.shape[1:], device=img.device, ) noise = noise - noise.mean() noise = noise * torch.sqrt( noise_mse / (noise**2).mean((-1, -2)).unsqueeze(-1).unsqueeze(-1) ) return img + noise
[docs] def modulate_phase(x: Tensor, phase_factor: float = 2.0) -> Tensor: """Modulate the phase of a complex signal. Doubling the phase of a complex signal allows you to, for example, take the correlation between steerable pyramid coefficients at two adjacent spatial scales. Parameters ---------- x : Complex tensor whose phase will be modulated. phase_factor : Multiplicative factor to change phase by. Returns ------- x_mod : Phase-modulated complex tensor. """ try: angle = torch.atan2(x.imag, x.real) except RuntimeError: # then x is not complex-valued raise TypeError("x must be a complex-valued tensor!") amp = x.abs() real = amp * torch.cos(phase_factor * angle) imag = amp * torch.sin(phase_factor * angle) return torch.complex(real, imag)
[docs] def autocorrelation(x: Tensor) -> Tensor: r"""Compute the autocorrelation of `x`. Parameters ---------- x : N-dimensional tensor. We assume the last two dimension are height and width and compute you autocorrelation on these dimensions (independently on each other dimension). Returns ------- ac : Autocorrelation of x Notes ----- - By the Einstein-Wiener-Khinchin theorem: The autocorrelation of a wide sense stationary (WSS) process is the inverse Fourier transform of its energy spectrum (ESD) - which itself is the multiplication between FT(x(t)) and FT(x(-t)). In other words, the auto-correlation is convolution of the signal `x` with itself, which corresponds to squaring in the frequency domain. This approach is computationally more efficient than brute force (n log(n) vs n^2). - By Cauchy-Swartz, the autocorrelation attains it is maximum at the center location (ie. no shift) - that maximum value is the signal's variance (assuming that the input signal is mean centered). """ # Calculate the auto-correlation ac = torch.fft.rfft2(x) # this is equivalent to ac.abs().pow(2) or to ac multiplied by its complex # conjugate ac = ac.real.pow(2) + ac.imag.pow(2) ac = torch.fft.irfft2(ac) ac = torch.fft.fftshift(ac, dim=(-2, -1)) / torch.mul(*x.shape[-2:]) return ac
[docs] def center_crop(x: Tensor, output_size: int) -> Tensor: """Crop out the center of a signal. If x has an even number of elements on either of those final two dimensions, we round up. Parameters ---------- x : N-dimensional tensor, we assume the last two dimensions are height and width. output_size : The size of the output. Note that we only support a single number, so both dimensions are cropped identically Returns ------- cropped : Tensor whose last two dimensions have each been cropped to ``output_size`` """ h, w = x.shape[-2:] return x[ ..., (h // 2 - output_size // 2) : (h // 2 + (output_size + 1) // 2), (w // 2 - output_size // 2) : (w // 2 + (output_size + 1) // 2), ]
[docs] def expand(x: Tensor, factor: float) -> Tensor: r"""Expand a signal by a factor. We do this in the frequency domain: pasting the Fourier contents of ``x`` in the center of a larger empty tensor, and then taking the inverse FFT. Parameters ---------- x: The signal for expansion. factor : Factor by which to resize image. Must be larger than 1 and `factor * x.shape[-2:]` must give integer values Returns ------- expanded : The expanded signal See Also -------- shrink : The inverse operation """ if factor <= 1: raise ValueError("factor must be strictly greater than 1!") im_x = x.shape[-1] im_y = x.shape[-2] mx = factor * im_x my = factor * im_y if int(mx) != mx: raise ValueError( f"factor * x.shape[-1] must be an integer but got {mx} instead!" ) if int(my) != my: raise ValueError( f"factor * x.shape[-2] must be an integer but got {my} instead!" ) mx = int(mx) my = int(my) fourier = factor**2 * torch.fft.fftshift(torch.fft.fft2(x), dim=(-2, -1)) fourier_large = torch.zeros( *x.shape[:-2], my, mx, device=fourier.device, dtype=fourier.dtype, ) y1 = my / 2 + 1 - im_y / 2 y2 = my / 2 + im_y / 2 x1 = mx / 2 + 1 - im_x / 2 x2 = mx / 2 + im_x / 2 # when any of these numbers are non-integers, if you round down, the # resulting image will be off. y1 = int(np.ceil(y1)) y2 = int(np.ceil(y2)) x1 = int(np.ceil(x1)) x2 = int(np.ceil(x2)) fourier_large[..., y1:y2, x1:x2] = fourier[..., 1:, 1:] fourier_large[..., y1 - 1, x1:x2] = fourier[..., 0, 1:] / 2 fourier_large[..., y2, x1:x2] = fourier[..., 0, 1:].flip(-1) / 2 fourier_large[..., y1:y2, x1 - 1] = fourier[..., 1:, 0] / 2 fourier_large[..., y1:y2, x2] = fourier[..., 1:, 0].flip(-1) / 2 esq = fourier[..., 0, 0] / 4 fourier_large[..., y1 - 1, x1 - 1] = esq fourier_large[..., y1 - 1, x2] = esq fourier_large[..., y2, x1 - 1] = esq fourier_large[..., y2, x2] = esq fourier_large = torch.fft.ifftshift(fourier_large, dim=(-2, -1)) im_large = torch.fft.ifft2(fourier_large) # if input was real-valued, output should be real-valued, but # using fft/ifft above means im_large will always be complex, # so make sure they align. if not x.is_complex(): im_large = torch.real(im_large) return im_large
[docs] def shrink(x: Tensor, factor: int) -> Tensor: r"""Shrink a signal by a factor. We do this in the frequency domain: cropping out the center of the Fourier transform of ``x``, putting it in a new tensor, and taking the IFFT. Parameters ---------- x : The signal for expansion. factor : Factor by which to resize image. Must be larger than 1 and `factor / x.shape[-2:]` must give integer values Returns ------- expanded : The expanded signal See Also -------- expand : The inverse operation """ if factor <= 1: raise ValueError("factor must be strictly greater than 1!") im_x = x.shape[-1] im_y = x.shape[-2] mx = im_x / factor my = im_y / factor if int(mx) != mx: raise ValueError(f"x.shape[-1]/factor must be an integer but got {mx} instead!") if int(my) != my: raise ValueError(f"x.shape[-2]/factor must be an integer but got {my} instead!") mx = int(mx) my = int(my) fourier = 1 / factor**2 * torch.fft.fftshift(torch.fft.fft2(x), dim=(-2, -1)) fourier_small = torch.zeros( *x.shape[:-2], my, mx, device=fourier.device, dtype=fourier.dtype, ) y1 = im_y / 2 + 1 - my / 2 y2 = im_y / 2 + my / 2 x1 = im_x / 2 + 1 - mx / 2 x2 = im_x / 2 + mx / 2 # when any of these numbers are non-integers, if you round down, the # resulting image will be off. y1 = int(np.ceil(y1)) y2 = int(np.ceil(y2)) x1 = int(np.ceil(x1)) x2 = int(np.ceil(x2)) # This line is equivalent to fourier_small[..., 1:, 1:] = fourier[..., y1:y2, x1:x2] fourier_small[..., 0, 1:] = ( fourier[..., y1 - 1, x1:x2] + fourier[..., y2, x1:x2] ) / 2 fourier_small[..., 1:, 0] = ( fourier[..., y1:y2, x1 - 1] + fourier[..., y1:y2, x2] ) / 2 fourier_small[..., 0, 0] = ( fourier[..., y1 - 1, x1 - 1] + fourier[..., y1 - 1, x2] + fourier[..., y2, x1 - 1] + fourier[..., y2, x2] ) / 4 fourier_small = torch.fft.ifftshift(fourier_small, dim=(-2, -1)) im_small = torch.fft.ifft2(fourier_small) # if input was real-valued, output should be real-valued, but # using fft/ifft above means im_small will always be complex, # so make sure they align. if not x.is_complex(): im_small = torch.real(im_small) return im_small