"""
Simple filters for visual models.
The functions in this module only create tensors, they do not perform convolution.
""" # numpydoc ignore=EX01
import torch
from torch import Tensor
__all__ = ["circular_gaussian2d"]
def __dir__() -> list[str]:
return __all__
[docs]
def circular_gaussian2d(
kernel_size: int | tuple[int, int],
std: int | list[int] | float | list[float] | Tensor,
out_channels: int | None = None,
) -> Tensor:
"""
Create normalized, centered circular 2D gaussian tensor with which to convolve.
The filter is normalized by total pixel-sum (*not* by ``2*pi*std``) and has shape
``(out_channels, 1, height, width)``. For 2d convolutions in torch, the first
dimensions of the filter tensor corresponds to ``out_channels`` and the second to
``in_channels``, see :class:`torch.nn.Conv2d` for more details.
Parameters
----------
kernel_size
Filter kernel size. Recommended to be odd so that kernel is properly centered.
If you use same-padding, convolution with an odd-length kernel will be faster,
see :func:`torch.nn.functional.conv2d`.
std
Standard deviation of 2D circular Gaussian. If a scalar and ``out_channels`` is
not ``None``, all out channels will have the same value. If not a scalar and
``out_channels`` is not ``None``, ``len(std)`` must equal ``out_channels``.
out_channels
Number of output channels. If ``None``, inferred from shape of ``std``.
Returns
-------
filt:
Circular gaussian kernel.
Raises
------
ValueError
If out_channels is not a positive integer.
ValueError
If kernel_size is not one or two positive integers.
ValueError
If std is not positive.
ValueError
If std is non-scalar and ``len(std) != out_channels``
See Also
--------
:class:`~plenoptic.models.Gaussian`
Torch Module to perform this convolution.
Examples
--------
Single output channel.
.. plot::
:context: reset
>>> import plenoptic as po
>>> from torch.nn.functional import conv2d
>>> import torch
>>> import matplotlib.pyplot as plt
>>> kernel_size = 32
>>> filt_2d = po.process.circular_gaussian2d(kernel_size=kernel_size, std=2)
>>> filt_2d.shape
torch.Size([1, 1, 32, 32])
>>> einstein_img = po.data.einstein()
>>> blurred_einstein = conv2d(einstein_img, filt_2d, padding="same")
>>> po.plot.imshow(
... [einstein_img, filt_2d, blurred_einstein],
... title=["Einstein", "2D Gaussian Filter", "Blurred Einstein"],
... )
<PyrFigure ...>
Multiple output channels with different standard deviations.
.. plot::
:context: close-figs
>>> kernel_size = 32
>>> filt_2d = po.process.circular_gaussian2d(
... kernel_size=kernel_size, std=[2, 5.5], out_channels=2
... )
>>> filt_2d.shape
torch.Size([2, 1, 32, 32])
>>> einstein_img = po.data.einstein()
>>> blurred_einstein = conv2d(einstein_img, filt_2d, padding="same")
>>> titles = [
... "Einstein",
... "2D Gaussian Filter",
... "Larger 2D Gaussian Filter",
... "Blurred Einstein",
... "Blurrier Einstein",
... ]
>>> po.plot.imshow([einstein_img, filt_2d, blurred_einstein], title=titles)
<PyrFigure ...>
Multiple input and output channels, convolved independently. See
:func:`torch.nn.functional.conv2d` to understand the behavior below:
.. plot::
:context: close-figs
>>> kernel_size = 32
>>> filt_2d = po.process.circular_gaussian2d(
... kernel_size=kernel_size, std=[2, 5.5], out_channels=2
... ).repeat(3, 1, 1, 1)
>>> filt_2d.shape
torch.Size([6, 1, 32, 32])
>>> wheel = po.data.color_wheel(as_gray=False)
>>> blurred_wheel = conv2d(wheel, filt_2d, groups=3, padding="same")
>>> titles = ["Wheel", "Blurred Wheel", "Blurrier Wheel"]
>>> # note that the order of channels: the first two correspond to the first
>>> # channel of the input image, convolved with the each of the two gaussians,
>>> # and so on.
>>> po.plot.imshow(
... [wheel, blurred_wheel[:, ::2], blurred_wheel[:, 1::2]],
... title=titles,
... as_rgb=True,
... )
<PyrFigure ...>
"""
kernel_size, std, out_channels = _validate_filter_args(
kernel_size, std, out_channels
)
origin = (kernel_size + 1) / 2
shift_y = torch.arange(1, kernel_size[0] + 1, device=std.device) - origin[0]
shift_x = torch.arange(1, kernel_size[1] + 1, device=std.device) - origin[1]
(xramp, yramp) = torch.meshgrid(shift_x, shift_y, indexing="xy")
log_filt = (xramp**2) + (yramp**2)
log_filt = log_filt.repeat(out_channels, 1, 1, 1)
log_filt = log_filt / (-2.0 * std**2).view(out_channels, 1, 1, 1)
filt = torch.exp(log_filt)
# normalize
filt = filt / torch.sum(filt, dim=[1, 2, 3], keepdim=True)
return filt
def _validate_filter_args(
kernel_size: int | tuple[int, int],
std: int | list[int] | float | list[float] | Tensor,
out_channels: int | None,
std_name: str = "std",
out_channels_name: str = "out_channels",
) -> tuple[Tensor, Tensor, Tensor]:
"""
Validate common filter args.
Checks that:
- kernel_size is positive, integer-valued, and has 1 or 2 values
- std is positive and either a single value (i.e., an int, float, or scalar tensor)
or ``len(std) == out_channels``
- out_channels must be a positive integer.
Does the following and then returns the three values
- if ``out_channels`` is ``None``, then infer from shape of ``std``.
- makes ``kernel_size`` a 1d tensor of size 2
- makes ``std`` a float32 1d tensor of size ``out_channels``
Parameters
----------
kernel_size
Filter kernel size.
std
Standard deviation of 2D circular Gaussian. If a scalar and ``out_channels`` is
not ``None``, all out channels will have the same value. If not a scalar and
``out_channels`` is not ``None``, ``len(std)`` must equal ``out_channels``.
out_channels
Number of output channels. If ``None``, inferred from ``len(std)``.
std_name, out_channels_name
Names of these variables to raise more informative error messages (when e.g.,
calling from ``CenterSurround``, which uses this function to validate different
std arguments).
Returns
-------
kernel_size, std, out_channels
The validated tensors.
Raises
------
ValueError
If out_channels is not a positive integer.
ValueError
If kernel_size is not one or two positive integers.
ValueError
If std is not positive.
ValueError
If std is non-scalar and ``len(std) != out_channels``
""" # numpydoc ignore=EX01
std = torch.as_tensor(std)
if not torch.is_floating_point(std):
std = std.to(torch.float32)
if out_channels is None:
out_channels = len(std) if std.ndim != 0 else 1
if out_channels < 1 or isinstance(out_channels, float):
raise ValueError(f"{out_channels_name} must be positive integer")
if std.ndim == 0:
std = std.repeat(out_channels)
kernel_size = torch.as_tensor(kernel_size).to(std.device)
if kernel_size.numel() == 1:
kernel_size = kernel_size.repeat(2)
if torch.is_floating_point(kernel_size):
raise ValueError("kernel_size must be integer-valued")
if torch.any(kernel_size < 1):
raise ValueError("kernel_size must be positive")
if torch.any(std <= 0):
raise ValueError(f"{std_name} must be positive")
if len(std) != out_channels:
raise ValueError(
f"If non-scalar, len({std_name}) must equal {out_channels_name}"
)
return kernel_size, std, out_channels