"""Helper functions for modifying tensors in useful ways."""
# numpydoc ignore=ES01
import numpy as np
import torch
from pyrtools.pyramids.steer import steer_to_harmonics_mtx
from torch import Tensor
__all__ = [
"add_noise",
"autocorrelation",
"center_crop",
"expand",
"modulate_phase",
"polar_to_rectangular",
"rectangular_to_polar",
"rescale",
"shrink",
]
def __dir__() -> list[str]:
return __all__
[docs]
def rescale(x: Tensor, a: float = 0.0, b: float = 1.0) -> Tensor:
r"""
Linearly rescale the dynamic range of the input to ``[a, b]``.
Parameters
----------
x
Tensor to rescale.
a, b
Min and max values, respectively, for the output.
Returns
-------
rescaled_x
The rescaled tensor.
Examples
--------
>>> import plenoptic as po
>>> import torch
>>> x = torch.tensor([2.0, 4.0, 6.0, 8.0])
>>> po.process.rescale(x)
tensor([0.0000, 0.3333, 0.6667, 1.0000])
>>> po.process.rescale(x, a=-1, b=1)
tensor([-1.0000, -0.3333, 0.3333, 1.0000])
""" # numpydoc ignore=ES01
v = x.max() - x.min()
g = x - x.min()
if v > 0:
g = g / v
return a + g * (b - a)
def _raised_cosine(
width: float = 1, position: float = 0, values: tuple[float, float] = (0, 1)
) -> tuple[np.ndarray, np.ndarray]:
"""
Return a lookup table containing a "raised cosine" soft threshold function.
.. code::
Y = VALUES(1)
+ (VALUES(2)-VALUES(1))
* cos^2( PI/2 * (X - POSITION + WIDTH)/WIDTH )
This lookup table is suitable for use by :func:`_interpolate1d`.
Parameters
----------
width
The width of the region over which the transition occurs.
position
The location of the center of the threshold.
values
2-tuple specifying the values to the left and right of the transition.
Returns
-------
X
The x values of this raised cosine.
Y
The y values of this raised cosine.
"""
sz = 256 # arbitrary!
X = np.pi * np.arange(-sz - 1, 2) / (2 * sz)
Y = values[0] + (values[1] - values[0]) * np.cos(X) ** 2
# make sure end values are repeated, for extrapolation...
Y[0] = Y[1]
Y[sz + 2] = Y[sz + 1]
X = position + (2 * width / np.pi) * (X + np.pi / 4)
return X, Y
def _interpolate1d(
x_new: Tensor, Y: Tensor | np.ndarray, X: Tensor | np.ndarray
) -> Tensor:
r"""
One-dimensional linear interpolation.
Returns the one-dimensional piecewise linear interpolant to a
function with given discrete data points ``(X, Y)``, evaluated at
``x_new``.
Note: this function is just a wrapper around :func:`np.interp()`.
Parameters
----------
x_new
The x-coordinates at which to evaluate the interpolated values.
Y
The y-coordinates of the data points.
X
The x-coordinates of the data points, same length as X.
Returns
-------
interp_x
Interpolated values of shape identical to ``x_new``.
"""
out = np.interp(x=x_new.flatten(), xp=X, fp=Y)
return np.reshape(out, x_new.shape)
[docs]
def rectangular_to_polar(x: Tensor) -> tuple[Tensor, Tensor]:
r"""
Rectangular to polar coordinate transform.
If input is real-valued, ``amplitude`` will be identical to the input and ``phase``
will be all 0s.
Parameters
----------
x
Complex tensor.
Returns
-------
amplitude
Tensor containing the amplitude (aka. complex modulus).
phase
Tensor containing the phase.
See Also
--------
:func:`~plenoptic.process.rectangular_to_polar_dict`
Same operation on dictionaries.
polar_to_rectangular
The inverse operation.
:func:`~plenoptic.process.local_gain_control`
The analogous function for real-valued signals.
Examples
--------
.. plot::
:context: reset
>>> import plenoptic as po
>>> import torch
>>> x = torch.tensor([1 + 1j, 1 - 1j])
>>> amplitude, phase = po.process.rectangular_to_polar(x)
>>> amplitude
tensor([1.4142, 1.4142])
>>> phase
tensor([ 0.7854, -0.7854])
In plenoptic, this function is typically used
for working with steerable pyramid coefficients:
.. plot::
:context: close-figs
>>> # starting from an image
>>> img = po.data.einstein()
>>> img.shape
torch.Size([1, 1, 256, 256])
>>> spyr = po.process.SteerablePyramidFreq(img.shape[-2:], is_complex=True)
>>> # let's only look at 1 scale and 1 orientation
>>> coeff = spyr(img)[0][:, :, 0]
>>> # the coefficients returned by spyr (forward) are in rectangular coordinates
>>> # so we can now use this function to get the polar coordinates
>>> amplitude, phase = po.process.rectangular_to_polar(coeff)
>>> amplitude.shape
torch.Size([1, 1, 256, 256])
>>> phase.shape
torch.Size([1, 1, 256, 256])
>>> # we can then invert the operation to verify that we get back the original
>>> rectangular_coeff = po.process.polar_to_rectangular(amplitude, phase)
>>> torch.allclose(coeff, rectangular_coeff)
True
>>> po.plot.imshow([amplitude, phase], title=["amplitude", "phase"])
<PyrFigure...>
"""
amplitude = torch.abs(x)
phase = torch.angle(x)
return amplitude, phase
[docs]
def polar_to_rectangular(amplitude: Tensor, phase: Tensor) -> Tensor:
r"""
Polar to rectangular coordinate transform.
Parameters
----------
amplitude
Tensor containing the amplitude (aka. complex modulus). Must be >= 0.
phase
Tensor containing the phase.
Returns
-------
image
Complex tensor.
Raises
------
ValueError
If ``amplitude`` is not non-negative.
See Also
--------
:func:`~plenoptic.process.polar_to_rectangular_dict`
Same operation on dictionaries.
rectangular_to_polar
The inverse operation.
:func:`~plenoptic.process.local_gain_release`
The analogous function for real-valued signals.
Examples
--------
>>> import plenoptic as po
>>> import torch
>>> amplitude = torch.tensor([1.4142, 1.4142])
>>> phase = torch.tensor([0.7854, -0.7854])
>>> po.process.polar_to_rectangular(amplitude, phase)
tensor([1.+1.j, 1.-1.j])
In plenoptic, this function is typically used
for working with steerable pyramid coefficients:
>>> import plenoptic as po
>>> import torch
>>> # starting from an image
>>> img = po.data.einstein()
>>> img.shape
torch.Size([1, 1, 256, 256])
>>> spyr = po.process.SteerablePyramidFreq(img.shape[-2:], is_complex=True)
>>> # let's only look at 1 scale and 1 orientation
>>> coeff = spyr(img)[0][:, :, 0]
>>> # the coefficients returned by spyr (forward) are in rectangular coordinates
>>> # so, we can manually compute polar coordinates
>>> amplitude, phase = torch.abs(coeff), torch.angle(coeff)
>>> amplitude.shape
torch.Size([1, 1, 256, 256])
>>> phase.shape
torch.Size([1, 1, 256, 256])
>>> # from those, we can use this function to recover the rectangular coordinates
>>> rectangular_coeff = po.process.polar_to_rectangular(amplitude, phase)
>>> rectangular_coeff.shape
torch.Size([1, 1, 256, 256])
>>> # we can verify that they match the original
>>> torch.allclose(coeff, rectangular_coeff)
True
""" # numpydoc ignore=ES01
if (amplitude < 0).any():
raise ValueError("Amplitudes must be non-negative.")
real = amplitude * torch.cos(phase)
imaginary = amplitude * torch.sin(phase)
return torch.complex(real, imaginary)
def _steer(
basis: Tensor,
angle: np.ndarray | Tensor | float,
harmonics: list[int] | None = None,
steermtx: Tensor | np.ndarray | None = None,
even_phase: bool = True,
) -> tuple[Tensor, Tensor]:
"""
Steer ``basis`` to the specified ``angle``.
Parameters
----------
basis
Array whose columns are vectorized rotated copies of a steerable
function, or the responses of a set of steerable filters.
angle
Scalar or column vector the size of the basis. Specifies the angle(s)
(in radians) to steer to.
harmonics
A list of harmonic numbers indicating the angular harmonic content of
the basis. If ``None``, will use N even or odd low frequencies, as for
derivative filters.
steermtx
Matrix which maps the filters onto Fourier series components (ordered
``[cos0, cos1, sin1, cos2, sin2, ..., sinN]``). See
:func:`pyrtools.pyramids.steer.steer_to_harmonics_mtx` function for more
details. If ``None``, assumes cosine phase harmonic components, and
filter positions at ``2pi*n/N``.
even_phase
Specifies whether the harmonics are cosine or sine phase aligned about
those positions.
Returns
-------
res
The resteered basis.
steervect
The weights used to resteer the basis.
Raises
------
ValueError
If ``angle`` is not a scalar or appropriately-sized column vector.
ValueError
If ``harmonics`` is not 1d.
ValueError
If ``harmonics`` is not compatible with the size of ``basis``.
""" # numpydoc ignore=ES01
num = basis.shape[-1]
device = basis.device
if isinstance(angle, int | float) or angle.ndim == 0:
angle = np.array([angle])
else:
if angle.shape[0] != basis.shape[0] or angle.shape[1] != 1:
raise ValueError(
"ANGLE must be a scalar, or a column vector the"
"size of the basis elements"
)
# If HARMONICS is not specified, assume derivatives.
if harmonics is None:
harmonics = np.arange(1 - (num % 2), num, 2)
if len(harmonics.shape) == 1 or harmonics.shape[0] == 1:
# reshape to column matrix
harmonics = harmonics.reshape(harmonics.shape[0], 1)
elif harmonics.shape[0] != 1 and harmonics.shape[1] != 1:
raise ValueError("input parameter HARMONICS must be 1D!")
if 2 * harmonics.shape[0] - (harmonics == 0).sum() != num:
raise ValueError("harmonics list is incompatible with basis size!")
# If STEERMTX not passed, assume evenly distributed cosine-phase filters:
if steermtx is None:
steermtx = steer_to_harmonics_mtx(
harmonics, np.pi * np.arange(num) / num, even_phase=even_phase
)
steervect = np.zeros((angle.shape[0], num))
arg = angle * harmonics[np.nonzero(harmonics)[0]].T
if all(harmonics):
steervect[:, range(0, num, 2)] = np.cos(arg)
steervect[:, range(1, num, 2)] = np.sin(arg)
else:
steervect[:, 0] = np.ones((arg.shape[0], 1))
steervect[:, range(1, num, 2)] = np.cos(arg)
steervect[:, range(2, num, 2)] = np.sin(arg)
steervect = np.dot(steervect, steermtx)
steervect = torch.as_tensor(steervect, dtype=basis.dtype).to(device)
if steervect.shape[0] > 1:
tmp = basis @ steervect
res = tmp.sum().t()
else:
res = basis @ steervect.t()
return res, steervect.reshape(num)
[docs]
def add_noise(img: Tensor, noise_mse: float | list[float]) -> Tensor:
"""
Add normally distributed noise to an image.
This adds normally-distributed noise to an image so that the resulting
noisy version has the specified mean-squared error.
Parameters
----------
img
The image to make noisy.
noise_mse
The target MSE value / variance of the noise. More than one value is
allowed.
Returns
-------
noisy_img
The noisy image. If ``noise_mse`` contains only one element, this will be
the same size as ``img``. Else, each separate value from ``noise_mse`` will
be along the batch dimension.
Examples
--------
Basic usage:
.. plot::
:context: reset
>>> import plenoptic as po
>>> import torch
>>> img = po.data.einstein()
>>> img.shape
torch.Size([1, 1, 256, 256])
>>> noisy = po.process.add_noise(img, noise_mse=0.1)
>>> noisy.shape
torch.Size([1, 1, 256, 256])
>>> po.plot.imshow([img, noisy])
<PyrFigure ...>
With multiple elements in ``noise_mse``:
.. plot::
:context: close-figs
>>> noisy_multi = po.process.add_noise(img, noise_mse=[0.01, 0.1, 1.0])
>>> noisy_multi.shape
torch.Size([3, 1, 256, 256])
>>> po.plot.imshow([img, noisy_multi])
<PyrFigure ...>
"""
noise_mse = torch.as_tensor(
noise_mse, dtype=img.dtype, device=img.device
).unsqueeze(0)
noise_mse = noise_mse.view(noise_mse.nelement(), 1, 1, 1)
noise = 200 * torch.randn(
max(noise_mse.shape[0], img.shape[0]),
*img.shape[1:],
device=img.device,
)
noise = noise - noise.mean()
noise = noise * torch.sqrt(
noise_mse / (noise**2).mean((-1, -2)).unsqueeze(-1).unsqueeze(-1)
)
return img + noise
[docs]
def modulate_phase(x: Tensor, phase_factor: float = 2.0) -> Tensor:
"""
Modulate the phase of a complex signal.
Doubling the phase of a complex signal allows you to, for example, take the
correlation between steerable pyramid coefficients at two adjacent spatial
scales.
Parameters
----------
x
Complex tensor whose phase will be modulated.
phase_factor
Multiplicative factor to change phase by.
Returns
-------
x_mod
Phase-modulated complex tensor.
Raises
------
TypeError
If ``x`` is not complex-valued.
Examples
--------
.. plot::
:context: reset
>>> import plenoptic as po
>>> import torch
>>> img = po.data.einstein()
>>> spyr = po.process.SteerablePyramidFreq(img.shape[-2:], is_complex=True)
>>> # let's only look at 1 scale and 1 orientation
>>> coeff = spyr(img)[3][:, :, 0]
>>> mod_coeff = po.process.modulate_phase(coeff, 2)
>>> po.plot.imshow([coeff, mod_coeff], title=["original", "modulated"], zoom=8)
<PyrFigure ...>
Note how the white and black streaks have changed between original and modulated.
"""
try:
angle = torch.atan2(x.imag, x.real)
except RuntimeError:
# then x is not complex-valued
raise TypeError("x must be a complex-valued tensor!")
amp = x.abs()
real = amp * torch.cos(phase_factor * angle)
imag = amp * torch.sin(phase_factor * angle)
return torch.complex(real, imag)
[docs]
def autocorrelation(x: Tensor) -> Tensor:
r"""
Compute the autocorrelation of ``x``.
This uses the Fourier transform to compute the autocorrelation in an efficient
manner (see Notes).
Parameters
----------
x
N-dimensional tensor. We assume the last two dimension are height and
width and compute you autocorrelation on these dimensions (independently
on each other dimension).
Returns
-------
ac
Autocorrelation of ``x``.
Notes
-----
- By the Einstein-Wiener-Khinchin theorem: The autocorrelation of a wide
sense stationary (WSS) process is the inverse Fourier transform of its
energy spectrum (ESD) - which itself is the multiplication between
FT(x(t)) and FT(x(-t)). In other words, the auto-correlation is
convolution of the signal ``x`` with itself, which corresponds to squaring
in the frequency domain. This approach is computationally more efficient
than brute force (:math:`n log(n)` vs :math:`n^2`).
- By Cauchy-Swartz, the autocorrelation attains it is maximum at the center
location (ie. no shift) - that maximum value is the signal's variance
(assuming that the input signal is mean centered).
Examples
--------
.. plot::
:context: reset
>>> import plenoptic as po
>>> import torch
>>> img = po.data.einstein()
>>> ac = po.process.autocorrelation(img)
>>> po.plot.imshow([img, ac], title=["image", "autocorrelation"])
<PyrFigure...>
If we start from random noise, we do not see the correlation
structure that is found in natural images:
.. plot::
:context: close-figs
>>> random_img = torch.rand(size=(1, 1, 256, 256))
>>> ac_noise = po.process.autocorrelation(random_img)
>>> po.plot.imshow(
... [random_img, ac_noise], title=["random noise", "autocorrelation"]
... )
<PyrFigure...>
Plenoptic models typically do not use the full autocorrelation, but rather
the first couple shifts only. Using a combination of this function and
:func:`~plenoptic.process.center_crop`, that is easily achieved:
.. plot::
:context: close-figs
>>> ac_cropped = po.process.center_crop(ac, 16)
>>> po.plot.imshow(
... [img, ac, ac_cropped],
... title=["image", "autocorrelation", "cropped autocorrelation"],
... )
<PyrFigure...>
"""
# Calculate the auto-correlation
ac = torch.fft.rfft2(x)
# this is equivalent to ac.abs().pow(2) or to ac multiplied by its complex
# conjugate
ac = ac.real.pow(2) + ac.imag.pow(2)
ac = torch.fft.irfft2(ac)
ac = torch.fft.fftshift(ac, dim=(-2, -1)) / torch.mul(*x.shape[-2:])
return ac
[docs]
def center_crop(x: Tensor, output_size: int) -> Tensor:
"""
Crop out the center of a signal.
If x has an even number of elements on either of those final two
dimensions, we round up.
Parameters
----------
x
N-dimensional tensor, we assume the last two dimensions are height and
width.
output_size
The size of the output. Must be a positive int. Note that we only support a
single number, so both dimensions are cropped identically.
Returns
-------
cropped
Tensor whose last two dimensions have each been cropped to
``output_size``.
Raises
------
TypeError
If ``output_size`` is not a single int.
ValueError
If ``output_size is not positive or larger than the height/width of ``x``.
Examples
--------
.. plot::
:context: reset
>>> import plenoptic as po
>>> img = po.data.einstein()
>>> img.shape
torch.Size([1, 1, 256, 256])
>>> crop = po.process.center_crop(img, 128)
>>> crop.shape
torch.Size([1, 1, 128, 128])
>>> po.plot.imshow([img, crop], title=["input", "cropped"])
<PyrFigure ...>
"""
h, w = x.shape[-2:]
output_size = torch.as_tensor(output_size)
if output_size.ndim > 0:
raise TypeError("output_size must be a single number!")
if torch.is_floating_point(output_size):
raise TypeError("output_size must be an int!")
if output_size > h or output_size > w:
raise ValueError("output_size is bigger than image height/width!")
if output_size <= 0:
raise ValueError("output_size must be positive!")
return x[
...,
(h // 2 - output_size // 2) : (h // 2 + (output_size + 1) // 2),
(w // 2 - output_size // 2) : (w // 2 + (output_size + 1) // 2),
]
[docs]
def expand(x: Tensor, factor: float) -> Tensor:
r"""
Expand a signal by a factor.
We do this in the frequency domain: pasting the Fourier contents of ``x``
in the center of a larger empty tensor, and then taking the inverse FFT.
Parameters
----------
x
The signal for expansion.
factor
Factor by which to resize image. Must be larger than 1 and ``factor *
x.shape[-2:]`` must give integer values.
Returns
-------
expanded
The expanded signal.
Raises
------
ValueError
If ``factor`` is less than or equal to 1.
ValueError
If ``factor`` times the height or width of ``x`` is not an integer.
See Also
--------
shrink
The inverse operation.
:func:`~plenoptic.process.upsample_blur`
An alternative upsampling operation.
Examples
--------
.. plot::
:context: reset
>>> import plenoptic as po
>>> import torch
>>> img = po.data.einstein()
>>> img.shape
torch.Size([1, 1, 256, 256])
>>> expanded = po.process.expand(img, factor=2)
>>> expanded.shape
torch.Size([1, 1, 512, 512])
>>> po.plot.imshow(
... [img, expanded], title=["original", "expanded"], zoom=0.5, vrange=(0, 1)
... )
<PyrFigure...>
Note that the range has changed:
>>> img.min(), img.max()
(tensor(0.0039), tensor(1.))
>>> expanded.min(), expanded.max()
(tensor(-0.1648), tensor(1.0239))
An alternative method for upsampling images is to use
:func:`~plenoptic.process.upsample_blur`:
.. plot::
:context: close-figs
>>> po.plot.imshow(
... [img, expanded, po.process.upsample_blur(img, (0, 0))],
... title=["original", "expanded", "blurred"],
... vrange=(0, 1),
... )
<PyrFigure...>
"""
if factor <= 1:
raise ValueError("factor must be strictly greater than 1!")
im_x = x.shape[-1]
im_y = x.shape[-2]
mx = factor * im_x
my = factor * im_y
if int(mx) != mx:
raise ValueError(
f"factor * x.shape[-1] must be an integer but got {mx} instead!"
)
if int(my) != my:
raise ValueError(
f"factor * x.shape[-2] must be an integer but got {my} instead!"
)
mx = int(mx)
my = int(my)
fourier = factor**2 * torch.fft.fftshift(torch.fft.fft2(x), dim=(-2, -1))
fourier_large = torch.zeros(
*x.shape[:-2],
my,
mx,
device=fourier.device,
dtype=fourier.dtype,
)
y1 = my / 2 + 1 - im_y / 2
y2 = my / 2 + im_y / 2
x1 = mx / 2 + 1 - im_x / 2
x2 = mx / 2 + im_x / 2
# when any of these numbers are non-integers, if you round down, the
# resulting image will be off.
y1 = int(np.ceil(y1))
y2 = int(np.ceil(y2))
x1 = int(np.ceil(x1))
x2 = int(np.ceil(x2))
fourier_large[..., y1:y2, x1:x2] = fourier[..., 1:, 1:]
fourier_large[..., y1 - 1, x1:x2] = fourier[..., 0, 1:] / 2
fourier_large[..., y2, x1:x2] = fourier[..., 0, 1:].flip(-1) / 2
fourier_large[..., y1:y2, x1 - 1] = fourier[..., 1:, 0] / 2
fourier_large[..., y1:y2, x2] = fourier[..., 1:, 0].flip(-1) / 2
esq = fourier[..., 0, 0] / 4
fourier_large[..., y1 - 1, x1 - 1] = esq
fourier_large[..., y1 - 1, x2] = esq
fourier_large[..., y2, x1 - 1] = esq
fourier_large[..., y2, x2] = esq
fourier_large = torch.fft.ifftshift(fourier_large, dim=(-2, -1))
im_large = torch.fft.ifft2(fourier_large)
# if input was real-valued, output should be real-valued, but
# using fft/ifft above means im_large will always be complex,
# so make sure they align.
if not x.is_complex():
im_large = torch.real(im_large)
return im_large
[docs]
def shrink(x: Tensor, factor: int) -> Tensor:
r"""
Shrink a signal by a factor.
We do this in the frequency domain: cropping out the center of the Fourier
transform of ``x``, putting it in a new tensor, and taking the IFFT.
Parameters
----------
x
The signal for expansion.
factor
Factor by which to resize image. Must be larger than 1 and ``factor /
x.shape[-2:]`` must give integer values.
Returns
-------
expanded
The expanded signal.
Raises
------
ValueError
If ``factor`` is less than or equal to 1.
ValueError
If the height or width of ``x`` divided by ``factor`` is not an integer.
See Also
--------
expand
The inverse operation.
:func:`~plenoptic.process.blur_downsample`
An alternative downsampling operation.
Examples
--------
.. plot::
:context: reset
>>> import plenoptic as po
>>> import torch
>>> img = po.data.einstein()
>>> img.shape
torch.Size([1, 1, 256, 256])
>>> shrunk = po.process.shrink(img, factor=2)
>>> shrunk.shape
torch.Size([1, 1, 128, 128])
>>> po.plot.imshow([img, shrunk], title=["original", "shrunk"])
<PyrFigure...>
Note the horizontal/vertical lines in the shrunk version of the image.
These are the result of aliasing.
To avoid these, use :func:`~plenoptic.process.blur_downsample`:
.. plot::
:context: close-figs
>>> po.plot.imshow(
... [img, shrunk, po.process.blur_downsample(img)],
... title=["original", "shrunk", "blurred"],
... )
<PyrFigure...>
You can invert ``shrink`` using :func:`~plenoptic.process.expand`, but the
inversion is not perfect; shrinking discards information that can not be recovered:
.. plot::
:context: close-figs
>>> expand_after_shrink = po.process.expand(
... po.process.shrink(img, factor=2), factor=2
... )
>>> torch.allclose(img, expand_after_shrink, atol=1e-2)
False
>>> po.plot.imshow(
... [img, expand_after_shrink],
... title=["original", "expand after shrink"],
... )
<PyrFigure...>
Even in the opposite order, i.e., shrinking an expanded image, the inversion is
not perfect. In this example with pixel values between 0 and 1,
there are differences on the order of 1e-3:
.. plot::
:context: close-figs
>>> shrink_after_expand = po.process.shrink(
... po.process.expand(img, factor=2), factor=2
... )
>>> torch.allclose(img, shrink_after_expand, atol=1e-2)
True
>>> torch.allclose(img, shrink_after_expand, atol=1e-3)
False
>>> po.plot.imshow(
... [img, shrink_after_expand, img - shrink_after_expand],
... title=["original", "shrink after expand", "difference"],
... )
<PyrFigure...>
"""
if factor <= 1:
raise ValueError("factor must be strictly greater than 1!")
im_x = x.shape[-1]
im_y = x.shape[-2]
mx = im_x / factor
my = im_y / factor
if int(mx) != mx:
raise ValueError(f"x.shape[-1]/factor must be an integer but got {mx} instead!")
if int(my) != my:
raise ValueError(f"x.shape[-2]/factor must be an integer but got {my} instead!")
mx = int(mx)
my = int(my)
fourier = 1 / factor**2 * torch.fft.fftshift(torch.fft.fft2(x), dim=(-2, -1))
fourier_small = torch.zeros(
*x.shape[:-2],
my,
mx,
device=fourier.device,
dtype=fourier.dtype,
)
y1 = im_y / 2 + 1 - my / 2
y2 = im_y / 2 + my / 2
x1 = im_x / 2 + 1 - mx / 2
x2 = im_x / 2 + mx / 2
# when any of these numbers are non-integers, if you round down, the
# resulting image will be off.
y1 = int(np.ceil(y1))
y2 = int(np.ceil(y2))
x1 = int(np.ceil(x1))
x2 = int(np.ceil(x2))
# This line is equivalent to
fourier_small[..., 1:, 1:] = fourier[..., y1:y2, x1:x2]
fourier_small[..., 0, 1:] = (
fourier[..., y1 - 1, x1:x2] + fourier[..., y2, x1:x2]
) / 2
fourier_small[..., 1:, 0] = (
fourier[..., y1:y2, x1 - 1] + fourier[..., y1:y2, x2]
) / 2
fourier_small[..., 0, 0] = (
fourier[..., y1 - 1, x1 - 1]
+ fourier[..., y1 - 1, x2]
+ fourier[..., y2, x1 - 1]
+ fourier[..., y2, x2]
) / 4
fourier_small = torch.fft.ifftshift(fourier_small, dim=(-2, -1))
im_small = torch.fft.ifft2(fourier_small)
# if input was real-valued, output should be real-valued, but
# using fft/ifft above means im_small will always be complex,
# so make sure they align.
if not x.is_complex():
im_small = torch.real(im_small)
return im_small