"""Convolution-related utility functions.""" # numpydoc ignore=ES01
import math
from typing import Literal
import numpy as np
import pyrtools as pt
import torch
import torch.nn.functional as F
from torch import Tensor
__all__ = [
"blur_downsample",
"correlate_downsample",
"same_padding",
"upsample_blur",
"upsample_convolve",
]
def __dir__() -> list[str]:
return __all__
[docs]
def correlate_downsample(
image: Tensor,
filt: Tensor,
padding_mode: Literal["constant", "reflect", "replicate", "circular"] = "reflect",
) -> Tensor:
"""
Correlate with a filter and downsample by a factor of 2.
This operation allows one to downsample in an alias-resistant manner, removing the
high frequencies that would result in aliasing in a smaller image.
Parameters
----------
image
Image, or batch of images, of shape (batch, channel, height, width).
Batches and channels are handled independently.
filt
2D tensor defining the filter to correlate with the input ``image``.
padding_mode
How to pad the image, so that we return an image of the appropriate size. The
option ``"constant"`` means padding with zeros.
Returns
-------
downsampled_image
The downsampled image.
Raises
------
ValueError
If ``filt`` or ``image`` has the wrong number of dimensions.
See Also
--------
blur_downsample
Perform this operation a user-specified number of times using a named filter.
upsample_convolve
Perform the inverse operation, upsampling and convolving with a filter.
Examples
--------
.. plot::
:context: reset
>>> import plenoptic as po
>>> import torch
>>> img = po.data.einstein()
>>> # 2x2 averaging filter
>>> filt = torch.ones(2, 2) / 4.0
>>> downsampled = po.process.correlate_downsample(img, filt)
>>> downsampled.shape
torch.Size([1, 1, 128, 128])
>>> po.plot.imshow([img, downsampled], title=["image", "downsampled"])
<PyrFigure...>
Note that the dimensions have changed.
This function always returns an image whose height and width are half that of the
input (rounded up). When convolving an image with a filter, the filter must be
centered on each output pixel. For pixels near the image boundary, the filter
extends outside the image boundary and thus we need to pad the input with extra
pixels. The ``padding_mode`` argument determines how to do so (using
:func:`same_padding`):
- reflect: mirror the image at boundaries
- constant: pad with zeroes
- replicate: repeat edge pixel values
- circular: wrap the image around
.. plot::
:context: close-figs
>>> # Large 50x50 averaging filter to make padding effects visible
>>> filt = torch.ones(50, 50) / (50 * 50)
>>> constant = po.process.correlate_downsample(img, filt, padding_mode="constant")
>>> reflect = po.process.correlate_downsample(img, filt, padding_mode="reflect")
>>> replicate = po.process.correlate_downsample(
... img, filt, padding_mode="replicate"
... )
>>> circular = po.process.correlate_downsample(img, filt, padding_mode="circular")
>>> po.plot.imshow(
... [reflect, constant, replicate, circular],
... title=[
... "reflect padding",
... "constant (zero) padding",
... "replicate padding",
... "circular padding",
... ],
... zoom=2,
... )
<PyrFigure...>
"""
if image.ndim != 4:
raise ValueError(f"image must be 4d but has {image.ndim} dimensions instead!")
if filt.ndim != 2:
raise ValueError(f"filt must be 2d but has {filt.ndim} dimensions instead!")
assert image.ndim == 4 and filt.ndim == 2
n_channels = image.shape[1]
image_padded = same_padding(image, kernel_size=filt.shape, pad_mode=padding_mode)
return F.conv2d(
image_padded,
filt.repeat(n_channels, 1, 1, 1),
stride=2,
groups=n_channels,
)
[docs]
def upsample_convolve(
image: Tensor,
odd: tuple[int, int],
filt: Tensor,
padding_mode: Literal["constant", "reflect", "replicate", "circular"] = "reflect",
) -> Tensor:
"""
Upsample by 2 and convolve with a filter.
When upsampling an image, we need some way to estimate the new pixels; convolving
with a filter allows us to interpolate these pixels from their neighbors.
Parameters
----------
image
Image, or batch of images, of shape (batch, channel, height, width).
Batches and channels are handled independently.
odd
This should contain two integers of value 0 or 1, which determines whether
the output height and width should be even (0) or odd (1).
filt
2D tensor defining the filter to convolve with the input ``image``.
padding_mode
How to pad the image, so that we return an image of the appropriate size. The
option ``"constant"`` means padding with zeros.
Returns
-------
upsampled_image
The upsampled image.
Raises
------
ValueError
If ``filt`` or ``image`` has the wrong number of dimensions.
See Also
--------
upsample_blur
Perform this operation a user-specified number of times using a named filter.
correlate_downsample
Perform the inverse operation, correlating and downsampling an image.
Examples
--------
.. plot::
:context: reset
>>> import plenoptic as po
>>> import torch
>>> img = po.data.einstein()
>>> # 2x2 interpolation filter
>>> filt = torch.ones(2, 2) / 4.0
>>> upsampled = po.process.upsample_convolve(img, odd=[0, 0], filt=filt)
>>> upsampled.shape
torch.Size([1, 1, 512, 512])
>>> po.plot.imshow([img, upsampled], title=["image", "upsampled"])
<PyrFigure...>
Note that the dimensions have changed.
The odd argument allows for choosing whether the output width and/or height should
be even or odd:
.. plot::
:context: close-figs
>>> upsampled_even = po.process.upsample_convolve(img, odd=[0, 0], filt=filt)
>>> upsampled_even.shape
torch.Size([1, 1, 512, 512])
>>> upsampled_odd = po.process.upsample_convolve(img, odd=[1, 1], filt=filt)
>>> upsampled_odd.shape
torch.Size([1, 1, 511, 511])
>>> upsampled_mixed_odd_even = po.process.upsample_convolve(
... img, odd=[1, 0], filt=filt
... )
>>> upsampled_mixed_odd_even.shape
torch.Size([1, 1, 511, 512])
This function always returns an image whose height and width are half that of the
input (rounded up). When convolving an image with a filter, the filter must be
centered on each output pixel. For pixels near the image boundary, the filter
extends outside the image boundary and thus we need to pad the input with extra
pixels. The ``padding_mode`` argument determines how to do so (using
:func:`same_padding`):
- reflect: mirror the image at boundaries
- constant: pad with zeroes
- replicate: repeat edge pixel values
- circular: wrap the image around
.. plot::
:context: close-figs
>>> # Large 50x50 interpolation filter to make padding effects visible
>>> filt = torch.ones(50, 50) / (50 * 50)
>>> constant = po.process.upsample_convolve(
... img, odd=[0, 0], filt=filt, padding_mode="constant"
... )
>>> reflect = po.process.upsample_convolve(
... img, odd=[0, 0], filt=filt, padding_mode="reflect"
... )
>>> replicate = po.process.upsample_convolve(
... img, odd=[0, 0], filt=filt, padding_mode="replicate"
... )
>>> circular = po.process.upsample_convolve(
... img, odd=[0, 0], filt=filt, padding_mode="circular"
... )
>>> po.plot.imshow(
... [reflect, constant, replicate, circular],
... title=[
... "reflect padding",
... "constant (zero) padding",
... "replicate padding",
... "circular padding",
... ],
... zoom=2,
... )
<PyrFigure...>
"""
if image.ndim != 4:
raise ValueError(f"image must be 4d but has {image.ndim} dimensions instead!")
if filt.ndim != 2:
raise ValueError(f"filt must be 2d but has {filt.ndim} dimensions instead!")
filt = filt.flip((0, 1))
n_channels = image.shape[1]
pad_start = torch.as_tensor(filt.shape) // 2
pad_end = torch.as_tensor(filt.shape) - torch.as_tensor(odd) - pad_start
pad = torch.as_tensor([pad_start[1], pad_end[1], pad_start[0], pad_end[0]])
image_prepad = F.pad(image, tuple(pad // 2), mode=padding_mode)
image_upsample = F.conv_transpose2d(
image_prepad,
weight=torch.ones(
(n_channels, 1, 1, 1), device=image.device, dtype=image.dtype
),
stride=2,
groups=n_channels,
)
image_postpad = F.pad(image_upsample, tuple(pad % 2))
return F.conv2d(image_postpad, filt.repeat(n_channels, 1, 1, 1), groups=n_channels)
[docs]
def blur_downsample(
image: Tensor,
n_scales: int = 1,
filtname: str = "binom5",
scale_filter: bool = True,
) -> Tensor:
r"""
Correlate with a named filter and downsample by 2.
This operation allows one to downsample in an alias-resistant manner, removing the
high frequencies that would result in aliasing in a smaller image.
Parameters
----------
image
Image, or batch of images, of shape (batch, channel, height, width).
Batches and channels are handled independently.
n_scales
Apply the blur and downsample procedure recursively ``n_scales`` times.
Must be positive.
filtname
Name of the filter. See :func:`~pyrtools.pyramids.filters.named_filter` for
options.
scale_filter
If ``True``, the filter sums to 1 (i.e., it does not affect the DC component of
the signal and the output's mean will approximately match that of the input). If
``False``, the filter sums to 2 (and the output's mean will be roughly double
that of the input).
Returns
-------
downsampled_image
The downsampled image.
Raises
------
ValueError
If ``n_scales`` is not positive.
See Also
--------
correlate_downsample
Perform this operation once using a user-specified filter.
upsample_blur
Perform the inverse operation, upsampling and convolving a user-specified number
of times using a named filter.
:func:`~plenoptic.process.shrink`
An alternative downsampling operation.
Examples
--------
.. plot::
:context: reset
>>> import plenoptic as po
>>> import torch
>>> img = po.data.einstein()
>>> downsampled = po.process.blur_downsample(img)
>>> downsampled.shape
torch.Size([1, 1, 128, 128])
>>> po.plot.imshow([img, downsampled], title=["image", "downsampled"])
<PyrFigure...>
Note that the dimensions have changed.
The ``n_scales`` argument allows for applying the blurring and downsampling
recursively:
.. plot::
:context: close-figs
>>> downsampled_2 = po.process.blur_downsample(img, n_scales=2)
>>> downsampled_2.shape
torch.Size([1, 1, 64, 64])
>>> downsampled_4 = po.process.blur_downsample(img, n_scales=4)
>>> downsampled_4.shape
torch.Size([1, 1, 16, 16])
>>> po.plot.imshow(
... [img, downsampled_2, downsampled_4],
... title=["image", "downsampled x2", "downsampled x4"],
... )
<PyrFigure...>
In Plenoptic, we typically use a fifth order binomial filter,
but many other filters are available,
see :func:`pyrtools.pyramids.filters.named_filter` for a list.
.. plot::
:context: close-figs
>>> named_filters = [
... "binom2",
... "binom3",
... "binom4",
... "haar",
... "qmf8",
... "daub2",
... "qmf5",
... ]
>>> downsampled_filter = [
... po.process.blur_downsample(img, n_scales=2, filtname=filt)
... for filt in named_filters
... ]
>>> po.plot.imshow(
... [img] + downsampled_filter,
... title=["image"] + named_filters,
... col_wrap=4,
... vrange=(0, 1),
... )
<PyrFigure...>
Note that this operation can change the minimum and maximum,
and different filters can do so differently:
>>> img.min()
tensor(0.0039)
>>> img.max()
tensor(1.)
>>> for filter_name, downsampled in zip(named_filters, downsampled_filter):
... print(
... f"filter: {filter_name}, "
... f"min={downsampled.min():.2f}, "
... f"max={downsampled.max():.2f}"
... )
filter: binom2, min=0.11, max=0.92
filter: binom3, min=0.11, max=0.91
filter: binom4, min=0.15, max=0.90
filter: haar, min=0.11, max=0.92
filter: qmf8, min=0.12, max=0.97
filter: daub2, min=0.09, max=0.95
filter: qmf5, min=0.09, max=0.94
The ``scale_filter`` argument forces the filter to sum to 1, making the mean of the
output approximately match that of the input. If set to ``False``, the filter will
sum to 2, and the output's mean will be approximately double that of the input:
.. plot::
:context: close-figs
>>> downsampled_nonscaled = po.process.blur_downsample(img, scale_filter=False)
>>> torch.allclose(img.mean(), downsampled.mean(), atol=1e-2)
True
>>> torch.allclose(img.mean(), downsampled_nonscaled.mean() / 2, atol=1e-2)
True
>>> po.plot.imshow(
... [img, downsampled, downsampled_nonscaled],
... title=[
... f"original, mean={img.mean().item():.3f}",
... f"scaled, mean={downsampled.mean().item():.3f}",
... f"unscaled, mean={downsampled_nonscaled.mean().item():.3f}",
... ],
... )
<PyrFigure...>
"""
if n_scales < 1:
raise ValueError("n_scales must be positive!")
f = pt.named_filter(filtname)
filt = torch.as_tensor(np.outer(f, f), dtype=image.dtype, device=image.device)
if scale_filter:
filt = filt / 2
for _ in range(n_scales):
image = correlate_downsample(image, filt)
return image
[docs]
def upsample_blur(
image: Tensor,
odd: tuple[int, int],
n_scales: int = 1,
filtname: str = "binom5",
scale_filter: bool = True,
) -> Tensor:
"""
Upsample by 2 and convolve with named filter.
When upsampling an image, we need some way to estimate the new pixels; convolving
with a filter allows us to interpolate these pixels from their neighbors.
Parameters
----------
image
Image, or batch of images, of shape (batch, channel, height, width).
Batches and channels are handled independently.
odd
This should contain two integers of value 0 or 1, which determines whether
the output height and width should be even (0) or odd (1).
n_scales
Apply the blur and downsample procedure recursively ``n_scales`` times.
Must be positive.
filtname
Name of the filter. See :func:`~pyrtools.pyramids.filters.named_filter` for
options.
scale_filter
If ``True``, the filter sums to 4 (i.e., it does not affect the DC component of
the signal and the output's mean will approximately match that of the input). If
``False``, the filter sums to 2 (and the output's mean will be roughly half
that of the input).
Returns
-------
upsampled_image
The upsampled image.
Raises
------
ValueError
If ``n_scales`` is not positive.
See Also
--------
upsample_convolve
Perform this operation once using a user-specified filter.
blur_downsample
Perform the inverse operation, correlating and downsampling a user-specified
number of times using a named filter.
:func:`~plenoptic.process.expand`
An alternative upsampling operation.
Examples
--------
.. plot::
:context: reset
>>> import plenoptic as po
>>> import torch
>>> img = po.data.einstein()
>>> upsampled = po.process.upsample_blur(img, odd=[0, 0])
>>> upsampled.shape
torch.Size([1, 1, 512, 512])
>>> po.plot.imshow([img, upsampled], title=["image", "upsampled"])
<PyrFigure...>
Note that the dimensions have changed.
The ``odd`` argument allows for choosing whether the output width and/or height
should be even or odd:
.. plot::
:context: close-figs
>>> upsampled_even = po.process.upsample_blur(img, odd=[0, 0])
>>> upsampled_even.shape
torch.Size([1, 1, 512, 512])
>>> upsampled_odd = po.process.upsample_blur(img, odd=[1, 1])
>>> upsampled_odd.shape
torch.Size([1, 1, 511, 511])
>>> upsampled_mixed_odd_even = po.process.upsample_blur(img, odd=[1, 0])
>>> upsampled_mixed_odd_even.shape
torch.Size([1, 1, 511, 512])
The ``n_scales`` argument allows for applying the upsampling and blurring
recursively:
.. plot::
:context: close-figs
>>> upsampled_2 = po.process.upsample_blur(img, odd=[0, 0], n_scales=2)
>>> upsampled_2.shape
torch.Size([1, 1, 1024, 1024])
>>> upsampled_4 = po.process.upsample_blur(img, odd=[0, 0], n_scales=4)
>>> upsampled_4.shape
torch.Size([1, 1, 4096, 4096])
>>> po.plot.imshow(
... [img, upsampled_2, upsampled_4],
... title=["image", "upsampled x2", "upsampled x4"],
... )
<PyrFigure...>
In Plenoptic, we typically use a fifth order binomial filter,
but many other filters are available,
see :func:`pyrtools.pyramids.filters.named_filter` for a list.
.. plot::
:context: close-figs
>>> named_filters = [
... "binom2",
... "binom3",
... "binom4",
... "haar",
... "qmf8",
... "daub2",
... "qmf5",
... ]
>>> upsampled_filter = [
... po.process.upsample_blur(img, n_scales=2, odd=[0, 0], filtname=filt)
... for filt in named_filters
... ]
>>> po.plot.imshow(
... [img] + upsampled_filter,
... title=["image"] + named_filters,
... col_wrap=4,
... vrange=(0, 1),
... )
<PyrFigure...>
Note that this operation can change the minimum and maximum,
and different filters can do so differently:
>>> img.min()
tensor(0.0039)
>>> img.max()
tensor(1.)
>>> for filter_name, upsampled in zip(named_filters, upsampled_filter):
... print(
... f"filter: {filter_name}, "
... f"min={upsampled.min():.2f}, "
... f"max={upsampled.max():.2f}"
... )
filter: binom2, min=0.00, max=1.00
filter: binom3, min=0.00, max=1.00
filter: binom4, min=0.02, max=0.94
filter: haar, min=0.00, max=1.00
filter: qmf8, min=-0.07, max=1.07
filter: daub2, min=-0.24, max=1.24
filter: qmf5, min=-0.26, max=1.29
"""
if n_scales < 1:
raise ValueError("n_scales must be positive!")
f = pt.named_filter(filtname)
filt = torch.as_tensor(np.outer(f, f), dtype=image.dtype, device=image.device)
if scale_filter:
filt = filt * 2
for _ in range(n_scales):
image = upsample_convolve(image, odd, filt)
return image
def _get_same_padding(x: int, kernel_size: int, stride: int, dilation: int) -> int:
"""Determine integer padding for F.pad() given img and kernel.""" # noqa: DOC201
# numpydoc ignore=ES01,PR01,RT01
pad = (math.ceil(x / stride) - 1) * stride + (kernel_size - 1) * dilation + 1 - x
pad = max(pad, 0)
return pad
[docs]
def same_padding(
image: Tensor,
kernel_size: tuple[int, int],
stride: int | tuple[int, int] = (1, 1),
dilation: int | tuple[int, int] = (1, 1),
pad_mode: str = "circular",
) -> Tensor:
r"""
Pad a tensor so that 2D convolution will result in output with same dims.
Parameters
----------
image
Image, or batch of images, with at least 2 dimensions (height and width).
Any additional dimensions are handled independently.
kernel_size
Size of the kernel that ``image`` will be convolved with.
stride
Stride argument that will be passed to the convolution function.
dilation
Dilation argument that will be passed to the convolution function.
pad_mode
How to pad ``image``. See :func:`torch.nn.functional.pad` for possible
values.
Returns
-------
padded_image
The padded tensor.
Raises
------
ValueError
If ``image`` is not 4d.
Examples
--------
.. plot::
:context: reset
>>> import plenoptic as po
>>> import torch
>>> img = po.data.einstein()
>>> img.shape
torch.Size([1, 1, 256, 256])
>>> padded = po.process.same_padding(img, kernel_size=(10, 10))
>>> padded.shape
torch.Size([1, 1, 265, 265])
>>> po.plot.imshow(padded)
<PyrFigure...>
The output grows by ``kernel_size - 1`` in each dimension, so that a
subsequent convolution with that kernel returns an output matching the
original spatial dimensions.
The following convolution functions all use this padding function to return outputs
with the same shape as the input:
- :func:`~plenoptic.process.correlate_downsample`
- :func:`~plenoptic.process.blur_downsample`
Here, let's apply a convolution manually and verify the shapes:
.. plot::
:context: close-figs
>>> kernel = torch.ones(1, 1, 10, 10) / 100
>>> padded = po.process.same_padding(img, kernel_size=(10, 10))
>>> convolved = torch.nn.functional.conv2d(padded, kernel)
>>> convolved.shape
torch.Size([1, 1, 256, 256])
>>> po.plot.imshow(
... [img, convolved],
... title=["original", "after convolution"],
... )
<PyrFigure...>
Non-square kernels are supported; padding is computed independently for
height and width:
.. plot::
:context: close-figs
>>> padded_rect = po.process.same_padding(img, kernel_size=(10, 20))
>>> padded_rect.shape
torch.Size([1, 1, 265, 275])
>>> po.plot.imshow(padded_rect)
<PyrFigure...>
The ``pad_mode`` argument controls how boundary values are filled.
The border of each image shows the filled padding region:
.. plot::
:context: close-figs
>>> pad = 50
>>> pad_modes = ["circular", "reflect", "replicate", "constant"]
>>> padded_imgs = [
... po.process.same_padding(img, kernel_size=(50, 50), pad_mode=m)
... for m in pad_modes
... ]
>>> corners = [p[:, :, : pad * 3, : pad * 3] for p in padded_imgs]
>>> po.plot.imshow(corners, title=pad_modes)
<PyrFigure...>
The ``stride`` and ``dilation`` arguments should match those passed to the
subsequent convolution. A strided convolution produces a smaller output but still
avoids losing edge information. A dilated convolution expands the effective
receptive field of the kernel without increasing its parameter count.
See :func:`torch.nn.functional.conv2d` for more information.
.. plot::
:context: close-figs
>>> kernel = torch.ones(1, 1, 10, 10) / 100
>>> padded_stride = po.process.same_padding(
... img, kernel_size=(10, 10), stride=(3, 3)
... )
>>> convolved_stride = torch.nn.functional.conv2d(
... padded_stride, kernel, stride=(3, 3)
... )
>>> padded_dilate = po.process.same_padding(
... img, kernel_size=(10, 10), dilation=(3, 3)
... )
>>> convolved_dilate = torch.nn.functional.conv2d(
... padded_dilate, kernel, dilation=(3, 3)
... )
>>> po.plot.imshow(convolved_stride, title="stride (3,3)", zoom=3)
<PyrFigure...>
>>> po.plot.imshow(convolved_dilate, title="dilation (3,3)")
<PyrFigure...>
Note the different dimensions.
""" # numpydoc ignore=ES01
if len(image.shape) < 2:
raise ValueError("Input must be tensor whose last dims are height x width")
ih, iw = image.shape[-2:]
pad_h = _get_same_padding(ih, kernel_size[0], stride[0], dilation[0])
pad_w = _get_same_padding(iw, kernel_size[1], stride[1], dilation[1])
if pad_h > 0 or pad_w > 0:
image = F.pad(
image,
[pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2],
mode=pad_mode,
)
return image