Source code for plenoptic.process.laplacian_pyramid
"""
Laplacian pyramid.
Simple class for handling the Laplacian Pyramid.
""" # numpydoc ignore=EX01
import torch
import torch.nn as nn
from .convolutions import blur_downsample, upsample_blur
__all__ = ["LaplacianPyramid"]
def __dir__() -> list[str]:
return __all__
[docs]
class LaplacianPyramid(nn.Module):
"""
Laplacian Pyramid in Torch.
The Laplacian pyramid (Burt and Adelson, 1983, [1]_) is a multiscale image
representation. It decomposes the image by computing the local mean using Gaussian
blurring filters and subtracting it from the image and repeating this operation on
the local mean itself after downsampling. This representation is overcomplete and
invertible.
Parameters
----------
n_scales
Number of scales to compute.
scale_filter
If ``True``, the norm of the downsampling/upsampling filter is 1. If ``False``,
it is 2. If the norm is 1, the image is multiplied by 4 during the upsampling
operation; the net effect is that the :math:`n` -th scale of the pyramid is
divided by :math:`2^n`.
Attributes
----------
n_scales : int
Number of computed scales.
scale_filter : bool
Whether the filter is scaled or not.
References
----------
.. [1] Burt, P. and Adelson, E., 1983. The Laplacian pyramid as a compact
image code. IEEE Transactions on communications, 31(4), pp.532-540.
Examples
--------
>>> import plenoptic as po
>>> lpyr = po.process.LaplacianPyramid(n_scales=4, scale_filter=True)
"""
def __init__(self, n_scales: int = 5, scale_filter: bool = False):
super().__init__()
self.n_scales = n_scales
self.scale_filter = scale_filter
# This model has no trainable parameters, so it's always in eval mode
self.eval()
[docs]
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
"""
Build the Laplacian pyramid of an image.
Builds a Laplacian pyramid of height ``self.n_scales``. Because the tensor at
each scale will have a different height and width, we return a list of tensors
instead of a single tensor.
Parameters
----------
x
Image, or batch of images of shape (batch, channel, height, width). If there
are multiple batches or channels, the Laplacian is computed separately for
each of them.
Returns
-------
y
Laplacian pyramid representation, each element of the list corresponds to a
scale, from fine to coarse.
Examples
--------
.. plot::
:context: reset
>>> import plenoptic as po
>>> img = po.data.einstein()
>>> lpyr = po.process.LaplacianPyramid()
>>> po.plot.imshow(lpyr(img))
<PyrFigure ...>
"""
y = []
for scale in range(self.n_scales - 1):
odd = torch.as_tensor(x.shape)[2:4] % 2
x_down = blur_downsample(x, scale_filter=self.scale_filter)
x_up = upsample_blur(x_down, odd, scale_filter=self.scale_filter)
y.append(x - x_up)
x = x_down
y.append(x)
return y
[docs]
def recon_pyr(self, y: list[torch.Tensor]) -> torch.Tensor:
"""
Reconstruct the image from its Laplacian pyramid coefficients.
The input to ``recon_pyr`` should be list of tensors similar to those returned
by ``self.forward``.
Parameters
----------
y
Laplacian pyramid representation, each element of the list
corresponds to a scale, from fine to coarse. ``len(y)`` should be
``self.n_scales``.
Returns
-------
x
Image, or batch of images.
Examples
--------
.. plot::
:context: reset
>>> import plenoptic as po
>>> import torch
>>> img = po.data.einstein()
>>> lpyr = po.process.LaplacianPyramid()
>>> coeffs = lpyr(img)
>>> recon = lpyr.recon_pyr(coeffs)
>>> torch.allclose(img, recon)
True
>>> titles = ["Original", "Reconstructed", "Difference"]
>>> po.plot.imshow([img, recon, img - recon], title=titles)
<PyrFigure ...>
"""
x = y[self.n_scales - 1]
for scale in range(self.n_scales - 1, 0, -1):
odd = torch.as_tensor(y[scale - 1].shape)[2:4] % 2
y_up = upsample_blur(x, odd, scale_filter=self.scale_filter)
x = y[scale - 1] + y_up
return x